Skip to content

Gridworlds

Grid

rlbook.gridworlds.grids.Grid

Base grid class with jax jit related helper methods.

Attributes:

Name Type Description
n_rows

number of rows.

n_cols

number of columns.

actions

actions that can be taken in the grid.

v_init

initial state values.

Source code in src/rlbook/gridworlds/grids.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Grid(metaclass=ABCMeta):
    """Base grid class with jax jit related helper methods.

    Attributes:
        n_rows: number of rows.
        n_cols: number of columns.
        actions: actions that can be taken in the grid.
        v_init: initial state values.
    """

    def __init__(
        self,
        n_rows: int = 5,
        n_cols: int = 5,
    ):
        """
        Args:
            n_rows: number of rows.
            n_cols: number of columns.
        """
        self.n_rows = n_rows
        self.n_cols = n_cols
        self.actions = jnp.array([[-1, 1, 0, 0], [0, 0, 1, -1]])
        self.v_init = jnp.zeros((self.n_rows, self.n_cols))

    @property
    @abstractmethod
    def policy(self): ...

    @abstractmethod
    def reward(self): ...

    def tree_flatten(self):
        """Jax flatten method for serialization, required to jit class methods."""
        children = (
            self.special_states_rewards,
            self.R,
            self.P,
            self.actions,
            self.v_init,
        )  # arrays and dynamic values
        # static values (non-arrays)
        aux_data = {
            "special_states": self.special_states,
            "special_states_prime": self.special_states_prime,
            "n_rows": self.n_rows,
            "n_cols": self.n_cols,
        }

        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Jax unflatten method for deserialization, required to jit class methods."""
        grid = cls(
            aux_data["special_states"],
            aux_data["special_states_prime"],
            children[0],
            R=children[1],
            P=children[2],
            n_rows=aux_data["n_rows"],
            n_cols=aux_data["n_cols"],
        )
        grid.v_init = children[4]

        return grid

__init__(n_rows=5, n_cols=5)

Parameters:

Name Type Description Default
n_rows int

number of rows.

5
n_cols int

number of columns.

5
Source code in src/rlbook/gridworlds/grids.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    n_rows: int = 5,
    n_cols: int = 5,
):
    """
    Args:
        n_rows: number of rows.
        n_cols: number of columns.
    """
    self.n_rows = n_rows
    self.n_cols = n_cols
    self.actions = jnp.array([[-1, 1, 0, 0], [0, 0, 1, -1]])
    self.v_init = jnp.zeros((self.n_rows, self.n_cols))

tree_flatten()

Jax flatten method for serialization, required to jit class methods.

Source code in src/rlbook/gridworlds/grids.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def tree_flatten(self):
    """Jax flatten method for serialization, required to jit class methods."""
    children = (
        self.special_states_rewards,
        self.R,
        self.P,
        self.actions,
        self.v_init,
    )  # arrays and dynamic values
    # static values (non-arrays)
    aux_data = {
        "special_states": self.special_states,
        "special_states_prime": self.special_states_prime,
        "n_rows": self.n_rows,
        "n_cols": self.n_cols,
    }

    return (children, aux_data)

tree_unflatten(aux_data, children) classmethod

Jax unflatten method for deserialization, required to jit class methods.

Source code in src/rlbook/gridworlds/grids.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@classmethod
def tree_unflatten(cls, aux_data, children):
    """Jax unflatten method for deserialization, required to jit class methods."""
    grid = cls(
        aux_data["special_states"],
        aux_data["special_states_prime"],
        children[0],
        R=children[1],
        P=children[2],
        n_rows=aux_data["n_rows"],
        n_cols=aux_data["n_cols"],
    )
    grid.v_init = children[4]

    return grid

RandomGrid

rlbook.gridworlds.grids.RandomGrid

Bases: Grid

RandomGrid class for estimating state values in a gridworld using a random policy.

Source code in src/rlbook/gridworlds/grids.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@register_pytree_node_class
class RandomGrid(Grid):
    """RandomGrid class for estimating state values in a gridworld using a random policy."""

    def __init__(
        self,
        special_states: list[list[int, int]],
        special_states_prime: list[list[int, int]],
        special_states_rewards: Int[Array, "1 {len(special_states)}"],
        n_rows: int = 5,
        n_cols: int = 5,
        R: Float[Array, "n_rows n_cols"] = None,
        P: Float[Array, "3 3"] = None,
    ):
        """
        Args:
            special_states: list containing special states row and columns.
              e.g. [[0, 0], [1, 3]] would correspond to special state A located at row 0 and column 1
              and special state B located at row 1 and column 3.
            special_states_prime: list of special states prime rows and columns, see previous special_states example.
            special_states_rewards: jax array of rewards for special states. Note: not a list!
            n_rows: number of rows.
            n_cols: number of columns.
            R: jax array specifying rewards for all states when taking a random policy.
            P: jax array specifying a conv kernel for a random policy.
        """
        super().__init__(n_rows=n_rows, n_cols=n_cols)
        self.special_states = special_states
        self.special_states_prime = special_states_prime
        self.special_states_rewards = special_states_rewards

        self.v_init = jnp.zeros((self.n_rows, self.n_cols))
        self.P = self.policy
        self.R = self.reward

    @property
    def policy(self) -> Float[Array, "3 3"]:
        """
        Define random policy conv kernel with equal probabilty of taking each action:

        P = Array([[0,     0.25,  0   ],
                   [0.25,  0,     0.25],
                   [0,     0.25,  0   ],]
        """
        policy = jnp.zeros((3, 3))
        policy = policy.at[self.actions[0] + 1, self.actions[1] + 1].set(0.25)

        return policy

    @property
    def reward(self) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
        """Provides reward for all states in grid when following a random policy"""
        R = correlate2d(
            jnp.pad(self.v_init, pad_width=(1, 1), constant_values=-1),
            self.P,
            mode="valid",
        )
        R = R.at[self.special_states[0], self.special_states[1]].set(
            self.special_states_rewards
        )

        return R

    @jit
    def state_value(
        self,
        v: Float[Array, "n_rows n_cols"],
        R: Float[Array, "n_rows n_cols"],
        P: Float[Array, "3 3"],
        special_states: list[list[int, int]],
        special_states_prime: list[list[int, int]],
        special_states_rewards: Float[Array, "1 {len(special_states)}"],
        discount: float = 0.9,
    ) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
        """State value function for estimating state values in a gridworld using a random policy"""
        # Update states
        vp = (
            R
            + correlate2d(
                jnp.pad(v, pad_width=(1, 1), mode="edge"),
                P,
                mode="valid",
            )
            * discount
        )

        # Update special states
        vp = vp.at[special_states[0], special_states[1]].set(
            v[special_states_prime[0], special_states_prime[1]] * discount
            + special_states_rewards
        )

        return vp

    def estimate_state_value(
        self, iter: int = 1000
    ) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
        """Estimate state values in a gridworld using a random policy"""
        v = self.v_init
        for _ in range(iter):
            v = self.state_value(
                v,
                self.R,
                self.P,
                self.special_states,
                self.special_states_prime,
                self.special_states_rewards,
            )
        return v

policy property

Define random policy conv kernel with equal probabilty of taking each action:

P = Array([[0, 0.25, 0 ], [0.25, 0, 0.25], [0, 0.25, 0 ],]

reward property

Provides reward for all states in grid when following a random policy

__init__(special_states, special_states_prime, special_states_rewards, n_rows=5, n_cols=5, R=None, P=None)

Parameters:

Name Type Description Default
special_states list[list[int, int]]

list containing special states row and columns. e.g. [[0, 0], [1, 3]] would correspond to special state A located at row 0 and column 1 and special state B located at row 1 and column 3.

required
special_states_prime list[list[int, int]]

list of special states prime rows and columns, see previous special_states example.

required
special_states_rewards Int[Array, '1 {len(special_states)}']

jax array of rewards for special states. Note: not a list!

required
n_rows int

number of rows.

5
n_cols int

number of columns.

5
R Float[Array, 'n_rows n_cols']

jax array specifying rewards for all states when taking a random policy.

None
P Float[Array, '3 3']

jax array specifying a conv kernel for a random policy.

None
Source code in src/rlbook/gridworlds/grids.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
    self,
    special_states: list[list[int, int]],
    special_states_prime: list[list[int, int]],
    special_states_rewards: Int[Array, "1 {len(special_states)}"],
    n_rows: int = 5,
    n_cols: int = 5,
    R: Float[Array, "n_rows n_cols"] = None,
    P: Float[Array, "3 3"] = None,
):
    """
    Args:
        special_states: list containing special states row and columns.
          e.g. [[0, 0], [1, 3]] would correspond to special state A located at row 0 and column 1
          and special state B located at row 1 and column 3.
        special_states_prime: list of special states prime rows and columns, see previous special_states example.
        special_states_rewards: jax array of rewards for special states. Note: not a list!
        n_rows: number of rows.
        n_cols: number of columns.
        R: jax array specifying rewards for all states when taking a random policy.
        P: jax array specifying a conv kernel for a random policy.
    """
    super().__init__(n_rows=n_rows, n_cols=n_cols)
    self.special_states = special_states
    self.special_states_prime = special_states_prime
    self.special_states_rewards = special_states_rewards

    self.v_init = jnp.zeros((self.n_rows, self.n_cols))
    self.P = self.policy
    self.R = self.reward

estimate_state_value(iter=1000)

Estimate state values in a gridworld using a random policy

Source code in src/rlbook/gridworlds/grids.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def estimate_state_value(
    self, iter: int = 1000
) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
    """Estimate state values in a gridworld using a random policy"""
    v = self.v_init
    for _ in range(iter):
        v = self.state_value(
            v,
            self.R,
            self.P,
            self.special_states,
            self.special_states_prime,
            self.special_states_rewards,
        )
    return v

state_value(v, R, P, special_states, special_states_prime, special_states_rewards, discount=0.9)

State value function for estimating state values in a gridworld using a random policy

Source code in src/rlbook/gridworlds/grids.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@jit
def state_value(
    self,
    v: Float[Array, "n_rows n_cols"],
    R: Float[Array, "n_rows n_cols"],
    P: Float[Array, "3 3"],
    special_states: list[list[int, int]],
    special_states_prime: list[list[int, int]],
    special_states_rewards: Float[Array, "1 {len(special_states)}"],
    discount: float = 0.9,
) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
    """State value function for estimating state values in a gridworld using a random policy"""
    # Update states
    vp = (
        R
        + correlate2d(
            jnp.pad(v, pad_width=(1, 1), mode="edge"),
            P,
            mode="valid",
        )
        * discount
    )

    # Update special states
    vp = vp.at[special_states[0], special_states[1]].set(
        v[special_states_prime[0], special_states_prime[1]] * discount
        + special_states_rewards
    )

    return vp

OptimalGrid

rlbook.gridworlds.grids.OptimalGrid

Bases: Grid

OptimalGrid class for estimating state values in a gridworld using an optimal policy

Source code in src/rlbook/gridworlds/grids.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@register_pytree_node_class
class OptimalGrid(Grid):
    """OptimalGrid class for estimating state values in a gridworld using an optimal policy"""

    def __init__(
        self,
        special_states: list[list[int], list[int]],
        special_states_prime: list[list[int], list[int]],
        special_states_rewards: Int[Array, "1 {len(special_states)}"],
        n_rows: int = 5,
        n_cols: int = 5,
        R: Float[Array, "n_rows n_cols"] = None,
        P: Float[Array, "3 3"] = None,
    ):
        """
        Args:
            special_states: list containing special states row and columns.
              e.g. [[0, 0], [1, 3]] would correspond to special state A located at row 0 and column 1
              and special state B located at row 1 and column 3.
            special_states_prime: list of special states prime rows and columns, see previous special_states example.
            special_states_rewards: jax array of rewards for special states. Note: not a list!
            n_rows: number of rows.
            n_cols: number of columns.
            R: jax array specifying rewards for all states when taking an optimal policy.
            P: jax array specifying a conv kernel for an optimal policy.
        """
        super().__init__(n_rows=n_rows, n_cols=n_cols)
        self.special_states = special_states
        self.special_states_prime = special_states_prime
        self.special_states_rewards = special_states_rewards

        self.P = self.policy
        self.R = self.reward(self.v_init, self.policy)

    @property
    def policy(self) -> Float[Array, "4 3 3"]:
        """
        Define policy conv kernel as a 3d array

        P = Array([[[0., 1., 0.], # action up only
                    [0., 0., 0.],
                    [0., 0., 0.]],

                    [[0., 0., 0.], # action left only
                    [1., 0., 0.],
                    [0., 0., 0.]],

                    [[0., 0., 0.], #action down only
                    [0., 0., 0.],
                    [0., 1., 0.]],

                    [[0., 0., 0.], @action right only
                    [0., 0., 1.],
                    [0., 0., 0.]]]

        """
        policy = jnp.zeros((4, 3, 3))
        policy = policy.at[[0, 1, 2, 3], [0, 1, 2, 1], [1, 0, 1, 2]].set(1)

        return policy

    def reward(
        self, v: Float[Array, "{self.n_rows} {self.n_cols}"], P: Float[Array, "4 3 3"]
    ) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
        """Provides reward for all states in grid when following an optimal policy"""
        R = jnp.zeros((4, self.n_rows, self.n_cols))

        # Policy up only
        R = R.at[0, :, :].set(
            correlate2d(
                jnp.pad(v, pad_width=(1, 1), constant_values=-1),
                P[0, :, :],
                mode="valid",
            )
        )

        # Policy left only
        R = R.at[1, :, :].set(
            correlate2d(
                jnp.pad(v, pad_width=(1, 1), constant_values=-1),
                P[1, :, :],
                mode="valid",
            )
        )

        # Policy down only
        R = R.at[2, :, :].set(
            correlate2d(
                jnp.pad(v, pad_width=(1, 1), constant_values=-1),
                P[2, :, :],
                mode="valid",
            )
        )

        # Policy right only
        R = R.at[3, :, :].set(
            correlate2d(
                jnp.pad(v, pad_width=(1, 1), constant_values=-1),
                P[3, :, :],
                mode="valid",
            )
        )

        # Set special state rewards
        R = R.at[:, self.special_states[0], self.special_states[1]].set(
            self.special_states_rewards
        )

        return R

    @jit
    def state_value(
        self,
        v: Float[Array, "n_rows n_cols"],
        R: Float[Array, "4 n_rows n_cols"],
        P: Float[Array, "4 3 3"],
        special_states: list[list[int], list[int]],
        special_states_prime: list[list[int], list[int]],
        special_states_rewards: Float[Array, "1 {len(special_states)}"],
        discount: float = 0.9,
    ) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
        """State value function for estimating state values in a gridworld using an optimal policy"""

        vp = jnp.zeros((4, self.n_rows, self.n_cols))

        # Policy up only
        vp = vp.at[0, :, :].set(
            R[0, :, :]
            + correlate2d(
                jnp.pad(v, pad_width=(1, 1), mode="edge"),
                P[0, :, :],
                mode="valid",
            )
            * discount
        )

        # Policy left only
        vp = vp.at[1, :, :].set(
            R[1, :, :]
            + correlate2d(
                jnp.pad(v, pad_width=(1, 1), mode="edge"),
                P[1, :, :],
                mode="valid",
            )
            * discount
        )

        # Policy down only
        vp = vp.at[2, :, :].set(
            R[2, :, :]
            + correlate2d(
                jnp.pad(v, pad_width=(1, 1), mode="edge"),
                P[2, :, :],
                mode="valid",
            )
            * discount
        )

        # Policy right only
        vp = vp.at[3, :, :].set(
            R[3, :, :]
            + correlate2d(
                jnp.pad(v, pad_width=(1, 1), mode="edge"),
                P[3, :, :],
                mode="valid",
            )
            * discount
        )

        # Update special states
        vp = vp.at[:, special_states[0], special_states[1]].set(
            v[special_states_prime[0], special_states_prime[1]] * discount
            + special_states_rewards
        )

        return jnp.max(vp, axis=0)

    def estimate_state_value(
        self, iter: int = 1000
    ) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
        """Estimate state values in a gridworld using an optimal policy"""
        v = self.v_init
        for _ in range(iter):
            v = self.state_value(
                v,
                self.R,
                self.P,
                self.special_states,
                self.special_states_prime,
                self.special_states_rewards,
            )
        return v

policy property

Define policy conv kernel as a 3d array

P = Array([[[0., 1., 0.], # action up only [0., 0., 0.], [0., 0., 0.]],

        [[0., 0., 0.], # action left only
        [1., 0., 0.],
        [0., 0., 0.]],

        [[0., 0., 0.], #action down only
        [0., 0., 0.],
        [0., 1., 0.]],

        [[0., 0., 0.], @action right only
        [0., 0., 1.],
        [0., 0., 0.]]]

__init__(special_states, special_states_prime, special_states_rewards, n_rows=5, n_cols=5, R=None, P=None)

Parameters:

Name Type Description Default
special_states list[list[int], list[int]]

list containing special states row and columns. e.g. [[0, 0], [1, 3]] would correspond to special state A located at row 0 and column 1 and special state B located at row 1 and column 3.

required
special_states_prime list[list[int], list[int]]

list of special states prime rows and columns, see previous special_states example.

required
special_states_rewards Int[Array, '1 {len(special_states)}']

jax array of rewards for special states. Note: not a list!

required
n_rows int

number of rows.

5
n_cols int

number of columns.

5
R Float[Array, 'n_rows n_cols']

jax array specifying rewards for all states when taking an optimal policy.

None
P Float[Array, '3 3']

jax array specifying a conv kernel for an optimal policy.

None
Source code in src/rlbook/gridworlds/grids.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def __init__(
    self,
    special_states: list[list[int], list[int]],
    special_states_prime: list[list[int], list[int]],
    special_states_rewards: Int[Array, "1 {len(special_states)}"],
    n_rows: int = 5,
    n_cols: int = 5,
    R: Float[Array, "n_rows n_cols"] = None,
    P: Float[Array, "3 3"] = None,
):
    """
    Args:
        special_states: list containing special states row and columns.
          e.g. [[0, 0], [1, 3]] would correspond to special state A located at row 0 and column 1
          and special state B located at row 1 and column 3.
        special_states_prime: list of special states prime rows and columns, see previous special_states example.
        special_states_rewards: jax array of rewards for special states. Note: not a list!
        n_rows: number of rows.
        n_cols: number of columns.
        R: jax array specifying rewards for all states when taking an optimal policy.
        P: jax array specifying a conv kernel for an optimal policy.
    """
    super().__init__(n_rows=n_rows, n_cols=n_cols)
    self.special_states = special_states
    self.special_states_prime = special_states_prime
    self.special_states_rewards = special_states_rewards

    self.P = self.policy
    self.R = self.reward(self.v_init, self.policy)

estimate_state_value(iter=1000)

Estimate state values in a gridworld using an optimal policy

Source code in src/rlbook/gridworlds/grids.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def estimate_state_value(
    self, iter: int = 1000
) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
    """Estimate state values in a gridworld using an optimal policy"""
    v = self.v_init
    for _ in range(iter):
        v = self.state_value(
            v,
            self.R,
            self.P,
            self.special_states,
            self.special_states_prime,
            self.special_states_rewards,
        )
    return v

reward(v, P)

Provides reward for all states in grid when following an optimal policy

Source code in src/rlbook/gridworlds/grids.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def reward(
    self, v: Float[Array, "{self.n_rows} {self.n_cols}"], P: Float[Array, "4 3 3"]
) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
    """Provides reward for all states in grid when following an optimal policy"""
    R = jnp.zeros((4, self.n_rows, self.n_cols))

    # Policy up only
    R = R.at[0, :, :].set(
        correlate2d(
            jnp.pad(v, pad_width=(1, 1), constant_values=-1),
            P[0, :, :],
            mode="valid",
        )
    )

    # Policy left only
    R = R.at[1, :, :].set(
        correlate2d(
            jnp.pad(v, pad_width=(1, 1), constant_values=-1),
            P[1, :, :],
            mode="valid",
        )
    )

    # Policy down only
    R = R.at[2, :, :].set(
        correlate2d(
            jnp.pad(v, pad_width=(1, 1), constant_values=-1),
            P[2, :, :],
            mode="valid",
        )
    )

    # Policy right only
    R = R.at[3, :, :].set(
        correlate2d(
            jnp.pad(v, pad_width=(1, 1), constant_values=-1),
            P[3, :, :],
            mode="valid",
        )
    )

    # Set special state rewards
    R = R.at[:, self.special_states[0], self.special_states[1]].set(
        self.special_states_rewards
    )

    return R

state_value(v, R, P, special_states, special_states_prime, special_states_rewards, discount=0.9)

State value function for estimating state values in a gridworld using an optimal policy

Source code in src/rlbook/gridworlds/grids.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@jit
def state_value(
    self,
    v: Float[Array, "n_rows n_cols"],
    R: Float[Array, "4 n_rows n_cols"],
    P: Float[Array, "4 3 3"],
    special_states: list[list[int], list[int]],
    special_states_prime: list[list[int], list[int]],
    special_states_rewards: Float[Array, "1 {len(special_states)}"],
    discount: float = 0.9,
) -> Float[Array, "{self.n_rows} {self.n_cols}"]:
    """State value function for estimating state values in a gridworld using an optimal policy"""

    vp = jnp.zeros((4, self.n_rows, self.n_cols))

    # Policy up only
    vp = vp.at[0, :, :].set(
        R[0, :, :]
        + correlate2d(
            jnp.pad(v, pad_width=(1, 1), mode="edge"),
            P[0, :, :],
            mode="valid",
        )
        * discount
    )

    # Policy left only
    vp = vp.at[1, :, :].set(
        R[1, :, :]
        + correlate2d(
            jnp.pad(v, pad_width=(1, 1), mode="edge"),
            P[1, :, :],
            mode="valid",
        )
        * discount
    )

    # Policy down only
    vp = vp.at[2, :, :].set(
        R[2, :, :]
        + correlate2d(
            jnp.pad(v, pad_width=(1, 1), mode="edge"),
            P[2, :, :],
            mode="valid",
        )
        * discount
    )

    # Policy right only
    vp = vp.at[3, :, :].set(
        R[3, :, :]
        + correlate2d(
            jnp.pad(v, pad_width=(1, 1), mode="edge"),
            P[3, :, :],
            mode="valid",
        )
        * discount
    )

    # Update special states
    vp = vp.at[:, special_states[0], special_states[1]].set(
        v[special_states_prime[0], special_states_prime[1]] * discount
        + special_states_rewards
    )

    return jnp.max(vp, axis=0)