Skip to content

Commit

Permalink
modified: main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
apandy02 committed Sep 18, 2024
1 parent a9a5d79 commit b1e8e96
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
agent.train(
train_batch_size=16,
neural_network=neural_network,
num_episodes=100,
num_episodes=5,
num_epochs=100,
)

Expand Down
29 changes: 13 additions & 16 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def train(
train_episodes = self.self_play(
model=current_network, num_episodes=num_episodes
)
print(f"SELF PLAY {_}")
# keep a copy of the current network for evaluation
old_network = copy.deepcopy(current_network)
state_old = str(current_network.state_dict())
Expand Down Expand Up @@ -120,9 +121,9 @@ def evaluate_networks(
curr_network_wins += 1

win_rate = curr_network_wins / num_games
# if network a wins most games, return network a
if win_rate > threshold:
logging.info(f"Current network win rate: {win_rate:.2f}")

if win_rate > threshold:
return curr_network

# else return network b
Expand All @@ -149,11 +150,12 @@ def retrain_nn(
dataloader = self.batch_episodes(train_data, batch_size=train_batch_size)
policy_losses_total = 0
value_losses_total = 0

for x_train, policy_train, value_train in dataloader:
policy_pred, value_pred = neural_network(x_train)

policy_loss = policy_loss_fn(policy_train, policy_pred)
value_loss = value_loss_fn(value_train, value_pred)
policy_loss = policy_loss_fn(policy_pred, policy_train)
value_loss = value_loss_fn(value_pred, value_train)

policy_losses_total += policy_loss.item()
value_losses_total += value_loss.item()
combined_loss = policy_loss + value_loss
Expand Down Expand Up @@ -191,23 +193,22 @@ def play_game(self, network_a: torch.nn.Module, network_b: torch.nn.Module) -> b
.unsqueeze(0)
)
print(f"Player {player}, state: \n{game_state}")
# print(f"player: {player}")
if player == 1:
policy, _ = network_a(stacked_tensor)
# print(f"network_a policy: {policy}")

else:
policy, _ = network_b(stacked_tensor)
# print(f"network_b policy: {policy}")

valid_moves = self.game.getValidMoves(game_state, player)
ones_indices = np.where(valid_moves == 1)[0]
mask = torch.zeros_like(policy.squeeze(), dtype=torch.bool)
mask[torch.tensor(ones_indices)] = True

policy[~mask] = 0
# print(policy)
action = torch.argmax(policy, dim=-1)

game_state, player = self.game.getNextState(game_state, player, action)
game_state = self.game.getCanonicalForm(game_state, player)
# print(f"Player after move {player}, state: \n{game_state}")
stacked_frames = self.mcts.no_history_model_input(
game_state, current_player=player
)
Expand All @@ -233,6 +234,7 @@ def self_play(self, model: torch.nn.Module, num_episodes: int):
episodes (list)
"""
train_episodes = []

for _ in range(num_episodes):
self.mcts = MCTS(self.game)
game_states = []
Expand All @@ -251,27 +253,22 @@ def self_play(self, model: torch.nn.Module, num_episodes: int):
valid_moves = self.game.getValidMoves(game_state, player)
sum_pi = np.sum(pi)

if sum_pi > 1e-8: # Check if sum is greater than a small threshold
if sum_pi > 1e-8:
pi = pi / sum_pi
else:
logging.info("uniform distribution")
# If sum is too small, use uniform distribution over valid moves
pi = valid_moves.astype(float) / np.sum(valid_moves)

action = np.random.choice(len(pi), p=pi)
game_state, player = self.game.getNextState(game_state, player, action)
game_state = self.game.getCanonicalForm(game_state, player=player)
print("game_state: \n", game_state)

game_result = self.game.getGameEnded(board=game_state, player=player)
last_player = player
# print(f"last_player: {last_player}")

for state, pi, player in game_states:
stacked_frames = self.mcts.no_history_model_input(
state, current_player=player
)
# print(f"player: {player}, last_player = {last_player}")
if player != last_player:
train_episodes.append((stacked_frames, pi, -game_result))
train_episodes.append((stacked_frames, pi, game_result))
Expand Down
12 changes: 10 additions & 2 deletions src/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ def apv_mcts(
backpropagate that estimated value.
"""
input_array = np.zeros((3, 8, 8))
depths = []
for _ in range(num_iterations):
state = canonical_root_state
player = 1
state_string = self.game.stringRepresentation(state)
path = []

state, state_string = self._find_leaf_or_terminal(
state, state_string, depth = self._find_leaf_or_terminal(
state, player, path, uct_c, state_string
)
depths.append(depth)

input_tensor = (
torch.tensor(input_array, dtype=torch.float32).to(device).unsqueeze(0)
Expand All @@ -70,6 +72,9 @@ def apv_mcts(

value = self._backpropagate(path, value)

avg_depth = np.mean(depths)
logging.info(f"Average MCTS depth: {avg_depth:.2f}")

root = path[0][1]
visits = [x ** (1 / temp) for x in self.num_visits_s_a[root]]
q_vals = [x ** (1 / temp) for x in self.q_s_a[root]]
Expand All @@ -93,7 +98,9 @@ def _find_leaf_or_terminal(self, state, player, path, uct_c, state_string):
Returns:
state: the state of the game at the end of the traversal
state_string: the string representation of the current state
depth: the depth of the traversal
"""
depth = 0
while (state_string in self.prior_probability) and not self.game.getGameEnded(
state, player=player
):
Expand Down Expand Up @@ -122,8 +129,9 @@ def _find_leaf_or_terminal(self, state, player, path, uct_c, state_string):
next_state = self.game.getCanonicalForm(next_state, player=player)
state = next_state
state_string = self.game.stringRepresentation(state)
depth += 1

return state, state_string
return state, state_string, depth

def select_action(self, state_string, c, valid_moves: np.ndarray) -> int:
"""
Expand Down

0 comments on commit b1e8e96

Please sign in to comment.