Skip to content

Commit

Permalink
tuning params for end to end training. No more SW bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
apandy02 committed Dec 5, 2023
1 parent bdcc3c2 commit 6e811b3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 18 deletions.
34 changes: 23 additions & 11 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(
num_simulations: int,
optimizer,
game: Game,
c_uct: float) -> None:
c_uct: float,
device) -> None:

self.c_parameter = c_uct
self.num_simulations = num_simulations
Expand All @@ -32,6 +33,7 @@ def __init__(
# notes: dynamics differences. Each game object instance
# is a game instance. Need to figure out how to re-initialize..?
self.game = game
self.device = device

def train(self,
neural_network: torch.nn.Module,
Expand All @@ -52,18 +54,17 @@ def train(self,
Returns:
current_network (torch.nn.Module)
"""
logging.info("Beginning agent AlphaZeroNano training")
current_network = neural_network

for _ in range(num_epochs):
logging.info("Epoch: %s/%s", _, num_epochs)
train_episodes = self.self_play(
model=current_network,
num_episodes=num_episodes
)
# keep a copy of the current network for evaluation
old_network = copy.deepcopy(current_network)

self.retrain_nn(
policy_loss, value_loss = self.retrain_nn(
neural_network=current_network,
train_data=train_episodes,
train_batch_size=train_batch_size
Expand All @@ -75,6 +76,8 @@ def train(self,
10
)

logging.info("Epoch: %s/%s value_loss: %s, policy_loss: %s", _, num_epochs, value_loss, policy_loss)

return current_network

def evaluate_networks(self,
Expand Down Expand Up @@ -131,18 +134,27 @@ def retrain_nn(self,
dataloader = self.batch_episodes(
train_data, batch_size=train_batch_size
)

policy_losses_total = 0
value_losses_total = 0
for x_train, policy_train, value_train in dataloader:
policy_pred, value_pred = neural_network(x_train)

policy_loss = policy_loss_fn(policy_train, policy_pred)
value_loss = value_loss_fn(value_train, value_pred)
policy_losses_total+=policy_loss.item()
value_losses_total+=value_loss.item()
combined_loss = policy_loss + value_loss

self.optimizer.zero_grad()
combined_loss.backward()
self.optimizer.step()


policy_loss_avg = policy_losses_total / len(dataloader)
value_losses_avg = value_losses_total / len(dataloader)

return policy_loss_avg, value_losses_avg

def play_game(self,
network_a: torch.nn.Module,
network_b: torch.nn.Module) -> bool:
Expand All @@ -158,7 +170,7 @@ def play_game(self,
player = 1
stacked_frames = no_history_model_input(game_state, current_player=player)
while not self.game.getGameEnded(board=game_state, player=player):
stacked_tensor = torch.tensor(stacked_frames, dtype = torch.float32).unsqueeze(0)
stacked_tensor = torch.tensor(stacked_frames, dtype = torch.float32).to(self.device).unsqueeze(0)
if player == 1:
policy, _ = network_a(stacked_tensor)
else:
Expand Down Expand Up @@ -205,8 +217,8 @@ def self_play(self, model: torch.nn.Module, num_episodes: int):
root_state=game_state,
model=model,
num_iterations=self.num_simulations,
c=self.c_parameter,
history_length=3
c=self.c_parameter,
device=self.device
)

game_states.append((game_state, policy, player))
Expand Down Expand Up @@ -244,9 +256,9 @@ def batch_episodes(self, train_data: list, batch_size: int):
before converting to a tensor.
"""
states, policies, results = zip(*train_data)
states_tensor = torch.tensor(np.array(states), dtype=torch.float32)
policies_tensor = torch.tensor(np.array(policies), dtype=torch.float32)
results_tensor = torch.tensor(np.array(results), dtype=torch.float32)
states_tensor = torch.tensor(np.array(states), dtype=torch.float32).to(self.device)
policies_tensor = torch.tensor(np.array(policies), dtype=torch.float32).to(self.device)
results_tensor = torch.tensor(np.array(results), dtype=torch.float32).to(self.device)

dataset = TensorDataset(states_tensor, policies_tensor, results_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
Expand Down
9 changes: 5 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def main():
"""
main
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
neural_network = OthelloNN()
neural_network = neural_network.to(device)
game = OthelloGame(n=8)
learning_rate = 3e-4
l2_reg = 1e-3
Expand All @@ -32,14 +34,13 @@ def main():
lr=learning_rate,
weight_decay=l2_reg
)

agent = AlphaZeroNano(optimizer=actor_optimizer,num_simulations=5, game=game, c_uct=0.1)
agent = AlphaZeroNano(optimizer=actor_optimizer,num_simulations=25, game=game, c_uct=0.1, device=device)

agent.train(
train_batch_size=32,
neural_network=neural_network,
num_episodes=5,
num_epochs=1000
num_episodes=100,
num_epochs=100
)


Expand Down
4 changes: 2 additions & 2 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def apv_mcts(
root_state,
model: torch.nn.Module(),
num_iterations: int,
history_length: int,
device,
c: float,
temp = 1):
"""
Expand Down Expand Up @@ -204,7 +204,7 @@ def apv_mcts(
# expansion phase
# for our leaf node, expand by adding possible children
# from that game state to node.children
input_tensor = torch.tensor(input_array, dtype=torch.float32).unsqueeze(0)
input_tensor = torch.tensor(input_array, dtype=torch.float32).to(device).unsqueeze(0)
if not game.getGameEnded(node.state, player=player):
cannonical_board = game.getCanonicalForm(node.state, player=player)
policy, _ = model(input_tensor)
Expand Down
8 changes: 7 additions & 1 deletion test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,13 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"import pandas as pd\n",
"data = {\n",
" \"a\": [True, False]\n",
"}\n",
"a_df = pd.DataFrame(data)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 6e811b3

Please sign in to comment.