From efd17b2f749e572cb9d819bd0ad84b233314c56f Mon Sep 17 00:00:00 2001 From: aryamanpandya99 Date: Mon, 4 Dec 2023 14:35:22 -0800 Subject: [PATCH] modified: agent.py modified: mcts.py modified: models.py modified: test.ipynb --- agent.py | 6 +- mcts.py | 2 +- models.py | 2 +- test.ipynb | 350 +++++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 332 insertions(+), 28 deletions(-) diff --git a/agent.py b/agent.py index aa01082..a6ee83b 100644 --- a/agent.py +++ b/agent.py @@ -136,7 +136,6 @@ def retrain_nn(self, policy_pred, value_pred = neural_network(x_train) policy_loss = policy_loss_fn(policy_train, policy_pred) - print(f"shape value train: {value_train.shape}, shape value pred: {value_pred.shape}") value_loss = value_loss_fn(value_train, value_pred) combined_loss = policy_loss + value_loss @@ -161,14 +160,15 @@ def play_game(self, while not self.game.getGameEnded(board=game_state, player=player): stacked_tensor = torch.tensor(stacked_frames, dtype = torch.float32).unsqueeze(0) if player == 1: - print(stacked_tensor.shape) policy, _ = network_a(stacked_tensor) else: policy, _ = network_b(stacked_tensor) valid_moves = self.game.getValidMoves(game_state, player) ones_indices = np.where(valid_moves == 1)[0] - action = np.random.choice(ones_indices) + mask = torch.zeros_like(policy.squeeze(), dtype=torch.bool) + mask[torch.tensor(ones_indices)] = True + policy[~mask] = 0 _, action = torch.max(policy, dim=-1) game_state, player = self.game.getNextState( game_state, diff --git a/mcts.py b/mcts.py index ad7c374..ae88c71 100644 --- a/mcts.py +++ b/mcts.py @@ -215,7 +215,7 @@ def apv_mcts( cannonical_board = game.getCanonicalForm(node.state, player=player) policy, _ = model(input_tensor) #print(f"value: {val.shape}") - policy = policy.cpu().detach().numpy().squeeze(0) + policy = policy.cpu().detach().numpy() possible_actions = game.getValidMoves(node.state, player=player) policy *= possible_actions diff --git a/models.py b/models.py index fbcda3f..b6b968e 100644 --- a/models.py +++ b/models.py @@ -85,4 +85,4 @@ def forward(self, state) -> tuple[np.array, int]: s = self.value_head_conv(s) val = self.value_head_linear(s).squeeze() - return pi, val + return pi.squeeze(0), val diff --git a/test.ipynb b/test.ipynb index 97bf0b7..5dab853 100644 --- a/test.ipynb +++ b/test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 17, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -39,7 +39,7 @@ " [ 0, 0, 0, 0, 0, 0, 0, 0]])" ] }, - "execution_count": 19, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -68,7 +68,7 @@ "numpy.ndarray" ] }, - "execution_count": 21, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -79,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -95,7 +95,7 @@ " [ 0, 0, 0, 0, 0, 0, 0, 0]])" ] }, - "execution_count": 22, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -106,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -125,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -143,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -159,7 +159,7 @@ " [0, 0, 0, 0, 0, 0, 0, 0]])" ] }, - "execution_count": 27, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -170,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -186,7 +186,7 @@ " [ 0, 0, 0, 0, 0, 0, 0, 0]])" ] }, - "execution_count": 28, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -197,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -208,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -217,7 +217,7 @@ "torch.Size([8, 8])" ] }, - "execution_count": 30, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -331,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -372,6 +372,310 @@ " board_tensor[:, :, -1] = player_plane\n", " return board_tensor" ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.zeros((8,8))" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "b = np.ones((8,8))" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.]])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "c = np.stack([a, b], axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 8, 8)" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "c.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "result_array = np.concatenate((c, b[np.newaxis, ...]), axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.]]])" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_array" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 8, 8)" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_array.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "torch_a = torch.tensor(a, dtype=torch.float16)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8, 8])" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_a.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "torch_a = torch_a.unsqueeze(3)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8, 1, 8, 1])" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_a.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "torch_a = torch_a.squeeze()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8, 8])" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_a.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float16)" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch_a" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,\n", + " 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7]),\n", + " tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,\n", + " 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,\n", + " 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]))" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.where(torch_a==0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {