Skip to content

Commit 848c251

Browse files
committed
Update
[ghstack-poisoned]
2 parents 19d45a8 + 72ddbac commit 848c251

File tree

1 file changed

+58
-37
lines changed

1 file changed

+58
-37
lines changed

torchrl/envs/custom/chess.py

+58-37
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
7676
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
7777
7878
Examples:
79+
>>> import torch
80+
>>> from torchrl.envs import ChessEnv
81+
>>> _ = torch.manual_seed(0)
7982
>>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
83+
>>> print(env)
84+
TransformedEnv(
85+
env=ChessEnv(),
86+
transform=ActionMask(keys=['action', 'action_mask']))
8087
>>> r = env.reset()
81-
>>> env.rand_step(r)
88+
>>> print(env.rand_step(r))
8289
TensorDict(
8390
fields={
8491
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
92+
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
8593
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
8694
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
8795
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
8896
next: TensorDict(
8997
fields={
98+
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
9099
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
91-
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
100+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
92101
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
93102
pgn: NonTensorData(data=[Event "?"]
94103
[Site "?"]
@@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
97106
[White "?"]
98107
[Black "?"]
99108
[Result "*"]
100-
1. b3 *, batch_size=torch.Size([]), device=None),
109+
110+
1. f4 *, batch_size=torch.Size([]), device=None),
101111
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
102-
san: NonTensorData(data=b3, batch_size=torch.Size([]), device=None),
112+
san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None),
103113
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
104114
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
105115
batch_size=torch.Size([]),
@@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
112122
[White "?"]
113123
[Black "?"]
114124
[Result "*"]
125+
115126
*, batch_size=torch.Size([]), device=None),
116-
san: NonTensorData(data=[SAN][START], batch_size=torch.Size([]), device=None),
127+
san: NonTensorData(data=<start>, batch_size=torch.Size([]), device=None),
117128
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
118129
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
119130
batch_size=torch.Size([]),
120131
device=None,
121132
is_shared=False)
122-
>>> env.rollout(1000)
133+
>>> print(env.rollout(1000))
123134
TensorDict(
124135
fields={
125-
action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
126-
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136+
action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
137+
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
138+
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
127139
fen: NonTensorStack(
128140
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
129-
batch_size=torch.Size([352]),
141+
batch_size=torch.Size([96]),
130142
device=None),
131-
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
143+
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
132144
next: TensorDict(
133145
fields={
134-
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
146+
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
147+
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
135148
fen: NonTensorStack(
136-
['rnbqkbnr/pppppppp/8/8/8/N7/PPPPPPPP/R1BQKBNR b K...,
137-
batch_size=torch.Size([352]),
149+
['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ...,
150+
batch_size=torch.Size([96]),
138151
device=None),
139-
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
152+
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
140153
pgn: NonTensorStack(
141154
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
142-
batch_size=torch.Size([352]),
155+
batch_size=torch.Size([96]),
143156
device=None),
144-
reward: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.float32, is_shared=False),
157+
reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False),
145158
san: NonTensorStack(
146-
['Na3', 'a5', 'Nb1', 'Nc6', 'a3', 'g6', 'd4', 'd6'...,
147-
batch_size=torch.Size([352]),
159+
['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra...,
160+
batch_size=torch.Size([96]),
148161
device=None),
149-
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
150-
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
151-
batch_size=torch.Size([352]),
162+
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
163+
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
164+
batch_size=torch.Size([96]),
152165
device=None,
153166
is_shared=False),
154167
pgn: NonTensorStack(
155168
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
156-
batch_size=torch.Size([352]),
169+
batch_size=torch.Size([96]),
157170
device=None),
158171
san: NonTensorStack(
159-
['[SAN][START]', 'Na3', 'a5', 'Nb1', 'Nc6', 'a3', ...,
160-
batch_size=torch.Size([352]),
172+
['<start>', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',...,
173+
batch_size=torch.Size([96]),
161174
device=None),
162-
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
163-
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
164-
batch_size=torch.Size([352]),
175+
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
176+
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
177+
batch_size=torch.Size([96]),
165178
device=None,
166179
is_shared=False)
167180
@@ -227,13 +240,15 @@ def _legal_moves_to_index(
227240
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
228241
dtype=torch.int64,
229242
)
230-
243+
mask = None
231244
if return_mask:
232-
return self._move_index_to_mask(indices)
245+
mask = self._move_index_to_mask(indices)
233246
if pad:
234247
indices = torch.nn.functional.pad(
235248
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
236249
)
250+
if return_mask:
251+
return indices, mask
237252
return indices
238253

239254
@classmethod
@@ -371,16 +386,19 @@ def _reset(self, tensordict=None):
371386
dest.set("pgn", pgn)
372387
dest.set("turn", turn)
373388
if self.include_legal_moves:
374-
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
375-
dest.set("legal_moves", moves_idx)
389+
moves_idx = self._legal_moves_to_index(
390+
board=self.board, pad=True, return_mask=self.mask_actions
391+
)
376392
if self.mask_actions:
377-
dest.set("action_mask", self._move_index_to_mask(moves_idx))
393+
moves_idx, mask = moves_idx
394+
dest.set("action_mask", mask)
395+
dest.set("legal_moves", moves_idx)
378396
elif self.mask_actions:
379397
dest.set(
380398
"action_mask",
381399
self._legal_moves_to_index(
382400
board=self.board, pad=True, return_mask=True
383-
),
401+
)[1],
384402
)
385403

386404
if self.pixels:
@@ -527,16 +545,19 @@ def _step(self, tensordict):
527545
dest.set("san", san)
528546

529547
if self.include_legal_moves:
530-
moves_idx = self._legal_moves_to_index(board=board, pad=True)
531-
dest.set("legal_moves", moves_idx)
548+
moves_idx = self._legal_moves_to_index(
549+
board=board, pad=True, return_mask=self.mask_actions
550+
)
532551
if self.mask_actions:
533-
dest.set("action_mask", self._move_index_to_mask(moves_idx))
552+
moves_idx, mask = moves_idx
553+
dest.set("action_mask", mask)
554+
dest.set("legal_moves", moves_idx)
534555
elif self.mask_actions:
535556
dest.set(
536557
"action_mask",
537558
self._legal_moves_to_index(
538559
board=self.board, pad=True, return_mask=True
539-
),
560+
)[1],
540561
)
541562

542563
turn = torch.tensor(board.turn)

0 commit comments

Comments
 (0)