From ddbc8169d0263d83412715d163a0f5817d1f8cb5 Mon Sep 17 00:00:00 2001 From: QueensGambit Date: Wed, 26 Dec 2018 22:30:39 +0100 Subject: [PATCH] updates for version 0.3.1: Higher NPS 250->300, enabled transposition table, added time management regime by spending less time on obious moves, added opening guard moves to avoid exploration of moves < 5% for a given number of moves in the opening, added increasing cpuct value as described by recent DeepMind publicdation --- DeepCrazyhouse/src/domain/__init__.py | 0 .../src/domain/abstract_cls/_GameState.py | 16 +- .../src/domain/abstract_cls/__init__.py | 0 .../src/domain/agent/NeuralNetAPI.py | 8 +- DeepCrazyhouse/src/domain/agent/README.md | 0 DeepCrazyhouse/src/domain/agent/__init__.py | 0 .../src/domain/agent/player/MCTSAgent.py | 720 ++++++++++++++---- .../src/domain/agent/player/RawNetAgent.py | 28 +- .../src/domain/agent/player/_Agent.py | 92 +-- .../src/domain/agent/player/__init__.py | 0 .../agent/player/util/NetPredService.py | 89 ++- .../src/domain/agent/player/util/Node.py | 131 ++-- .../src/domain/agent/player/util/__init__.py | 0 .../src/domain/crazyhouse/GameState.py | 51 +- .../src/domain/crazyhouse/__init__.py | 0 .../src/domain/crazyhouse/constants.py | 0 .../domain/crazyhouse/input_representation.py | 48 +- .../crazyhouse/output_representation.py | 2 +- .../src/domain/neural_net/__init__.py | 0 .../architectures/AlphaZeroResnet.py | 0 .../domain/neural_net/architectures/Rise.py | 0 .../neural_net/architectures/__init__.py | 0 .../neural_net/architectures/builder_util.py | 0 .../architectures/rise_builder_util.py | 0 DeepCrazyhouse/src/domain/util.py | 5 + .../runtime/{Colorer.py => ColorLogger.py} | 40 +- DeepCrazyhouse/src/runtime/__init__.py | 0 .../src/samples/MCTS_eval_demo.ipynb | 20 +- crazyara.py | 342 +++++++-- setup.py | 16 - 30 files changed, 1135 insertions(+), 473 deletions(-) mode change 100644 => 100755 DeepCrazyhouse/src/domain/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/abstract_cls/_GameState.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/abstract_cls/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/NeuralNetAPI.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/README.md mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/RawNetAgent.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/_Agent.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/util/NetPredService.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/util/Node.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/agent/player/util/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/crazyhouse/GameState.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/crazyhouse/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/crazyhouse/constants.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/crazyhouse/input_representation.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/crazyhouse/output_representation.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/neural_net/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/neural_net/architectures/AlphaZeroResnet.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/neural_net/architectures/Rise.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/neural_net/architectures/__init__.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/neural_net/architectures/builder_util.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/neural_net/architectures/rise_builder_util.py mode change 100644 => 100755 DeepCrazyhouse/src/domain/util.py rename DeepCrazyhouse/src/runtime/{Colorer.py => ColorLogger.py} (78%) mode change 100644 => 100755 mode change 100644 => 100755 DeepCrazyhouse/src/runtime/__init__.py mode change 100644 => 100755 crazyara.py delete mode 100644 setup.py diff --git a/DeepCrazyhouse/src/domain/__init__.py b/DeepCrazyhouse/src/domain/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/abstract_cls/_GameState.py b/DeepCrazyhouse/src/domain/abstract_cls/_GameState.py old mode 100644 new mode 100755 index 4f8904ab..2bd40b64 --- a/DeepCrazyhouse/src/domain/abstract_cls/_GameState.py +++ b/DeepCrazyhouse/src/domain/abstract_cls/_GameState.py @@ -17,7 +17,7 @@ def __init__(self, board): self.board = board self._fen_dic = {} - def apply_move(self, move: chess.Move, remember_state=False): + def apply_move(self, move: chess.Move): #, remember_state=False): self.board.push(move) def get_state_planes(self): @@ -52,5 +52,19 @@ def get_board_fen(self): return self.board.fen() #return self.board.fen().rsplit(' ', 1)[0] + def get_transposition_key(self): + """ + Returns an identifier key for the current board state excluding move counters. + Calling ._transposition_key() is faster than .fen() + :return: + """ + return self.board._transposition_key() + def new_game(self): raise NotImplementedError + + def get_halfmove_counter(self): + return self.board.halfmove_clock + + def get_fullmove_number(self): + return self.board.fullmove_number diff --git a/DeepCrazyhouse/src/domain/abstract_cls/__init__.py b/DeepCrazyhouse/src/domain/abstract_cls/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/agent/NeuralNetAPI.py b/DeepCrazyhouse/src/domain/agent/NeuralNetAPI.py old mode 100644 new mode 100755 index c47112aa..d2d74e49 --- a/DeepCrazyhouse/src/domain/agent/NeuralNetAPI.py +++ b/DeepCrazyhouse/src/domain/agent/NeuralNetAPI.py @@ -1,6 +1,5 @@ import logging import numpy as np -import DeepCrazyhouse.src.runtime.Colorer import time import json import glob @@ -79,6 +78,13 @@ def __init__(self, ctx='cpu', batch_size=1): grad_req='null', force_rebind=True) self.executor.copy_params_from(arg_params, aux_params) + self.executors = [] + for i in range(batch_size): + executor = sym.simple_bind(ctx=self.ctx, data=(i+1, NB_CHANNELS_FULL, BOARD_HEIGHT, BOARD_WIDTH), + grad_req='null', force_rebind=True) + executor.copy_params_from(arg_params, aux_params) + self.executors.append(executor) + def get_executor(self): """ Returns the executor object used for inference diff --git a/DeepCrazyhouse/src/domain/agent/README.md b/DeepCrazyhouse/src/domain/agent/README.md old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/agent/__init__.py b/DeepCrazyhouse/src/domain/agent/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py b/DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py old mode 100644 new mode 100755 index 2a7600b8..9b48f0d4 --- a/DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py +++ b/DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py @@ -4,14 +4,23 @@ @project: crazy_ara_refactor @author: queensgambit -Please describe what the content of this file is about +The MCTSAgent runs playouts/simulations in the search tree and updates the node statistics. +The final move is chosen according to the visit count of each direct child node. +One playout is defined as expanding one new node in the tree. In the case of chess this means evaluating a new board position. + +If the evaluation for one move takes too long on your hardware you can decrease the value for: + nb_playouts_empty_pockets and nb_playouts_filled_pockets. + +For more details and the mathematical equations please take a look at src/domain/agent/README.md as well as the +official DeepMind-papers. """ import numpy as np + from DeepCrazyhouse.src.domain.crazyhouse.output_representation import get_probs_of_move_list, value_to_centipawn from DeepCrazyhouse.src.domain.agent.NeuralNetAPI import NeuralNetAPI from copy import deepcopy -from multiprocessing import Barrier, Pipe +from multiprocessing import Pipe import logging from DeepCrazyhouse.src.domain.agent.player.util.NetPredService import NetPredService from DeepCrazyhouse.src.domain.agent.player.util.Node import Node @@ -19,9 +28,12 @@ from time import time from DeepCrazyhouse.src.domain.agent.player._Agent import _Agent from DeepCrazyhouse.src.domain.crazyhouse.GameState import GameState +import collections +from DeepCrazyhouse.src.domain.crazyhouse.constants import NB_CHANNELS_FULL, BOARD_WIDTH, BOARD_HEIGHT, NB_LABELS +import math import cProfile, pstats, io -from numba import jit +DTYPE = np.float def profile(fnc): @@ -48,22 +60,13 @@ def inner(*args, **kwargs): class MCTSAgent(_Agent): - def __init__(self, net: NeuralNetAPI, threads=16, batch_size=8, playouts_empty_pockets=256, + def __init__(self, nets: [NeuralNetAPI], threads=16, batch_size=8, playouts_empty_pockets=256, playouts_filled_pockets=512, cpuct=1, dirichlet_epsilon=.25, - dirichlet_alpha=0.2, max_search_depth=15, temperature=0., clip_quantil=0., + dirichlet_alpha=0.2, max_search_depth=15, temperature=0., temperature_moves=4, q_value_weight=0., virtual_loss=3, verbose=True, min_movetime=100, check_mate_in_one=False, - enable_timeout=False): + use_pruning=True, use_oscillating_cpuct=True, use_time_management=True, opening_guard_moves=0): """ Constructor of the MCTSAgent. - The MCTSAgent runs playouts/simulations in the search tree and updates the node statistics. - The final move is chosen according to the visit count of each direct child node. - One playout is defined as expanding one new node in the tree. In the case of chess this means evaluating a new board position. - - If the evaluation for one move takes too long on your hardware you can decrease the value for: - nb_playouts_empty_pockets and nb_playouts_filled_pockets. - - For more details and the mathematical equations please take a look at src/domain/agent/README.md as well as the - official DeepMind-papers. :param net: NeuralNetAPI handle which is used to communicate with the neural network :param threads: Number of threads to evaluate the nodes in parallel @@ -90,6 +93,8 @@ def __init__(self, net: NeuralNetAPI, threads=16, batch_size=8, playouts_empty_p If 0. -> Deterministic policy. The move is chosen with the highest probability If 1. -> Pure random sampling policy. The move is sampled from the posterior without any scaling being applied. + :param temperature_moves: Number of fullmoves in which the temperature parameter will be applied. + Otherwise the temperature will be set to 0 for deterministic play. :param clip_quantil: A quantil clipping parameter with range [0., 1.]. All cummulated low percentages for moves are set to 0. This makes sure that very unlikely moves (blunders) are clipped after the exponential scaling. @@ -105,9 +110,14 @@ def __init__(self, net: NeuralNetAPI, threads=16, batch_size=8, playouts_empty_p option is disabled because it takes costs too much nps regarding its benefit. :param enable_timeout: Decides weather to enable a timout if a batch didn't occur under 1 second for the NetPredService. + :param use_time_management: If set to true the mcts will spent less time on "obvious" moves an allocate a time + buffer for more critical moves. + :param opening_guard_moves: Number of moves for which the exploration is limited (only recommended for . Moves which have a prior + probability < 5% are clipped and not evaluated. If 0 no clipping will be done in the + opening. """ - super().__init__(temperature, clip_quantil, verbose) + super().__init__(temperature, temperature_moves, verbose) # the root node contains all references to its child nodes self.root_node = None @@ -119,10 +129,15 @@ def __init__(self, net: NeuralNetAPI, threads=16, batch_size=8, playouts_empty_p self.node_lookup = {} # get the network reference - self.net = net + self.nets = nets self.virtual_loss = virtual_loss self.cpuct_init = cpuct + + if cpuct < 0.01 or cpuct > 10: + raise Exception('You might have confused centi-cpuct with cpuct.' + 'The requested cpuct is beyond reasonable range: cpuct should be around > 0.01 and < 10.') + self.cpuct = cpuct self.max_search_depth = max_search_depth self.threads = threads @@ -147,8 +162,6 @@ def __init__(self, net: NeuralNetAPI, threads=16, batch_size=8, playouts_empty_p self.my_pipe_endings.append(ending1) pip_endings_external.append(ending2) - self.net_pred_service = NetPredService(pip_endings_external, self.net, batch_size, enable_timeout) - self.nb_playouts_empty_pockets = playouts_empty_pockets self.nb_playouts_filled_pockets = playouts_filled_pockets @@ -166,7 +179,32 @@ def __init__(self, net: NeuralNetAPI, threads=16, batch_size=8, playouts_empty_p # number of nodes before the evaluate_board_state() call are stored here to measure the nps correctly self.total_nodes_pre_search = None - def evaluate_board_state(self, state_in: GameState): + # allocate shared memory for communicating with the network prediction service + self.batch_state_planes = np.zeros((self.threads, NB_CHANNELS_FULL, BOARD_HEIGHT, BOARD_WIDTH), DTYPE) + self.batch_value_results = np.zeros(self.threads, DTYPE) + self.batch_policy_results = np.zeros((self.threads, NB_LABELS), DTYPE) + + # initialize the NetworkPredictionService and give the pointers to the shared memory + self.net_pred_services = [] + nb_pipes = self.threads // len(nets) + + # create multiple gpu-access points + for i, net in enumerate(nets): + net_pred_service = NetPredService(pip_endings_external[i*nb_pipes:(i+1)*nb_pipes], net, batch_size, self.batch_state_planes, + self.batch_value_results, self.batch_policy_results) + self.net_pred_services.append(net_pred_service) + + self.transposition_table = collections.Counter() + self.send_batches = False + self.root_node_prior_policy = None + + self.use_pruning = use_pruning + self.use_oscillating_cpuct = use_oscillating_cpuct + self.time_buffer_ms = 0 + self.use_time_management = use_time_management + self.opening_guard_moves = opening_guard_moves + + def evaluate_board_state(self, state: GameState): """ Analyzes the current board state. This is the main method which get called by the uci interface or analysis request. @@ -178,30 +216,36 @@ def evaluate_board_state(self, state_in: GameState): # store the time at which the search started self.t_start_eval = time() - # create a deepcopy of the state in order not to change the given input parameter - state = deepcopy(state_in) - # check if the net prediction service has already been started - if self.net_pred_service.running is False: + if self.net_pred_services[0].running is False: # start the prediction daemon thread - self.net_pred_service.start() + for net_pred_service in self.net_pred_services: + net_pred_service.start() # receive a list of all possible legal move in the current board position - legal_moves = list(state.get_legal_moves()) + legal_moves = state.get_legal_moves() # consistency check if len(legal_moves) == 0: raise Exception('The given board state has no legal move available') # check first if the the current tree can be reused - board_fen = state.get_board_fen() - if board_fen in self.node_lookup: - self.root_node = self.node_lookup[board_fen] + key = state.get_transposition_key() + (state.get_fullmove_number(),) + + if self.use_pruning is False and key in self.node_lookup: + #if key in self.node_lookup: + + self.root_node = self.node_lookup[key] + logging.debug('Reuse the search tree. Number of nodes in search tree: %d', self.root_node.nb_total_expanded_child_nodes) self.total_nodes_pre_search = deepcopy(self.root_node.n_sum) + + # reset potential good nodes for the root + # self.root_node.q[self.root_node.q < 1.1] = 0 + self.root_node.q[self.root_node.q < 0] = self.root_node.q.max() - 0.25 + else: - logging.debug("The given board position wasn't found in the search tree.") logging.debug("Starting a brand new search tree...") self.root_node = None self.total_nodes_pre_search = 0 @@ -221,21 +265,85 @@ def evaluate_board_state(self, state_in: GameState): # run a single expansion on the root node self._expand_root_node_multiple_moves(state, legal_moves) + # opening guard + if state.get_fullmove_number() <= self.opening_guard_moves: #100: #7: #10: + self.root_node.q[self.root_node.p < 5e-2] = -9999 + #elif len(legal_moves) > 50: + # self.root_node.q[self.root_node.p < 1e-3] = -9999 + # conduct the mcts-search based on the given settings max_depth_reached = self._run_mcts_search(state) t_elapsed = time() - self.t_start_eval print('info string move overhead is %dms' % (t_elapsed*1000 - self.movetime_ms)) + #xth_n_max = self.get_xth_max(10) + #print('xth_n-max: ', xth_n_max) # receive the policy vector based on the MCTS search - p_vec_small = self.root_node.get_mcts_policy(self.q_value_weight) + p_vec_small = self.root_node.get_mcts_policy(self.q_value_weight) #, xth_n_max=xth_n_max, is_root=True) + + + # experimental + """ + orig_q = np.array(self.root_node.q) + #indices = self.root_node.n.max() > clip_fac + candidate_child = p_vec_small.argmax() #self.get_2nd_max() + latest, indices = self.get_last_q_values(candidate_child) + + + # ensure that the q value for the end node are properly set + #if len(indices) > 0: + #self.root_node.w[indices] += (self.root_node.n[indices]/1) * latest[indices] + #self.root_node.q[indices] = self.root_node.w[indices] / (self.root_node.n[indices] + (self.root_node.n[indices]/1)) + if True: #self.root_node.q[candidate_child] < 0: # and latest[candidate_child] + 0.5 < self.root_node.q[candidate_child]: + #self.root_node.q[candidate_child] = (self.root_node.q[candidate_child] + latest[candidate_child]) + #self.root_node.q[latest[self.root_node.thresh_idcs_root] < self.root_node.q[self.root_node.thresh_idcs_root]] = -1 + #print('q - shape', self.root_node.q.shape) + #print('latest - shape', latest.shape) + #print('thresh - shape', self.root_node.thresh_idcs_root.shape) + sel_indices = latest < self.root_node.q + sel_indices[np.invert(self.root_node.thresh_idcs_root)] = False + + + #print('sel indices -shape', len(sel_indices)) + self.root_node.q[sel_indices] = (latest[sel_indices] + self.root_node.q[sel_indices]) / 2 #-1 + + sel_indices = np.invert(sel_indices) + sel_indices[self.root_node.thresh_idcs_root] = False + self.root_node.q[sel_indices] = (latest[sel_indices] + self.root_node.q[sel_indices]) / 2 + + #prior_child = self.root_node.p.argmax() + #if latest[prior_child] > self.root_node.q[prior_child]: + # self.root_node.q[prior_child] = (latest[prior_child] + self.root_node.q[prior_child]) / 2 #latest[prior_child] + + #self.root_node.q[indices] = self.root_node.q[indices] * (latest[indices] + 1) + #self.root_node.q[indices] += latest[indices] + + # receive the policy vector based on the MCTS search + p_vec_small = self.root_node.get_mcts_policy(self.q_value_weight) #, xth_n_max=xth_n_max, is_root=True) + """ + + #max_n = self.root_node.n.max() + #latest[self.root_node.n < max_n / 2] = -1 + #latest += 1 + #latest /= sum(latest) + #if latest.max() > 0: + # p_vec_small[latest < 0] = 0 + #p_vec_small = p_vec_small + latest + #p_vec_small[p_vec_small < 0] = 0 + #p_vec_small[p_vec_small > 1] = 1 + + #p_vec_small /= sum(p_vec_small) + + #if self.use_pruning is False: # store the current root in the lookup table - self.node_lookup[state.get_board_fen()] = self.root_node + self.node_lookup[key] = self.root_node # select the q-value according to the mcts best child value - best_child_idx = self.root_node.get_mcts_policy(self.q_value_weight).argmax() + best_child_idx = p_vec_small.argmax() value = self.root_node.q[best_child_idx] + #value = orig_q[best_child_idx] lst_best_moves, _ = self.get_calculated_line() str_moves = self._mv_list_to_str(lst_best_moves) @@ -243,21 +351,31 @@ def evaluate_board_state(self, state_in: GameState): # show the best calculated line node_searched = int(self.root_node.n_sum - self.total_nodes_pre_search) # In uci the depth is given using half-moves notation also called plies - logging.debug('Update info') time_e = time() - self.t_start_eval - print('info score cp %d depth %d nodes %d time %d nps %d pv%s' % ( value_to_centipawn(value), - max_depth_reached, - node_searched, time_e*1000, - node_searched/max(1,time_e), str_moves)) if len(legal_moves) != len(p_vec_small): - raise Exception('Legal move list %s with length %s is uncompatible to policy vector %s' - ' with shape %s for board state %s' % (legal_moves, len(legal_moves), - p_vec_small, p_vec_small.shape, state_in)) + raise Exception('Legal move list %s with length %s is uncompatible to policy vector %s' + ' with shape %s for board state %s and nodes legal move list: %s' % + (legal_moves, len(legal_moves), + p_vec_small, p_vec_small.shape, state, + self.root_node.legal_moves)) + + # define the remaining return variables + cp = value_to_centipawn(value) + depth = max_depth_reached + nodes = node_searched + time_elapsed_s = time_e*1000 + nps = node_searched/time_e + pv = str_moves + + # print out the score as a debug message if verbose it set to true + # the file crazyara.py will print the chosen line to the std output + if self.verbose is True: + score = "score cp %d depth %d nodes %d time %d nps %d pv %s" % (cp, depth, nodes, time_elapsed_s, nps, pv) + logging.info('info string %s' % score) + + return value, legal_moves, p_vec_small, cp, depth, nodes, time_elapsed_s, nps, pv - return value, legal_moves, p_vec_small - - @jit def _expand_root_node_multiple_moves(self, state, legal_moves): """ Checks if the current root node can be found in the look-up table. @@ -273,7 +391,7 @@ def _expand_root_node_multiple_moves(self, state, legal_moves): # start a brand new tree state_planes = state.get_state_planes() - [value, policy_vec] = self.net.predict_single(state_planes) + [value, policy_vec] = self.nets[0].predict_single(state_planes) # extract a sparse policy vector with normalized probabilities p_vec_small = get_probs_of_move_list(policy_vec, legal_moves, state.is_white_to_move()) @@ -284,9 +402,8 @@ def _expand_root_node_multiple_moves(self, state, legal_moves): str_legal_moves = '' # create a new root node - self.root_node = Node(value, p_vec_small, legal_moves, str_legal_moves, is_leaf) + self.root_node = Node(value, p_vec_small, legal_moves, str_legal_moves, is_leaf, clip_low_visit=False) - @jit def _expand_root_node_single_move(self, state, legal_moves): """ Expands the current root in the case if there's only a single move available. @@ -297,12 +414,14 @@ def _expand_root_node_single_move(self, state, legal_moves): :return: """ - # set value 0 as a dummy value - value = 0 + # request the value prediction for the current position + state_planes = state.get_state_planes() + [value, _] = self.nets[0].predict_single(state_planes) + # we can create the move probability vector without the NN this time p_vec_small = np.array([1], np.float32) # create a new root node - self.root_node = Node(value, p_vec_small, legal_moves, str(state.get_legal_moves())) + self.root_node = Node(value, p_vec_small, legal_moves, str(state.get_legal_moves()), clip_low_visit=False) # check a child node if it doesn't exists already if self.root_node.child_nodes[0] is None: @@ -322,18 +441,19 @@ def _expand_root_node_single_move(self, state, legal_moves): p_vec_small_child = None # check if you can claim a draw - its assumed that the draw is always claimed - elif state.is_draw() is True: + elif self.can_claim_threefold_repetition(state.get_transposition_key(), [0]) or\ + state.get_pythonchess_board().can_claim_fifty_moves() is True: value = 0 is_leaf = True legal_moves_child = [] p_vec_small_child = None else: - legal_moves_child = list(state_child.get_legal_moves()) + legal_moves_child = state_child.get_legal_moves() # start a brand new prediction for the child state_planes = state_child.get_state_planes() - [value, policy_vec] = self.net.predict_single(state_planes) + [value, policy_vec] = self.nets[0].predict_single(state_planes) # extract a sparse policy vector with normalized probabilities p_vec_small_child = get_probs_of_move_list(policy_vec, legal_moves_child, @@ -346,6 +466,10 @@ def _expand_root_node_single_move(self, state, legal_moves): # connect the child to the root self.root_node.child_nodes[0] = child_node + # assign the value of the root node as the q-value for the child + # here we must invert the invert the value because it's the value prediction of the next state + self.root_node.q[0] = -value + def _run_mcts_search(self, state): """ Runs a new or continues the mcts on the current search tree. @@ -357,6 +481,9 @@ def _run_mcts_search(self, state): # clear the look up table self.node_lookup = {} + # safe the prior policy of the root node + self.root_node_prior_policy = deepcopy(self.root_node.p) + # apply dirichlet noise to the prior probabilities in order to ensure # that every move can possibly be visited self.root_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon, @@ -373,79 +500,151 @@ def _run_mcts_search(self, state): nb_playouts = self.nb_playouts_empty_pockets else: nb_playouts = self.nb_playouts_filled_pockets + #self.temperature_current = 0 + + #if self.root_node.v > 0.65: + # fac = 0.1 + #else: + # fac = 0.02 #2 + # iterate through all children and add dirichlet if there exists any + for child_node in self.root_node.child_nodes: + if child_node is not None: + # add dirichlet noise to a the child nodes of the root node + child_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon * 0.05, #02, + alpha=self.dirichlet_alpha) + #child_node.q[child_node.q < 0] = child_node.q.max() - 0.25 + + t_elapsed_ms = 0 - t_elapsed = 0 cur_playouts = 0 old_time = time() cpuct_init = self.cpuct + decline = True + + move_step = self.movetime_ms / 2 + move_update = move_step + move_update_2 = self.movetime_ms * 0.9 + + #nb_playouts_update_step = 4000 + #nb_playouts_update = 4000 + + #self.hard_clipping = True + + if self.use_time_management is True: + time_checked = False + time_checked_early = False + else: + time_checked = True + time_checked_early = True + + consistent_check = False #False + consistent_check_playouts = 2048 + while max_depth_reached < self.max_search_depth and \ cur_playouts < nb_playouts and \ - t_elapsed * 1000 < self.movetime_ms: # and np.abs(self.root_node.q.mean()) < 0.99: + t_elapsed_ms < self.movetime_ms: # and np.abs(self.root_node.q.mean()) < 0.99: - # Test about decreasing CPUCT value - self.cpuct -= 0.005 #2 #5 #1 #np.random.randint(1,5) #0.005 - if self.cpuct < 1.3: # 5: - self.cpuct = 1.3 # 5 + if self.use_oscillating_cpuct is True: + # Test about decreasing CPUCT value + if decline is True: + self.cpuct -= 0.01 + else: + self.cpuct += 0.01 + if self.cpuct < cpuct_init * .5: + decline = False + elif self.cpuct > cpuct_init: + decline = True + """ + if cur_playouts >= nb_playouts_update: + #print('UPDATE') + self.root_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon, + alpha=self.dirichlet_alpha) + # iterate through all children and add dirichlet if there exists any + for child_node in self.root_node.child_nodes: + if child_node is not None: + # test of adding dirichlet noise to a new node + child_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon * fac, + alpha=self.dirichlet_alpha) + nb_playouts_update += nb_playouts_update_step + #move_update += move_step + """ # start searching + with ThreadPoolExecutor(max_workers=self.threads) as executor: for i in range(self.threads): # calculate the thread id based on the current playout futures.append(executor.submit(self._run_single_playout, state=state, - parent_node=self.root_node, pipe_id=i, depth=1, mv_list=[])) + parent_node=self.root_node, pipe_id=i, depth=1, chosen_nodes=[])) cur_playouts += self.threads time_show_info = time() - old_time - # store the mean of all value predictions in this variable - # mean_value = 0 - for i, f in enumerate(futures): - cur_value, cur_depth, mv_list = f.result() - - # sum up all values - # mean_value += cur_value + cur_value, cur_depth, chosen_nodes = f.result() if cur_depth > max_depth_reached: max_depth_reached = cur_depth # Print the explored line of the last line for every x seconds if verbose is true if self.verbose and time_show_info > 0.5 and i == len(futures) - 1: + mv_list = self._create_mv_list(chosen_nodes) str_moves = self._mv_list_to_str(mv_list) - #logging.debug('Update: %d' % cur_depth) - print('info score cp %d depth %d nodes %d pv%s' % ( + print('info score cp %d depth %d nodes %d pv %s' % ( value_to_centipawn(cur_value), cur_depth, self.root_node.n_sum, str_moves)) - #print('info cpuct: %.2f' % self.cpuct) + logging.debug('Update info') old_time = time() - """ - # Show only current best line - # Print every second if verbose is true - if self.verbose and time_show_info > 1: - # select the q-value according to the mcts best child value - best_child_idx = self.root_node.get_mcts_policy(self.q_value_weight).argmax() - cur_value = self.root_node.q[best_child_idx] - - lst_best_moves, _ = self.get_calculated_line() - str_moves = self._mv_list_to_str(lst_best_moves) - print('info score cp %d depth %d nodes %d pv%s' % ( - value_to_centipawn(cur_value), len(lst_best_moves), self.root_node.n_sum, str_moves)) - old_time = time() - """ - # update the current search time t_elapsed = time() - self.t_start_eval - if self.verbose and time_show_info > 1: + t_elapsed_ms = t_elapsed * 1000 + if time_show_info > 1: node_searched = int(self.root_node.n_sum - self.total_nodes_pre_search) - print('info nps %d time %d' % (int((node_searched / t_elapsed)), t_elapsed * 1000)) - + print('info nps %d time %d' % (int((node_searched / t_elapsed)), t_elapsed_ms)) + + if time_checked_early is False and t_elapsed_ms > move_update: + #node, _, _, child_idx = self._select_node_based_on_mcts_policy(self.root_node) + if self.root_node.p.max() > 0.9 and self.root_node.p.argmax() == self.root_node.q.argmax(): + self.time_buffer_ms += (self.movetime_ms - t_elapsed_ms) * 0.9 + print('info early break up') + break + else: + time_checked_early = True + + if consistent_check is False and cur_playouts > consistent_check_playouts and self.root_node_prior_policy.max() > np.partition(self.root_node_prior_policy.flatten(), -2)[-2] + 0.3: + print('Consistency check') + if self.root_node.get_mcts_policy(self.q_value_weight).argmax() == self.root_node_prior_policy.argmax(): + self.time_buffer_ms += (self.movetime_ms - t_elapsed_ms) * 0.9 + print('info early break up') + break + else: + consistent_check = True + + if self.time_buffer_ms > 2500 and time_checked is False and t_elapsed_ms > move_update_2 and self.root_node.q[self.root_node.n.argmax()] < self.root_node.v + 0.01: + print('info increase time') + time_checked = True + #for child_node in self.root_node.child_nodes: + # if child_node is not None: + # # test of adding dirichlet noise to a new node + # child_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon * .5, + # alpha=self.dirichlet_alpha) + time_boni = self.time_buffer_ms / 4 + # increase the movetime + self.time_buffer_ms -= time_boni + self.movetime_ms += (time_boni) * 0.75 + self.root_node.v = self.root_node.q[self.root_node.n.argmax()] + if self.time_buffer_ms < 0: + self.movetime_ms += self.time_buffer_ms + self.time_buffer_ms = 0 + #if self.root_node.q[child_idx] < 0: + # self.hard_clipping = False self.cpuct = cpuct_init return max_depth_reached - def perform_action(self, state: GameState, verbose=True): + def perform_action(self, state_in: GameState, verbose=True): """ Return a value, best move with according to the mcts search. This method is used when using the mcts agent as a player. @@ -458,50 +657,28 @@ def perform_action(self, state: GameState, verbose=True): selected_child_idx - Child index which correspond to the selected child """ - value, selected_move, confidence, selected_child_idx = super().perform_action(state) - - # apply the selected mve on the current board state in order to create a lookup table for future board states - state.apply_move(selected_move) - - # select the q value for the child which leads to the best calculated line - value = self.root_node.q[selected_child_idx] - - # select the next node - node = self.root_node.child_nodes[selected_child_idx] - - # store the reference links for all possible child future child to the node lookup table - for idx, mv in enumerate(state.get_legal_moves()): - state_future = deepcopy(state) - state_future.apply_move(mv) - - # store the current child node with it's board fen as the hash-key if the child node has already been expanded - if node is not None and idx < node.nb_direct_child_nodes and node.child_nodes[idx] is not None: - self.node_lookup[state_future.get_board_fen()] = node.child_nodes[idx] + # create a deepcopy of the state in order not to change the given input parameter + state = deepcopy(state_in) - return value, selected_move, confidence, selected_child_idx + return super().perform_action(state) - #@profile - def _run_single_playout(self, state: GameState, parent_node: Node, pipe_id=0, depth=1, mv_list=[]): + def _run_single_playout(self, state: GameState, parent_node: Node, pipe_id=0, depth=1, chosen_nodes=[]): """ This function works recursively until a leaf or terminal node is reached. It ends by backpropagating the value of the new expanded node or by propagating the value of a terminal state. - :param state: Current game-state for the evaluation. This state differs between the treads + :param state_: Current game-state for the evaluation. This state differs between the treads :param parent_node: Current parent-node of the selected node. In the first expansion this is the root node. :param depth: Current depth for the evaluation. Depth is increased by 1 for every recusive call - :param mv_list: List of moves which have been taken in the current path. For each selected child node this list + :param chosen_nodes: List of moves which have been taken in the current path. For each selected child node this list is expanded by one move recursively. + :param chosen_nodes: List of all nodes that this thread has explored with respect to the root node :return: -value: The inverse value prediction of the current board state. The flipping by -1 each turn is needed because the point of view changes each half-move depth: Current depth reach by this evaluation mv_list: List of moves which have been selected """ - # create a deepcopy of the state for all future recursive calls if it's the first of function - # call of _run_single_playout() - if depth == 1: - state = deepcopy(state) - # select a legal move on the chess board node, move, child_idx = self._select_node(parent_node) @@ -513,69 +690,118 @@ def _run_single_playout(self, state: GameState, parent_node: Node, pipe_id=0, de # the effect of virtual loss will be undone if the playout is over parent_node.apply_virtual_loss_to_child(child_idx, self.virtual_loss) + if depth == 1: + state = GameState(deepcopy(state.get_pythonchess_board())) + # apply the selected move on the board state.apply_move(move) # append the selected move to the move list - mv_list.append(move) + # append the chosen child idx to the chosen_nodes list + chosen_nodes.append(child_idx) if node is None: - # get the board-fen which is used as an identifier for the board positions in the look-up table - board_fen = state.get_board_fen() + # get the transposition-key which is used as an identifier for the board positions in the look-up table + transposition_key = state.get_transposition_key() # check if the addressed fen exist in the look-up table - if board_fen in self.node_lookup: + # note: It's important to use also the halfmove-counter here, otherwise the system can create an infinite + # feed-back-loop + key = transposition_key + (state.get_fullmove_number(),) + use_tran_table = True + + node_varified = False + if use_tran_table is True and key in self.node_lookup: + #if self.check_for_duplicate(transposition_key, chosen_nodes) is False: # get the node from the look-up list - node = self.node_lookup[board_fen] + node = self.node_lookup[key] + + # make sure that you don't connect to a node with lower visits + if node.n_sum > parent_node.n_sum: + node_varified = True + + if node_varified is True: with parent_node.lock: # setup a new connection from the parent to the child parent_node.child_nodes[child_idx] = node + #logging.debug('found key: %s' % state.get_board_fen()) # get the prior value from the leaf node which has already been expanded value = node.v + # receive a free available pipe - my_pipe = self.my_pipe_endings[pipe_id] - my_pipe.send(state.get_state_planes()) - #this pipe waits for the predictions of the network inference service - [_, _] = my_pipe.recv() + #my_pipe = self.my_pipe_endings[pipe_id] + #my_pipe.send(state.get_state_planes()) + # this pipe waits for the predictions of the network inference service + #[_, _] = my_pipe.recv() # get the value from the leaf node (the current function is called recursively) - #value, depth, mv_list = self._run_single_playout(state, node, pipe_id, depth+1, mv_list) - + # value, depth, mv_list = self._run_single_playout(state, node, pipe_id, depth+1, mv_list) else: # expand and evaluate the new board state (the node wasn't found in the look-up table) # its value will be backpropagated through the tree and flipped after every layer - # receive a free available pipe - my_pipe = self.my_pipe_endings[pipe_id] #.pop() - #logging.debug('thread %d request' % pipe_id) - my_pipe.send(state.get_state_planes()) - # this pipe waits for the predictions of the network inference service - [value, policy_vec] = my_pipe.recv() + my_pipe = self.my_pipe_endings[pipe_id] + + if self.send_batches is True: + my_pipe.send(state.get_state_planes()) + # this pipe waits for the predictions of the network inference service + [value, policy_vec] = my_pipe.recv() + else: + state_planes = state.get_state_planes() + self.batch_state_planes[pipe_id] = state_planes + + my_pipe.send(pipe_id) + + result_channel = my_pipe.recv() + + value = np.array(self.batch_value_results[result_channel]) + policy_vec = np.array(self.batch_policy_results[result_channel]) # initialize is_leaf by default to false is_leaf = False + # check if the current player has won the game # (we don't need to check for is_lost() because the game is already over # if the current player checkmated his opponent) - if state.is_won() is True: - value = -1 - is_leaf = True - legal_moves = [] - p_vec_small = None + is_won = False + #is_check = False + + if state.is_check() is True: + # enhance checking nodes + #if depth == 1: + # parent_node.p.mean() + # with parent_node.lock: + # if parent_node.p[child_idx] < 0.1: + # parent_node.p[child_idx] = 0.1 + #is_check = True + if state.is_won() is True: + is_won = True + + if is_won is True: + value = -1 + is_leaf = True + legal_moves = [] + p_vec_small = None + # establish a mate in one connection in order to stop exploring different alternatives + parent_node.mate_child_idx = child_idx + # get the value from the leaf node (the current function is called recursively) # check if you can claim a draw - its assumed that the draw is always claimed - elif False: #state.is_draw() is True: TODO: Create more performant implementation + elif self.can_claim_threefold_repetition(transposition_key, chosen_nodes) or \ + state.get_pythonchess_board().can_claim_fifty_moves() is True: + #raise Exception('Threefold!') value = 0 is_leaf = True legal_moves = [] p_vec_small = None else: # get the current legal move of its board state - legal_moves = list(state.get_legal_moves()) + legal_moves = state.get_legal_moves() + if len(legal_moves) < 1: raise Exception('No legal move is available for state: %s' % state) @@ -587,50 +813,108 @@ def _run_single_playout(self, state: GameState, parent_node: Node, pipe_id=0, de except KeyError: raise Exception('Key Error for state: %s' % state) + #if state.get_board_fen() == 'r1b3k1/ppq2pP1/3n1Ppp/4Q2N/4B3/P1P3bP/2P1nPPr/4rB1K/PRPNp w - - 0 36': + # print('found it! > is won %d' % is_won) + # convert all legal moves to a string if the option check_mate_in_one was enabled if self.check_mate_in_one is True: str_legal_moves = str(state.get_legal_moves()) else: str_legal_moves = '' + # clip the visit nodes for all nodes in the search tree except the director opp. move + clip_low_visit = self.use_pruning and depth != 1 + # create a new node - new_node = Node(value, p_vec_small, legal_moves, str_legal_moves, is_leaf) + new_node = Node(value, p_vec_small, legal_moves, str_legal_moves, is_leaf, transposition_key, clip_low_visit) + + if depth == 1: + + # disable uncertain moves from being visited by giving them a very bad score + if is_leaf is False and self.use_pruning is True: + if self.root_node_prior_policy[child_idx] < 1e-3 and value * -1 < self.root_node.v: + with parent_node.lock: + value = 99 + + if parent_node.v > 0.65: # and state.are_pocket_empty(): #and pipe_id == 0: + # test of adding dirichlet noise to a new node + fac = 0.25 + if len(parent_node.legal_moves) < 20: + fac *= 5 + new_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon*fac, + alpha=self.dirichlet_alpha) - #if is_leaf is False: - #if depth == 2 and pipe_id == 0: - # # test of adding dirichlet noise to a new node - # new_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon/3, alpha=self.dirichlet_alpha) + if value < 0: # and state.are_pocket_empty(): #and pipe_id == 0: + # test of adding dirichlet noise to a new node + new_node.apply_dirichlet_noise_to_prior_policy(epsilon=self.dirichlet_epsilon * .02, + alpha=self.dirichlet_alpha) - # include a reference to the new node in the look-up table - self.node_lookup[board_fen] = new_node + if self.use_pruning is False: + # include a reference to the new node in the look-up table + self.node_lookup[key] = new_node with parent_node.lock: # add the new node to its parent parent_node.child_nodes[child_idx] = new_node - # check if the new node has a mate_in_one connection (if yes overwrite the network prediction) - if new_node.mate_child_idx is not None: - value = 1 - # check if we have reached a leaf node elif node.is_leaf is True: value = node.v - # receive a free available pipe - my_pipe = self.my_pipe_endings[pipe_id] #.pop() - #logging.debug('thread %d request' % pipe_id) - my_pipe.send(state.get_state_planes()) - # this pipe waits for the predictions of the network inference service - [_, _] = my_pipe.recv() else: # get the value from the leaf node (the current function is called recursively) - value, depth, mv_list = self._run_single_playout(state, node, pipe_id, depth+1, mv_list) + value, depth, chosen_nodes = self._run_single_playout(state, node, pipe_id, depth + 1, chosen_nodes) # revert the virtual loss and apply the predicted value by the network to the node parent_node.revert_virtual_loss_and_update(child_idx, self.virtual_loss, -value) # we invert the value prediction for the parent of the above node layer because the player's turn is flipped every turn - return -value, depth, mv_list + return -value, depth, chosen_nodes + + def check_for_duplicate(self, transposition_key, chosen_nodes): + + node = self.root_node.child_nodes[chosen_nodes[0]] + + # iterate over all accessed nodes during the current search of the thread and check for same transposition key + for node_idx in chosen_nodes[1:-1]: + if node.transposition_key == transposition_key: + #print('DUPLICATE CHECK = TRUE! ') + return True + node = node.child_nodes[node_idx] + if node is None: + break + + return False + + + def can_claim_threefold_repetition(self, transposition_key, chosen_nodes): + """ + Checks if a three fold repetition event can be claimed in the current search path. + This method makes use of the class transposition table and checks for board occurrences in the local search path + of the current thread as well. + + :param transposition_key: Transposition key which defines the board state by all it's pieces and pocket state. + The move counter is disregarded. + :param chosen_nodes: List of integer indices which correspond to the child node indices chosen from the + root node downwards. + :return: True, if threefold repetition can be claimed, else False + """ + + # set the number of occurrences by default to 0 + search_occurrence_counter = 0 + + node = self.root_node.child_nodes[chosen_nodes[0]] + + # iterate over all accessed nodes during the current search of the thread and check for same transposition key + for node_idx in chosen_nodes[1:-1]: + if node.transposition_key == transposition_key: + search_occurrence_counter += 1 + node = node.child_nodes[node_idx] + if node is None: + break + + # use all occurrences in the class transposition table as well as the locally found equalities + return (self.transposition_table[transposition_key] + search_occurrence_counter) >= 2 def _select_node(self, parent_node: Node): """ @@ -649,9 +933,24 @@ def _select_node(self, parent_node: Node): else: # find the move according to the q- and u-values for each move + if self.use_oscillating_cpuct is False: + pb_c_base = 19652 + pb_c_init = self.cpuct + + cpuct = math.log((parent_node.n_sum + pb_c_base + 1) / + pb_c_base) + pb_c_init + else: + cpuct = self.cpuct + # calculate the current u values # it's not worth to save the u values as a node attribute because u is updated every time n_sum changes - u = self.cpuct * parent_node.p * (np.sqrt(parent_node.n_sum) / (1 + parent_node.n)) + u = cpuct * parent_node.p * (np.sqrt(parent_node.n_sum) / (1 + parent_node.n)) + + #if depth == 1 and self.hard_clipping is True and self.use_pruning is True: # and id <= (self.threads//2+1): #and id % 2 != 0: + # u[parent_node.thresh_idcs_root] = -9999 #1 + #if self.use_pruning is True and depth >= 2: # and depth >= 2: and id % 3 != 0: + # u[parent_node.thresh_idcs] = -9999 + child_idx = (parent_node.q + u).argmax() node = parent_node.child_nodes[child_idx] @@ -660,19 +959,20 @@ def _select_node(self, parent_node: Node): return node, move, child_idx - def _select_node_based_on_mcts_policy(self, parent_node: Node): + def _select_node_based_on_mcts_policy(self, parent_node: Node, is_root=False): """ Selects the next node based on the mcts policy which is used to predict the final best move. :param parent_node: Node from which to select the next child. :return: """ - child_idx = parent_node.get_mcts_policy(self.q_value_weight).argmax() + + child_idx = parent_node.get_mcts_policy(self.q_value_weight, is_root=is_root).argmax() nb_visits = parent_node.n[child_idx] move = parent_node.legal_moves[child_idx] - return parent_node.child_nodes[child_idx], move, nb_visits + return parent_node.child_nodes[child_idx], move, nb_visits, child_idx def show_next_pred_line(self): best_moves = [] @@ -685,6 +985,54 @@ def show_next_pred_line(self): best_moves.append(move) return best_moves + def get_2nd_max(self): + n_child = self.root_node.n.argmax() + n_max = self.root_node.n[n_child] + self.root_node.n[n_child] = 0 + + second_max = self.root_node.n.max() + self.root_node.n[n_child] = n_max + + return second_max + + def get_xth_max(self, xth_node): + if len(self.root_node.n) < xth_node: + return self.root_node.n.min() + else: + return np.sort(self.root_node.n)[-xth_node] + + def get_last_q_values(self, second_max=0, clip_fac=0.25): + """ + Returns the values of the last node in the caluclated lines according to the mcts search for the most + visited nodes + :return: + """ + + q_future = np.zeros(self.root_node.nb_direct_child_nodes) + + indices = [] + for i in range(self.root_node.nb_direct_child_nodes): + if self.root_node.n[i] >= self.root_node.n.max() * 0.33: #i == second_max: # #second_max: + node = self.root_node.child_nodes[i] + print(self.root_node.legal_moves[i].uci(), end=' ') + turn = 1 + final_node = node + move = self.root_node.legal_moves[i] + + while node is not None and node.is_leaf is False and node.n_sum > 3: + final_node = node + print(move.uci() + ' ', end='') + node, move, _, _ = self._select_node_based_on_mcts_policy(node) + turn *= -1 + + if final_node is not None: + q_future[i] = final_node.v + indices.append(i) + q_future[i] *= turn + print(q_future[i]) + + return q_future, indices + def get_calculated_line(self): """ Prints out the best search line estimated for both players on the given board state. @@ -698,10 +1046,12 @@ def get_calculated_line(self): lst_nb_visits = [] # start at the root node node = self.root_node + is_root = True while node is not None and node.is_leaf is False: # go deep through the tree by always selecting the best move for both players - node, move, nb_visits = self._select_node_based_on_mcts_policy(node) + node, move, nb_visits, _ = self._select_node_based_on_mcts_policy(node, is_root) + is_root = False lst_best_moves.append(move) lst_nb_visits.append(nb_visits) return lst_best_moves, lst_nb_visits @@ -712,11 +1062,28 @@ def _mv_list_to_str(self, lst_moves): :param lst_moves: List chess.Moves objects :return: String representing each move in the list """ - str_moves = "" - for mv in lst_moves: + str_moves = lst_moves[0].uci() + + for mv in lst_moves[1:]: str_moves += " " + mv.uci() + return str_moves + def _create_mv_list(self, lst_chosen_nodes: [int]): + """ + Creates a movement list given the child node indices from the root node onwards. + :param lst_chosen_nodes: List of chosen nodes + :return: mv_list - List of python chess moves + """ + mv_list = [] + node = self.root_node + + for child_idx in lst_chosen_nodes: + mv = node.legal_moves[child_idx] + node = node.child_nodes[child_idx] + mv_list.append(mv) + return mv_list + def update_movetime(self, time_ms_per_move): """ Update move time allocation. @@ -724,3 +1091,20 @@ def update_movetime(self, time_ms_per_move): :return: """ self.movetime_ms = time_ms_per_move + + def set_max_search_depth(self, max_search_depth: int): + """ + Assigns a new maximum search depth for the next search + :param max_search_depth: Specifier of the search depth + :return: + """ + self.max_search_depth = max_search_depth + + def update_tranposition_table(self, transposition_key): + """ + + :param transposition_key: (gamestate.get_transposition_key(),) + :return: + """ + + self.transposition_table.update(transposition_key) diff --git a/DeepCrazyhouse/src/domain/agent/player/RawNetAgent.py b/DeepCrazyhouse/src/domain/agent/player/RawNetAgent.py old mode 100644 new mode 100755 index 4df18b25..b5e585b0 --- a/DeepCrazyhouse/src/domain/agent/player/RawNetAgent.py +++ b/DeepCrazyhouse/src/domain/agent/player/RawNetAgent.py @@ -13,32 +13,38 @@ from DeepCrazyhouse.src.domain.agent.NeuralNetAPI import NeuralNetAPI from DeepCrazyhouse.src.domain.crazyhouse.output_representation import get_probs_of_move_list, value_to_centipawn from time import time - +import sys class RawNetAgent(_Agent): - def __init__(self, net: NeuralNetAPI, temperature=0., clip_quantil=0., verbose=True): - super().__init__(temperature, clip_quantil, verbose) + def __init__(self, net: NeuralNetAPI, temperature=0., temperature_moves=4, verbose=True): + super().__init__(temperature, temperature_moves, verbose) self._net = net - def evaluate_board_state(self, state: _GameState, verbose=True): + def evaluate_board_state(self, state: _GameState): """ :param state: :return: """ + t_start_eval = time() pred_value, pred_policy = self._net.predict_single(state.get_state_planes()) legal_moves = list(state.get_legal_moves()) + p_vec_small = get_probs_of_move_list(pred_policy, legal_moves, state.is_white_to_move()) - if verbose is True: - # use the move with the highest probability as the best move for logging - instinct_move = legal_moves[p_vec_small.argmax()] + # use the move with the highest probability as the best move for logging + instinct_move = legal_moves[p_vec_small.argmax()] - # show the best calculated line - print('info score cp %d depth %d nodes %d time %d pv %s' % ( - value_to_centipawn(pred_value), 1, 1, (time() - t_start_eval) * 1000, instinct_move.uci())) + # define the remaining return variables + time_e = (time() - t_start_eval) + cp = value_to_centipawn(pred_value) + depth = 1 + nodes = 1 + time_elapsed_s = time_e * 1000 + nps = nodes/time_e + pv = instinct_move.uci() - return pred_value, legal_moves, p_vec_small + return pred_value, legal_moves, p_vec_small, cp, depth, nodes, time_elapsed_s, nps, pv diff --git a/DeepCrazyhouse/src/domain/agent/player/_Agent.py b/DeepCrazyhouse/src/domain/agent/player/_Agent.py old mode 100644 new mode 100755 index ff8674f4..40c9cb90 --- a/DeepCrazyhouse/src/domain/agent/player/_Agent.py +++ b/DeepCrazyhouse/src/domain/agent/player/_Agent.py @@ -17,10 +17,11 @@ class _Agent: The greedy agent always performs the first legal move with the highest move probability """ - def __init__(self, temperature=0., clip_quantil=0., verbose=True): + def __init__(self, temperature=0, temperature_moves=4, verbose=True): self.temperature = temperature - self.p_vec_small = None - self.clip_quantil = clip_quantil + self.temperature_current = temperature + self.temperature_moves = temperature_moves + #self.p_vec_small = None self.verbose = verbose def evaluate_board_state(self, state: _GameState): @@ -29,76 +30,57 @@ def evaluate_board_state(self, state: _GameState): def perform_action(self, state: _GameState): # the first step is to call you policy agent to evaluate the given position - value, legal_moves, self.p_vec_small = self.evaluate_board_state(state) + value, legal_moves, p_vec_small, cp, depth, nodes, time_elapsed_s, nps, pv = self.evaluate_board_state(state) - if len(legal_moves) != len(self.p_vec_small): - raise Exception('Legal move list %s is uncompatible to policy vector %s' % (legal_moves, self.p_vec_small)) + if len(legal_moves) != len(p_vec_small): + raise Exception('Legal move list %s is uncompatible to policy vector %s' % (legal_moves, p_vec_small)) + + if state.get_fullmove_number() <= self.temperature_moves: + self.temperature_current = self.temperature + else: + self.temperature_current = 0 if len(legal_moves) == 1: selected_move = legal_moves[0] confidence = 1. idx = 0 else: - if self.temperature <= 0.01: - idx = self.p_vec_small.argmax() + if self.temperature_current <= 0.01: + idx = p_vec_small.argmax() else: - self._apply_temperature_to_policy() - self._apply_quantil_clipping() - idx = np.random.choice(range(len(legal_moves)), p=self.p_vec_small) + p_vec_small = self._apply_temperature_to_policy(p_vec_small) + idx = np.random.choice(range(len(legal_moves)), p=p_vec_small) selected_move = legal_moves[idx] - confidence = self.p_vec_small[idx] + confidence = p_vec_small[idx] - return value, selected_move, confidence, idx - - def _apply_quantil_clipping(self): - """ + if value > 0: + # check for draw and decline if value is greater 0 + state_future = deepcopy(state) + state_future.apply_move(selected_move) + if state_future.get_pythonchess_board().can_claim_threefold_repetition() is True: + p_vec_small[idx] = 0 + idx = p_vec_small.argmax() + selected_move = legal_moves[idx] + confidence = p_vec_small[idx] - :param p_vec_small: - :param clip_quantil: - :return: - """ + return value, selected_move, confidence, idx, cp, depth, nodes, time_elapsed_s, nps, pv - if self.clip_quantil > 0: - # remove the lower percentage values in order to avoid strange blunders for moves with low confidence - p_vec_small_clipped = deepcopy(self.p_vec_small) - - # get the sorted indices in ascending order - idx_order = np.argsort(self.p_vec_small) - # create a quantil tank which measures how much quantil power is left - quantil_tank = self.clip_quantil - - # iterate over the indices (ascending) and apply the quantil clipping to it - for idx in idx_order: - if quantil_tank >= p_vec_small_clipped[idx]: - # remove the prob from the quantil tank - quantil_tank -= p_vec_small_clipped[idx] - # clip the index to 0 - p_vec_small_clipped[idx] = 0 - else: - # the target prob is greate than the current quantil tank - p_vec_small_clipped[idx] -= quantil_tank - # stop the for loop - break - - # renormalize the policy - p_vec_small_clipped /= p_vec_small_clipped.sum() - - # apply the changes - self.p_vec_small = p_vec_small_clipped - - def _apply_temperature_to_policy(self): + def _apply_temperature_to_policy(self, p_vec_small): """ :return: """ # treat very small temperature value as a deterministic policy - if self.temperature <= 0.01: - p_vec_one_hot = np.zeros_like(self.p_vec_small) - p_vec_one_hot[np.argmax(self.p_vec_small)] = 1. - self.p_vec_small = p_vec_one_hot + if self.temperature_current <= 0.01: + p_vec_one_hot = np.zeros_like(p_vec_small) + p_vec_one_hot[np.argmax(p_vec_small)] = 1. + p_vec_small = p_vec_one_hot else: # apply exponential scaling - self.p_vec_small = np.power(self.p_vec_small, 1/self.temperature) + p_vec_small = p_vec_small ** (1/self.temperature_current) # renormalize the values to probabilities again - self.p_vec_small /= self.p_vec_small.sum() + p_vec_small /= p_vec_small.sum() + + return p_vec_small + diff --git a/DeepCrazyhouse/src/domain/agent/player/__init__.py b/DeepCrazyhouse/src/domain/agent/player/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/agent/player/util/NetPredService.py b/DeepCrazyhouse/src/domain/agent/player/util/NetPredService.py old mode 100644 new mode 100755 index 63db223e..dde29ab3 --- a/DeepCrazyhouse/src/domain/agent/player/util/NetPredService.py +++ b/DeepCrazyhouse/src/domain/agent/player/util/NetPredService.py @@ -15,17 +15,25 @@ import numpy as np from DeepCrazyhouse.src.domain.crazyhouse.output_representation import NB_LABELS, LABELS from time import time +import cython class NetPredService: - def __init__(self, pipe_endings: [connection], net: NeuralNetAPI, batch_size, enable_timeout=False): + def __init__(self, pipe_endings: [connection], net: NeuralNetAPI, batch_size, batch_state_planes: np.ndarray, + batch_value_results: np.ndarray, batch_policy_results: np.ndarray): """ :param pipe_endings: List of pip endings which are for communicating with the thread workers. :param net: Neural Network API object which provides the reference for the neural network. :param batch_size: Constant batch_size used for inference. - :param enable_timeout: Decides wether to enable a timout if a batch didn't occur under 1 second. + :param batch_state_planes: Shared numpy memory in which all threads set their state plane request for the + prediction service. Each threads has it's own channel. + :param batch_value_results: Shared numpy memory in which the value results of all threads are stored. + Each threads has it's own channel. + :param batch_policy_results: Shared numpy memory in which the policy results of all threads are stored. + Each threads has it's own channel. + #:param enable_timeout: Decides wether to enable a timout if a batch didn't occur under 1 second. """ self.net = net self.my_pipe_endings = pipe_endings @@ -34,14 +42,23 @@ def __init__(self, pipe_endings: [connection], net: NeuralNetAPI, batch_size, en self.thread_inference = Thread(target=self._provide_inference, args=(pipe_endings,), daemon=True) self.batch_size = batch_size - self.time_start = None - self.timeout_second = 1 - #self.enable_timeout = enable_timeout + self.batch_state_planes = batch_state_planes + self.batch_value_results = batch_value_results + self.batch_policy_results = batch_policy_results + + #@cython.boundscheck(False) + #@cython.wraparound(False) def _provide_inference(self, pipe_endings): print('provide inference...') - #use_random = False + #use_random = True + + #cdef double[:, :, :, ::1] batch_state_planes_view = self.batch_state_planes + #cdef double[::1] batch_value_results_view = self.batch_value_results + #cdef double[:, ::1] batch_policy_results = self.batch_policy_results + + send_batches = False #True while self.running is True: @@ -49,26 +66,48 @@ def _provide_inference(self, pipe_endings): if filled_pipes: - if True or len(filled_pipes) >= self.batch_size: + if True or len(filled_pipes) >= self.batch_size: # 1 + + if send_batches is True: + planes_batch = [] + pipes_pred_output = [] + + for pipe in filled_pipes[:self.batch_size]: + while pipe.poll(): + planes_batch.append(pipe.recv()) + pipes_pred_output.append(pipe) - planes_batch = [] - pipes_pred_output = [] + # logging.debug('planes_batch length: %d %d' % (len(planes_batch), len(filled_pipes))) + state_planes_mxnet = mx.nd.array(planes_batch, ctx=self.net.get_ctx()) + else: + planes_ids = [] + pipes_pred_output = [] - for pipe in filled_pipes[:self.batch_size]: - while pipe.poll(): - planes_batch.append(pipe.recv()) - pipes_pred_output.append(pipe) + for pipe in filled_pipes[:self.batch_size]: + while pipe.poll(): + planes_ids.append(pipe.recv()) + pipes_pred_output.append(pipe) - #logging.debug('planes_batch length: %d %d' % (len(planes_batch), len(filled_pipes))) - planes_batch = mx.nd.array(planes_batch, ctx=self.net.get_ctx()) + #logging.debug('planes_batch length: %d %d' % (len(planes_batch), len(filled_pipes))) + state_planes_mxnet = mx.nd.array(self.batch_state_planes[planes_ids], ctx=self.net.get_ctx()) - #pred = self.net.get_executor().forward(is_train=False, data=planes_batch) - pred = self.net.get_net()(planes_batch) + + #print(len(state_planes_mxnet)) + executor = self.net.executors[len(state_planes_mxnet)-1] + pred = executor.forward(is_train=False, data=state_planes_mxnet) + #pred = self.net.get_net()(state_planes_mxnet) + #print('pred: %.3f' % (time()-t_s)*1000) + #t_s = time() value_preds = pred[0].asnumpy() + # renormalize to [0,1] + #value_preds += 1 + #value_preds /= 2 + # for the policy prediction we still have to apply the softmax activation # because it's not done by the neural net + #policy_preds = pred[1].softmax().asnumpy() policy_preds = pred[1].softmax().asnumpy() #if use_random is True: @@ -77,10 +116,20 @@ def _provide_inference(self, pipe_endings): # send the predictions back to the according workers for i, pipe in enumerate(pipes_pred_output): - pipe.send([value_preds[i], policy_preds[i]]) - # reset the timer - self.time_start = time() + if send_batches is True: + pipe.send([value_preds[i], policy_preds[i]]) + else: + # get the according channel index for setting the result + channel_idx = planes_ids[i] + + # set the value result + self.batch_value_results[channel_idx] = value_preds[i] + self.batch_policy_results[channel_idx] = policy_preds[i] + # give the thread the signal that the result has been set by sending back his channel_idx + pipe.send(channel_idx) + + #print('send back res: %.3f' % (time()-t_s)*1000) def start(self): print('start inference thread...') diff --git a/DeepCrazyhouse/src/domain/agent/player/util/Node.py b/DeepCrazyhouse/src/domain/agent/player/util/Node.py old mode 100644 new mode 100755 index 5b15527f..f9b12e39 --- a/DeepCrazyhouse/src/domain/agent/player/util/Node.py +++ b/DeepCrazyhouse/src/domain/agent/player/util/Node.py @@ -7,17 +7,18 @@ Helper class which stores the statistics of all nodes and in the search tree. """ -from numba import jit from threading import Lock import chess import numpy as np -import logging from copy import deepcopy +from collections import deque +QSIZE = 100 class Node: - def __init__(self, value, p_vec_small: np.ndarray, legal_moves: [chess.Move], str_legal_moves: str, is_leaf=False): + def __init__(self, value, p_vec_small: np.ndarray, legal_moves: [chess.Move], str_legal_moves: str, is_leaf=False, + transposition_key=None, clip_low_visit=True): # lock object for this node to protect its member variables self.lock = Lock() @@ -29,55 +30,54 @@ def __init__(self, value, p_vec_small: np.ndarray, legal_moves: [chess.Move], st self.nb_direct_child_nodes = 0 else: # specify the number of direct child nodes from this node - self.nb_direct_child_nodes = np.array(len(p_vec_small)) #, np.uint32) + self.nb_direct_child_nodes = np.array(len(p_vec_small)) # prior probability selecting each child, which is estimated by the neural network - self.p = p_vec_small #np.zeros(self.nb_direct_child_nodes, np.float32) + self.p = p_vec_small # possible legal moves from this node on which represents the edges self.legal_moves = legal_moves # stores the number of all direct children and all grand children which have already been expanded - self.nb_total_expanded_child_nodes = np.array(0) #, np.uint32) + self.nb_total_expanded_child_nodes = np.array(0) # visit count of all its child nodes - self.n = np.zeros(self.nb_direct_child_nodes) #, np.int32) + self.n = np.zeros(self.nb_direct_child_nodes) # total action value estimated by MCTS for each child node - self.w = np.zeros(self.nb_direct_child_nodes) #, np.float32) + self.w = np.zeros(self.nb_direct_child_nodes) + #self.w = np.ones(self.nb_direct_child_nodes) * -0.01 #1 + # q: combined action value which is calculated by the averaging over all action values # u: exploration metric for each child node # (the q and u values are stacked into 1 list in order to speed-up the argmax() operation - self.q = np.zeros(self.nb_direct_child_nodes) #, np.float32) - #self.q_u = np.stack((q, u)) + #self.q = np.zeros(self.nb_direct_child_nodes) + self.q = np.ones(self.nb_direct_child_nodes) * -1 - #np.concatenate((q, u)) + if is_leaf is False: + if clip_low_visit is True: + self.q[p_vec_small < 1e-3] = -9999 + #else: + # self.thresh_idcs_root = p_vec_small < 5e-2 # number of total visits to this node # we initialize with 1 because if the node was created it must have been visited - self.n_sum = np.array(1) #, #np.int32) + self.n_sum = 1 # check if there's a possible mate on the board if yes create a quick link to the mate move mate_mv_idx_str = str_legal_moves.find('#') - #logging.debug('legal_moves: %s' % str(str_legal_moves)) - #logging.debug('mate_mv_idx_str: %d' % mate_mv_idx_str) if mate_mv_idx_str != -1: # -1 means that no mate move has been found # find the according index of the move in the legal_moves generator list # here we count the ',' which represent the move index mate_mv_idx = str_legal_moves[:mate_mv_idx_str].count(',') # quick reference path to a child node which leads to mate - self.mate_child_idx = mate_mv_idx #legal_moves[mate_mv_idx] - # overwrite the number of direct child nodes to 1 - #self.nb_direct_child_nodes = np.array(1) #, np.uint32) - #logging.debug('set mate in one connection') + self.mate_child_idx = mate_mv_idx else: # no direct mate move is possible so set the reference to None self.mate_child_idx = None # stores the number of all possible expandable child nodes - self.nb_expandable_child_nodes = np.array(self.nb_direct_child_nodes) #, np.uint32) - - #assert self.nb_direct_child_nodes > 0 + # self.nb_expandable_child_nodes = np.array(self.nb_direct_child_nodes) # list of all child nodes which are described by each board position # the children are ordered in the same way as the legal_move generator output @@ -86,18 +86,14 @@ def __init__(self, value, p_vec_small: np.ndarray, legal_moves: [chess.Move], st # determine if the node is a leaf node this avoids checking for state.is_draw() or .state.is_won() self.is_leaf = is_leaf - ''' TODO: Delete - def update_u_for_child(self, child_idx, cpuct): - """ - Updates the u parameter via the formula given in the AlphaZero paper for a given child index - :param child_idx: Child index to update - :param cpuct: cpuct constant to apply (cpuct manages the amount of exploration) - :return: - """ - self.q_u[child_idx] = cpuct * self.p[child_idx] * (np.sqrt(self.n_sum) / (1 + self.n[child_idx])) - ''' + # store a unique identifier for the board state excluding the move counter for this node + self.transposition_key = transposition_key + + #self.replay_buffer = deque([0] * 512) + #self.q_freash = np.zeros(self.nb_direct_child_nodes) + #self.w_freash = np.zeros(self.nb_direct_child_nodes) - def get_mcts_policy(self, q_value_weight=.65): + def get_mcts_policy(self, q_value_weight=.65, clip_low_visit_nodes=True, is_root=False, xth_n_max=0): """ Calculates the finetuned policies based on the MCTS search. These policies should be better than the initial policy predicted by a the raw network. @@ -109,31 +105,54 @@ def get_mcts_policy(self, q_value_weight=.65): :return: Pruned policy vector based on the MCTS search """ - assert 0 <= q_value_weight <= 1. - clip_low_visit_nodes = True + if clip_low_visit_nodes is True and q_value_weight > 0: - if clip_low_visit_nodes is True: + #q_value_weight -= 1e-4 * self.n_sum + #q_value_weight = max(q_value_weight, 0.01) visit = deepcopy(self.n) value = deepcopy((self.q + 1)) + #values_confident = self.p[] if visit.max() > 0: - visit = self.n / self.n.sum() + max_visits = visit.max() + #if is_root is True: + # if self.n_sum > 2000 and self.thresh_idcs_root.max() == 1: + # if visit[self.thresh_idcs_root].max() < visit[np.invert(self.thresh_idcs_root)].max() * 1.5: + # visit[self.thresh_idcs_root] = 0 + # #print('clipped nodes') + # mask out nodes that haven't been visited much - thresh_idces = visit < max_visits * 0.5 #0.33 #0.5 #.33 + #thresh_idces = visit < max(max_visits * 0.33, xth_n_max) + thresh_idces = visit < max_visits * 0.33 #, xth_n_max) + # normalize to sum of 1 - value /= value.sum() value[thresh_idces] = 0 + #visit[thresh_idces] = 0 + + # renormalize ot 1 + visit /= visit.sum() + + #value *= self.p + value /= value.sum() + + # use prior policy + #init_p = deepcopy(self.p) + #init_p[value < value.max() * 0.2] = 0 + #visit += self.p policy = ((1-q_value_weight) * visit + q_value_weight * value) + #if is_root is True: + # indices = (self.q < self.v) * self.thresh_idcs_root + # policy[indices] = 0 + return policy / sum(policy) else: return visit - elif q_value_weight > 0: # disable the q values if there's at least one child which wasn't explored if None in self.child_nodes: @@ -145,9 +164,9 @@ def get_mcts_policy(self, q_value_weight=.65): return policy else: if max(self.n) == 1: - policy = (self.n + 0.05 * self.p)#/ self.n_sum + policy = (self.n + 0.05 * self.p) else: - policy = (self.n - 0.05 * self.p) #/ self.n_sum + policy = (self.n - 0.05 * self.p) return policy / sum(policy) @@ -161,8 +180,9 @@ def apply_dirichlet_noise_to_prior_policy(self, epsilon=0.25, alpha=0.15): """ if self.is_leaf is False: - dirichlet_noise = np.random.dirichlet([alpha] * self.nb_direct_child_nodes) - self.p = (1 - epsilon) * self.p + epsilon * dirichlet_noise + with self.lock: + dirichlet_noise = np.random.dirichlet([alpha] * self.nb_direct_child_nodes) + self.p = (1 - epsilon) * self.p + epsilon * dirichlet_noise def apply_virtual_loss_to_child(self, child_idx, virtual_loss): @@ -177,16 +197,31 @@ def apply_virtual_loss_to_child(self, child_idx, virtual_loss): # make it look like if one has lost X games from this node forward where X is the virtual loss value self.w[child_idx] -= virtual_loss self.q[child_idx] = self.w[child_idx] / self.n[child_idx] - #parent_node.update_u_for_child(child_idx, self.cpuct) + + # use queue + #self.q[child_idx] = self.w[child_idx] / min(self.n[child_idx], QSIZE) def revert_virtual_loss_and_update(self, child_idx, virtual_loss, value): # revert the virtual loss effect and apply the backpropagated value of its child node with self.lock: + self.n_sum -= virtual_loss - 1 + #factor = max(self.n[child_idx] // 1000, 1) + #fac = (self.n[child_idx]+1) ** 0.2 + self.n[child_idx] -= virtual_loss - 1 + self.w[child_idx] += virtual_loss + value - self.q[child_idx] = self.w[child_idx] / self.n[child_idx] - #parent_node.update_u_for_child(child_idx, self.cpuct) - self.nb_total_expanded_child_nodes += 1 - self.nb_expandable_child_nodes += self.nb_direct_child_nodes + self.q[child_idx] = self.w[child_idx] / self.n[child_idx] + + + #self.nb_total_expanded_child_nodes += 1 + #self.nb_expandable_child_nodes += self.nb_direct_child_nodes + + #last_value = self.child_nodes[child_idx].replay_buffer.popleft() + #self.child_nodes[child_idx].replay_buffer.append(value) + + # use queue + #self.w_freash[child_idx] += virtual_loss + value - last_value + #self.q_freash[child_idx] = self.w[child_idx] / QSIZE # min(self.n[child_idx], QSIZE) diff --git a/DeepCrazyhouse/src/domain/agent/player/util/__init__.py b/DeepCrazyhouse/src/domain/agent/player/util/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/crazyhouse/GameState.py b/DeepCrazyhouse/src/domain/crazyhouse/GameState.py old mode 100644 new mode 100755 index 5052d991..3867c58d --- a/DeepCrazyhouse/src/domain/crazyhouse/GameState.py +++ b/DeepCrazyhouse/src/domain/crazyhouse/GameState.py @@ -3,6 +3,8 @@ from DeepCrazyhouse.src.domain.crazyhouse.input_representation import board_to_planes from DeepCrazyhouse.src.domain.abstract_cls._GameState import _GameState import numpy as np +import collections + class GameState(_GameState): @@ -12,12 +14,12 @@ def __init__(self, board=CrazyhouseBoard()): self._fen_dic = {} self._board_occ = 0 - def apply_move(self, move: chess.Move, remember_state=False): + def apply_move(self, move: chess.Move): #, remember_state=False): # apply the move on the board self.board.push(move) - if remember_state is True: - self._remember_board_state() + #if remember_state is True: + # self._remember_board_state() def get_state_planes(self): return board_to_planes(self.board, board_occ=self._board_occ, normalize=True) @@ -28,8 +30,15 @@ def get_pythonchess_board(self): def is_draw(self): # check if you can claim a draw - its assumed that the draw is always claimed - return self.board.can_claim_fifty_moves() #can_claim_draw() - #return self.board.can_claim_threefold_repetition() + return self.can_claim_threefold_repetition() or self.board.can_claim_fifty_moves() + #return self.board.can_claim_draw() + + def can_claim_threefold_repetition(self): + """ + Custom implementation for threefold-repetition check which uses the board_occ variable. + :return: True if claim is legal else False + """ + return self._board_occ >= 2 def is_won(self): # only a is_won() and no is_lost() function is needed because the game is over @@ -37,7 +46,11 @@ def is_won(self): return self.board.is_checkmate() def get_legal_moves(self): - return self.board.legal_moves + #return list(self.board.legal_moves) + legal_moves = [] + for mv in self.board.generate_legal_moves(): + legal_moves.append(mv) + return legal_moves def is_white_to_move(self): return self.board.turn @@ -52,19 +65,19 @@ def new_game(self): def set_fen(self, fen, remember_state=True): self.board.set_fen(fen) - if remember_state is True: - self._remember_board_state() - - def _remember_board_state(self): - fen = self.board.board_fen() - if fen in self._fen_dic: - # create a new entry in the dictionary if there exists one - self._fen_dic[fen] += 1 - else: - # create a new entry in the dictionary if there exists one - self._fen_dic[fen] = 1 - # receive the number of occurrence given the fen list - self._board_occ = self._fen_dic[fen] - 1 + #if remember_state is True: + # self._remember_board_state() + + #def _remember_board_state(self): + # calculate the transposition key + # transposition_key = self.get_transposition_key() + # update the number of board occurrences + #self._board_occ = self._transposition_table[transposition_key] + # increase the counter for this transposition key + # self._transposition_table.update((transposition_key,)) + + def is_check(self): + return self.board.is_check() def are_pocket_empty(self): """ diff --git a/DeepCrazyhouse/src/domain/crazyhouse/__init__.py b/DeepCrazyhouse/src/domain/crazyhouse/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/crazyhouse/constants.py b/DeepCrazyhouse/src/domain/crazyhouse/constants.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/crazyhouse/input_representation.py b/DeepCrazyhouse/src/domain/crazyhouse/input_representation.py old mode 100644 new mode 100755 index 4a677374..48f8cd6a --- a/DeepCrazyhouse/src/domain/crazyhouse/input_representation.py +++ b/DeepCrazyhouse/src/domain/crazyhouse/input_representation.py @@ -69,6 +69,8 @@ def board_to_planes(board, board_occ=0, normalize=True): :return: planes - the plane representation of the current board state """ + # TODO: Remove board.mirror() for black by addressing the according color channel + # (I) Define the Input Representation for one position planes_pos = np.zeros((NB_CHANNELS_POS, BOARD_HEIGHT, BOARD_WIDTH)) planes_const = np.zeros((NB_CHANNELS_CONST, BOARD_HEIGHT, BOARD_WIDTH)) @@ -82,14 +84,20 @@ def board_to_planes(board, board_occ=0, normalize=True): board = board.mirror() # Fill in the piece positions - board_piece_map = board.piece_map() - for pos in board_piece_map: - p_char = str(board_piece_map[pos]) - channel = P_MAP[p_char] - row, col = get_row_col(pos) - # set the bit at the right position - planes_pos[channel, row, col] = 1 + # Iterate over both color starting with WHITE + for z, color in enumerate(chess.COLORS): + # the PIECE_TYPE is an integer list in python-chess + for piece_type in chess.PIECE_TYPES: + # define the channel by the piecetype (the input representation uses the same ordering as python-chess) + # we add an offset for the black pieces + # note that we subtract 1 because in python chess the PAWN has index 1 and not 0 + channel = (piece_type-1) + z * len(chess.PIECE_TYPES) + # iterate over the piece mask and receive every position square of it + for pos in board.pieces(piece_type, color): + row, col = get_row_col(pos) + # set the bit at the right position + planes_pos[channel, row, col] = 1 # (II) Fill in the Repetition Data # a game to test out if everything is working correctly is: https://lichess.org/jkItXBWy#73 @@ -114,22 +122,15 @@ def board_to_planes(board, board_occ=0, normalize=True): planes_pos[ch + 5, :, :] = board.pockets[chess.BLACK].count(p_type) # (III) Fill in the promoted pieces + # iterate over all promoted pieces according to the mask and set the according bit + ch = CHANNEL_MAPPING_POS['promo'] + for pos in chess.SquareSet(board.promoted): + row, col = get_row_col(pos) - bb_pos = chess.BB_A1 - - # iterate over all board field and check if there is a positive result for the binary & operation - board_promoted = board.promoted - for pos in range(0, 64): - if board_promoted & bb_pos > 0: - ch = CHANNEL_MAPPING_POS['promo'] - row, col = get_row_col(pos) - - if board.piece_at(pos).color == chess.WHITE: - planes_pos[ch, row, col] = 1 - else: - planes_pos[ch + 1, row, col] = 1 - # for each new square the value is doubled - bb_pos *= 2 + if board.piece_at(pos).color == chess.WHITE: + planes_pos[ch, row, col] = 1 + else: + planes_pos[ch + 1, row, col] = 1 # (III.2) En Passant Square # mark the square where an en-passant capture is possible @@ -189,7 +190,8 @@ def board_to_planes(board, board_occ=0, normalize=True): board = board.mirror() if normalize is True: - planes = normalize_input_planes(planes) + planes *= MATRIX_NORMALIZER + #planes = normalize_input_planes(planes) # return the plane representation of the given board return planes diff --git a/DeepCrazyhouse/src/domain/crazyhouse/output_representation.py b/DeepCrazyhouse/src/domain/crazyhouse/output_representation.py old mode 100644 new mode 100755 index 47f9e3d2..8a9535b8 --- a/DeepCrazyhouse/src/domain/crazyhouse/output_representation.py +++ b/DeepCrazyhouse/src/domain/crazyhouse/output_representation.py @@ -208,7 +208,7 @@ def value_to_centipawn(value): :return: """ - if np.absolute(value) == 1.: + if np.absolute(value) >= 1.: # return a constant if the given value is 1 (otherwise log will result in infinity) return np.sign(value) * 9999 else: diff --git a/DeepCrazyhouse/src/domain/neural_net/__init__.py b/DeepCrazyhouse/src/domain/neural_net/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/neural_net/architectures/AlphaZeroResnet.py b/DeepCrazyhouse/src/domain/neural_net/architectures/AlphaZeroResnet.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/neural_net/architectures/Rise.py b/DeepCrazyhouse/src/domain/neural_net/architectures/Rise.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/neural_net/architectures/__init__.py b/DeepCrazyhouse/src/domain/neural_net/architectures/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/neural_net/architectures/builder_util.py b/DeepCrazyhouse/src/domain/neural_net/architectures/builder_util.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/neural_net/architectures/rise_builder_util.py b/DeepCrazyhouse/src/domain/neural_net/architectures/rise_builder_util.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/domain/util.py b/DeepCrazyhouse/src/domain/util.py old mode 100644 new mode 100755 index 713a9d88..e14e944f --- a/DeepCrazyhouse/src/domain/util.py +++ b/DeepCrazyhouse/src/domain/util.py @@ -147,6 +147,11 @@ def normalize_input_planes(x): return x +# use a constant matrix for normalization to allow broad cast operations +MATRIX_NORMALIZER = np.ones((NB_CHANNELS_FULL, BOARD_HEIGHT, BOARD_WIDTH)) +MATRIX_NORMALIZER = normalize_input_planes(MATRIX_NORMALIZER) + + def unnormalize_input_planes(x): """ Reverts normalization back to integer values. Works in place. diff --git a/DeepCrazyhouse/src/runtime/Colorer.py b/DeepCrazyhouse/src/runtime/ColorLogger.py old mode 100644 new mode 100755 similarity index 78% rename from DeepCrazyhouse/src/runtime/Colorer.py rename to DeepCrazyhouse/src/runtime/ColorLogger.py index 17e2cddb..83114adb --- a/DeepCrazyhouse/src/runtime/Colorer.py +++ b/DeepCrazyhouse/src/runtime/ColorLogger.py @@ -1,8 +1,5 @@ """ @file: Colorer.py -Created on 10.06.18 -@project: DeepCrazyhouse -@author: queensgambit Script which allows are colored logging output multiplattform. The script is based on this post and was slightly adjusted: @@ -127,26 +124,27 @@ def new(*args): return new -if platform.system() == 'Windows': - # Windows does not support ANSI escapes and we are using API calls to set the console color - logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) -else: - # all non-Windows platforms are supporting ANSI escapes so we use them - logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit) +def enable_color_logging(debug_lvl=logging.DEBUG): + if platform.system() == 'Windows': + # Windows does not support ANSI escapes and we are using API calls to set the console color + logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) + else: + # all non-Windows platforms are supporting ANSI escapes so we use them + logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit) -root = logging.getLogger() -root.setLevel(logging.DEBUG) + root = logging.getLogger() + root.setLevel(debug_lvl) -ch = logging.StreamHandler(sys.stdout) -ch.setLevel(logging.DEBUG) -# FORMAT = '[%(asctime)-s][%(name)-s][\033[1m%(levelname)-7s\033[0m] %(message)-s' -# FORMAT='%(asctime)s %(name)-12s %(levelname)-8s %(message)s' + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(debug_lvl) + # FORMAT = '[%(asctime)-s][%(name)-s][\033[1m%(levelname)-7s\033[0m] %(message)-s' + # FORMAT='%(asctime)s %(name)-12s %(levelname)-8s %(message)s' -# FORMAT from https://github.com/xolox/python-coloredlogs -FORMAT = '%(asctime)s %(name)s[%(process)d] \033[1m%(levelname)s\033[0m %(message)s' + # FORMAT from https://github.com/xolox/python-coloredlogs + FORMAT = '%(asctime)s %(name)s[%(process)d] \033[1m%(levelname)s\033[0m %(message)s' -# FORMAT="%(asctime)s %(name)-12s %(levelname)-8s %(message)s" -formatter = logging.Formatter(FORMAT, "%Y-%m-%d %H:%M:%S") + # FORMAT="%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + formatter = logging.Formatter(FORMAT, "%Y-%m-%d %H:%M:%S") -ch.setFormatter(formatter) -root.addHandler(ch) + ch.setFormatter(formatter) + root.addHandler(ch) diff --git a/DeepCrazyhouse/src/runtime/__init__.py b/DeepCrazyhouse/src/runtime/__init__.py old mode 100644 new mode 100755 diff --git a/DeepCrazyhouse/src/samples/MCTS_eval_demo.ipynb b/DeepCrazyhouse/src/samples/MCTS_eval_demo.ipynb index 9d10d1f0..11dc5769 100644 --- a/DeepCrazyhouse/src/samples/MCTS_eval_demo.ipynb +++ b/DeepCrazyhouse/src/samples/MCTS_eval_demo.ipynb @@ -38,7 +38,7 @@ "import numpy as np\n", "import sys\n", "sys.path.insert(0,'../../../')\n", - "import DeepCrazyhouse.src.runtime.Colorer\n", + "from DeepCrazyhouse.src.runtime.ColorLogger import enable_color_logging\n", "from DeepCrazyhouse.src.domain.agent.NeuralNetAPI import NeuralNetAPI\n", "from DeepCrazyhouse.src.domain.agent.player.MCTSAgent import MCTSAgent\n", "from DeepCrazyhouse.src.domain.agent.player.RawNetAgent import RawNetAgent\n", @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "net = NeuralNetAPI(ctx='gpu', batch_size=batch_size)" + "net = NeuralNetAPI(ctx='cpu', batch_size=batch_size)" ] }, { @@ -81,10 +81,10 @@ "metadata": {}, "outputs": [], "source": [ - "mcts_agent = MCTSAgent(net, threads=16, playouts_empty_pockets=4096*5, playouts_filled_pockets=4096*5,\n", - " cpuct=3, dirichlet_epsilon=.1, dirichlet_alpha=0.2, batch_size=batch_size, q_value_weight=0,#.5, #99,\n", - " max_search_depth=40, temperature=0., clip_quantil=0., virtual_loss=3, verbose=False,\n", - " min_movetime=20000, check_mate_in_one=False, enable_timeout=False)" + "mcts_agent = MCTSAgent([net], threads=16, playouts_empty_pockets=4096*5, playouts_filled_pockets=4096*5,\n", + " cpuct=1, dirichlet_epsilon=.1, dirichlet_alpha=0.2, batch_size=batch_size, q_value_weight=0,#.5, #99,\n", + " max_search_depth=40, temperature=0., virtual_loss=3, verbose=True,\n", + " min_movetime=20000, check_mate_in_one=False)" ] }, { @@ -110,7 +110,7 @@ "#fen = 'rn2N2k/pp5p/3pp1pN/3p4/5P2/3P1p2/PP3RPP/RN4K1/QQprbbpbb b - - 1 30'\n", "\n", "#fen = '3R1b2/1bP1kp2/3Npn1p/3p4/5p2/5N1b/PPP1QP1P/3R1RK1/QPpprnpbp b - - 0 29'\n", - "#fen = 'rn2N2k/pp5p/3pp1pN/3p4/3q1P2/3P1p2/PP3PPP/RN3RK1[Qrbbpbb] b - - 3 30'\n", + "fen = 'rn2N2k/pp5p/3pp1pN/3p4/3q1P2/3P1p2/PP3PPP/RN3RK1[Qrbbpbb] b - - 3 30'\n", "#fen = 'q6r/p2P1pkp/1p1b1n2/2p2B2/8/6n1/PPP2KPp/R4R2/PNNRPBPbqpp w - - 2 26'\n", "#fen ='2kr1b2/1bp2p1p/p3pP1p/1p5Q/5B2/3B1p2/PPP2PrP/R4R1K/QNpnnnp w - - 0 18'\n", "#fen = 'q6r/p2P1pkp/1p1b1n2/2p2B2/8/6n1/PPP2KPp/R4R2/PNNRPBPbqpp w - - 50 26'\n", @@ -155,6 +155,8 @@ "#fen = 'r1bk3r/ppppbpQp/4p3/8/4n3/4P2N/PPPP2PP/R1Bq1BKR/PNNp b - - 2 13'\n", "#fen = 'r1bqkbnr/ppp2ppp/3p4/8/3QP3/2N4p/PPP2PPP/R1B1KB1R/PNn w KQkq - 1 7'\n", "#fen = 'r1bq1rk1/ppp2pp1/2np1n1p/2b1p1B1/2B1P3/2NP1N2/PPP2PPP/R2Q1RK1/ w - - 14 8'\n", + "fen = 'rnb1kb1r/ppp1pppp/5n2/q7/8/2N2N2/PPPP1PPP/R1BQKB1R/Pp w KQkq - 0 5'\n", + "fen = 'rnb2b1r/ppp1pkpp/5n2/2q5/3N2p1/2N5/PPPP1PPP/R1BQK2R/PPb w KQ - 0 8'\n", "board.set_fen(fen)\n", "#board = board.mirror()\n", "\n", @@ -229,7 +231,7 @@ "outputs": [], "source": [ "t_s = time()\n", - "value, legal_moves, p_vec_small = raw_agent.evaluate_board_state(state)\n", + "pred_value, legal_moves, p_vec_small, cp, depth, nodes, time_elapsed_s, nps, pv = raw_agent.evaluate_board_state(state)\n", "print('Elapsed time: %.4fs' % (time()-t_s))" ] }, @@ -256,7 +258,7 @@ "outputs": [], "source": [ "t_s = time()\n", - "value, legal_moves, p_vec_small = mcts_agent.evaluate_board_state(state)\n", + "pred_value, legal_moves, p_vec_small, cp, depth, nodes, time_elapsed_s, nps, pv = mcts_agent.evaluate_board_state(state)\n", "print('Elapsed time: %.4fs' % (time()-t_s))" ] }, diff --git a/crazyara.py b/crazyara.py old mode 100644 new mode 100755 index d8c04477..c1b9ef79 --- a/crazyara.py +++ b/crazyara.py @@ -14,20 +14,37 @@ import chess.pgn import traceback - +import collections +import numpy as np +# import the Colorer to have a nicer logging printout +from DeepCrazyhouse.src.runtime.ColorLogger import enable_color_logging +enable_color_logging() # Constants MIN_SEARCH_TIME_MS = 100 +MAX_SEARCH_TIME_MS = 10e10 INC_FACTOR = 7 INC_DIV = 8 MIN_MOVES_LEFT = 10 MAX_BAD_POS_VALUE = -0.10 # When pos eval [-1.0 to 1.0] is equal or worst than this then extend time MOVES_LEFT_INCREMENT = 10 # Used to reduce the movetime in the opening +# this is the assumed "maximum" blitz game length for calculating a constant movetime +# after 80% of this game length a new time management starts which is based on movetime left +BLITZ_GAME_LENGTH = 50 +# use less time in the opening defined by "max_move_num_to_reduce_movetime" by using a portion of the constant move time +MV_TIME_OPENING_PORTION = 0.7 +# this variable is intended to increase the variance in the moves played by using a small different amount of time each +# move +RANDOM_MV_TIME_PORTION = 0.1 + +# enable this variable if you want to see debug messages in certain environments, like the lichess.org api +ENABLE_LICHESS_DEBUG_MSG = True + client = { 'name': 'CrazyAra', - 'version': '0.2.0', - 'authors': 'Johannes Czech, Moritz Willig, Alena Beyer et al.' + 'version': '0.3.1', + 'authors': 'Johannes Czech, Moritz Willig, Alena Beyer' } @@ -49,12 +66,13 @@ " (__.'/ /` .'` 1 ////__////__////__////__/ \n" \ " (_.'/ /` /` a b c d e f g h \n" \ " _|.' /` \n" \ - "jgs.-` __.'| Developers: Johannes Czech, Moritz Willig, Alena Beyer et al. \n" \ + "jgs.-` __.'| Developers: Johannes Czech, Moritz Willig, Alena Beyer \n" \ " .-'|| | Source-Code: QueensGambit/CrazyAra (GPLv3-License) \n" \ " \_`/ Inspiration: A0-paper by Silver, Hubert, Schrittwieser et al. \n" \ " ASCII-Art: Joan G. Stark, Chappell, Burton \n" log_file_path = "CrazyAra-log.txt" +score_file_path = "score-log.txt" try: log_file = open(log_file_path, 'w') @@ -66,21 +84,40 @@ print(traceback_text) +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + +def print_if_debug(string): + if ENABLE_LICHESS_DEBUG_MSG is True: + eprint("[debug] " + string) + + def log_print(text: str): print(text) + print_if_debug(text) if log_file: log_file.write("< %s\n" % text) log_file.flush() +def write_score_to_file(score: str): + #score_file = open(score_file_path, 'w') + + with open(score_file_path, 'w') as f: + f.seek(0) + f.write(score) + f.truncate() + + def log(text: str): if log_file: log_file.write("> %s\n" % text) log_file.flush() -print(INTRO_PART1, end="") -print(INTRO_PART2, end="") +eprint(INTRO_PART1, end="") +eprint(INTRO_PART2, end="") # GLOBAL VARIABLES mcts_agent = None @@ -88,7 +125,9 @@ def log(text: str): gamestate = None setup_done = False bestmove_value = None +constant_move_time = None engine_played_move = 0 +score = None # SETTINGS s = { @@ -99,13 +138,16 @@ def log(text: str): "use_raw_network": False, "threads": 16, "batch_size": 8, + "neural_net_services": 2, "playouts_empty_pockets": 8192, "playouts_filled_pockets": 8192, - "centi_cpuct": 300, - "centi_dirichlet_epsilon": 10, + "centi_cpuct": 250, + "centi_dirichlet_epsilon": 25, "centi_dirichlet_alpha": 20, "max_search_depth": 40, - "centi_temperature": 0, + "centi_temperature": 7, + "temperature_moves": 4, + "opening_guard_moves": 7, "centi_clip_quantil": 0, "virtual_loss": 3, "centi_q_value_weight": 70, @@ -113,9 +155,11 @@ def log(text: str): "move_overhead_ms": 300, "moves_left": 40, "extend_time_on_bad_position": True, - "max_move_num_to_reduce_movetime": 0, + "max_move_num_to_reduce_movetime": 4, "check_mate_in_one": False, - "enable_timeout": False, + "use_pruning": True, + "use_oscillating_cpuct": False, + "use_time_management": True, "verbose": False } @@ -142,41 +186,58 @@ def setup_network(): # check for valid parameter setup and do auto-corrections if possible param_validity_check() - net = NeuralNetAPI(ctx=s['context'], batch_size=s['batch_size']) - rawnet_agent = RawNetAgent(net, temperature=s['centi_temperature'], clip_quantil=s['centi_clip_quantil']) + nets = [] + for i in range(s['neural_net_services']): + nets.append(NeuralNetAPI(ctx=s['context'], batch_size=s['batch_size'])) - mcts_agent = MCTSAgent(net, cpuct=s['centi_cpuct'] / 100, playouts_empty_pockets=s['playouts_empty_pockets'], + rawnet_agent = RawNetAgent(nets[0], temperature=s['centi_temperature'] / 100, temperature_moves=s['temperature_moves']) + + mcts_agent = MCTSAgent(nets, cpuct=s['centi_cpuct'] / 100, playouts_empty_pockets=s['playouts_empty_pockets'], playouts_filled_pockets=s['playouts_filled_pockets'], max_search_depth=s['max_search_depth'], dirichlet_alpha=s['centi_dirichlet_alpha'] / 100, q_value_weight=s['centi_q_value_weight'] / 100, dirichlet_epsilon=s['centi_dirichlet_epsilon'] / 100, virtual_loss=s['virtual_loss'], - threads=s['threads'], temperature=s['centi_temperature'] / 100, verbose=s['verbose'], - clip_quantil=s['centi_clip_quantil'] / 100, min_movetime=MIN_SEARCH_TIME_MS, + threads=s['threads'], temperature=s['centi_temperature'] / 100, + temperature_moves=s['temperature_moves'], verbose=s['verbose'], + min_movetime=MIN_SEARCH_TIME_MS, batch_size=s['batch_size'], check_mate_in_one=s['check_mate_in_one'], - enable_timeout=s['enable_timeout']) + use_pruning=s['use_pruning'], use_oscillating_cpuct=s['use_oscillating_cpuct'], + use_time_management=s['use_time_management'], opening_guard_moves=s['opening_guard_moves']) gamestate = GameState() setup_done = True -def param_validity_check(): +def validity_with_threads(optname: str): """ - Handles some possible issues when giving an illegal batch_size and number of threads combination. + Checks for consistency with the number of threads with the given parameter + :param optname: Option name :return: """ - if s['batch_size'] > s['threads']: + + if s[optname] > s['threads']: log_print('info string The given batch_size %d is higher than the number of threads %d. ' 'The maximum legal batch_size is the same as the number of threads (here: %d) ' - % (s['batch_size'], s['threads'], s['threads'])) - s['batch_size'] = s['threads'] - log_print('info string The batch_size was reduced to %d' % s['batch_size']) + % (s[optname], s['threads'], s['threads'])) + s[optname] = s['threads'] + log_print('info string The batch_size was reduced to %d' % s[optname]) - if s['threads'] % s['batch_size'] != 0: + if s['threads'] % s[optname] != 0: log_print('info string You requested an illegal combination of threads %d and batch_size %d.' - ' The batch_size must be a divisor of the number of threads' % (s['threads'], s['batch_size'])) - divisor = s['threads'] // s['batch_size'] - s['batch_size'] = s['threads'] // divisor - log_print('info string The batch_size was changed to %d' % s['batch_size']) + ' The batch_size must be a divisor of the number of threads' % (s['threads'], s[optname])) + divisor = s['threads'] // s[optname] + s[optname] = s['threads'] // divisor + log_print('info string The batch_size was changed to %d' % s[optname]) + + +def param_validity_check(): + """ + Handles some possible issues when giving an illegal batch_size and number of threads combination. + :return: + """ + + validity_with_threads('batch_size') + validity_with_threads('neural_net_services') def perform_action(cmd_list): @@ -191,6 +252,8 @@ def perform_action(cmd_list): global rawnet_agent global bestmove_value global engine_played_move + global constant_move_time + global score movetime_ms = MIN_SEARCH_TIME_MS tc_type = None @@ -214,6 +277,9 @@ def perform_action(cmd_list): my_time = btime my_inc = binc + if constant_move_time is None: + constant_move_time = (my_time + BLITZ_GAME_LENGTH * my_inc) / BLITZ_GAME_LENGTH + # TC with period (traditional) like 40/60 or 40 moves in 60 sec repeating if 'movestogo' in cmd_list: tc_type = 'traditional' @@ -229,7 +295,14 @@ def perform_action(cmd_list): moves_left = s['moves_left'] moves_left = adjust_moves_left(moves_left, tc_type, bestmove_value) - movetime_ms = max(my_time/moves_left + INC_FACTOR*my_inc//INC_DIV - s['move_overhead_ms'], MIN_SEARCH_TIME_MS) + if tc_type == 'blitz' and engine_played_move < BLITZ_GAME_LENGTH * .8: + movetime_ms = constant_move_time + (np.random.rand()-0.5) * RANDOM_MV_TIME_PORTION * constant_move_time + + if engine_played_move < s['max_move_num_to_reduce_movetime']: + # avoid spending too much time in the opening + movetime_ms *= MV_TIME_OPENING_PORTION + else: + movetime_ms = max(my_time/moves_left + INC_FACTOR*my_inc//INC_DIV - s['move_overhead_ms'], MIN_SEARCH_TIME_MS) # movetime in UCI protocol, go movetime x, search exactly x mseconds # UCI protocol: http://wbec-ridderkerk.nl/html/UCIProtocol.html @@ -238,17 +311,53 @@ def perform_action(cmd_list): mcts_agent.update_movetime(movetime_ms) log_print('info string Time for this move is %dms' % movetime_ms) + log_print('info string Requested pos: %s' % gamestate) + + # assign search depth + try: + # we try to extract the search depth from the cmd list + depth_idx = cmd_list.index("depth") + 1 + mcts_agent.set_max_search_depth(int(cmd_list[depth_idx])) + # increase the movetime to maximum to make sure to reach the given depth + movetime_ms = MAX_SEARCH_TIME_MS + mcts_agent.update_movetime(movetime_ms) + except ValueError: + # the given command wasn't found in the command list + pass + + # disable noise for short move times + if movetime_ms < 1000: + mcts_agent.dirichlet_epsilon = 0.1 + elif movetime_ms < 7000: + # reduce noise for very short move times + mcts_agent.dirichlet_epsilon = .2 if s['use_raw_network'] or movetime_ms <= s['threshold_time_for_raw_net_ms']: log_print('info string Using raw network for fast mode...') - value, selected_move, confidence, _ = rawnet_agent.perform_action(gamestate) + value, selected_move, confidence, _, cp, depth, nodes, time_elapsed_s, nps, pv = rawnet_agent.perform_action(gamestate) else: - value, selected_move, confidence, _ = mcts_agent.perform_action(gamestate) + value, selected_move, confidence, _, cp, depth, nodes, time_elapsed_s, nps, pv = mcts_agent.perform_action(gamestate) + + score = "score cp %d depth %d nodes %d time %d nps %d pv %s" % (cp, depth, nodes, time_elapsed_s, nps, pv) + if ENABLE_LICHESS_DEBUG_MSG: + try: + write_score_to_file(score) + except Exception: + pass + # print out the search information + log_print('info %s' % score) # Save the bestmove value [-1.0 to 1.0] to modify the next movetime bestmove_value = float(value) engine_played_move += 1 + # apply CrazyAra's selected move the global gamestate + if gamestate.get_pythonchess_board().is_legal(selected_move): + # apply the last move CrazyAra played + _apply_move(selected_move) + else: + raise Exception('all_ok is false! - crazyara_last_move') + log_print('bestmove %s' % selected_move.uci()) @@ -259,28 +368,83 @@ def setup_gamestate(cmd_list): :param cmd_list: Input-command lists arguments :return: """ - #artificial_max_game_len = 30 + + global gamestate + global mcts_agent position_type = cmd_list[1] - if position_type == "startpos": - gamestate.new_game() - else: - fen = " ".join(cmd_list[2:8]) - gamestate.set_fen(fen) if 'moves' in cmd_list: + # position startpos moves e2e4 g8f6 if position_type == 'startpos': mv_list = cmd_list[3:] else: # position fen rn2N2k/pp5p/3pp1pN/3p4/3q1P2/3P1p2/PP3PPP/RN3RK1/Qrbbpbb b - - 3 27 moves d4f2 f1f2 mv_list = cmd_list[9:] - for move in mv_list: - gamestate.apply_move(chess.Move.from_uci(move)) - #if len(mv_list)//2 > artificial_max_game_len: - # log_print('info string Setting fullmove_number to %d' % artificial_max_game_len) - # gamestate.get_pythonchess_board().fullmove_number = artificial_max_game_len + # try to apply opponent last move to the board state + + if len(mv_list) > 0: + # the move the opponent just played is the last move in the list + opponent_last_move = chess.Move.from_uci(mv_list[-1]) + if gamestate.get_pythonchess_board().is_legal(opponent_last_move): + # apply the last move the opponent played + _apply_move(opponent_last_move) + mv_compatible = True + else: + log_print('info string all_ok is false! - opponent_last_move %s' % opponent_last_move) + mv_compatible = False + else: + mv_compatible = False + + if not mv_compatible: + log_print("info string The given last two moves couldn't be applied to the previous board-state.") + log_print("info string Rebuilding the game from scratch...") + + # create a new game state from scratch + if position_type == "startpos": + new_game() + else: + fen = " ".join(cmd_list[2:8]) + gamestate.set_fen(fen) + + for move in mv_list: + _apply_move(chess.Move.from_uci(move)) + else: + log_print("info string Move Compatible") + else: + if position_type == 'fen': + fen = " ".join(cmd_list[2:8]) + gamestate.set_fen(fen) + mcts_agent.update_tranposition_table((gamestate.get_transposition_key(),)) + #log_print("info string Added %s - count %d" % (gamestate.get_board_fen(), + # mcts_agent.transposition_table[gamestate.get_transposition_key()])) + + +def _apply_move(selected_move : chess.Move): + """ + Applies the given move on the gamestate and updates the transposition table of the environment + :param selected_move: Move in python chess format + :return: + """ + global gamestate + global mcts_agent + + gamestate.apply_move(selected_move) + mcts_agent.update_tranposition_table((gamestate.get_transposition_key(),)) + # log_print("info string Added %s - count %d" % (gamestate.get_board_fen(), + # mcts_agent.transposition_table[ + # gamestate.get_transposition_key()])) + +def new_game(): + global gamestate + global mcts_agent + log_print("info string >> New Game") + gamestate.new_game() + mcts_agent.transposition_table = collections.Counter() + mcts_agent.time_buffer_ms = 0 + mcts_agent.dirichlet_epsilon = s['centi_dirichlet_epsilon'] / 100 def set_options(cmd_list): """ @@ -292,41 +456,49 @@ def set_options(cmd_list): # SETTINGS global s - if cmd_list[1] != 'name' or cmd_list[3] != 'value': - log_print("info string The given setoption command wasn't understood") - log_print('info string An example call could be: "setoption name threads value 4"') - else: - option_name = cmd_list[2] + # make sure there exists enough items in the given command list like "setoption name nb_threads value 1" + if len(cmd_list) >= 5: + if cmd_list[1] != 'name' or cmd_list[3] != 'value': + log_print("info string The given setoption command wasn't understood") + log_print('info string An example call could be: "setoption name threads value 4"') + else: + option_name = cmd_list[2] - if option_name not in s: - raise Exception("The given option %s wasn't found in the settings list" % option_name) + if option_name not in s: + log_print("info string The given option %s wasn't found in the settings list" % option_name) + else: - if option_name in ['UCI_Variant', 'context', 'use_raw_network', - 'extend_time_on_bad_position', 'verbose', 'check_mate_in_one', 'enable_timeout']: + if option_name in ['UCI_Variant', 'context', 'use_raw_network', + 'extend_time_on_bad_position', 'verbose', 'check_mate_in_one', 'use_pruning', + 'use_oscillating_cpuct', 'use_time_management']: - value = cmd_list[4] - else: - value = int(cmd_list[4]) - - if option_name == 'use_raw_network': - s['use_raw_network'] = True if value == 'true' else False - elif option_name == 'extend_time_on_bad_position': - s['extend_time_on_bad_position'] = True if value == 'true' else False - elif option_name == 'verbose': - s['verbose'] = True if value == 'true' else False - elif option_name == 'check_mate_in_one': - s['check_mate_in_one'] = True if value == 'true' else False - elif option_name == 'enable_timeout': - s['enable_timeout'] = True if value == 'true' else False - else: - # by default all options are treated as integers - s[option_name] = value + value = cmd_list[4] + else: + value = int(cmd_list[4]) + + if option_name == 'use_raw_network': + s['use_raw_network'] = True if value == 'true' else False + elif option_name == 'extend_time_on_bad_position': + s['extend_time_on_bad_position'] = True if value == 'true' else False + elif option_name == 'verbose': + s['verbose'] = True if value == 'true' else False + elif option_name == 'check_mate_in_one': + s['check_mate_in_one'] = True if value == 'true' else False + elif option_name == 'use_pruning': + s['use_pruning'] = True if value == 'true' else False + elif option_name == 'use_oscillating_cpuct': + s['use_oscillating_cpuct'] = True if value == 'true' else False + elif option_name == 'use_time_management': + s['use_time_management'] = True if value == 'true' else False + else: + # by default all options are treated as integers + s[option_name] = value - # Guard threads limits - if option_name == 'threads': - s[option_name] = min(4096, max(1, s[option_name])) + # Guard threads limits + if option_name == 'threads': + s[option_name] = min(4096, max(1, s[option_name])) - log_print('info string Updated option %s to %s' % (option_name, value)) + log_print('info string Updated option %s to %s' % (option_name, value)) def adjust_moves_left(moves_left, tc_type, prev_bm_value): @@ -367,20 +539,23 @@ def uci_reply(): log_print('id author %s' % client['authors']) # tell the GUI all possible options log_print('option name UCI_Variant type combo default crazyhouse var crazyhouse') - log_print('option name context type combo default cpu var cpu var gpu') + log_print('option name context type combo default %s var cpu var gpu' % s['context']) log_print('option name use_raw_network type check default %s' %\ ('false' if not s['use_raw_network'] else 'true')) log_print('option name threads type spin default %d min 1 max 4096' % s['threads']) - log_print('option name batch_size type spin default %d min 1 max 4096' % s['batch_size']) + log_print('option name batch_size type spin default %d min 1 max 4096' % s['batch_size']) + log_print('option name neural_net_services type spin default %d min 1 max 10' % s['neural_net_services']) log_print('option name playouts_empty_pockets type spin default %d min 56 max 8192' %\ s['playouts_empty_pockets']) log_print('option name playouts_filled_pockets type spin default %d min 56 max 8192' %\ s['playouts_filled_pockets']) log_print('option name centi_cpuct type spin default %d min 1 max 500' % s['centi_cpuct']) - log_print('option name centi_dirichlet_epsilon type spin default 10 min 0 max 100') - log_print('option name centi_dirichlet_alpha type spin default 20 min 0 max 100') - log_print('option name max_search_depth type spin default 40 min 1 max 100') - log_print('option name centi_temperature type spin default 0 min 0 max 100') + log_print('option name centi_dirichlet_epsilon type spin default %d min 0 max 100' % s['centi_dirichlet_epsilon']) + log_print('option name centi_dirichlet_alpha type spin default %d min 0 max 100' % s['centi_dirichlet_alpha']) + log_print('option name max_search_depth type spin default %d min 1 max 100' % s['max_search_depth']) + log_print('option name centi_temperature type spin default %d min 0 max 100' % s['centi_temperature']) + log_print('option name temperature_moves type spin default %d min 0 max 99999' % s['temperature_moves']) + log_print('option name opening_guard_moves type spin default %d min 0 max 99999' % s['opening_guard_moves']) log_print('option name centi_clip_quantil type spin default 0 min 0 max 100') log_print('option name virtual_loss type spin default 3 min 0 max 10') log_print('option name centi_q_value_weight type spin default %d min 0 max 100' % s['centi_q_value_weight']) @@ -394,8 +569,12 @@ def uci_reply(): s['max_move_num_to_reduce_movetime']) log_print('option name check_mate_in_one type check default %s' %\ ('false' if not s['check_mate_in_one'] else 'true')) - log_print('option name enable_timeout type check default %s' %\ - ('false' if not s['check_mate_in_one'] else 'true')) + log_print('option name use_pruning type check default %s' %\ + ('false' if not s['use_pruning'] else 'true')) + log_print('option name use_oscillating_cpuct type check default %s' %\ + ('false' if not s['use_oscillating_cpuct'] else 'true')) + log_print('option name use_time_management type check default %s' %\ + ('false' if not s['use_time_management'] else 'true')) log_print('option name verbose type check default %s' %\ ('false' if not s['verbose'] else 'true')) @@ -411,6 +590,8 @@ def main(): while True: line = input() + print_if_debug("waiting ...") + print_if_debug(line) # wait for an std-in input command if line: @@ -431,6 +612,7 @@ def main(): elif main_cmd == 'ucinewgame': bestmove_value = None engine_played_move = 0 + new_game() elif main_cmd == "position": setup_gamestate(cmd_list) elif main_cmd == "setoption": diff --git a/setup.py b/setup.py deleted file mode 100644 index 7ce3dc89..00000000 --- a/setup.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -@file: setup.py -Created on 02.11.18 -@project: CrazyAra -@author: queensgambit - -Setup scripting for creating Cython binaries -""" - -from distutils.core import setup -from Cython.Build import cythonize - -setup( - ext_modules = cythonize(["DeepCrazyhouse/src/domain/agent/player/MCTSAgent.pyx", - "DeepCrazyhouse/src/domain/agent/player/Node.pyx"], annotate=True) -)