-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 62ceee3
Showing
4 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from collections import deque | ||
|
||
|
||
PLAYER_X = 1 | ||
PLAYER_O = -1 | ||
NO_PLAYER = 0 | ||
|
||
STR_MATRIX = { | ||
PLAYER_X: 'X', | ||
PLAYER_O: 'O', | ||
NO_PLAYER: '-' | ||
} | ||
|
||
ROWS = 3 | ||
BOARD_SIZE = ROWS*ROWS | ||
|
||
LOSS = 0.0 | ||
DRAW = 0.5 | ||
WIN = 1.0 | ||
|
||
|
||
class BaseBoard: | ||
"""Defines the general structure which a board implementation must implement""" | ||
def __init__(self): | ||
raise NotImplementedError | ||
|
||
def __str__(self): | ||
raise NotImplementedError | ||
|
||
def __copy__(self): | ||
raise NotImplementedError | ||
|
||
def make_move(self, move): | ||
raise NotImplementedError | ||
|
||
def take_move(self): | ||
raise NotImplementedError | ||
|
||
def get_moves(self): | ||
raise NotImplementedError | ||
|
||
def get_result(self, player_jm): | ||
raise NotImplementedError | ||
|
||
|
||
class Board: | ||
def __init__(self): | ||
self.pos = [0] * BOARD_SIZE | ||
self.side = PLAYER_X | ||
self.playerJustMoved = PLAYER_O | ||
self.history = deque() | ||
|
||
def __str__(self): | ||
lines = [] | ||
for combo in zip(*[self.pos[i::ROWS] for i in range(ROWS)]): | ||
lines.extend(['{:<5}'.format(STR_MATRIX[elem]) for elem in combo]) | ||
lines.append('\n') | ||
return ''.join(lines) | ||
|
||
def __copy__(self): | ||
_b = Board() | ||
_b.pos = self.pos[:] # copy list | ||
_b.side = self.side # todo remove this, not needed since player just moved | ||
_b.playerJustMoved = self.playerJustMoved | ||
_b.history = self.history.copy() # todo copying deque is too slow | ||
return _b | ||
|
||
def make_move(self, move): | ||
assert move in self.get_moves(), 'Position is already occupied' | ||
|
||
self.pos[move] = self.side | ||
self.side = -self.side # change side to move | ||
self.playerJustMoved = -self.playerJustMoved | ||
self.history.append(move) | ||
|
||
def take_move(self): | ||
move = self.history.pop() | ||
self.pos[move] = NO_PLAYER | ||
self.side = -self.side # change side to move | ||
self.playerJustMoved = -self.playerJustMoved | ||
|
||
def get_moves(self): | ||
return [idx for idx, value in enumerate(self.pos) if value == NO_PLAYER] | ||
|
||
def get_result(self, player_jm): | ||
cols_combo = [self.pos[i::ROWS] for i in range(ROWS)] | ||
rows_combo = list(zip(*cols_combo)) | ||
# print(cols_combo) | ||
# print(row s_combo) | ||
|
||
for i in range(ROWS): | ||
# Sum a row and a column | ||
row_result, col_result = sum(rows_combo[i]), sum(cols_combo[i]) | ||
|
||
# Check if sum of values of a row is not equal to number of rows i.e. all 1s or all -1s | ||
if abs(row_result) == ROWS: | ||
return WIN if int(row_result / ROWS) == player_jm else LOSS | ||
|
||
if abs(col_result) == ROWS: | ||
return WIN if int(col_result / ROWS) == player_jm else LOSS | ||
|
||
# Sum values on Right diagonal | ||
# Look at right Diagonal | ||
# exclude last element since it is not part of the diagonal | ||
# i.e. if you have [1, 2, 3, | ||
# 4, 5, 6, | ||
# 7 ,8 ,9] then right diagonal is [3, 5, 7] | ||
# i.e. starting from the right corner the diagonal is formed by every second number | ||
# (3, 5, 7), however this will also result in 9 being included which it should not be | ||
# therefore we remove it | ||
result = sum(self.pos[ROWS - 1::ROWS - 1][:-1]) | ||
if abs(result) == ROWS: | ||
return WIN if int(result / ROWS) == player_jm else LOSS | ||
|
||
# Left diagonal | ||
result = sum(self.pos[::ROWS + 1]) | ||
if abs(result) == ROWS: | ||
return WIN if int(result / ROWS) == player_jm else LOSS | ||
|
||
# Lastly check if no available squares are on the board => TIE | ||
if sum([abs(elem) for elem in self.pos]) == BOARD_SIZE: | ||
return DRAW |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from math import sqrt, log | ||
|
||
|
||
class Node: | ||
""" A node in the game tree. Note wins is always from the viewpoint of playerJustMoved. | ||
Crashes if state not specified. | ||
""" | ||
|
||
def __init__(self, move=None, parent=None, state=None): | ||
self.move = move # the move that got us to this node - "None" for the root node | ||
self.parentNode = parent # "None" for the root node | ||
self.childNodes = [] | ||
self.wins = 0 | ||
self.visits = 0 | ||
self.untriedMoves = state.get_moves() # future child nodes | ||
self.playerJustMoved = state.playerJustMoved # the only part of the state that the Node needs later | ||
|
||
def uct_select_child(self): | ||
""" Use the UCB1 formula to select a child node. Often a constant UCTK is applied so we have | ||
lambda c: c.wins/c.visits + UCTK * sqrt(2*log(self.visits)/c.visits to vary the amount of | ||
exploration versus exploitation. | ||
""" | ||
s = sorted(self.childNodes, key=lambda c: c.wins / c.visits + sqrt(2 * log(self.visits) / c.visits))[-1] | ||
return s | ||
|
||
def add_child(self, m, s): | ||
""" Remove m from untriedMoves and add a new child node for this move. | ||
Return the added child node | ||
""" | ||
n = Node(move=m, parent=self, state=s) | ||
self.untriedMoves.remove(m) | ||
self.childNodes.append(n) | ||
return n | ||
|
||
def update(self, result): | ||
""" Update this node - one additional visit and result additional wins. result must be from | ||
the viewpoint of playerJustmoved. | ||
""" | ||
self.visits += 1 | ||
self.wins += result | ||
|
||
def __repr__(self): | ||
return "[M:" + str(self.move) + " W/V:" + str(self.wins) + "/" + str(self.visits) + " U:" + str( | ||
self.untriedMoves) + "]" | ||
|
||
def convert_tree_to_string(self, indent): | ||
s = self.get_indent_string(indent) + str(self) | ||
for c in self.childNodes: | ||
s += c.convert_tree_to_string(indent + 1) | ||
return s | ||
|
||
@staticmethod | ||
def get_indent_string(indent): | ||
s = "\n" | ||
for i in range(1, indent + 1): | ||
s += "| " | ||
return s | ||
|
||
def convert_children_to_string(self): | ||
s = "" | ||
for c in self.childNodes: | ||
s += str(c) + "\n" | ||
return s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import random | ||
import time | ||
from multiprocessing import Queue, Process | ||
from operator import itemgetter | ||
|
||
from ultimate_tttoe.board import Board | ||
from ultimate_tttoe.node import Node | ||
|
||
|
||
def uct_multi(rootstate_: Board, itermax, verbose): | ||
moves = rootstate_.get_moves() | ||
|
||
if len(moves) == 1: # if only 1 move is possible don't bother searching anything | ||
return moves[0] | ||
|
||
avg_iters = itermax // len(moves) | ||
queue = Queue() | ||
|
||
processes = [] | ||
for move in moves: | ||
current_state = rootstate_.__copy__() | ||
current_state.make_move(move) | ||
p = Process(target=uct, args=(queue, move, current_state, avg_iters, verbose)) | ||
p.start() | ||
processes.append(p) | ||
|
||
for process in processes: | ||
process.join() | ||
# for move in moves: | ||
# state = rootstate_.__copy__() | ||
# state.make_move(move) | ||
# uct(queue, move, state, avg_iters, verbose) | ||
|
||
results = [] | ||
while not queue.empty(): | ||
move, wins, visits = queue.get() | ||
results.append((move, wins/visits)) | ||
|
||
# the score here refers to the score of the best enemy reply -> we choose a move which leads to a best enemy reply | ||
# with the least score | ||
best_move, score = sorted(results, key=itemgetter(1))[0] | ||
return best_move | ||
|
||
|
||
def rand_choice(x): # fastest way to get random item from list | ||
return x[int(random.random() * len(x))] | ||
|
||
|
||
def uct(queue: Queue, move_origin, rootstate, itermax, verbose=False): | ||
""" Conduct a UCT search for itermax iterations starting from rootstate. | ||
Return the best move from the rootstate. | ||
Assumes 2 alternating players (player 1 starts), with game results in the range [0.0, 1.0].""" | ||
|
||
rootnode = Node(state=rootstate) | ||
|
||
state = rootstate | ||
for i in range(itermax): | ||
node = rootnode | ||
moves_to_root = 0 | ||
|
||
# Select | ||
while not node.untriedMoves and node.childNodes: # node is fully expanded and non-terminal | ||
node = node.uct_select_child() | ||
state.make_move(node.move) | ||
moves_to_root += 1 | ||
|
||
# Expand | ||
if node.untriedMoves: # if we can expand (i.e. state/node is non-terminal) | ||
m = rand_choice(node.untriedMoves) | ||
state.make_move(m) | ||
moves_to_root += 1 | ||
node = node.add_child(m, state) # add child and descend tree | ||
|
||
# Rollout - this can often be made orders of magnitude quicker using a state.GetRandomMove() function | ||
while state.get_result(state.side) is None: # while state is non-terminal | ||
state.make_move(rand_choice(state.get_moves())) | ||
moves_to_root += 1 | ||
|
||
# Backpropagate | ||
while node is not None: # backpropagate from the expanded node and work back to the root node | ||
# state is terminal. Update node with result from POV of node.playerJustMoved | ||
result = state.get_result(node.playerJustMoved) | ||
node.update(result) | ||
node = node.parentNode | ||
|
||
for _ in range(moves_to_root): | ||
state.take_move() | ||
|
||
# Output some information about the tree - can be omitted | ||
# if verbose: | ||
# print(rootnode.convert_tree_to_string(0)) | ||
# else: | ||
# print(rootnode.convert_children_to_string()) | ||
|
||
# return sorted(rootnode.childNodes, key=lambda c: c.visits)[-1].move # return the move that was most visited | ||
bestNode = sorted(rootnode.childNodes, key=lambda c: c.visits)[-1] | ||
queue.put((move_origin, bestNode.wins, bestNode.visits)) | ||
|
||
|
||
def uct_play_game(): | ||
""" Play a sample game between two UCT players where each player gets a different number | ||
of UCT iterations (= simulations = tree nodes). | ||
""" | ||
state = Board() | ||
|
||
while state.get_result(state.side) is None: | ||
print(state) | ||
start = time.time() | ||
m = uct_multi(rootstate_=state, itermax=50000, verbose=False) # play with values for itermax and verbose = True | ||
print('Time it took', time.time() - start) | ||
print("Best Move: ", m, "\n") | ||
state.make_move(m) | ||
print(state) | ||
if state.get_result(state.playerJustMoved) == 1.0: | ||
print("Player " + str(state.playerJustMoved) + " wins!") | ||
elif state.get_result(state.playerJustMoved) == 0.0: | ||
print("Player " + str(-state.playerJustMoved) + " wins!") | ||
else: | ||
print("Nobody wins!") | ||
|
||
|
||
def user_play(): | ||
state = Board() | ||
|
||
while state.get_result(state.side) is None: | ||
print(state) | ||
move = int(input('Enter move:')) | ||
state.make_move(move) | ||
print(state) | ||
start = time.time() | ||
m = uct_multi(rootstate_=state, itermax=50000, verbose=False) # play with values for itermax and verbose = True | ||
print('Time it took', time.time() - start) | ||
print("Best Move: ", m, "\n") | ||
state.make_move(m) | ||
print(state) | ||
if state.get_result(state.playerJustMoved) == 1.0: | ||
print("Player " + str(state.playerJustMoved) + " wins!") | ||
elif state.get_result(state.playerJustMoved) == 0.0: | ||
print("Player " + str(-state.playerJustMoved) + " wins!") | ||
else: | ||
print("Nobody wins!") | ||
|
||
|
||
if __name__ == "__main__": | ||
""" Play a single game to the end using UCT for both players. | ||
""" | ||
# uct_play_game() | ||
user_play() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from ultimate_tttoe.board import * | ||
|
||
|
||
class UltimateBoard(BaseBoard): | ||
def __init__(self): | ||
pass | ||
|
||
def __str__(self): | ||
pass | ||
|
||
def __copy__(self): | ||
pass | ||
|
||
def make_move(self, move): | ||
pass | ||
|
||
def take_move(self): | ||
pass | ||
|
||
def get_moves(self): | ||
pass | ||
|
||
def get_result(self, player_jm): | ||
pass |