Skip to content

Commit

Permalink
analysis of training data, proper eval for only 1 legal move, enabled…
Browse files Browse the repository at this point in the history
… go search depth X
  • Loading branch information
QueensGambit committed Dec 1, 2018
1 parent 043d5fa commit a33c53e
Show file tree
Hide file tree
Showing 11 changed files with 559 additions and 23 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,8 @@ main_config.py
# avoid pushing log-files generated by uci-communication
CrazyAra-log.txt
score-log.txt

# avoid pushing dataset files used for visualization
crazyara_lichess_dataset.pgn
crazyara_lichess_dataset_stats.csv

18 changes: 16 additions & 2 deletions DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,10 @@ def _expand_root_node_single_move(self, state, legal_moves):
:return:
"""

# set value 0 as a dummy value
value = 0
# request the value prediction for the current position
state_planes = state.get_state_planes()
[value, _] = self.nets[0].predict_single(state_planes)
# we can create the move probability vector without the NN this time
p_vec_small = np.array([1], np.float32)

# create a new root node
Expand Down Expand Up @@ -382,6 +384,10 @@ def _expand_root_node_single_move(self, state, legal_moves):
# connect the child to the root
self.root_node.child_nodes[0] = child_node

# assign the value of the root node as the q-value for the child
# here we must invert the invert the value because it's the value prediction of the next state
self.root_node.q[0] = -value

def _run_mcts_search(self, state):
"""
Runs a new or continues the mcts on the current search tree.
Expand Down Expand Up @@ -785,3 +791,11 @@ def update_movetime(self, time_ms_per_move):
:return:
"""
self.movetime_ms = time_ms_per_move

def set_max_search_depth(self, max_search_depth: int):
"""
Assigns a new maximum search depth for the next search
:param max_search_depth: Specifier of the search depth
:return:
"""
self.max_search_depth = max_search_depth
31 changes: 31 additions & 0 deletions DeepCrazyhouse/src/preprocessing/PGN2PlanesConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,37 @@ def _filter_pgn_thread(self, queue, pgn):
queue.put(batch_black_won)
queue.put(batch_draw)

def filter_all_pgns(self):
"""
Filters out all games based on the given conditions in the constructor and returns all games in
:return: lst_all_pgn_sel: List of selected games in String-IO format
lst_nb_games_sel: Number of selected games for each pgn file
lst_batch_white_won: Number of white wins in each pgn file
lst_black_won: Number of black wins in each pgn file
lst_draw_won: Number of draws in each pgn file
"""

total_games_exported = 0

lst_all_pgn_sel = []
lst_nb_games_sel = []
lst_batch_white_won = []
lst_batch_black_won = []
lst_batch_draw = []

pgns = os.listdir(self._import_dir)
for pgn_name in pgns:
self._pgn_name = pgn_name
all_pgn_sel, nb_games_sel, batch_white_won, batch_black_won, batch_draw = self.filter_pgn()
lst_all_pgn_sel.append(all_pgn_sel)
lst_nb_games_sel.append(nb_games_sel)
lst_batch_white_won.append(batch_white_won)
lst_batch_black_won.append(batch_black_won)
lst_batch_draw.append(batch_draw)

return lst_all_pgn_sel, lst_nb_games_sel, lst_batch_white_won, lst_batch_black_won, lst_batch_draw


def convert_all_pgns_to_planes(self):
"""
Master function which calls convert_pgn_to_planes() for all available pgns in the import directory
Expand Down
Loading

0 comments on commit a33c53e

Please sign in to comment.