From a4a4b254d0889590d5dc17fba1c5bbcc405e6f76 Mon Sep 17 00:00:00 2001 From: aryamanpandya99 Date: Tue, 12 Dec 2023 15:11:24 -0500 Subject: [PATCH] changed model arch to include dropouts for both heads --- agent.py | 28 +++++++++++++++++++--------- main.py | 4 ++-- mcts.py | 6 ++++-- models.py | 35 ++++++++++++++++++++++------------- 4 files changed, 47 insertions(+), 26 deletions(-) diff --git a/agent.py b/agent.py index 685d344..0137f30 100644 --- a/agent.py +++ b/agent.py @@ -62,12 +62,17 @@ def train(self, ) # keep a copy of the current network for evaluation old_network = copy.deepcopy(current_network) - + state_old = current_network.state_dict().__str__() + print(f"old network params: {old_network.parameters()}") policy_loss, value_loss = self.retrain_nn( neural_network=current_network, train_data=train_episodes, train_batch_size=train_batch_size ) + state_curr = current_network.state_dict().__str__() + print(f"curr: {current_network.parameters()}") + if state_curr == state_old: + print("network not updating") # note: figure out if the following assignment makes sense current_network = self.evaluate_networks( current_network, @@ -174,20 +179,21 @@ def play_game(self, result (bool) """ game_state = self.game.getInitBoard() - player = 1 + player = np.random.choice([-1,1]) + print(f"starting player: {player}") game_state = self.game.getCanonicalForm(game_state, player) stacked_frames = self.mcts.no_history_model_input(game_state, current_player=player) while not self.game.getGameEnded(board=game_state, player=player): - print(f"stacked frames: \n{stacked_frames}") + # print(f"stacked frames: \n{stacked_frames}") stacked_tensor = torch.tensor(stacked_frames, dtype = torch.float32).to(self.device).unsqueeze(0) - #print(f"Player {player}, state: \n{game_state}") - print(f"player: {player}") + print(f"Player {player}, state: \n{game_state}") + #print(f"player: {player}") if player == 1: policy, _ = network_a(stacked_tensor) - print(f"network_a policy: {policy}") + #print(f"network_a policy: {policy}") else: policy, _ = network_b(stacked_tensor) - print(f"network_b policy: {policy}") + #print(f"network_b policy: {policy}") valid_moves = self.game.getValidMoves(game_state, player) ones_indices = np.where(valid_moves == 1)[0] mask = torch.zeros_like(policy.squeeze(), dtype=torch.bool) @@ -223,13 +229,12 @@ def self_play(self, model: torch.nn.Module, num_episodes: int): episodes (list) """ train_episodes = [] - for _ in range(num_episodes): + self.mcts = MCTS(self.game) game_states = [] game_state = self.game.getInitBoard() player = 1 game_state = self.game.getCanonicalForm(game_state, player=player) - self.mcts = MCTS(self.game) while not self.game.getGameEnded(board=game_state, player=player): pi = self.mcts.apv_mcts( canonical_root_state=game_state, @@ -251,8 +256,13 @@ def self_play(self, model: torch.nn.Module, num_episodes: int): game_state = self.game.getCanonicalForm(game_state, player=player) game_result = self.game.getGameEnded(board=game_state, player=player) + last_player = player + #print(f"last_player: {last_player}") for state, pi, player in game_states: stacked_frames = self.mcts.no_history_model_input(state, current_player=player) + #print(f"player: {player}, last_player = {last_player}") + if player != last_player: + train_episodes.append((stacked_frames, pi, -game_result)) train_episodes.append((stacked_frames, pi, game_result)) return train_episodes diff --git a/main.py b/main.py index 47b736e..61e6945 100644 --- a/main.py +++ b/main.py @@ -35,10 +35,10 @@ def main(): lr=learning_rate, weight_decay=l2_reg ) - agent = AlphaZeroNano(optimizer=actor_optimizer,num_simulations=50, game=game, c_uct=1, device=device, mcts=MCTS(game)) + agent = AlphaZeroNano(optimizer=actor_optimizer,num_simulations=25, game=game, c_uct=1, device=device, mcts=MCTS(game)) agent.train( - train_batch_size=32, + train_batch_size=16, neural_network=neural_network, num_episodes=100, num_epochs=100 diff --git a/mcts.py b/mcts.py index a7644a4..40ab077 100644 --- a/mcts.py +++ b/mcts.py @@ -73,7 +73,7 @@ def split_player_boards(self, board: np.ndarray) -> tuple[np.ndarray, np.ndarray """ player_a = np.maximum(board, 0) player_b = board.copy() - player_b[player_b < 0] = 1 + player_b = -1 * np.minimum(player_b, 0) return player_a, player_b @@ -193,6 +193,7 @@ def apv_mcts( input_array = self.no_history_model_input(board_arr=state, current_player=player) path.append((input_array, state_string, action)) + #print(f'Player: {-player}, state: \n{state}, action: {action}, next_state:\n{next_state}') next_state = self.game.getCanonicalForm(next_state, player=player) state = next_state state_string = self.game.stringRepresentation(state) @@ -202,6 +203,7 @@ def apv_mcts( input_tensor = torch.tensor(input_array, dtype=torch.float32).to(device).unsqueeze(0) game_ended = self.game.getGameEnded(state, player=player) if not game_ended : + model.eval() policy, value = model(input_tensor) policy = policy.cpu().detach().numpy() possible_actions = self.game.getValidMoves(state, player=player) @@ -216,7 +218,7 @@ def apv_mcts( else: value = game_ended - + #print(f"value: {value}") # backpropagation phase for _, state_string, action in reversed(path): if state_string in self.num_visits_s: diff --git a/models.py b/models.py index b6b968e..ef4af27 100644 --- a/models.py +++ b/models.py @@ -6,6 +6,7 @@ from torch import nn import numpy as np +import torch.nn.functional as F class OthelloNN(nn.Module): @@ -37,31 +38,37 @@ def __init__(self) -> None: nn.ReLU(), ) self.conv3 = nn.Sequential( - nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3), nn.BatchNorm2d(128), nn.ReLU(), ) self.conv4 = nn.Sequential( - nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3), nn.BatchNorm2d(128), nn.ReLU(), ) + self.fc1 = nn.Sequential( + nn.Linear(2048, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + ) + + self.fc2 = nn.Sequential( + nn.Linear(1024, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + ) + self.policy_head = nn.Sequential( nn.Flatten(), - nn.Linear(128 * 8 * 8, 65), # Flatten the conv output and connect to a Linear layer + nn.Linear(512, 65), # Flatten the conv output and connect to a Linear layer nn.Softmax(dim=1) ) - self.value_head_conv = nn.Sequential( - nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1), - nn.BatchNorm2d(1), - nn.ReLU(), - ) - - self.value_head_linear = nn.Sequential( - nn.Flatten(), nn.Linear(64, 1), nn.ReLU(), nn.Tanh() + self.value_head = nn.Sequential( + nn.Flatten(), nn.Linear(512, 1), nn.ReLU(), nn.Tanh() ) def forward(self, state) -> tuple[np.array, int]: @@ -80,9 +87,11 @@ def forward(self, state) -> tuple[np.array, int]: s = self.conv2(s) s = self.conv3(s) s = self.conv4(s) + s = s.view(-1, 2048) + s = F.dropout(self.fc1(s)) + s = F.dropout(self.fc2(s)) pi = self.policy_head(s) - s = self.value_head_conv(s) - val = self.value_head_linear(s).squeeze() + val = self.value_head(s).squeeze() return pi.squeeze(0), val