Skip to content

Commit

Permalink
changed model arch to include dropouts for both heads
Browse files Browse the repository at this point in the history
  • Loading branch information
apandy02 committed Dec 12, 2023
1 parent ef25a8a commit a4a4b25
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 26 deletions.
28 changes: 19 additions & 9 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,17 @@ def train(self,
)
# keep a copy of the current network for evaluation
old_network = copy.deepcopy(current_network)

state_old = current_network.state_dict().__str__()
print(f"old network params: {old_network.parameters()}")
policy_loss, value_loss = self.retrain_nn(
neural_network=current_network,
train_data=train_episodes,
train_batch_size=train_batch_size
)
state_curr = current_network.state_dict().__str__()
print(f"curr: {current_network.parameters()}")
if state_curr == state_old:
print("network not updating")
# note: figure out if the following assignment makes sense
current_network = self.evaluate_networks(
current_network,
Expand Down Expand Up @@ -174,20 +179,21 @@ def play_game(self,
result (bool)
"""
game_state = self.game.getInitBoard()
player = 1
player = np.random.choice([-1,1])
print(f"starting player: {player}")
game_state = self.game.getCanonicalForm(game_state, player)
stacked_frames = self.mcts.no_history_model_input(game_state, current_player=player)
while not self.game.getGameEnded(board=game_state, player=player):
print(f"stacked frames: \n{stacked_frames}")
# print(f"stacked frames: \n{stacked_frames}")
stacked_tensor = torch.tensor(stacked_frames, dtype = torch.float32).to(self.device).unsqueeze(0)
#print(f"Player {player}, state: \n{game_state}")
print(f"player: {player}")
print(f"Player {player}, state: \n{game_state}")
#print(f"player: {player}")
if player == 1:
policy, _ = network_a(stacked_tensor)
print(f"network_a policy: {policy}")
#print(f"network_a policy: {policy}")
else:
policy, _ = network_b(stacked_tensor)
print(f"network_b policy: {policy}")
#print(f"network_b policy: {policy}")
valid_moves = self.game.getValidMoves(game_state, player)
ones_indices = np.where(valid_moves == 1)[0]
mask = torch.zeros_like(policy.squeeze(), dtype=torch.bool)
Expand Down Expand Up @@ -223,13 +229,12 @@ def self_play(self, model: torch.nn.Module, num_episodes: int):
episodes (list)
"""
train_episodes = []

for _ in range(num_episodes):
self.mcts = MCTS(self.game)
game_states = []
game_state = self.game.getInitBoard()
player = 1
game_state = self.game.getCanonicalForm(game_state, player=player)
self.mcts = MCTS(self.game)
while not self.game.getGameEnded(board=game_state, player=player):
pi = self.mcts.apv_mcts(
canonical_root_state=game_state,
Expand All @@ -251,8 +256,13 @@ def self_play(self, model: torch.nn.Module, num_episodes: int):
game_state = self.game.getCanonicalForm(game_state, player=player)

game_result = self.game.getGameEnded(board=game_state, player=player)
last_player = player
#print(f"last_player: {last_player}")
for state, pi, player in game_states:
stacked_frames = self.mcts.no_history_model_input(state, current_player=player)
#print(f"player: {player}, last_player = {last_player}")
if player != last_player:
train_episodes.append((stacked_frames, pi, -game_result))
train_episodes.append((stacked_frames, pi, game_result))

return train_episodes
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def main():
lr=learning_rate,
weight_decay=l2_reg
)
agent = AlphaZeroNano(optimizer=actor_optimizer,num_simulations=50, game=game, c_uct=1, device=device, mcts=MCTS(game))
agent = AlphaZeroNano(optimizer=actor_optimizer,num_simulations=25, game=game, c_uct=1, device=device, mcts=MCTS(game))

agent.train(
train_batch_size=32,
train_batch_size=16,
neural_network=neural_network,
num_episodes=100,
num_epochs=100
Expand Down
6 changes: 4 additions & 2 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def split_player_boards(self, board: np.ndarray) -> tuple[np.ndarray, np.ndarray
"""
player_a = np.maximum(board, 0)
player_b = board.copy()
player_b[player_b < 0] = 1
player_b = -1 * np.minimum(player_b, 0)

return player_a, player_b

Expand Down Expand Up @@ -193,6 +193,7 @@ def apv_mcts(
input_array = self.no_history_model_input(board_arr=state,
current_player=player)
path.append((input_array, state_string, action))
#print(f'Player: {-player}, state: \n{state}, action: {action}, next_state:\n{next_state}')
next_state = self.game.getCanonicalForm(next_state, player=player)
state = next_state
state_string = self.game.stringRepresentation(state)
Expand All @@ -202,6 +203,7 @@ def apv_mcts(
input_tensor = torch.tensor(input_array, dtype=torch.float32).to(device).unsqueeze(0)
game_ended = self.game.getGameEnded(state, player=player)
if not game_ended :
model.eval()
policy, value = model(input_tensor)
policy = policy.cpu().detach().numpy()
possible_actions = self.game.getValidMoves(state, player=player)
Expand All @@ -216,7 +218,7 @@ def apv_mcts(
else:
value = game_ended


#print(f"value: {value}")
# backpropagation phase
for _, state_string, action in reversed(path):
if state_string in self.num_visits_s:
Expand Down
35 changes: 22 additions & 13 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from torch import nn
import numpy as np
import torch.nn.functional as F


class OthelloNN(nn.Module):
Expand Down Expand Up @@ -37,31 +38,37 @@ def __init__(self) -> None:
nn.ReLU(),
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(),
)
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(),
)

self.fc1 = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
)

self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
)


self.policy_head = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 8 * 8, 65), # Flatten the conv output and connect to a Linear layer
nn.Linear(512, 65), # Flatten the conv output and connect to a Linear layer
nn.Softmax(dim=1)
)

self.value_head_conv = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1),
nn.BatchNorm2d(1),
nn.ReLU(),
)

self.value_head_linear = nn.Sequential(
nn.Flatten(), nn.Linear(64, 1), nn.ReLU(), nn.Tanh()
self.value_head = nn.Sequential(
nn.Flatten(), nn.Linear(512, 1), nn.ReLU(), nn.Tanh()
)

def forward(self, state) -> tuple[np.array, int]:
Expand All @@ -80,9 +87,11 @@ def forward(self, state) -> tuple[np.array, int]:
s = self.conv2(s)
s = self.conv3(s)
s = self.conv4(s)
s = s.view(-1, 2048)
s = F.dropout(self.fc1(s))
s = F.dropout(self.fc2(s))

pi = self.policy_head(s)
s = self.value_head_conv(s)
val = self.value_head_linear(s).squeeze()
val = self.value_head(s).squeeze()

return pi.squeeze(0), val

0 comments on commit a4a4b25

Please sign in to comment.