@@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
76
76
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
77
77
78
78
Examples:
79
+ >>> import torch
80
+ >>> from torchrl.envs import ChessEnv
81
+ >>> _ = torch.manual_seed(0)
79
82
>>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
83
+ >>> print(env)
84
+ TransformedEnv(
85
+ env=ChessEnv(),
86
+ transform=ActionMask(keys=['action', 'action_mask']))
80
87
>>> r = env.reset()
81
- >>> env.rand_step(r)
88
+ >>> print( env.rand_step(r) )
82
89
TensorDict(
83
90
fields={
84
91
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
92
+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
85
93
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
86
94
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
87
95
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
88
96
next: TensorDict(
89
97
fields={
98
+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
90
99
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
91
- fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP /RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
100
+ fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP /RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
92
101
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
93
102
pgn: NonTensorData(data=[Event "?"]
94
103
[Site "?"]
@@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
97
106
[White "?"]
98
107
[Black "?"]
99
108
[Result "*"]
100
- 1. b3 *, batch_size=torch.Size([]), device=None),
109
+
110
+ 1. f4 *, batch_size=torch.Size([]), device=None),
101
111
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
102
- san: NonTensorData(data=b3 , batch_size=torch.Size([]), device=None),
112
+ san: NonTensorData(data=f4 , batch_size=torch.Size([]), device=None),
103
113
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
104
114
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
105
115
batch_size=torch.Size([]),
@@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
112
122
[White "?"]
113
123
[Black "?"]
114
124
[Result "*"]
125
+
115
126
*, batch_size=torch.Size([]), device=None),
116
- san: NonTensorData(data=[SAN][START] , batch_size=torch.Size([]), device=None),
127
+ san: NonTensorData(data=<start> , batch_size=torch.Size([]), device=None),
117
128
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
118
129
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
119
130
batch_size=torch.Size([]),
120
131
device=None,
121
132
is_shared=False)
122
- >>> env.rollout(1000)
133
+ >>> print( env.rollout(1000) )
123
134
TensorDict(
124
135
fields={
125
- action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
126
- done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136
+ action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
137
+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
138
+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
127
139
fen: NonTensorStack(
128
140
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
129
- batch_size=torch.Size([352 ]),
141
+ batch_size=torch.Size([96 ]),
130
142
device=None),
131
- legal_moves: Tensor(shape=torch.Size([352 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
143
+ legal_moves: Tensor(shape=torch.Size([96 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
132
144
next: TensorDict(
133
145
fields={
134
- done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
146
+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
147
+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
135
148
fen: NonTensorStack(
136
- ['rnbqkbnr/pppppppp/8/8/8/N7 /PPPPPPPP/R1BQKBNR b K ...,
137
- batch_size=torch.Size([352 ]),
149
+ ['rnbqkbnr/pppppppp/8/8/8/5N2 /PPPPPPPP/RNBQKB1R b ...,
150
+ batch_size=torch.Size([96 ]),
138
151
device=None),
139
- legal_moves: Tensor(shape=torch.Size([352 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
152
+ legal_moves: Tensor(shape=torch.Size([96 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
140
153
pgn: NonTensorStack(
141
154
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
142
- batch_size=torch.Size([352 ]),
155
+ batch_size=torch.Size([96 ]),
143
156
device=None),
144
- reward: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.float32, is_shared=False),
157
+ reward: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.float32, is_shared=False),
145
158
san: NonTensorStack(
146
- ['Na3 ', 'a5 ', 'Nb1 ', 'Nc6 ', 'a3 ', 'g6 ', 'd4 ', 'd6' ...,
147
- batch_size=torch.Size([352 ]),
159
+ ['Nf3 ', 'Na6 ', 'c4 ', 'f6 ', 'h4 ', 'Rb8 ', 'Na3 ', 'Ra ...,
160
+ batch_size=torch.Size([96 ]),
148
161
device=None),
149
- terminated: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
150
- turn: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.bool, is_shared=False)},
151
- batch_size=torch.Size([352 ]),
162
+ terminated: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
163
+ turn: Tensor(shape=torch.Size([96 ]), device=cpu, dtype=torch.bool, is_shared=False)},
164
+ batch_size=torch.Size([96 ]),
152
165
device=None,
153
166
is_shared=False),
154
167
pgn: NonTensorStack(
155
168
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
156
- batch_size=torch.Size([352 ]),
169
+ batch_size=torch.Size([96 ]),
157
170
device=None),
158
171
san: NonTensorStack(
159
- ['[SAN][START] ', 'Na3 ', 'a5 ', 'Nb1 ', 'Nc6 ', 'a3 ', ...,
160
- batch_size=torch.Size([352 ]),
172
+ ['<start> ', 'Nf3 ', 'Na6 ', 'c4 ', 'f6 ', 'h4 ', 'Rb8', ...,
173
+ batch_size=torch.Size([96 ]),
161
174
device=None),
162
- terminated: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
163
- turn: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.bool, is_shared=False)},
164
- batch_size=torch.Size([352 ]),
175
+ terminated: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
176
+ turn: Tensor(shape=torch.Size([96 ]), device=cpu, dtype=torch.bool, is_shared=False)},
177
+ batch_size=torch.Size([96 ]),
165
178
device=None,
166
179
is_shared=False)
167
180
@@ -227,13 +240,15 @@ def _legal_moves_to_index(
227
240
[self ._san_moves .index (board .san (m )) for m in board .legal_moves ],
228
241
dtype = torch .int64 ,
229
242
)
230
-
243
+ mask = None
231
244
if return_mask :
232
- return self ._move_index_to_mask (indices )
245
+ mask = self ._move_index_to_mask (indices )
233
246
if pad :
234
247
indices = torch .nn .functional .pad (
235
248
indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
236
249
)
250
+ if return_mask :
251
+ return indices , mask
237
252
return indices
238
253
239
254
@classmethod
@@ -371,16 +386,19 @@ def _reset(self, tensordict=None):
371
386
dest .set ("pgn" , pgn )
372
387
dest .set ("turn" , turn )
373
388
if self .include_legal_moves :
374
- moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
375
- dest .set ("legal_moves" , moves_idx )
389
+ moves_idx = self ._legal_moves_to_index (
390
+ board = self .board , pad = True , return_mask = self .mask_actions
391
+ )
376
392
if self .mask_actions :
377
- dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
393
+ moves_idx , mask = moves_idx
394
+ dest .set ("action_mask" , mask )
395
+ dest .set ("legal_moves" , moves_idx )
378
396
elif self .mask_actions :
379
397
dest .set (
380
398
"action_mask" ,
381
399
self ._legal_moves_to_index (
382
400
board = self .board , pad = True , return_mask = True
383
- ),
401
+ )[ 1 ] ,
384
402
)
385
403
386
404
if self .pixels :
@@ -527,16 +545,19 @@ def _step(self, tensordict):
527
545
dest .set ("san" , san )
528
546
529
547
if self .include_legal_moves :
530
- moves_idx = self ._legal_moves_to_index (board = board , pad = True )
531
- dest .set ("legal_moves" , moves_idx )
548
+ moves_idx = self ._legal_moves_to_index (
549
+ board = board , pad = True , return_mask = self .mask_actions
550
+ )
532
551
if self .mask_actions :
533
- dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
552
+ moves_idx , mask = moves_idx
553
+ dest .set ("action_mask" , mask )
554
+ dest .set ("legal_moves" , moves_idx )
534
555
elif self .mask_actions :
535
556
dest .set (
536
557
"action_mask" ,
537
558
self ._legal_moves_to_index (
538
559
board = self .board , pad = True , return_mask = True
539
- ),
560
+ )[ 1 ] ,
540
561
)
541
562
542
563
turn = torch .tensor (board .turn )
0 commit comments