Skip to content

Commit

Permalink
modified: agent.py
Browse files Browse the repository at this point in the history
	modified:   mcts.py
	modified:   models.py
	modified:   test.ipynb
  • Loading branch information
apandy02 committed Dec 4, 2023
1 parent 532c97d commit efd17b2
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 28 deletions.
6 changes: 3 additions & 3 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def retrain_nn(self,
policy_pred, value_pred = neural_network(x_train)

policy_loss = policy_loss_fn(policy_train, policy_pred)
print(f"shape value train: {value_train.shape}, shape value pred: {value_pred.shape}")
value_loss = value_loss_fn(value_train, value_pred)
combined_loss = policy_loss + value_loss

Expand All @@ -161,14 +160,15 @@ def play_game(self,
while not self.game.getGameEnded(board=game_state, player=player):
stacked_tensor = torch.tensor(stacked_frames, dtype = torch.float32).unsqueeze(0)
if player == 1:
print(stacked_tensor.shape)
policy, _ = network_a(stacked_tensor)
else:
policy, _ = network_b(stacked_tensor)

valid_moves = self.game.getValidMoves(game_state, player)
ones_indices = np.where(valid_moves == 1)[0]
action = np.random.choice(ones_indices)
mask = torch.zeros_like(policy.squeeze(), dtype=torch.bool)
mask[torch.tensor(ones_indices)] = True
policy[~mask] = 0
_, action = torch.max(policy, dim=-1)
game_state, player = self.game.getNextState(
game_state,
Expand Down
2 changes: 1 addition & 1 deletion mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def apv_mcts(
cannonical_board = game.getCanonicalForm(node.state, player=player)
policy, _ = model(input_tensor)
#print(f"value: {val.shape}")
policy = policy.cpu().detach().numpy().squeeze(0)
policy = policy.cpu().detach().numpy()
possible_actions = game.getValidMoves(node.state, player=player)
policy *= possible_actions

Expand Down
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def forward(self, state) -> tuple[np.array, int]:
s = self.value_head_conv(s)
val = self.value_head_linear(s).squeeze()

return pi, val
return pi.squeeze(0), val
Loading

0 comments on commit efd17b2

Please sign in to comment.