From dbcaf664f01f3b7f41a8c115b95685a01b5f2b32 Mon Sep 17 00:00:00 2001 From: Disservin Date: Fri, 15 Mar 2024 17:09:15 +0100 Subject: [PATCH 1/3] Add formatting utilities to Makefile --- .clang-format | 44 ++++++++++++++++++++++++++++++++++++++++++++ Makefile | 6 ++++++ 2 files changed, 50 insertions(+) create mode 100644 .clang-format create mode 100644 Makefile diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..e2f0581d --- /dev/null +++ b/.clang-format @@ -0,0 +1,44 @@ +AccessModifierOffset: -1 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: Consecutive +AlignConsecutiveDeclarations: Consecutive +AlignEscapedNewlines: DontAlign +AlignOperands: AlignAfterOperator +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortCaseLabelsOnASingleLine: false +AllowShortEnumsOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false +AlwaysBreakTemplateDeclarations: Yes +BasedOnStyle: WebKit +BitFieldColonSpacing: After +BinPackParameters: false +BreakBeforeBinaryOperators: NonAssignment +BreakBeforeBraces: Custom +BraceWrapping: + AfterFunction: false + AfterClass: false + AfterControlStatement: true + BeforeElse: true +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: AfterColon +BreakStringLiterals: false +ColumnLimit: 100 +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: true +IndentGotoLabels: false +IndentPPDirectives: BeforeHash +IndentWidth: 4 +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: None +PackConstructorInitializers: Never +ReflowComments: false +SortIncludes: false +SortUsingDeclarations: false +SpaceAfterCStyleCast: true +SpaceAfterTemplateKeyword: false +SpaceBeforeCaseColon: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeInheritanceColon: false +SpaceInEmptyBlock: false +SpacesBeforeTrailingComments: 2 diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..60626998 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +SRCS = training_data_loader.cpp +HEADERS = lib/nnue_training_data_formats.h lib/nnue_training_data_stream.h lib/rng.h + +format: + black . + clang-format -i $(SRCS) $(HEADERS) -style=file \ No newline at end of file From d6a9f582208fe818f8b6c51e636f487cd89494e0 Mon Sep 17 00:00:00 2001 From: Disservin Date: Fri, 15 Mar 2024 17:09:36 +0100 Subject: [PATCH 2/3] Format Code --- cross_check_eval.py | 152 +- delete_bad_nets.py | 89 +- do_plots.py | 164 +- feature_block.py | 22 +- feature_set.py | 42 +- feature_transformer.py | 344 +- features.py | 27 +- ftperm.py | 267 +- halfka.py | 115 +- halfka_v2.py | 127 +- halfka_v2_hm.py | 196 +- halfkp.py | 132 +- lib/nnue_training_data_formats.h | 12272 ++++++++++++++--------------- lib/nnue_training_data_stream.h | 383 +- lib/rng.h | 12 +- model.py | 748 +- nnue_dataset.py | 408 +- perf_sigmoid_fitter.py | 87 +- ranger.py | 141 +- run_games.py | 322 +- scripts/easy_train.py | 1586 ++-- serialize.py | 747 +- train.py | 482 +- training_data_loader.cpp | 900 +-- visualize.py | 549 +- visualize_multi_hist.py | 130 +- 26 files changed, 10706 insertions(+), 9738 deletions(-) diff --git a/cross_check_eval.py b/cross_check_eval.py index ff858738..5d3f3d29 100644 --- a/cross_check_eval.py +++ b/cross_check_eval.py @@ -7,24 +7,53 @@ import chess from model import NNUE + def read_model(nnue_path, feature_set): - with open(nnue_path, 'rb') as f: + with open(nnue_path, "rb") as f: reader = serialize.NNUEReader(f, feature_set) return reader.model + def make_fen_batch_provider(data_path, batch_size): return nnue_dataset.FenBatchProvider(data_path, True, 1, batch_size, False, 10) -def eval_model_batch(model, batch): - us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch.contents.get_tensors('cuda') - evals = [v.item() for v in model.forward(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) * 600.0] +def eval_model_batch(model, batch): + ( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + outcome, + score, + psqt_indices, + layer_stack_indices, + ) = batch.contents.get_tensors("cuda") + + evals = [ + v.item() + for v in model.forward( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + psqt_indices, + layer_stack_indices, + ) + * 600.0 + ] for i in range(len(evals)): if them[i] > 0.5: evals[i] = -evals[i] return evals -re_nnue_eval = re.compile(r'NNUE evaluation:?\s*?([-+]?\d*?\.\d*)') + +re_nnue_eval = re.compile(r"NNUE evaluation:?\s*?([-+]?\d*?\.\d*)") + def compute_basic_eval_stats(evals): min_engine_eval = min(evals) @@ -34,39 +63,78 @@ def compute_basic_eval_stats(evals): return min_engine_eval, max_engine_eval, avg_engine_eval, avg_abs_engine_eval + def compute_correlation(engine_evals, model_evals): if len(engine_evals) != len(model_evals): - raise Exception("number of engine evals doesn't match the number of model evals") - - min_engine_eval, max_engine_eval, avg_engine_eval, avg_abs_engine_eval = compute_basic_eval_stats(engine_evals) - min_model_eval, max_model_eval, avg_model_eval, avg_abs_model_eval = compute_basic_eval_stats(model_evals) - - print('Min engine/model eval: {} / {}'.format(min_engine_eval, min_model_eval)) - print('Max engine/model eval: {} / {}'.format(max_engine_eval, max_model_eval)) - print('Avg engine/model eval: {} / {}'.format(avg_engine_eval, avg_model_eval)) - print('Avg abs engine/model eval: {} / {}'.format(avg_abs_engine_eval, avg_abs_model_eval)) - - relative_model_error = sum(abs(model - engine) / (abs(engine)+0.001) for model, engine in zip(model_evals, engine_evals)) / len(engine_evals) - relative_engine_error = sum(abs(model - engine) / (abs(model)+0.001) for model, engine in zip(model_evals, engine_evals)) / len(engine_evals) - min_diff = min(abs(model - engine) for model, engine in zip(model_evals, engine_evals)) - max_diff = max(abs(model - engine) for model, engine in zip(model_evals, engine_evals)) - print('Relative engine error: {}'.format(relative_engine_error)) - print('Relative model error: {}'.format(relative_model_error)) - print('Avg abs difference: {}'.format(sum(abs(model - engine) for model, engine in zip(model_evals, engine_evals)) / len(engine_evals))) - print('Min difference: {}'.format(min_diff)) - print('Max difference: {}'.format(max_diff)) + raise Exception( + "number of engine evals doesn't match the number of model evals" + ) + + ( + min_engine_eval, + max_engine_eval, + avg_engine_eval, + avg_abs_engine_eval, + ) = compute_basic_eval_stats(engine_evals) + ( + min_model_eval, + max_model_eval, + avg_model_eval, + avg_abs_model_eval, + ) = compute_basic_eval_stats(model_evals) + + print("Min engine/model eval: {} / {}".format(min_engine_eval, min_model_eval)) + print("Max engine/model eval: {} / {}".format(max_engine_eval, max_model_eval)) + print("Avg engine/model eval: {} / {}".format(avg_engine_eval, avg_model_eval)) + print( + "Avg abs engine/model eval: {} / {}".format( + avg_abs_engine_eval, avg_abs_model_eval + ) + ) + + relative_model_error = sum( + abs(model - engine) / (abs(engine) + 0.001) + for model, engine in zip(model_evals, engine_evals) + ) / len(engine_evals) + relative_engine_error = sum( + abs(model - engine) / (abs(model) + 0.001) + for model, engine in zip(model_evals, engine_evals) + ) / len(engine_evals) + min_diff = min( + abs(model - engine) for model, engine in zip(model_evals, engine_evals) + ) + max_diff = max( + abs(model - engine) for model, engine in zip(model_evals, engine_evals) + ) + print("Relative engine error: {}".format(relative_engine_error)) + print("Relative model error: {}".format(relative_model_error)) + print( + "Avg abs difference: {}".format( + sum(abs(model - engine) for model, engine in zip(model_evals, engine_evals)) + / len(engine_evals) + ) + ) + print("Min difference: {}".format(min_diff)) + print("Max difference: {}".format(max_diff)) + def eval_engine_batch(engine_path, net_path, fens): - engine = subprocess.Popen([engine_path], stdin=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True) - parts = ['uci', 'setoption name EvalFile value {}'.format(net_path)] + engine = subprocess.Popen( + [engine_path], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + parts = ["uci", "setoption name EvalFile value {}".format(net_path)] for fen in fens: - parts.append('position fen {}'.format(fen)) - parts.append('eval') - parts.append('quit') - query = '\n'.join(parts) + parts.append("position fen {}".format(fen)) + parts.append("eval") + parts.append("quit") + query = "\n".join(parts) out = engine.communicate(input=query)[0] evals = re.findall(re_nnue_eval, out) - return [int(float(v)*208) for v in evals] + return [int(float(v) * 208) for v in evals] + def filter_fens(fens): # We don't want fens where a king is in check, as these cannot be evaluated by the engine. @@ -77,13 +145,20 @@ def filter_fens(fens): filtered_fens.append(fen) return filtered_fens + def main(): parser = argparse.ArgumentParser(description="") parser.add_argument("--net", type=str, help="path to a .nnue net") parser.add_argument("--engine", type=str, help="path to stockfish") parser.add_argument("--data", type=str, help="path to a .bin or .binpack dataset") - parser.add_argument("--checkpoint", type=str, help="Optional checkpoint (used instead of nnue for local eval)") - parser.add_argument("--count", type=int, default=100, help="number of datapoints to process") + parser.add_argument( + "--checkpoint", + type=str, + help="Optional checkpoint (used instead of nnue for local eval)", + ) + parser.add_argument( + "--count", type=int, default=100, help="number of datapoints to process" + ) features.add_argparse_args(parser) args = parser.parse_args() @@ -102,20 +177,23 @@ def main(): engine_evals = [] done = 0 - print('Processed {} positions.'.format(done)) + print("Processed {} positions.".format(done)) while done < args.count: fens = filter_fens(next(fen_batch_provider)) - b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens)) + b = nnue_dataset.make_sparse_batch_from_fens( + feature_set, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens) + ) model_evals += eval_model_batch(model, b) nnue_dataset.destroy_sparse_batch(b) engine_evals += eval_engine_batch(args.engine, args.net, fens) done += len(fens) - print('Processed {} positions.'.format(done)) + print("Processed {} positions.".format(done)) compute_correlation(engine_evals, model_evals) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/delete_bad_nets.py b/delete_bad_nets.py index 54ac2d0f..d4b00535 100644 --- a/delete_bad_nets.py +++ b/delete_bad_nets.py @@ -3,13 +3,14 @@ import os import itertools + def parse_ordo(ordo_filename): ordo_scores = [] - with open(ordo_filename, 'r') as ordo_file: + with open(ordo_filename, "r") as ordo_file: lines = ordo_file.readlines() for line in lines: - if 'nn-epoch' in line: + if "nn-epoch" in line: fields = line.split() net = fields[1] rating = float(fields[3]) @@ -18,8 +19,9 @@ def parse_ordo(ordo_filename): return ordo_scores + def find_ckpt_files(root_dir): - p = re.compile('.*\\.ckpt') + p = re.compile(".*\\.ckpt") ckpt_files = [] for path, subdirs, files in os.walk(root_dir, followlinks=False): for filename in files: @@ -28,8 +30,9 @@ def find_ckpt_files(root_dir): ckpt_files.append(os.path.join(path, filename)) return ckpt_files + def find_nnue_files(root_dir): - p = re.compile('.*\\.nnue') + p = re.compile(".*\\.nnue") nnue_files = [] for path, subdirs, files in os.walk(root_dir, followlinks=False): for filename in files: @@ -38,15 +41,18 @@ def find_nnue_files(root_dir): nnue_files.append(os.path.join(path, filename)) return nnue_files + def get_net_dir(net_path): return os.path.dirname(net_path) + def split_nets_by_strength(nets, split_point=16): nets.sort(key=lambda x: -x[1]) - best_nets = nets[:min(split_point, len(nets))] - worst_nets = nets[min(split_point, len(nets)):] + best_nets = nets[: min(split_point, len(nets))] + worst_nets = nets[min(split_point, len(nets)) :] return best_nets, worst_nets + def get_nets_by_directory(best_nets, worst_nets, num_best_to_keep=16): binned_best_nets = dict() binned_worst_nets = dict() @@ -68,76 +74,84 @@ def get_nets_by_directory(best_nets, worst_nets, num_best_to_keep=16): return binned_best_nets, binned_worst_nets + def delete_bad_nets(root_dir, num_best_to_keep=16): net_epoch_p = re.compile(".*epoch([0-9]*)\\.nnue") ckpt_epoch_p = re.compile(".*epoch=([0-9]*).*\\.ckpt") ordo_filename = os.path.join(root_dir, "ordo.out") if not os.path.exists(ordo_filename): - print('No ordo file found. Exiting.') + print("No ordo file found. Exiting.") return else: nets = parse_ordo(ordo_filename) best_nets, worst_nets = split_nets_by_strength(nets, num_best_to_keep) - best_nets_by_dir, worst_nets_by_dir = get_nets_by_directory(best_nets, worst_nets, num_best_to_keep) + best_nets_by_dir, worst_nets_by_dir = get_nets_by_directory( + best_nets, worst_nets, num_best_to_keep + ) for basedir, worst_nets_in_dir in worst_nets_by_dir.items(): ckpt_files = find_ckpt_files(basedir) nnue_files = find_nnue_files(basedir) - worst_epochs = [net_epoch_p.match(net_name)[1] for net_name in worst_nets_in_dir] + worst_epochs = [ + net_epoch_p.match(net_name)[1] for net_name in worst_nets_in_dir + ] for ckpt_file in ckpt_files: try: ckpt_epoch = ckpt_epoch_p.match(ckpt_file)[1] if ckpt_epoch in worst_epochs: - print('Delete {}'.format(ckpt_file)) + print("Delete {}".format(ckpt_file)) os.remove(ckpt_file) except: pass - print('Keep {}'.format(ckpt_file)) + print("Keep {}".format(ckpt_file)) for nnue_file in nnue_files: try: nnue_epoch = net_epoch_p.match(nnue_file)[1] if nnue_epoch in worst_epochs: - print('Delete {}'.format(nnue_file)) + print("Delete {}".format(nnue_file)) os.remove(nnue_file) except: pass - print('Keep {}'.format(nnue_file)) + print("Keep {}".format(nnue_file)) def show_help(): - print('Usage: python delete_bad_nets.py root_dir [num_best_to_keep]') + print("Usage: python delete_bad_nets.py root_dir [num_best_to_keep]") print('root_dir - the directory to "cleanup"') - print('num_best_to_keep - the number of best nets to keep. Default: 16') - print('') - print('It expects to find ordo.out somewhere within root_dir.') - print('If the ordo.out is not found nothing is deleted.') - print('It uses the ratings from the ordo file to determine which nets are best.') - print('The engine names must contain the network name in the') + print("num_best_to_keep - the number of best nets to keep. Default: 16") + print("") + print("It expects to find ordo.out somewhere within root_dir.") + print("If the ordo.out is not found nothing is deleted.") + print("It uses the ratings from the ordo file to determine which nets are best.") + print("The engine names must contain the network name in the") print('following format: "nn-epoch[0-9]*\\.nnue". The network file') - print('can be specified with a parent directory (for example') + print("can be specified with a parent directory (for example") print('"run_0/nn-epoch100.nnue"), in which case the .ckpt file corresponding') - print('to this .nnue file will only be searched for in the parent ("run_0") directory.') + print( + 'to this .nnue file will only be searched for in the parent ("run_0") directory.' + ) print('The .ckpt files must contain "epoch=([0-9]*).*\\.ckpt".') - print('Both ckpt and nnue files are deleted. Only nets listed in the ordo') - print('file can be deleted. Other nets are always kept.') - print('The .nnue and .ckpt files are matched by epoch.') - print('') - print('The directory layout can be for example:') - print('- root_dir') - print(' - run_0') - print(' - a/b/c/d.ckpt') - print(' - *.nnue') - print(' - run_1') - print(' - a/b/c/d.ckpt') - print(' - *.nnue') - print(' - ordo.out') - print(' (in this case ony lines with engine name matching') + print("Both ckpt and nnue files are deleted. Only nets listed in the ordo") + print("file can be deleted. Other nets are always kept.") + print("The .nnue and .ckpt files are matched by epoch.") + print("") + print("The directory layout can be for example:") + print("- root_dir") + print(" - run_0") + print(" - a/b/c/d.ckpt") + print(" - *.nnue") + print(" - run_1") + print(" - a/b/c/d.ckpt") + print(" - *.nnue") + print(" - ordo.out") + print(" (in this case ony lines with engine name matching") print(' "run_[01]/nn-epoch[0-9]*\\.nnue" will be used.)') + def main(): if len(sys.argv) < 2: show_help() @@ -147,5 +161,6 @@ def main(): num_best_to_keep = sys.argv[2] if len(sys.argv) >= 3 else 16 delete_bad_nets(root_dir, num_best_to_keep) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/do_plots.py b/do_plots.py index 06f496a9..9b6c1a0b 100644 --- a/do_plots.py +++ b/do_plots.py @@ -9,8 +9,9 @@ import os import collections + def find_event_files(root_dir): - p = re.compile('events\\.out\\.tfevents.*') + p = re.compile("events\\.out\\.tfevents.*") tfevent_files = [] for path, subdirs, files in os.walk(root_dir, followlinks=False): for filename in files: @@ -19,30 +20,33 @@ def find_event_files(root_dir): tfevent_files.append(os.path.join(path, filename)) return tfevent_files + def find_ordo_file(root_dir): for path, subdirs, files in os.walk(root_dir, followlinks=False): for filename in files: - if filename == 'ordo.out': + if filename == "ordo.out": return os.path.join(path, filename) return None -def get_list_aggregator(aggregation_mode='avg'): - if aggregation_mode == 'min': + +def get_list_aggregator(aggregation_mode="avg"): + if aggregation_mode == "min": return lambda x: min(x) - elif aggregation_mode == 'max': + elif aggregation_mode == "max": return lambda x: max(x) - elif aggregation_mode == 'avg': + elif aggregation_mode == "avg": return lambda x: sum(x) / len(x) else: - raise Exception('Invalid aggregation_mode {}'.format(aggregation_mode)) + raise Exception("Invalid aggregation_mode {}".format(aggregation_mode)) + -def aggregate_dict(values, aggregation_mode='avg'): - ''' +def aggregate_dict(values, aggregation_mode="avg"): + """ values must be a dict of lists each list is aggregated to a single scalar based on the aggregation_mode can be one of 'min', 'max', 'avg' - ''' + """ aggregate_list = get_list_aggregator(aggregation_mode) @@ -51,6 +55,7 @@ def aggregate_dict(values, aggregation_mode='avg'): res[k] = aggregate_list(v) return res + def dict_to_xy(d): x = [] y = [] @@ -59,13 +64,14 @@ def dict_to_xy(d): y.append(v) return x, y + def parse_ordo_file(filename, label): - p = re.compile('.*nn-epoch(\\d*)\\.nnue') - with open(filename, 'r') as ordo_file: + p = re.compile(".*nn-epoch(\\d*)\\.nnue") + with open(filename, "r") as ordo_file: rows = [] lines = ordo_file.readlines() for line in lines: - if 'nn-epoch' in line and label in line: + if "nn-epoch" in line and label in line: fields = line.split() net = fields[1] epoch = int(p.match(net)[1]) @@ -75,24 +81,26 @@ def parse_ordo_file(filename, label): return rows + def transpose_list_of_tuples(l): return list(map(list, zip(*l))) + def do_plots(out_filename, root_dirs, elo_range, loss_range, split): - ''' - 1. Find tfevents files for each root directory - 2. Look for metrics - 2.1. Look for 'val_loss' - 3. Look for ordo.out - 3.1. Parse elo from ordo. - 4. Do plots. - ''' + """ + 1. Find tfevents files for each root directory + 2. Look for metrics + 2.1. Look for 'val_loss' + 3. Look for ordo.out + 3.1. Parse elo from ordo. + 4. Do plots. + """ tf_size_guidance = { - 'compressedHistograms': 10, - 'images': 0, - 'scalars': 0, - 'histograms': 1 + "compressedHistograms": 10, + "images": 0, + "scalars": 0, + "histograms": 1, } fig = plt.figure() @@ -101,12 +109,11 @@ def do_plots(out_filename, root_dirs, elo_range, loss_range, split): ax_val_loss = fig.add_subplot(312) ax_elo = None - ax_val_loss.set_xlabel('step') - ax_val_loss.set_ylabel('val_loss') - - ax_train_loss.set_xlabel('step') - ax_train_loss.set_ylabel('train_loss') + ax_val_loss.set_xlabel("step") + ax_val_loss.set_ylabel("val_loss") + ax_train_loss.set_xlabel("step") + ax_train_loss.set_ylabel("train_loss") for user_root_dir in root_dirs: @@ -115,67 +122,71 @@ def do_plots(out_filename, root_dirs, elo_range, loss_range, split): # we use the ordo file in the root dir, but split the content. split_root_dirs = [user_root_dir] if split: - split_root_dirs = [] - for item in os.listdir(user_root_dir): - if os.path.isdir(os.path.join(user_root_dir, item)): - root_dir = os.path.join(user_root_dir, item) - if len(find_event_files(root_dir)) > 0: - split_root_dirs.append(root_dir) - split_root_dirs.sort() + split_root_dirs = [] + for item in os.listdir(user_root_dir): + if os.path.isdir(os.path.join(user_root_dir, item)): + root_dir = os.path.join(user_root_dir, item) + if len(find_event_files(root_dir)) > 0: + split_root_dirs.append(root_dir) + split_root_dirs.sort() for root_dir in split_root_dirs: - print('Processing root_dir {}'.format(root_dir)) + print("Processing root_dir {}".format(root_dir)) tfevents_files = find_event_files(root_dir) - print('Found {} tfevents files.'.format(len(tfevents_files))) + print("Found {} tfevents files.".format(len(tfevents_files))) val_losses = collections.defaultdict(lambda: []) train_losses = collections.defaultdict(lambda: []) for i, tfevents_file in enumerate(tfevents_files): - print('Processing tfevents file {}/{}: {}'.format(i+1, len(tfevents_files), tfevents_file)) + print( + "Processing tfevents file {}/{}: {}".format( + i + 1, len(tfevents_files), tfevents_file + ) + ) events_acc = EventAccumulator(tfevents_file, tf_size_guidance) events_acc.Reload() - vv = events_acc.Scalars('val_loss') - print('Found {} val_loss entries.'.format(len(vv))) + vv = events_acc.Scalars("val_loss") + print("Found {} val_loss entries.".format(len(vv))) minloss = min([v[2] for v in vv]) for v in vv: if v[2] < minloss + loss_range: - step = v[1] - val_losses[step].append(v[2]) + step = v[1] + val_losses[step].append(v[2]) - vv = events_acc.Scalars('train_loss') + vv = events_acc.Scalars("train_loss") minloss = min([v[2] for v in vv]) - print('Found {} train_loss entries.'.format(len(vv))) + print("Found {} train_loss entries.".format(len(vv))) for v in vv: if v[2] < minloss + loss_range: - step = v[1] - train_losses[step].append(v[2]) + step = v[1] + train_losses[step].append(v[2]) - print('Aggregating data...') + print("Aggregating data...") - val_loss = aggregate_dict(val_losses, 'min') + val_loss = aggregate_dict(val_losses, "min") x, y = dict_to_xy(val_loss) ax_val_loss.plot(x, y, label=root_dir) - train_loss = aggregate_dict(train_losses, 'min') + train_loss = aggregate_dict(train_losses, "min") x, y = dict_to_xy(train_loss) ax_train_loss.plot(x, y, label=root_dir) - print('Finished aggregating data.') + print("Finished aggregating data.") ordo_file = find_ordo_file(user_root_dir) if ordo_file: - print('Found ordo file {}'.format(ordo_file)) + print("Found ordo file {}".format(ordo_file)) if ax_elo is None: ax_elo = fig.add_subplot(313) - ax_elo.set_xlabel('epoch') - ax_elo.set_ylabel('Elo') + ax_elo.set_xlabel("epoch") + ax_elo.set_ylabel("Elo") for root_dir in split_root_dirs: rows = parse_ordo_file(ordo_file, root_dir if split else "nnue") if len(rows) == 0: - continue - rows = sorted(rows, key=lambda x:x[1]) + continue + rows = sorted(rows, key=lambda x: x[1]) epochs = [] elos = [] errors = [] @@ -185,31 +196,30 @@ def do_plots(out_filename, root_dirs, elo_range, loss_range, split): elo = row[2] error = row[3] if not epoch in epochs: - if elo > maxelo - elo_range: - epochs.append(epoch) - elos.append(elo) - errors.append(error) + if elo > maxelo - elo_range: + epochs.append(epoch) + elos.append(elo) + errors.append(error) - print('Found ordo data for {} epochs'.format(len(epochs))) + print("Found ordo data for {} epochs".format(len(epochs))) ax_elo.errorbar(epochs, elos, yerr=errors, label=root_dir) else: - print('Did not find ordo file. Skipping.') - + print("Did not find ordo file. Skipping.") ax_val_loss.legend() ax_train_loss.legend() if ax_elo: ax_elo.legend() - print('Saving plot at {}'.format(out_filename)) - #plt.show() + print("Saving plot at {}".format(out_filename)) + # plt.show() plt.savefig(out_filename, dpi=300) def main(): - #do_plots('test_plot_out.png', ['../nnue-pytorch-training/experiment_10', '../nnue-pytorch-training/experiment_11']) + # do_plots('test_plot_out.png', ['../nnue-pytorch-training/experiment_10', '../nnue-pytorch-training/experiment_11']) parser = argparse.ArgumentParser( description="Generate plots of losses and Elo for experiments run", @@ -218,8 +228,8 @@ def main(): parser.add_argument( "root_dirs", type=str, - nargs='+', - help="multiple root directories (containing ordo.out and tensorflow event files)" + nargs="+", + help="multiple root directories (containing ordo.out and tensorflow event files)", ) parser.add_argument( "--output", @@ -239,14 +249,22 @@ def main(): default=0.004, help="Limit loss data shown to the best result + loss_range", ) - parser.add_argument("--split", - action='store_true', + parser.add_argument( + "--split", + action="store_true", help="Split the root dirs provided, assumes the ordo file is still at the root, and nets in that ordo file match root_dir/sub_dir/", ) args = parser.parse_args() print(args.root_dirs) - do_plots(args.output, args.root_dirs, elo_range = args.elo_range, loss_range = args.loss_range, split = args.split) + do_plots( + args.output, + args.root_dirs, + elo_range=args.elo_range, + loss_range=args.loss_range, + split=args.split, + ) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/feature_block.py b/feature_block.py index 178424a4..1b2b1417 100644 --- a/feature_block.py +++ b/feature_block.py @@ -1,10 +1,12 @@ from collections import OrderedDict + def _get_main_factor_name(full_name): - return full_name.replace('^', '') + return full_name.replace("^", "") + class FeatureBlock: - ''' + """ This is the base class for all the network input features. All features must inherit from this class. It abstracts a named set of features in a way that @@ -45,11 +47,11 @@ class FeatureBlock: get_active_features (def get_active_features(self, board: chess.Board)), which takes the board and returns the list of indices of the features that are active for this board. - ''' + """ def __init__(self, name, hash, factors): if not isinstance(factors, OrderedDict): - raise Exception('Factors must be an collections.OrderedDict') + raise Exception("Factors must be an collections.OrderedDict") self.name = name self.hash = hash @@ -61,18 +63,20 @@ def __init__(self, name, hash, factors): def get_main_factor_name(self): return _get_main_factor_name(self.name) - ''' + """ This method represents the default factorizer. If your feature block has multiple factors you need to override this method to return a list of factors for a given feature. - ''' + """ + def get_feature_factors(self, idx): return [idx] - ''' + """ This method takes a string name of a factor and returns the offset of the first feature in this factor when consulted with the sizes of the previous factors. - ''' + """ + def get_factor_base_feature(self, name): offset = 0 for n, s in self.factors.items(): @@ -81,4 +85,4 @@ def get_factor_base_feature(self, name): offset += s - raise Exception('No factor named {} in {}'.format(name, self.name)) + raise Exception("No factor named {} in {}".format(name, self.name)) diff --git a/feature_set.py b/feature_set.py index ab1b670a..87414f29 100644 --- a/feature_set.py +++ b/feature_set.py @@ -3,46 +3,53 @@ import torch import chess + def _calculate_features_hash(features): if len(features) == 1: return features[0].hash tail_hash = calculate_features_hash(features[1:]) - return features[0].hash ^ (tail_hash << 1) ^ (tail_hash >> 1) & 0xffffffff + return features[0].hash ^ (tail_hash << 1) ^ (tail_hash >> 1) & 0xFFFFFFFF + class FeatureSet: - ''' + """ A feature set is nothing more than a list of named FeatureBlocks. It itself functions similarily to a feature block, but we don't want to be explicit about it as we don't want it to be used as a building block for other feature sets. You can think of this class as a composite, but not the full extent. It is basically a concatenation of feature blocks. - ''' + """ def __init__(self, features): for feature in features: if not isinstance(feature, FeatureBlock): - raise Exception('All features must subclass FeatureBlock') + raise Exception("All features must subclass FeatureBlock") self.features = features self.hash = _calculate_features_hash(features) - self.name = '+'.join(feature.name for feature in features) + self.name = "+".join(feature.name for feature in features) self.num_real_features = sum(feature.num_real_features for feature in features) - self.num_virtual_features = sum(feature.num_virtual_features for feature in features) + self.num_virtual_features = sum( + feature.num_virtual_features for feature in features + ) self.num_features = sum(feature.num_features for feature in features) - ''' + """ This method returns the feature ranges for the virtual factors of the underlying feature blocks. This is useful to know during initialization, when we want to zero initialize the virtual feature weights, but give some other values to the real feature weights. - ''' + """ + def get_virtual_feature_ranges(self): ranges = [] offset = 0 for feature in self.features: if feature.num_virtual_features: - ranges.append((offset + feature.num_real_features, offset + feature.num_features)) + ranges.append( + (offset + feature.num_real_features, offset + feature.num_features) + ) offset += feature.num_features return ranges @@ -56,12 +63,13 @@ def get_real_feature_ranges(self): return ranges - ''' + """ This method goes over all of the feature blocks and gathers the active features. Each block has its own index space assigned so the features from two different blocks will never have the same index here. Basically the thing you would expect to happen after concatenating many feature blocks. - ''' + """ + def get_active_features(self, board): w = torch.zeros(0) b = torch.zeros(0) @@ -77,11 +85,12 @@ def get_active_features(self, board): return w, b - ''' + """ This method takes a feature idx and looks for the block that owns it. If it found the block it asks it to factorize the index, otherwise it throws and Exception. The idx must refer to a real feature. - ''' + """ + def get_feature_factors(self, idx): offset = 0 for feature in self.features: @@ -89,9 +98,9 @@ def get_feature_factors(self, idx): return [offset + i for i in feature.get_feature_factors(idx - offset)] offset += feature.num_features - raise Exception('No feature block to factorize {}'.format(idx)) + raise Exception("No feature block to factorize {}".format(idx)) - ''' + """ This method does what get_feature_factors does but for all valid features at the same time. It returns a list of length self.num_real_features with ith element being a list of factors @@ -100,7 +109,8 @@ def get_feature_factors(self, idx): slightly faster when there's many feature blocks. It might be worth to add a similar method to the FeatureBlock itself - to make it faster for feature blocks with many factors. - ''' + """ + def get_virtual_to_real_features_gather_indices(self): indices = [] real_offset = 0 diff --git a/feature_transformer.py b/feature_transformer.py index 7084c1f4..bc47d19c 100644 --- a/feature_transformer.py +++ b/feature_transformer.py @@ -4,54 +4,71 @@ import cupy as cp import math + def _find_nearest_divisor(value, target): divisors = [] - for i in range(1, value+1): + for i in range(1, value + 1): if value % i == 0: - divisors.append((i, abs(target-i))) - divisors.sort(key=lambda x:x[1]) + divisors.append((i, abs(target - i))) + divisors.sort(key=lambda x: x[1]) return divisors[0][0] + _num_threads_forward_cache = dict() + + def _get_num_threads_for_forward(output_size): optimal_num_threads = 512 if output_size not in _num_threads_forward_cache: - _num_threads_forward_cache[output_size] = _find_nearest_divisor(output_size, optimal_num_threads) + _num_threads_forward_cache[output_size] = _find_nearest_divisor( + output_size, optimal_num_threads + ) return _num_threads_forward_cache[output_size] + _num_threads_backward_cache = dict() + + def _get_num_threads_for_backward(output_size): optimal_num_threads = 512 if output_size not in _num_threads_backward_cache: - _num_threads_backward_cache[output_size] = _find_nearest_divisor(output_size, optimal_num_threads) + _num_threads_backward_cache[output_size] = _find_nearest_divisor( + output_size, optimal_num_threads + ) return _num_threads_backward_cache[output_size] + def _kernel_with_threads(kernel, threads): def f(grid, args): kernel(grid=grid, block=threads, args=args) + return f + _feature_transformer_slice_forward_kernel_cache = dict() + + def make_feature_transformer_slice_forward_kernel(max_active_features, output_size): - ''' - @param: max_active_features - The maximum number of features that are active - (non-zero) for a single position. This value determines - the shape of the inputs. - This value is of type uint32_t. - - @param: output_size - The number of outputs. Must match the shape of weights - and biases. - This value is of type uint32. - ''' + """ + @param: max_active_features + The maximum number of features that are active + (non-zero) for a single position. This value determines + the shape of the inputs. + This value is of type uint32_t. + + @param: output_size + The number of outputs. Must match the shape of weights + and biases. + This value is of type uint32. + """ num_threads = _get_num_threads_for_forward(output_size) output_thread_slice_size = output_size // num_threads key = (max_active_features, output_size, num_threads) if key not in _feature_transformer_slice_forward_kernel_cache: - kernel = cp.RawKernel(r''' + kernel = cp.RawKernel( + r""" typedef unsigned int uint32_t; typedef int int32_t; @@ -144,34 +161,42 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si }} }} -'''.format( +""".format( max_active_features=max_active_features, output_thread_slice_size=output_thread_slice_size, - output_size=output_size), - 'feature_transformer_slice_forward') + output_size=output_size, + ), + "feature_transformer_slice_forward", + ) kernel.compile() - _feature_transformer_slice_forward_kernel_cache[key] = _kernel_with_threads(kernel, (num_threads,)) + _feature_transformer_slice_forward_kernel_cache[key] = _kernel_with_threads( + kernel, (num_threads,) + ) return _feature_transformer_slice_forward_kernel_cache[key] + _feature_transformer_slice_backward_kernel_cache = dict() + + def make_feature_transformer_slice_backward_kernel(max_active_features, output_size): - '''' - @param: max_active_features - The maximum number of features that are active - (non-zero) for a single position. This value determines - the shape of the inputs. - This value is of type uint32_t. - - @param: output_size - The number of outputs. Must match the shape of weights - and biases. - This value is of type uint32. - ''' + """' + @param: max_active_features + The maximum number of features that are active + (non-zero) for a single position. This value determines + the shape of the inputs. + This value is of type uint32_t. + + @param: output_size + The number of outputs. Must match the shape of weights + and biases. + This value is of type uint32. + """ num_threads = _get_num_threads_for_backward(output_size) output_thread_slice_size = output_size // num_threads key = (max_active_features, output_size, num_threads) if key not in _feature_transformer_slice_backward_kernel_cache: - kernel = cp.RawKernel(r''' + kernel = cp.RawKernel( + r""" typedef unsigned int uint32_t; typedef int int32_t; @@ -273,17 +298,21 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s }} }} -'''.format( +""".format( max_active_features=max_active_features, output_thread_slice_size=output_thread_slice_size, - output_size=output_size), - 'feature_transformer_slice_backward') + output_size=output_size, + ), + "feature_transformer_slice_backward", + ) kernel.compile() - _feature_transformer_slice_backward_kernel_cache[key] = _kernel_with_threads(kernel, (num_threads,)) + _feature_transformer_slice_backward_kernel_cache[key] = _kernel_with_threads( + kernel, (num_threads,) + ) return _feature_transformer_slice_backward_kernel_cache[key] -class FeatureTransformerSliceFunction(autograd.Function): +class FeatureTransformerSliceFunction(autograd.Function): @staticmethod def forward(ctx, feature_indices, feature_values, weight, bias): ctx.save_for_backward(feature_indices, feature_values, weight, bias) @@ -320,9 +349,17 @@ def forward(ctx, feature_indices, feature_values, weight, bias): max_active_features = feature_indices.shape[1] output_size = weight.shape[1] - output = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) + output = torch.empty( + batch_size, + output_size, + dtype=torch.float32, + device=device, + requires_grad=True, + ) - kernel = make_feature_transformer_slice_forward_kernel(max_active_features, output_size) + kernel = make_feature_transformer_slice_forward_kernel( + max_active_features, output_size + ) kernel( grid=(batch_size,), args=( @@ -330,8 +367,8 @@ def forward(ctx, feature_indices, feature_values, weight, bias): feature_values.data_ptr(), weight.data_ptr(), bias.data_ptr(), - output.data_ptr() - ) + output.data_ptr(), + ), ) return output @@ -350,10 +387,14 @@ def backward(ctx, grad_output): max_active_features = feature_indices.shape[1] output_size = weight.shape[1] - weight_grad = torch.zeros(weight.shape[0], weight.shape[1], dtype=torch.float32, device=device) + weight_grad = torch.zeros( + weight.shape[0], weight.shape[1], dtype=torch.float32, device=device + ) bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device) - kernel = make_feature_transformer_slice_backward_kernel(max_active_features, output_size) + kernel = make_feature_transformer_slice_backward_kernel( + max_active_features, output_size + ) kernel( grid=(batch_size,), args=( @@ -361,17 +402,32 @@ def backward(ctx, grad_output): feature_values.data_ptr(), weight_grad.data_ptr(), bias_grad.data_ptr(), - grad_output.data_ptr() - ) + grad_output.data_ptr(), + ), ) return None, None, weight_grad, bias_grad -class DoubleFeatureTransformerSliceFunction(autograd.Function): +class DoubleFeatureTransformerSliceFunction(autograd.Function): @staticmethod - def forward(ctx, feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, weight, bias): - ctx.save_for_backward(feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, weight, bias) + def forward( + ctx, + feature_indices_0, + feature_values_0, + feature_indices_1, + feature_values_1, + weight, + bias, + ): + ctx.save_for_backward( + feature_indices_0, + feature_values_0, + feature_indices_1, + feature_values_1, + weight, + bias, + ) assert len(feature_indices_0.shape) == 2 assert len(feature_values_0.shape) == 2 @@ -418,10 +474,24 @@ def forward(ctx, feature_indices_0, feature_values_0, feature_indices_1, feature max_active_features = feature_indices_0.shape[1] output_size = weight.shape[1] - output0 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) - output1 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) + output0 = torch.empty( + batch_size, + output_size, + dtype=torch.float32, + device=device, + requires_grad=True, + ) + output1 = torch.empty( + batch_size, + output_size, + dtype=torch.float32, + device=device, + requires_grad=True, + ) - kernel = make_feature_transformer_slice_forward_kernel(max_active_features, output_size) + kernel = make_feature_transformer_slice_forward_kernel( + max_active_features, output_size + ) kernel( grid=(batch_size,), args=( @@ -429,8 +499,8 @@ def forward(ctx, feature_indices_0, feature_values_0, feature_indices_1, feature feature_values_0.data_ptr(), weight.data_ptr(), bias.data_ptr(), - output0.data_ptr() - ) + output0.data_ptr(), + ), ) kernel( @@ -440,8 +510,8 @@ def forward(ctx, feature_indices_0, feature_values_0, feature_indices_1, feature feature_values_1.data_ptr(), weight.data_ptr(), bias.data_ptr(), - output1.data_ptr() - ) + output1.data_ptr(), + ), ) return output0, output1 @@ -454,17 +524,28 @@ def backward(ctx, grad_output_0, grad_output_1): grad_output_0 = grad_output_0.contiguous() grad_output_1 = grad_output_1.contiguous() - feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, weight, bias = ctx.saved_tensors + ( + feature_indices_0, + feature_values_0, + feature_indices_1, + feature_values_1, + weight, + bias, + ) = ctx.saved_tensors device = feature_indices_0.device batch_size = feature_indices_0.shape[0] max_active_features = feature_indices_0.shape[1] output_size = weight.shape[1] - weight_grad = torch.zeros(weight.shape[0], weight.shape[1], dtype=torch.float32, device=device) + weight_grad = torch.zeros( + weight.shape[0], weight.shape[1], dtype=torch.float32, device=device + ) bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device) - kernel = make_feature_transformer_slice_backward_kernel(max_active_features, output_size) + kernel = make_feature_transformer_slice_backward_kernel( + max_active_features, output_size + ) kernel( grid=(batch_size,), args=( @@ -472,8 +553,8 @@ def backward(ctx, grad_output_0, grad_output_1): feature_values_0.data_ptr(), weight_grad.data_ptr(), bias_grad.data_ptr(), - grad_output_0.data_ptr() - ) + grad_output_0.data_ptr(), + ), ) kernel( @@ -483,24 +564,33 @@ def backward(ctx, grad_output_0, grad_output_1): feature_values_1.data_ptr(), weight_grad.data_ptr(), bias_grad.data_ptr(), - grad_output_1.data_ptr() - ) + grad_output_1.data_ptr(), + ), ) return None, None, None, None, weight_grad, bias_grad + class FeatureTransformerSlice(nn.Module): def __init__(self, num_inputs, num_outputs): super(FeatureTransformerSlice, self).__init__() self.num_inputs = num_inputs self.num_outputs = num_outputs - sigma = math.sqrt(1/num_inputs) - self.weight = nn.Parameter(torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma) - sigma) - self.bias = nn.Parameter(torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma) + sigma = math.sqrt(1 / num_inputs) + self.weight = nn.Parameter( + torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma) + - sigma + ) + self.bias = nn.Parameter( + torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma + ) def forward(self, feature_indices, feature_values): - return FeatureTransformerSliceFunction.apply(feature_indices, feature_values, self.weight, self.bias) + return FeatureTransformerSliceFunction.apply( + feature_indices, feature_values, self.weight, self.bias + ) + class DoubleFeatureTransformerSlice(nn.Module): def __init__(self, num_inputs, num_outputs): @@ -508,23 +598,42 @@ def __init__(self, num_inputs, num_outputs): self.num_inputs = num_inputs self.num_outputs = num_outputs - sigma = math.sqrt(1/num_inputs) - self.weight = nn.Parameter(torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma) - sigma) - self.bias = nn.Parameter(torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma) + sigma = math.sqrt(1 / num_inputs) + self.weight = nn.Parameter( + torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma) + - sigma + ) + self.bias = nn.Parameter( + torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma + ) - def forward(self, feature_indices_0, feature_values_0, feature_indices_1, feature_values_1): - return DoubleFeatureTransformerSliceFunction.apply(feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, self.weight, self.bias) + def forward( + self, feature_indices_0, feature_values_0, feature_indices_1, feature_values_1 + ): + return DoubleFeatureTransformerSliceFunction.apply( + feature_indices_0, + feature_values_0, + feature_indices_1, + feature_values_1, + self.weight, + self.bias, + ) -if __name__ == '__main__': + +if __name__ == "__main__": import time import sys import os - def FeatureTransformerSliceFunctionEmulate(feature_indices, feature_values, weight, bias): + def FeatureTransformerSliceFunctionEmulate( + feature_indices, feature_values, weight, bias + ): batch_size = feature_indices.shape[0] num_inputs = weight.shape[0] max_active_features = feature_indices.shape[1] - inputs = torch.zeros(batch_size, num_inputs, dtype=torch.float32, device=weight.device) + inputs = torch.zeros( + batch_size, num_inputs, dtype=torch.float32, device=weight.device + ) for i in range(batch_size): for j in range(max_active_features): feature = feature_indices[i, j] @@ -541,21 +650,40 @@ def test(): MAX_ERROR = 1e-4 torch.manual_seed(0) - weight0 = torch.rand(INPUT_SIZE, STRIDE, dtype=torch.float32, requires_grad=True) + weight0 = torch.rand( + INPUT_SIZE, STRIDE, dtype=torch.float32, requires_grad=True + ) bias0 = torch.rand(STRIDE, dtype=torch.float32, requires_grad=True) torch.manual_seed(0) - weight1 = torch.rand(INPUT_SIZE, STRIDE, dtype=torch.float32, requires_grad=True) + weight1 = torch.rand( + INPUT_SIZE, STRIDE, dtype=torch.float32, requires_grad=True + ) bias1 = torch.rand(STRIDE, dtype=torch.float32, requires_grad=True) - indices0 = (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES) * INPUT_SIZE).to(dtype=torch.int32) - indices1 = (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES) * INPUT_SIZE).to(dtype=torch.int32) + indices0 = (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES) * INPUT_SIZE).to( + dtype=torch.int32 + ) + indices1 = (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES) * INPUT_SIZE).to( + dtype=torch.int32 + ) values0 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32) values1 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32) - output00 = FeatureTransformerSliceFunctionEmulate(indices0.clone(), values0.clone(), weight0, bias0) - output01 = FeatureTransformerSliceFunctionEmulate(indices1.clone(), values1.clone(), weight0, bias0) - #output10 = FeatureTransformerSliceFunction.apply(indices0.clone().cuda(), values0.clone().cuda(), weight1.cuda(), bias1.cuda()) - #output11 = FeatureTransformerSliceFunction.apply(indices1.clone().cuda(), values1.clone().cuda(), weight1.cuda(), bias1.cuda()) - output10, output11 = DoubleFeatureTransformerSliceFunction.apply(indices0.clone().cuda(), values0.clone().cuda(), indices1.clone().cuda(), values1.clone().cuda(), weight1.cuda(), bias1.cuda()) + output00 = FeatureTransformerSliceFunctionEmulate( + indices0.clone(), values0.clone(), weight0, bias0 + ) + output01 = FeatureTransformerSliceFunctionEmulate( + indices1.clone(), values1.clone(), weight0, bias0 + ) + # output10 = FeatureTransformerSliceFunction.apply(indices0.clone().cuda(), values0.clone().cuda(), weight1.cuda(), bias1.cuda()) + # output11 = FeatureTransformerSliceFunction.apply(indices1.clone().cuda(), values1.clone().cuda(), weight1.cuda(), bias1.cuda()) + output10, output11 = DoubleFeatureTransformerSliceFunction.apply( + indices0.clone().cuda(), + values0.clone().cuda(), + indices1.clone().cuda(), + values1.clone().cuda(), + weight1.cuda(), + bias1.cuda(), + ) assert torch.max(output00.cpu() - output10.cpu()) < MAX_ERROR assert torch.max(output01.cpu() - output11.cpu()) < MAX_ERROR @@ -563,7 +691,7 @@ def test(): (output10 - output11).sum().backward() assert torch.max(weight0.grad.cpu() - weight1.grad.cpu()) < MAX_ERROR assert torch.max(bias0.grad.cpu() - bias1.grad.cpu()) < MAX_ERROR - print('Tests passed.') + print("Tests passed.") def bench(): INPUT_SIZE = 40960 @@ -573,10 +701,36 @@ def bench(): MAX_ACTIVE_FEATURES = 64 layer = DoubleFeatureTransformerSlice(INPUT_SIZE, STRIDE).cuda() - indices0 = torch.cat([torch.sort((torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES * 3 // 4) * INPUT_SIZE), dim=1)[0].to(dtype=torch.int32), torch.full((BATCH_SIZE, MAX_ACTIVE_FEATURES // 4), -1, dtype=torch.int32)], dim=1).cuda() - values0 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32).cuda() - indices1 = torch.cat([torch.sort((torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES * 3 // 4)) * INPUT_SIZE, dim=1)[0].to(dtype=torch.int32), torch.full((BATCH_SIZE, MAX_ACTIVE_FEATURES // 4), -1, dtype=torch.int32)], dim=1).cuda() - values1 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32).cuda() + indices0 = torch.cat( + [ + torch.sort( + (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES * 3 // 4) * INPUT_SIZE), + dim=1, + )[0].to(dtype=torch.int32), + torch.full( + (BATCH_SIZE, MAX_ACTIVE_FEATURES // 4), -1, dtype=torch.int32 + ), + ], + dim=1, + ).cuda() + values0 = torch.rand( + BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32 + ).cuda() + indices1 = torch.cat( + [ + torch.sort( + (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES * 3 // 4)) * INPUT_SIZE, + dim=1, + )[0].to(dtype=torch.int32), + torch.full( + (BATCH_SIZE, MAX_ACTIVE_FEATURES // 4), -1, dtype=torch.int32 + ), + ], + dim=1, + ).cuda() + values1 = torch.rand( + BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32 + ).cuda() output0, output1 = layer(indices0, values0, indices1, values1) @@ -589,17 +743,17 @@ def bench(): output0 = torch.clamp(output0, 0.0, 1.0) output1 = torch.clamp(output1, 0.0, 1.0) - g = ((output0 - output1)**2).mean() + g = ((output0 - output1) ** 2).mean() g.backward() torch.cuda.synchronize() end = time.time() - #for param in layer.parameters(): + # for param in layer.parameters(): # print(param.grad) - print('{} pos/s'.format((ITERS * BATCH_SIZE) / (end - start))) + print("{} pos/s".format((ITERS * BATCH_SIZE) / (end - start))) test() - bench() \ No newline at end of file + bench() diff --git a/features.py b/features.py index 6325c302..117a3f26 100644 --- a/features.py +++ b/features.py @@ -3,12 +3,12 @@ import argparse -''' +""" Each module that defines feature blocks must be imported here and added to the _feature_modules list. Each such module must define a function `get_feature_block_clss` at module scope that returns the list of feature block classes in that module. -''' +""" import halfkp import halfka import halfka_v2 @@ -18,35 +18,50 @@ _feature_blocks_by_name = dict() + def _add_feature_block(feature_block_cls): feature_block = feature_block_cls() _feature_blocks_by_name[feature_block.name] = feature_block + def _add_features_blocks_from_module(module): feature_block_clss = module.get_feature_block_clss() for feature_block_cls in feature_block_clss: _add_feature_block(feature_block_cls) + def get_feature_block_from_name(name): return _feature_blocks_by_name[name] + def get_feature_blocks_from_names(names): return [_feature_blocks_by_name[name] for name in names] + def get_feature_set_from_name(name): - feature_block_names = name.split('+') + feature_block_names = name.split("+") blocks = get_feature_blocks_from_names(feature_block_names) return FeatureSet(blocks) + def get_available_feature_blocks_names(): return list(iter(_feature_blocks_by_name)) + def add_argparse_args(parser): - _default_feature_set_name = 'HalfKAv2_hm^' - parser.add_argument("--features", dest='features', default=_default_feature_set_name, help="The feature set to use. Can be a union of feature blocks (for example P+HalfKP). \"^\" denotes a factorized block. Currently available feature blocks are: " + ', '.join(get_available_feature_blocks_names())) + _default_feature_set_name = "HalfKAv2_hm^" + parser.add_argument( + "--features", + dest="features", + default=_default_feature_set_name, + help='The feature set to use. Can be a union of feature blocks (for example P+HalfKP). "^" denotes a factorized block. Currently available feature blocks are: ' + + ", ".join(get_available_feature_blocks_names()), + ) + def _init(): for module in _feature_modules: _add_features_blocks_from_module(module) -_init() \ No newline at end of file + +_init() diff --git a/ftperm.py b/ftperm.py index c1539913..fbebccff 100644 --- a/ftperm.py +++ b/ftperm.py @@ -1,4 +1,4 @@ -''' +""" NOTE: This script uses CUDA and may requires large amounts of VRAM. Decrease --count if encountering problems. @@ -29,7 +29,7 @@ python serialize.py nn-5af11540bbfe.nnue permuted.nnue --features=HalfKAv2_hm --ft_optimize --ft_optimize_data=noob_master_leaf_static_d12_85M_0.binpack --ft_optimize_count=10000 -''' +""" import time import argparse @@ -45,41 +45,42 @@ import cupy as cp from math import ceil -''' +""" Algorithm by Daniel Monroe. Github @Ergodice. -''' +""" ZERO_BLOCK_SIZE = 4 VERBOSE = False USE_CUPY = False + def batched(arr, batch_size): - ''' + """ Utility generator that yields chunks of array `arr` of size `batch_size` Expects arr to be a numpy-like array - ''' + """ n_samples = arr.shape[0] idx = 0 while idx < n_samples: - yield arr[idx:min(idx+batch_size, n_samples)] + yield arr[idx : min(idx + batch_size, n_samples)] idx += batch_size def apply_swap(perm, i, j): - ''' + """ Swap `i`-th and `j`-th elements in the array `perm`. - ''' + """ perm[i], perm[j] = perm[j], perm[i] def apply_rotate_right(perm, indices): - ''' + """ Rotates right the values in `perm` at selected indices `indices`. The rotation is performed as-if the selected indices were layed out in the order specified in the `indices` list. - ''' + """ values = [perm[i] for i in indices] new_values = [values[-1]] + values[:-1] for i, j in zip(indices, new_values): @@ -93,7 +94,9 @@ def get_swapped_zero_positive_count(actmat_flat, use_cupy=True): shape = actmat_flat.shape # Group into blocks that are processed at once during inference # actmat is a boolean matrix of shape (N, L1 // 2) with "True" meaning 0 - actmat_chunked = actmat_flat.reshape((actmat_flat.shape[0], actmat_flat.shape[1]//ZERO_BLOCK_SIZE, ZERO_BLOCK_SIZE)) + actmat_chunked = actmat_flat.reshape( + (actmat_flat.shape[0], actmat_flat.shape[1] // ZERO_BLOCK_SIZE, ZERO_BLOCK_SIZE) + ) if use_cupy: # Calculate number of zeros in each block @@ -110,22 +113,32 @@ def get_swapped_zero_positive_count(actmat_flat, use_cupy=True): # actmat_chunked = [... [... [1, 1, 0, 1], [0, 0, 1, 0], [1, 1, 1, 1] ...] ...] # rest_zero_indicator = [... [... [0, 0, 1, 0], [0, 0, 0, 0], [1, 1, 1, 1] ...] ...] # - rest_zero_indicator = (num_zeros - actmat_chunked == ZERO_BLOCK_SIZE - 1).reshape(shape).astype(cp.int8) + rest_zero_indicator = ( + (num_zeros - actmat_chunked == ZERO_BLOCK_SIZE - 1) + .reshape(shape) + .astype(cp.int8) + ) # Sum all possible pairs of elements in a single sample of actmat_flat and rest_zero_indicator. # Aggregate sum over the whole batch. # This tells us how much "good" a swap of i-th and j-th slices would do. It doesn't consider # how much "bad" it would do though, that will be accounted for later, for performance reasons. - swapped_zero_count = cp.einsum('bi,bj->ij', actmat_flat, rest_zero_indicator, dtype=int) + swapped_zero_count = cp.einsum( + "bi,bj->ij", actmat_flat, rest_zero_indicator, dtype=int + ) else: # Same operation but with numpy num_zeros = np.sum(actmat_chunked, axis=2, keepdims=True) num_zeros = np.tile(num_zeros, (1, 1, ZERO_BLOCK_SIZE)) - - rest_zero_indicator = (num_zeros - actmat_chunked == ZERO_BLOCK_SIZE - 1).reshape(shape).astype(int) - swapped_zero_count = np.einsum('bi,bj->ij', actmat_flat, rest_zero_indicator) + rest_zero_indicator = ( + (num_zeros - actmat_chunked == ZERO_BLOCK_SIZE - 1) + .reshape(shape) + .astype(int) + ) + + swapped_zero_count = np.einsum("bi,bj->ij", actmat_flat, rest_zero_indicator) return swapped_zero_count @@ -138,7 +151,9 @@ def get_swapped_zero_increase(actmat, use_cupy=True): # TODO: Find a good batch size. Try lowest as possible as VRAM is an issue on low end devices. BATCH_SIZE = 10000 for actmat_batch in batched(actmat, BATCH_SIZE): - swapped_zero_count += get_swapped_zero_positive_count(actmat_batch, use_cupy=use_cupy) + swapped_zero_count += get_swapped_zero_positive_count( + actmat_batch, use_cupy=use_cupy + ) # (L1/2) x (L1/2) if use_cupy: @@ -146,11 +161,15 @@ def get_swapped_zero_increase(actmat, use_cupy=True): # This is the place where we account for how much "bad" it would do. # It is done here because we process earlier in batches, but this operation is distributive, # so it needs to only be done once at the end. - swapped_zero_increase = swapped_zero_count - cp.reshape(cp.diag(swapped_zero_count), (1, n_neurons)) + swapped_zero_increase = swapped_zero_count - cp.reshape( + cp.diag(swapped_zero_count), (1, n_neurons) + ) swapped_zero_increase = cp.asnumpy(swapped_zero_increase) else: - swapped_zero_increase = swapped_zero_count - np.reshape(np.diag(swapped_zero_count), (1, n_neurons)) + swapped_zero_increase = swapped_zero_count - np.reshape( + np.diag(swapped_zero_count), (1, n_neurons) + ) return swapped_zero_increase @@ -170,9 +189,9 @@ def get_score_change(actmat, use_cupy=True): def make_swaps_2(actmat, use_cupy=True): - ''' + """ Returns a series of independent 2-swap operations that collectively improve the objective function. - ''' + """ # For each pair of nodes, we want to calculate the difference between the number of 4-zero runs when swapping them start_time = time.time() @@ -188,7 +207,7 @@ def make_swaps_2(actmat, use_cupy=True): score_change = score_change + score_change.T def all_indices_in_same_block(i): - ''' Returns a list of indices of all neurons in the same block as the i-th neuron. ''' + """Returns a list of indices of all neurons in the same block as the i-th neuron.""" # Floor to the start of the block. base = i // ZERO_BLOCK_SIZE * ZERO_BLOCK_SIZE return list(range(base, base + ZERO_BLOCK_SIZE)) @@ -218,7 +237,9 @@ def all_indices_in_same_block(i): score_change[:, index] = 0 score_change[index, :] = 0 - total_improvement = total_score_change / n_samples / (n_neurons//ZERO_BLOCK_SIZE) * 100 + total_improvement = ( + total_score_change / n_samples / (n_neurons // ZERO_BLOCK_SIZE) * 100 + ) print(f"Time elapsed: {time.time() - start_time:0.3f}") print(f"Improvement this iteration: {total_improvement:0.3f}") @@ -227,9 +248,9 @@ def all_indices_in_same_block(i): def make_swaps_3(actmat, use_cupy=True): - ''' + """ Returns a series of independent left-rotates operations that collectively improve the objective function. - ''' + """ # For each triplet of nodes, we want to calculate the change in score when moving them in a cycle print("Starting make_swaps_3") @@ -243,7 +264,11 @@ def make_swaps_3(actmat, use_cupy=True): # For each neuron i, j, k we sum score_change[i, j] + score_change[j, k] + score_change[k, i] # This is the cumulative impact of the right-rotation. - score_changes = score_changes[:, :, None] + score_changes[None, :, :] + (score_changes.T)[:, None, :] + score_changes = ( + score_changes[:, :, None] + + score_changes[None, :, :] + + (score_changes.T)[:, None, :] + ) orig_shape = (n_neurons,) * 3 compressed_shape = (n_blocks, ZERO_BLOCK_SIZE) * 3 @@ -253,11 +278,15 @@ def make_swaps_3(actmat, use_cupy=True): if use_cupy: # We don't want to have to go through an enormous array so compress it to represent blocks rather than neurons # Cupy doesn't support a list of axes so we go one by one. - max_values = cp.amax(cp.reshape(score_changes, compressed_shape), axis=5, keepdims=False) + max_values = cp.amax( + cp.reshape(score_changes, compressed_shape), axis=5, keepdims=False + ) max_values = cp.amax(max_values, axis=3, keepdims=False) max_values = cp.amax(max_values, axis=1, keepdims=False) else: - max_values = np.amax(np.reshape(score_changes, compressed_shape), axis=(5, 3, 1), keepdims=False) + max_values = np.amax( + np.reshape(score_changes, compressed_shape), axis=(5, 3, 1), keepdims=False + ) # Kill rotates that would only affect less than 3 different blocks. # We must do this, because the rest of the algorithm relies on it for correctness. @@ -281,11 +310,15 @@ def make_swaps_3(actmat, use_cupy=True): # Now we need to find the best set of neurons for this rotation in the found blocks # (we already know there is a gain available) - local_score_changes = score_changes[i:i+ZERO_BLOCK_SIZE, j:j+ZERO_BLOCK_SIZE, k:k+ZERO_BLOCK_SIZE] + local_score_changes = score_changes[ + i : i + ZERO_BLOCK_SIZE, j : j + ZERO_BLOCK_SIZE, k : k + ZERO_BLOCK_SIZE + ] best_neurons = local_score_changes.argmax() improvement_neurons = local_score_changes.flatten()[best_neurons] assert improvement_blocks == improvement_neurons - i1, j1, k1 = np.unravel_index(best_neurons, (ZERO_BLOCK_SIZE, ZERO_BLOCK_SIZE, ZERO_BLOCK_SIZE)) + i1, j1, k1 = np.unravel_index( + best_neurons, (ZERO_BLOCK_SIZE, ZERO_BLOCK_SIZE, ZERO_BLOCK_SIZE) + ) i, j, k = i + i1, j + j1, k + k1 if VERBOSE: @@ -301,14 +334,14 @@ def make_swaps_3(actmat, use_cupy=True): max_values[:, b, :] = 0 max_values[:, :, b] = 0 - total_improvement = total_score_change / n_samples / (n_neurons//4) * 100 + total_improvement = total_score_change / n_samples / (n_neurons // 4) * 100 print(f"Time elapsed: {time.time() - start_time:0.3f}") print(f"Improvement this iteration: {total_improvement:0.3f}") return cycles, total_improvement def find_perm_impl(actmat): - actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1]//2)) + actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1] // 2)) if USE_CUPY: actmat = cp.asarray(actmat, dtype=cp.int8) actmat_orig = actmat.copy() @@ -324,7 +357,7 @@ def find_perm_impl(actmat): num_fails = 0 for i in range(50): - print("Iteration", i+1) + print("Iteration", i + 1) # Choose the current stage optimization function swap_fn = stages[stage_id] @@ -340,7 +373,7 @@ def find_perm_impl(actmat): apply_rotate_right(perm, cycle) total_score_change += score_change - print(f'Total improvement: {total_score_change}\n') + print(f"Total improvement: {total_score_change}\n") if score_change == 0: num_fails += 1 @@ -348,25 +381,30 @@ def find_perm_impl(actmat): num_fails = 0 stage_id += 1 - if stage_id >= len(stages) or (stop_after_stage is not None and stage_id > stop_after_stage): - print('No more improvement possible.') + if stage_id >= len(stages) or ( + stop_after_stage is not None and stage_id > stop_after_stage + ): + print("No more improvement possible.") break - print(f'Switching to stage {stage_id}') + print(f"Switching to stage {stage_id}") return perm + # ------------------------------------------------------------- + def read_model(nnue_path, feature_set): - with open(nnue_path, 'rb') as f: + with open(nnue_path, "rb") as f: reader = serialize.NNUEReader(f, feature_set) return reader.model - + def make_fen_batch_provider(data_path, batch_size): return nnue_dataset.FenBatchProvider(data_path, True, 1, batch_size, False, 10) + def filter_fens(fens): # We don't want fens where a king is in check, as these cannot be evaluated by the engine. filtered_fens = [] @@ -376,11 +414,23 @@ def filter_fens(fens): filtered_fens.append(fen) return filtered_fens + def quantize_ft(model): model.input.weight.data = model.input.weight.data.mul(model.quantized_one).round() model.input.bias.data = model.input.bias.data.mul(model.quantized_one).round() -def forward_ft(model, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices): + +def forward_ft( + model, + us, + them, + white_indices, + white_values, + black_indices, + black_values, + psqt_indices, + layer_stack_indices, +): wp, bp = model.input(white_indices, white_values, black_indices, black_values) w, wpsqt = torch.split(wp, M.L1, dim=1) b, bpsqt = torch.split(bp, M.L1, dim=1) @@ -391,22 +441,47 @@ def forward_ft(model, us, them, white_indices, white_values, black_indices, blac l0_s1 = [l0_s[0] * l0_s[1], l0_s[2] * l0_s[3]] # We multiply by 127/128 because in the quantized network 1.0 is represented by 127 # and it's more efficient to divide by 128 instead. - l0_ = torch.cat(l0_s1, dim=1) * (1/128) + l0_ = torch.cat(l0_s1, dim=1) * (1 / 128) return l0_.round() + def eval_ft(model, batch): with torch.no_grad(): - us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch.contents.get_tensors('cuda') - res = forward_ft(model, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) + ( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + outcome, + score, + psqt_indices, + layer_stack_indices, + ) = batch.contents.get_tensors("cuda") + res = forward_ft( + model, + us, + them, + white_indices, + white_values, + black_indices, + black_values, + psqt_indices, + layer_stack_indices, + ) return res + def ft_permute_impl(model, permutation): permutation = list(permutation) l1_size = model.layer_stacks.l1.in_features - if l1_size != len(permutation)*2: - raise Exception(f'Invalid permutation size. Expected {l1_size}. Got {len(permutation)*2}.') + if l1_size != len(permutation) * 2: + raise Exception( + f"Invalid permutation size. Expected {l1_size}. Got {len(permutation)*2}." + ) # Both sides of the FT must use the same permutation. permutation.extend([x + l1_size // 2 for x in permutation]) @@ -417,16 +492,20 @@ def ft_permute_impl(model, permutation): # Apply the permutation in place. model.input.weight.data = model.input.weight.data[:, ft_permutation] model.input.bias.data = model.input.bias.data[ft_permutation] - model.layer_stacks.l1.weight.data = model.layer_stacks.l1.weight.data[:, permutation] + model.layer_stacks.l1.weight.data = model.layer_stacks.l1.weight.data[ + :, permutation + ] + def ft_permute(model, ft_perm_path): - with open(ft_perm_path, 'rb') as f: + with open(ft_perm_path, "rb") as f: permutation = np.load(f) ft_permute_impl(model, permutation) + def gather_impl(model, dataset, count): - ZERO_POINT = 0.0 # Vary this to check hypothetical forced larger truncation to zero + ZERO_POINT = 0.0 # Vary this to check hypothetical forced larger truncation to zero BATCH_SIZE = 1000 old_device = model.device @@ -440,21 +519,28 @@ def gather_impl(model, dataset, count): actmats = [] done = 0 - print('Processed {} positions.'.format(done)) + print("Processed {} positions.".format(done)) while done < count: fens = filter_fens(next(fen_batch_provider)) - b = nnue_dataset.make_sparse_batch_from_fens(quantized_model.feature_set, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens)) + b = nnue_dataset.make_sparse_batch_from_fens( + quantized_model.feature_set, + fens, + [0] * len(fens), + [1] * len(fens), + [0] * len(fens), + ) actmat = eval_ft(quantized_model, b).cpu() - actmat = (actmat <= ZERO_POINT) + actmat = actmat <= ZERO_POINT actmats.append(actmat.numpy()) nnue_dataset.destroy_sparse_batch(b) done += len(fens) - print('Processed {} positions.'.format(done)) + print("Processed {} positions.".format(done)) return np.concatenate(actmats, axis=0) + def command_gather(args): feature_set = features.get_feature_set_from_name(args.features) if args.checkpoint: @@ -466,67 +552,69 @@ def command_gather(args): actmat = gather_impl(model, args.data, args.count) - with open(args.out, 'wb') as file: + with open(args.out, "wb") as file: np.save(file, actmat) + def eval_act_mat(actmat): - actmat = actmat.reshape((actmat.shape[0], actmat.shape[1]//4, 4)) + actmat = actmat.reshape((actmat.shape[0], actmat.shape[1] // 4, 4)) r = np.all(actmat, axis=2) return np.count_nonzero(r) / r.shape[0] / r.shape[1] def eval_perm_impl(actmat, perm=None): - actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1]//2)) + actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1] // 2)) actmat_eval = eval_act_mat(actmat) - print(f'Combined zeros in base matrix: {actmat_eval*100:0.6f}') + print(f"Combined zeros in base matrix: {actmat_eval*100:0.6f}") if perm is not None: perm_act_mat = actmat[:, perm] perm_act_mat_eval = eval_act_mat(perm_act_mat) - print(f'Combined zeros in perm matrix: {perm_act_mat_eval*100:0.6f}') + print(f"Combined zeros in perm matrix: {perm_act_mat_eval*100:0.6f}") def command_eval_perm(args): - with open(args.data, 'rb') as file: + with open(args.data, "rb") as file: actmat = np.load(file) if args.perm is not None: - with open(args.perm, 'rb') as file: + with open(args.perm, "rb") as file: perm = np.load(file) else: perm = None eval_perm_impl(actmat, perm) + def command_find_perm(args): - with open(args.data, 'rb') as file: + with open(args.data, "rb") as file: actmat = np.load(file) perm = find_perm_impl(actmat) # perm = np.random.permutation([i for i in range(M.L1)]) - with open(args.out, 'wb') as file: + with open(args.out, "wb") as file: np.save(file, perm) def ft_optimize(model, dataset_path, count, actmat_save_path=None, perm_save_path=None): - print('Gathering activation data...') + print("Gathering activation data...") actmat = gather_impl(model, dataset_path, count) if actmat_save_path is not None: - with open(actmat_save_path, 'wb') as file: + with open(actmat_save_path, "wb") as file: np.save(file, actmat) - print('Finding permutation...') + print("Finding permutation...") perm = find_perm_impl(actmat) if actmat_save_path is not None: - with open(perm_save_path, 'wb') as file: + with open(perm_save_path, "wb") as file: np.save(file, perm) - print('Evaluating permutation...') + print("Evaluating permutation...") eval_perm_impl(actmat, perm) - print('Applying permutation...') + print("Applying permutation...") ft_permute_impl(model, perm) @@ -534,27 +622,46 @@ def main(): parser = argparse.ArgumentParser(description="") subparsers = parser.add_subparsers() - parser_gather = subparsers.add_parser('gather', help='a help') + parser_gather = subparsers.add_parser("gather", help="a help") parser_gather.add_argument("--net", type=str, help="path to a .nnue net") - parser_gather.add_argument("--data", type=str, help="path to a .bin or .binpack dataset") - parser_gather.add_argument("--checkpoint", type=str, help="Optional checkpoint (used instead of nnue for local eval)") - parser_gather.add_argument("--count", type=int, default=1000, help="number of datapoints to process") - parser_gather.add_argument("--out", type=str, help="Filename under which to save the resulting ft matrix") + parser_gather.add_argument( + "--data", type=str, help="path to a .bin or .binpack dataset" + ) + parser_gather.add_argument( + "--checkpoint", + type=str, + help="Optional checkpoint (used instead of nnue for local eval)", + ) + parser_gather.add_argument( + "--count", type=int, default=1000, help="number of datapoints to process" + ) + parser_gather.add_argument( + "--out", type=str, help="Filename under which to save the resulting ft matrix" + ) features.add_argparse_args(parser_gather) parser_gather.set_defaults(func=command_gather) - parser_gather = subparsers.add_parser('find_perm', help='a help') - parser_gather.add_argument("--data", type=str, help="path to the previously gathered ft activation data") - parser_gather.add_argument("--out", type=str, help="path to where to save the permutation") + parser_gather = subparsers.add_parser("find_perm", help="a help") + parser_gather.add_argument( + "--data", type=str, help="path to the previously gathered ft activation data" + ) + parser_gather.add_argument( + "--out", type=str, help="path to where to save the permutation" + ) parser_gather.set_defaults(func=command_find_perm) - parser_gather = subparsers.add_parser('eval_perm', help='a help') - parser_gather.add_argument("--data", type=str, help="path to the previously gathered ft activation data") - parser_gather.add_argument("--perm", type=str, help="path to the previously generated perm file") + parser_gather = subparsers.add_parser("eval_perm", help="a help") + parser_gather.add_argument( + "--data", type=str, help="path to the previously gathered ft activation data" + ) + parser_gather.add_argument( + "--perm", type=str, help="path to the previously generated perm file" + ) parser_gather.set_defaults(func=command_eval_perm) args = parser.parse_args() args.func(args) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/halfka.py b/halfka.py index 7826e9c3..85ec8689 100644 --- a/halfka.py +++ b/halfka.py @@ -6,72 +6,89 @@ NUM_SQ = 64 NUM_PT = 12 -NUM_PLANES = (NUM_SQ * NUM_PT + 1) +NUM_PLANES = NUM_SQ * NUM_PT + 1 + def orient(is_white_pov: bool, sq: int): - return (56 * (not is_white_pov)) ^ sq + return (56 * (not is_white_pov)) ^ sq + def halfka_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece): - p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) - return 1 + orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES + p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) + return 1 + orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES + def halfka_psqts(): - # values copied from stockfish, in stockfish internal units - piece_values = { - chess.PAWN : 126, - chess.KNIGHT : 781, - chess.BISHOP : 825, - chess.ROOK : 1276, - chess.QUEEN : 2538 - } - - values = [0] * (NUM_PLANES * NUM_SQ) - - for ksq in range(64): - for s in range(64): - for pt, val in piece_values.items(): - idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE)) - idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK)) - values[idxw] = val - values[idxb] = -val - - return values + # values copied from stockfish, in stockfish internal units + piece_values = { + chess.PAWN: 126, + chess.KNIGHT: 781, + chess.BISHOP: 825, + chess.ROOK: 1276, + chess.QUEEN: 2538, + } + + values = [0] * (NUM_PLANES * NUM_SQ) + + for ksq in range(64): + for s in range(64): + for pt, val in piece_values.items(): + idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE)) + idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK)) + values[idxw] = val + values[idxb] = -val + + return values + class Features(FeatureBlock): - def __init__(self): - super(Features, self).__init__('HalfKA', 0x5f134cb8, OrderedDict([('HalfKA', NUM_PLANES * NUM_SQ)])) + def __init__(self): + super(Features, self).__init__( + "HalfKA", 0x5F134CB8, OrderedDict([("HalfKA", NUM_PLANES * NUM_SQ)]) + ) + + def get_active_features(self, board: chess.Board): + def piece_features(turn): + indices = torch.zeros(NUM_PLANES * NUM_SQ) + for sq, p in board.piece_map().items(): + indices[halfka_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0 + return indices + + return (piece_features(chess.WHITE), piece_features(chess.BLACK)) - def get_active_features(self, board: chess.Board): - def piece_features(turn): - indices = torch.zeros(NUM_PLANES * NUM_SQ) - for sq, p in board.piece_map().items(): - indices[halfka_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0 - return indices - return (piece_features(chess.WHITE), piece_features(chess.BLACK)) + def get_initial_psqt_features(self): + return halfka_psqts() - def get_initial_psqt_features(self): - return halfka_psqts() class FactorizedFeatures(FeatureBlock): - def __init__(self): - super(FactorizedFeatures, self).__init__('HalfKA^', 0x5f134cb8, OrderedDict([('HalfKA', NUM_PLANES * NUM_SQ), ('A', NUM_SQ * NUM_PT)])) + def __init__(self): + super(FactorizedFeatures, self).__init__( + "HalfKA^", + 0x5F134CB8, + OrderedDict([("HalfKA", NUM_PLANES * NUM_SQ), ("A", NUM_SQ * NUM_PT)]), + ) - def get_active_features(self, board: chess.Board): - raise Exception('Not supported yet, you must use the c++ data loader for factorizer support during training') + def get_active_features(self, board: chess.Board): + raise Exception( + "Not supported yet, you must use the c++ data loader for factorizer support during training" + ) - def get_feature_factors(self, idx): - if idx >= self.num_real_features: - raise Exception('Feature must be real') + def get_feature_factors(self, idx): + if idx >= self.num_real_features: + raise Exception("Feature must be real") - a_idx = idx % NUM_PLANES - 1 + a_idx = idx % NUM_PLANES - 1 - return [idx, self.get_factor_base_feature('A') + a_idx] + return [idx, self.get_factor_base_feature("A") + a_idx] - def get_initial_psqt_features(self): - return halfka_psqts() + [0] * (NUM_SQ * NUM_PT) + def get_initial_psqt_features(self): + return halfka_psqts() + [0] * (NUM_SQ * NUM_PT) -''' + +""" This is used by the features module for discovery of feature blocks. -''' +""" + + def get_feature_block_clss(): - return [Features, FactorizedFeatures] + return [Features, FactorizedFeatures] diff --git a/halfka_v2.py b/halfka_v2.py index 6b9096fc..65736c29 100644 --- a/halfka_v2.py +++ b/halfka_v2.py @@ -11,76 +11,97 @@ NUM_PLANES_VIRTUAL = NUM_SQ * NUM_PT_VIRTUAL NUM_INPUTS = NUM_PLANES_REAL * NUM_SQ + def orient(is_white_pov: bool, sq: int): - return (56 * (not is_white_pov)) ^ sq + return (56 * (not is_white_pov)) ^ sq + def halfka_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece): - p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) - if p_idx == 11: - p_idx -= 1 - return orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES_REAL + p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) + if p_idx == 11: + p_idx -= 1 + return orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES_REAL + def halfka_psqts(): - # values copied from stockfish, in stockfish internal units - piece_values = { - chess.PAWN : 126, - chess.KNIGHT : 781, - chess.BISHOP : 825, - chess.ROOK : 1276, - chess.QUEEN : 2538 - } - - values = [0] * (NUM_PLANES_REAL * NUM_SQ) - - for ksq in range(64): - for s in range(64): - for pt, val in piece_values.items(): - idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE)) - idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK)) - values[idxw] = val - values[idxb] = -val - - return values + # values copied from stockfish, in stockfish internal units + piece_values = { + chess.PAWN: 126, + chess.KNIGHT: 781, + chess.BISHOP: 825, + chess.ROOK: 1276, + chess.QUEEN: 2538, + } + + values = [0] * (NUM_PLANES_REAL * NUM_SQ) + + for ksq in range(64): + for s in range(64): + for pt, val in piece_values.items(): + idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE)) + idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK)) + values[idxw] = val + values[idxb] = -val + + return values + class Features(FeatureBlock): - def __init__(self): - super(Features, self).__init__('HalfKAv2', 0x5f234cb8, OrderedDict([('HalfKAv2', NUM_PLANES_REAL * NUM_SQ)])) + def __init__(self): + super(Features, self).__init__( + "HalfKAv2", + 0x5F234CB8, + OrderedDict([("HalfKAv2", NUM_PLANES_REAL * NUM_SQ)]), + ) + + def get_active_features(self, board: chess.Board): + def piece_features(turn): + indices = torch.zeros(NUM_PLANES_REAL * NUM_SQ) + for sq, p in board.piece_map().items(): + indices[halfka_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0 + return indices + + return (piece_features(chess.WHITE), piece_features(chess.BLACK)) - def get_active_features(self, board: chess.Board): - def piece_features(turn): - indices = torch.zeros(NUM_PLANES_REAL * NUM_SQ) - for sq, p in board.piece_map().items(): - indices[halfka_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0 - return indices - return (piece_features(chess.WHITE), piece_features(chess.BLACK)) + def get_initial_psqt_features(self): + return halfka_psqts() - def get_initial_psqt_features(self): - return halfka_psqts() class FactorizedFeatures(FeatureBlock): - def __init__(self): - super(FactorizedFeatures, self).__init__('HalfKAv2^', 0x5f234cb8, OrderedDict([('HalfKAv2', NUM_PLANES_REAL * NUM_SQ), ('A', NUM_PLANES_VIRTUAL)])) + def __init__(self): + super(FactorizedFeatures, self).__init__( + "HalfKAv2^", + 0x5F234CB8, + OrderedDict( + [("HalfKAv2", NUM_PLANES_REAL * NUM_SQ), ("A", NUM_PLANES_VIRTUAL)] + ), + ) - def get_active_features(self, board: chess.Board): - raise Exception('Not supported yet, you must use the c++ data loader for factorizer support during training') + def get_active_features(self, board: chess.Board): + raise Exception( + "Not supported yet, you must use the c++ data loader for factorizer support during training" + ) - def get_feature_factors(self, idx): - if idx >= self.num_real_features: - raise Exception('Feature must be real') + def get_feature_factors(self, idx): + if idx >= self.num_real_features: + raise Exception("Feature must be real") - a_idx = idx % NUM_PLANES_REAL - k_idx = idx // NUM_PLANES_REAL + a_idx = idx % NUM_PLANES_REAL + k_idx = idx // NUM_PLANES_REAL - if a_idx // NUM_SQ == 10 and k_idx != a_idx % NUM_SQ: - a_idx += NUM_SQ + if a_idx // NUM_SQ == 10 and k_idx != a_idx % NUM_SQ: + a_idx += NUM_SQ - return [idx, self.get_factor_base_feature('A') + a_idx] + return [idx, self.get_factor_base_feature("A") + a_idx] - def get_initial_psqt_features(self): - return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL + def get_initial_psqt_features(self): + return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL -''' + +""" This is used by the features module for discovery of feature blocks. -''' +""" + + def get_feature_block_clss(): - return [Features, FactorizedFeatures] + return [Features, FactorizedFeatures] diff --git a/halfka_v2_hm.py b/halfka_v2_hm.py index 4ac6cbb2..916ddcc9 100644 --- a/halfka_v2_hm.py +++ b/halfka_v2_hm.py @@ -12,84 +12,162 @@ NUM_INPUTS = NUM_PLANES_REAL * NUM_SQ // 2 KingBuckets = [ - -1, -1, -1, -1, 31, 30, 29, 28, - -1, -1, -1, -1, 27, 26, 25, 24, - -1, -1, -1, -1, 23, 22, 21, 20, - -1, -1, -1, -1, 19, 18, 17, 16, - -1, -1, -1, -1, 15, 14, 13, 12, - -1, -1, -1, -1, 11, 10, 9, 8, - -1, -1, -1, -1, 7, 6, 5, 4, - -1, -1, -1, -1, 3, 2, 1, 0 + -1, + -1, + -1, + -1, + 31, + 30, + 29, + 28, + -1, + -1, + -1, + -1, + 27, + 26, + 25, + 24, + -1, + -1, + -1, + -1, + 23, + 22, + 21, + 20, + -1, + -1, + -1, + -1, + 19, + 18, + 17, + 16, + -1, + -1, + -1, + -1, + 15, + 14, + 13, + 12, + -1, + -1, + -1, + -1, + 11, + 10, + 9, + 8, + -1, + -1, + -1, + -1, + 7, + 6, + 5, + 4, + -1, + -1, + -1, + -1, + 3, + 2, + 1, + 0, ] + def orient(is_white_pov: bool, sq: int, ksq: int): - # ksq must not be oriented - kfile = (ksq % 8) - return (7 * (kfile < 4)) ^ (56 * (not is_white_pov)) ^ sq + # ksq must not be oriented + kfile = ksq % 8 + return (7 * (kfile < 4)) ^ (56 * (not is_white_pov)) ^ sq + def halfka_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece): - p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) - o_ksq = orient(is_white_pov, king_sq, king_sq) - if p_idx == 11: - p_idx -= 1 - return orient(is_white_pov, sq, king_sq) + p_idx * NUM_SQ + KingBuckets[o_ksq] * NUM_PLANES_REAL + p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) + o_ksq = orient(is_white_pov, king_sq, king_sq) + if p_idx == 11: + p_idx -= 1 + return ( + orient(is_white_pov, sq, king_sq) + + p_idx * NUM_SQ + + KingBuckets[o_ksq] * NUM_PLANES_REAL + ) + def halfka_psqts(): - # values copied from stockfish, in stockfish internal units - piece_values = { - chess.PAWN : 126, - chess.KNIGHT : 781, - chess.BISHOP : 825, - chess.ROOK : 1276, - chess.QUEEN : 2538 - } - - values = [0] * NUM_INPUTS - - for ksq in range(64): - for s in range(64): - for pt, val in piece_values.items(): - idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE)) - idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK)) - values[idxw] = val - values[idxb] = -val - - return values + # values copied from stockfish, in stockfish internal units + piece_values = { + chess.PAWN: 126, + chess.KNIGHT: 781, + chess.BISHOP: 825, + chess.ROOK: 1276, + chess.QUEEN: 2538, + } + + values = [0] * NUM_INPUTS + + for ksq in range(64): + for s in range(64): + for pt, val in piece_values.items(): + idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE)) + idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK)) + values[idxw] = val + values[idxb] = -val + + return values + class Features(FeatureBlock): - def __init__(self): - super(Features, self).__init__('HalfKAv2_hm', 0x7f234cb8, OrderedDict([('HalfKAv2_hm', NUM_INPUTS)])) + def __init__(self): + super(Features, self).__init__( + "HalfKAv2_hm", 0x7F234CB8, OrderedDict([("HalfKAv2_hm", NUM_INPUTS)]) + ) + + def get_active_features(self, board: chess.Board): + raise Exception( + "Not supported yet, you must use the c++ data loader for support during training" + ) - def get_active_features(self, board: chess.Board): - raise Exception('Not supported yet, you must use the c++ data loader for support during training') + def get_initial_psqt_features(self): + return halfka_psqts() - def get_initial_psqt_features(self): - return halfka_psqts() class FactorizedFeatures(FeatureBlock): - def __init__(self): - super(FactorizedFeatures, self).__init__('HalfKAv2_hm^', 0x7f234cb8, OrderedDict([('HalfKAv2_hm', NUM_INPUTS), ('A', NUM_PLANES_VIRTUAL)])) + def __init__(self): + super(FactorizedFeatures, self).__init__( + "HalfKAv2_hm^", + 0x7F234CB8, + OrderedDict([("HalfKAv2_hm", NUM_INPUTS), ("A", NUM_PLANES_VIRTUAL)]), + ) + + def get_active_features(self, board: chess.Board): + raise Exception( + "Not supported yet, you must use the c++ data loader for factorizer support during training" + ) - def get_active_features(self, board: chess.Board): - raise Exception('Not supported yet, you must use the c++ data loader for factorizer support during training') + def get_feature_factors(self, idx): + if idx >= self.num_real_features: + raise Exception("Feature must be real") - def get_feature_factors(self, idx): - if idx >= self.num_real_features: - raise Exception('Feature must be real') + a_idx = idx % NUM_PLANES_REAL + k_idx = idx // NUM_PLANES_REAL - a_idx = idx % NUM_PLANES_REAL - k_idx = idx // NUM_PLANES_REAL + if a_idx // NUM_SQ == 10 and k_idx != KingBuckets[a_idx % NUM_SQ]: + a_idx += NUM_SQ - if a_idx // NUM_SQ == 10 and k_idx != KingBuckets[a_idx % NUM_SQ]: - a_idx += NUM_SQ + return [idx, self.get_factor_base_feature("A") + a_idx] - return [idx, self.get_factor_base_feature('A') + a_idx] + def get_initial_psqt_features(self): + return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL - def get_initial_psqt_features(self): - return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL -''' +""" This is used by the features module for discovery of feature blocks. -''' +""" + + def get_feature_block_clss(): - return [Features, FactorizedFeatures] + return [Features, FactorizedFeatures] diff --git a/halfkp.py b/halfkp.py index bb23abab..1603ebe7 100644 --- a/halfkp.py +++ b/halfkp.py @@ -6,68 +6,90 @@ NUM_SQ = 64 NUM_PT = 10 -NUM_PLANES = (NUM_SQ * NUM_PT + 1) +NUM_PLANES = NUM_SQ * NUM_PT + 1 + def orient(is_white_pov: bool, sq: int): - return (63 * (not is_white_pov)) ^ sq + return (63 * (not is_white_pov)) ^ sq + def halfkp_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece): - p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) - return 1 + orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES + p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov) + return 1 + orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES + class Features(FeatureBlock): - def __init__(self): - super(Features, self).__init__('HalfKP', 0x5d69d5b8, OrderedDict([('HalfKP', NUM_PLANES * NUM_SQ)])) - - def get_active_features(self, board: chess.Board): - def piece_features(turn): - indices = torch.zeros(NUM_PLANES * NUM_SQ) - for sq, p in board.piece_map().items(): - if p.piece_type == chess.KING: - continue - indices[halfkp_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0 - return indices - return (piece_features(chess.WHITE), piece_features(chess.BLACK)) - - def get_initial_psqt_features(self): - raise Exception('Not supported yet. See HalfKA') + def __init__(self): + super(Features, self).__init__( + "HalfKP", 0x5D69D5B8, OrderedDict([("HalfKP", NUM_PLANES * NUM_SQ)]) + ) + + def get_active_features(self, board: chess.Board): + def piece_features(turn): + indices = torch.zeros(NUM_PLANES * NUM_SQ) + for sq, p in board.piece_map().items(): + if p.piece_type == chess.KING: + continue + indices[halfkp_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0 + return indices + + return (piece_features(chess.WHITE), piece_features(chess.BLACK)) + + def get_initial_psqt_features(self): + raise Exception("Not supported yet. See HalfKA") + class FactorizedFeatures(FeatureBlock): - def __init__(self): - super(FactorizedFeatures, self).__init__('HalfKP^', 0x5d69d5b8, OrderedDict([('HalfKP', NUM_PLANES * NUM_SQ), ('HalfK', NUM_SQ), ('P', NUM_SQ * 10 )])) - self.base = Features() - - def get_active_features(self, board: chess.Board): - white, black = self.base.get_active_features(board) - def piece_features(base, color): - indices = torch.zeros(NUM_SQ * 11) - piece_count = 0 - # P feature - for sq, p in board.piece_map().items(): - if p.piece_type == chess.KING: - continue - piece_count += 1 - p_idx = (p.piece_type - 1) * 2 + (p.color != color) - indices[(p_idx + 1) * NUM_SQ + orient(color, sq)] = 1.0 - # HalfK feature - indices[orient(color, board.king(color))] = piece_count - return torch.cat((base, indices)) - return (piece_features(white, chess.WHITE), piece_features(black, chess.BLACK)) - - def get_feature_factors(self, idx): - if idx >= self.num_real_features: - raise Exception('Feature must be real') - - k_idx = idx // NUM_PLANES - p_idx = idx % NUM_PLANES - 1 - - return [idx, self.get_factor_base_feature('HalfK') + k_idx, self.get_factor_base_feature('P') + p_idx] - - def get_initial_psqt_features(self): - raise Exception('Not supported yet. See HalfKA^') - -''' + def __init__(self): + super(FactorizedFeatures, self).__init__( + "HalfKP^", + 0x5D69D5B8, + OrderedDict( + [("HalfKP", NUM_PLANES * NUM_SQ), ("HalfK", NUM_SQ), ("P", NUM_SQ * 10)] + ), + ) + self.base = Features() + + def get_active_features(self, board: chess.Board): + white, black = self.base.get_active_features(board) + + def piece_features(base, color): + indices = torch.zeros(NUM_SQ * 11) + piece_count = 0 + # P feature + for sq, p in board.piece_map().items(): + if p.piece_type == chess.KING: + continue + piece_count += 1 + p_idx = (p.piece_type - 1) * 2 + (p.color != color) + indices[(p_idx + 1) * NUM_SQ + orient(color, sq)] = 1.0 + # HalfK feature + indices[orient(color, board.king(color))] = piece_count + return torch.cat((base, indices)) + + return (piece_features(white, chess.WHITE), piece_features(black, chess.BLACK)) + + def get_feature_factors(self, idx): + if idx >= self.num_real_features: + raise Exception("Feature must be real") + + k_idx = idx // NUM_PLANES + p_idx = idx % NUM_PLANES - 1 + + return [ + idx, + self.get_factor_base_feature("HalfK") + k_idx, + self.get_factor_base_feature("P") + p_idx, + ] + + def get_initial_psqt_features(self): + raise Exception("Not supported yet. See HalfKA^") + + +""" This is used by the features module for discovery of feature blocks. -''' +""" + + def get_feature_block_clss(): - return [Features, FactorizedFeatures] + return [Features, FactorizedFeatures] diff --git a/lib/nnue_training_data_formats.h b/lib/nnue_training_data_formats.h index abcead68..70ad9400 100644 --- a/lib/nnue_training_data_formats.h +++ b/lib/nnue_training_data_formats.h @@ -51,8127 +51,7295 @@ THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. #include "rng.h" #if (defined(_MSC_VER) || defined(__INTEL_COMPILER)) && !defined(__clang__) -#include + #include #endif -namespace chess -{ - #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) +namespace chess { +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) #define FORCEINLINE __attribute__((always_inline)) - #elif defined(_MSC_VER) +#elif defined(_MSC_VER) // NOTE: for some reason it breaks the profiler a little // keep it on only when not profiling. //#define FORCEINLINE __forceinline #define FORCEINLINE - #else +#else #define FORCEINLINE inline - #endif +#endif - #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) #define NOINLINE __attribute__((noinline)) - #elif defined(_MSC_VER) +#elif defined(_MSC_VER) #define NOINLINE __declspec(noinline) - #else +#else #define NOINLINE - #endif +#endif - namespace intrin +namespace intrin { +[[nodiscard]] constexpr int popcount_constexpr(std::uint64_t value) { + int r = 0; + while (value) { - [[nodiscard]] constexpr int popcount_constexpr(std::uint64_t value) - { - int r = 0; - while (value) - { - value &= value - 1; - ++r; - } - return r; - } + value &= value - 1; + ++r; + } + return r; +} - [[nodiscard]] constexpr int lsb_constexpr(std::uint64_t value) - { - int c = 0; - value &= ~value + 1; // leave only the lsb - if ((value & 0x00000000FFFFFFFFull) == 0) c += 32; - if ((value & 0x0000FFFF0000FFFFull) == 0) c += 16; - if ((value & 0x00FF00FF00FF00FFull) == 0) c += 8; - if ((value & 0x0F0F0F0F0F0F0F0Full) == 0) c += 4; - if ((value & 0x3333333333333333ull) == 0) c += 2; - if ((value & 0x5555555555555555ull) == 0) c += 1; - return c; - } +[[nodiscard]] constexpr int lsb_constexpr(std::uint64_t value) { + int c = 0; + value &= ~value + 1; // leave only the lsb + if ((value & 0x00000000FFFFFFFFull) == 0) + c += 32; + if ((value & 0x0000FFFF0000FFFFull) == 0) + c += 16; + if ((value & 0x00FF00FF00FF00FFull) == 0) + c += 8; + if ((value & 0x0F0F0F0F0F0F0F0Full) == 0) + c += 4; + if ((value & 0x3333333333333333ull) == 0) + c += 2; + if ((value & 0x5555555555555555ull) == 0) + c += 1; + return c; +} - [[nodiscard]] constexpr int msb_constexpr(std::uint64_t value) - { - int c = 63; - if ((value & 0xFFFFFFFF00000000ull) == 0) { c -= 32; value <<= 32; } - if ((value & 0xFFFF000000000000ull) == 0) { c -= 16; value <<= 16; } - if ((value & 0xFF00000000000000ull) == 0) { c -= 8; value <<= 8; } - if ((value & 0xF000000000000000ull) == 0) { c -= 4; value <<= 4; } - if ((value & 0xC000000000000000ull) == 0) { c -= 2; value <<= 2; } - if ((value & 0x8000000000000000ull) == 0) { c -= 1; } - return c; - } +[[nodiscard]] constexpr int msb_constexpr(std::uint64_t value) { + int c = 63; + if ((value & 0xFFFFFFFF00000000ull) == 0) + { + c -= 32; + value <<= 32; } - - namespace intrin + if ((value & 0xFFFF000000000000ull) == 0) { - [[nodiscard]] inline int popcount(std::uint64_t b) - { - #if (defined(_MSC_VER) || defined(__INTEL_COMPILER)) && !defined(__clang__) + c -= 16; + value <<= 16; + } + if ((value & 0xFF00000000000000ull) == 0) + { + c -= 8; + value <<= 8; + } + if ((value & 0xF000000000000000ull) == 0) + { + c -= 4; + value <<= 4; + } + if ((value & 0xC000000000000000ull) == 0) + { + c -= 2; + value <<= 2; + } + if ((value & 0x8000000000000000ull) == 0) + { + c -= 1; + } + return c; +} +} - return static_cast(_mm_popcnt_u64(b)); +namespace intrin { +[[nodiscard]] inline int popcount(std::uint64_t b) { +#if (defined(_MSC_VER) || defined(__INTEL_COMPILER)) && !defined(__clang__) - #else + return static_cast(_mm_popcnt_u64(b)); - return static_cast(__builtin_popcountll(b)); +#else - #endif - } + return static_cast(__builtin_popcountll(b)); - #if defined(_MSC_VER) && !defined(__clang__) +#endif +} - [[nodiscard]] inline int lsb(std::uint64_t value) - { - assert(value != 0); +#if defined(_MSC_VER) && !defined(__clang__) - unsigned long idx; - _BitScanForward64(&idx, value); - return static_cast(idx); - } +[[nodiscard]] inline int lsb(std::uint64_t value) { + assert(value != 0); - [[nodiscard]] inline int msb(std::uint64_t value) - { - assert(value != 0); + unsigned long idx; + _BitScanForward64(&idx, value); + return static_cast(idx); +} - unsigned long idx; - _BitScanReverse64(&idx, value); - return static_cast(idx); - } +[[nodiscard]] inline int msb(std::uint64_t value) { + assert(value != 0); - #else + unsigned long idx; + _BitScanReverse64(&idx, value); + return static_cast(idx); +} - [[nodiscard]] inline int lsb(std::uint64_t value) - { - assert(value != 0); +#else - return __builtin_ctzll(value); - } +[[nodiscard]] inline int lsb(std::uint64_t value) { + assert(value != 0); - [[nodiscard]] inline int msb(std::uint64_t value) - { - assert(value != 0); + return __builtin_ctzll(value); +} - return 63 ^ __builtin_clzll(value); - } +[[nodiscard]] inline int msb(std::uint64_t value) { + assert(value != 0); - #endif - } + return 63 ^ __builtin_clzll(value); +} - template - [[nodiscard]] constexpr IntT floorLog2(IntT value) - { - return intrin::msb_constexpr(value); - } +#endif +} - template - constexpr auto computeMasks() - { - static_assert(std::is_unsigned_v); +template +[[nodiscard]] constexpr IntT floorLog2(IntT value) { + return intrin::msb_constexpr(value); +} - constexpr std::size_t numBits = sizeof(IntT) * CHAR_BIT; - std::array nbitmasks{}; +template +constexpr auto computeMasks() { + static_assert(std::is_unsigned_v); - for (std::size_t i = 0; i < numBits; ++i) - { - nbitmasks[i] = (static_cast(1u) << i) - 1u; - } - nbitmasks[numBits] = ~static_cast(0u); + constexpr std::size_t numBits = sizeof(IntT) * CHAR_BIT; + std::array nbitmasks{}; - return nbitmasks; + for (std::size_t i = 0; i < numBits; ++i) + { + nbitmasks[i] = (static_cast(1u) << i) - 1u; } + nbitmasks[numBits] = ~static_cast(0u); - template - constexpr auto nbitmask = computeMasks(); + return nbitmasks; +} - template > - inline ToT signExtend(FromT value) - { - static_assert(std::is_signed_v); - static_assert(std::is_unsigned_v); - static_assert(sizeof(ToT) == sizeof(FromT)); +template +constexpr auto nbitmask = computeMasks(); - constexpr std::size_t totalBits = sizeof(FromT) * CHAR_BIT; +template> +inline ToT signExtend(FromT value) { + static_assert(std::is_signed_v); + static_assert(std::is_unsigned_v); + static_assert(sizeof(ToT) == sizeof(FromT)); - static_assert(N > 0 && N <= totalBits); + constexpr std::size_t totalBits = sizeof(FromT) * CHAR_BIT; - constexpr std::size_t unusedBits = totalBits - N; - if constexpr (ToT(~FromT(0)) >> 1 == ToT(~FromT(0))) - { - return ToT(value << unusedBits) >> ToT(unusedBits); - } - else + static_assert(N > 0 && N <= totalBits); + + constexpr std::size_t unusedBits = totalBits - N; + if constexpr (ToT(~FromT(0)) >> 1 == ToT(~FromT(0))) + { + return ToT(value << unusedBits) >> ToT(unusedBits); + } + else + { + constexpr FromT mask = (~FromT(0)) >> unusedBits; + value &= mask; + if (value & (FromT(1) << (N - 1))) { - constexpr FromT mask = (~FromT(0)) >> unusedBits; - value &= mask; - if (value & (FromT(1) << (N - 1))) - { - value |= ~mask; - } - return static_cast(value); + value |= ~mask; } + return static_cast(value); } +} + +namespace lookup { +constexpr int nthSetBitIndexNaive(std::uint64_t value, int n) { + for (int i = 0; i < n; ++i) + { + value &= value - 1; + } + return intrin::lsb_constexpr(value); +} + +constexpr std::array, 256> nthSetBitIndex = []() { + std::array, 256> t{}; - namespace lookup + for (int i = 0; i < 256; ++i) { - constexpr int nthSetBitIndexNaive(std::uint64_t value, int n) + for (int j = 0; j < 8; ++j) { - for (int i = 0; i < n; ++i) - { - value &= value - 1; - } - return intrin::lsb_constexpr(value); + t[i][j] = nthSetBitIndexNaive(i, j); } + } - constexpr std::array, 256> nthSetBitIndex = []() - { - std::array, 256> t{}; + return t; +}(); +} - for (int i = 0; i < 256; ++i) - { - for (int j = 0; j < 8; ++j) - { - t[i][j] = nthSetBitIndexNaive(i, j); - } - } +inline int nthSetBitIndex(std::uint64_t v, std::uint64_t n) { + std::uint64_t shift = 0; - return t; - }(); - } + std::uint64_t p = intrin::popcount(v & 0xFFFFFFFFull); + std::uint64_t pmask = static_cast(p > n) - 1ull; + v >>= 32 & pmask; + shift += 32 & pmask; + n -= p & pmask; - inline int nthSetBitIndex(std::uint64_t v, std::uint64_t n) - { - std::uint64_t shift = 0; + p = intrin::popcount(v & 0xFFFFull); + pmask = static_cast(p > n) - 1ull; + v >>= 16 & pmask; + shift += 16 & pmask; + n -= p & pmask; - std::uint64_t p = intrin::popcount(v & 0xFFFFFFFFull); - std::uint64_t pmask = static_cast(p > n) - 1ull; - v >>= 32 & pmask; - shift += 32 & pmask; - n -= p & pmask; + p = intrin::popcount(v & 0xFFull); + pmask = static_cast(p > n) - 1ull; + shift += 8 & pmask; + v >>= 8 & pmask; + n -= p & pmask; - p = intrin::popcount(v & 0xFFFFull); - pmask = static_cast(p > n) - 1ull; - v >>= 16 & pmask; - shift += 16 & pmask; - n -= p & pmask; + return static_cast(lookup::nthSetBitIndex[v & 0xFFull][n] + shift); +} - p = intrin::popcount(v & 0xFFull); - pmask = static_cast(p > n) - 1ull; - shift += 8 & pmask; - v >>= 8 & pmask; - n -= p & pmask; +namespace util { +inline std::size_t usedBits(std::size_t value) { + if (value == 0) + return 0; + return intrin::msb(value) + 1; +} +} - return static_cast(lookup::nthSetBitIndex[v & 0xFFull][n] + shift); - } +template +struct EnumTraits; - namespace util - { - inline std::size_t usedBits(std::size_t value) - { - if (value == 0) return 0; - return intrin::msb(value) + 1; - } - } +template +[[nodiscard]] constexpr auto hasEnumTraits() -> decltype(EnumTraits::cardinaliy, bool{}) { + return true; +} - template - struct EnumTraits; +template +[[nodiscard]] constexpr bool hasEnumTraits(...) { + return false; +} - template - [[nodiscard]] constexpr auto hasEnumTraits() -> decltype(EnumTraits::cardinaliy, bool{}) - { - return true; - } +template +[[nodiscard]] constexpr bool isNaturalIndex() noexcept { + return EnumTraits::isNaturalIndex; +} - template - [[nodiscard]] constexpr bool hasEnumTraits(...) - { - return false; - } +template +[[nodiscard]] constexpr int cardinality() noexcept { + return EnumTraits::cardinality; +} - template - [[nodiscard]] constexpr bool isNaturalIndex() noexcept - { - return EnumTraits::isNaturalIndex; - } +template +[[nodiscard]] constexpr const std::array()>& values() noexcept { + return EnumTraits::values; +} - template - [[nodiscard]] constexpr int cardinality() noexcept - { - return EnumTraits::cardinality; - } +template +[[nodiscard]] constexpr EnumT fromOrdinal(int id) noexcept { + assert(!EnumTraits::isNaturalIndex || (id >= 0 && id < EnumTraits::cardinality)); - template - [[nodiscard]] constexpr const std::array()>& values() noexcept - { - return EnumTraits::values; - } + return EnumTraits::fromOrdinal(id); +} - template - [[nodiscard]] constexpr EnumT fromOrdinal(int id) noexcept - { - assert(!EnumTraits::isNaturalIndex || (id >= 0 && id < EnumTraits::cardinality)); +template +[[nodiscard]] constexpr typename EnumTraits::IdType ordinal(EnumT v) noexcept { + return EnumTraits::ordinal(v); +} - return EnumTraits::fromOrdinal(id); - } +template()>> +[[nodiscard]] constexpr decltype(auto) toString(EnumT v, ArgsTs&&... args) { + return EnumTraits::toString(v, std::forward(args)...); +} - template - [[nodiscard]] constexpr typename EnumTraits::IdType ordinal(EnumT v) noexcept - { - return EnumTraits::ordinal(v); - } +template +[[nodiscard]] constexpr decltype(auto) toString(EnumT v) { + return EnumTraits::toString(v); +} - template ()>> - [[nodiscard]] constexpr decltype(auto) toString(EnumT v, ArgsTs&&... args) - { - return EnumTraits::toString(v, std::forward(args)...); - } +template()>> +[[nodiscard]] constexpr decltype(auto) toString(FormatT&& f, EnumT v) { + return EnumTraits::toString(std::forward(f), v); +} - template - [[nodiscard]] constexpr decltype(auto) toString(EnumT v) - { - return EnumTraits::toString(v); - } +template +[[nodiscard]] constexpr decltype(auto) toChar(EnumT v) { + return EnumTraits::toChar(v); +} - template ()>> - [[nodiscard]] constexpr decltype(auto) toString(FormatT&& f, EnumT v) - { - return EnumTraits::toString(std::forward(f), v); - } +template +[[nodiscard]] constexpr decltype(auto) toChar(FormatT&& f, EnumT v) { + return EnumTraits::toChar(std::forward(f), v); +} - template - [[nodiscard]] constexpr decltype(auto) toChar(EnumT v) - { - return EnumTraits::toChar(v); - } +template +[[nodiscard]] constexpr decltype(auto) fromString(ArgsTs&&... args) { + return EnumTraits::fromString(std::forward(args)...); +} - template - [[nodiscard]] constexpr decltype(auto) toChar(FormatT&& f, EnumT v) - { - return EnumTraits::toChar(std::forward(f), v); - } +template +[[nodiscard]] constexpr decltype(auto) fromChar(ArgsTs&&... args) { + return EnumTraits::fromChar(std::forward(args)...); +} - template - [[nodiscard]] constexpr decltype(auto) fromString(ArgsTs&& ... args) - { - return EnumTraits::fromString(std::forward(args)...); +template<> +struct EnumTraits { + using IdType = int; + using EnumType = bool; + + static constexpr int cardinality = 2; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{false, true}; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); } - template - [[nodiscard]] constexpr decltype(auto) fromChar(ArgsTs&& ... args) - { - return EnumTraits::fromChar(std::forward(args)...); + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + return static_cast(id); } +}; - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = bool; +template()> +struct EnumArray { + static_assert(isNaturalIndex(), "Enum must start with 0 and end with cardinality-1."); - static constexpr int cardinality = 2; - static constexpr bool isNaturalIndex = true; + using value_type = ValueT; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = ValueT*; + using const_pointer = const ValueT*; + using reference = ValueT&; + using const_reference = const ValueT&; - static constexpr std::array values{ - false, - true - }; + using iterator = pointer; + using const_iterator = const_pointer; - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + using KeyType = EnumT; + using ValueType = ValueT; + + constexpr void fill(const ValueType& init) { + for (auto& v : elements) { - return static_cast(id); + v = init; } - }; + } - template ()> - struct EnumArray - { - static_assert(isNaturalIndex(), "Enum must start with 0 and end with cardinality-1."); + [[nodiscard]] constexpr ValueType& operator[](const KeyType& dir) { + assert(static_cast(ordinal(dir)) < static_cast(SizeV)); - using value_type = ValueT; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; - using pointer = ValueT *; - using const_pointer = const ValueT*; - using reference = ValueT &; - using const_reference = const ValueT &; + return elements[ordinal(dir)]; + } - using iterator = pointer; - using const_iterator = const_pointer; + [[nodiscard]] constexpr const ValueType& operator[](const KeyType& dir) const { + assert(static_cast(ordinal(dir)) < static_cast(SizeV)); - using reverse_iterator = std::reverse_iterator; - using const_reverse_iterator = std::reverse_iterator; + return elements[ordinal(dir)]; + } - using KeyType = EnumT; - using ValueType = ValueT; + [[nodiscard]] constexpr ValueType& front() { return elements[0]; } - constexpr void fill(const ValueType& init) - { - for (auto& v : elements) - { - v = init; - } - } + [[nodiscard]] constexpr const ValueType& front() const { return elements[0]; } - [[nodiscard]] constexpr ValueType& operator[](const KeyType& dir) - { - assert(static_cast(ordinal(dir)) < static_cast(SizeV)); + [[nodiscard]] constexpr ValueType& back() { return elements[SizeV - 1]; } - return elements[ordinal(dir)]; - } + [[nodiscard]] constexpr const ValueType& back() const { return elements[SizeV - 1]; } - [[nodiscard]] constexpr const ValueType& operator[](const KeyType& dir) const - { - assert(static_cast(ordinal(dir)) < static_cast(SizeV)); + [[nodiscard]] constexpr pointer data() { return elements; } - return elements[ordinal(dir)]; - } + [[nodiscard]] constexpr const_pointer data() const { return elements; } - [[nodiscard]] constexpr ValueType& front() - { - return elements[0]; - } + [[nodiscard]] constexpr iterator begin() noexcept { return elements; } - [[nodiscard]] constexpr const ValueType& front() const - { - return elements[0]; - } + [[nodiscard]] constexpr const_iterator begin() const noexcept { return elements; } - [[nodiscard]] constexpr ValueType& back() - { - return elements[SizeV - 1]; - } + [[nodiscard]] constexpr iterator end() noexcept { return elements + SizeV; } - [[nodiscard]] constexpr const ValueType& back() const - { - return elements[SizeV - 1]; - } + [[nodiscard]] constexpr const_iterator end() const noexcept { return elements + SizeV; } - [[nodiscard]] constexpr pointer data() - { - return elements; - } + [[nodiscard]] constexpr reverse_iterator rbegin() noexcept { return reverse_iterator(end()); } - [[nodiscard]] constexpr const_pointer data() const - { - return elements; - } + [[nodiscard]] constexpr const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator(end()); + } - [[nodiscard]] constexpr iterator begin() noexcept - { - return elements; - } + [[nodiscard]] constexpr reverse_iterator rend() noexcept { return reverse_iterator(begin()); } - [[nodiscard]] constexpr const_iterator begin() const noexcept - { - return elements; - } + [[nodiscard]] constexpr const_reverse_iterator rend() const noexcept { + return const_reverse_iterator(begin()); + } - [[nodiscard]] constexpr iterator end() noexcept - { - return elements + SizeV; - } + [[nodiscard]] constexpr const_iterator cbegin() const noexcept { return begin(); } - [[nodiscard]] constexpr const_iterator end() const noexcept - { - return elements + SizeV; - } + [[nodiscard]] constexpr const_iterator cend() const noexcept { return end(); } - [[nodiscard]] constexpr reverse_iterator rbegin() noexcept - { - return reverse_iterator(end()); - } + [[nodiscard]] constexpr const_reverse_iterator crbegin() const noexcept { return rbegin(); } - [[nodiscard]] constexpr const_reverse_iterator rbegin() const noexcept - { - return const_reverse_iterator(end()); - } + [[nodiscard]] constexpr const_reverse_iterator crend() const noexcept { return rend(); } - [[nodiscard]] constexpr reverse_iterator rend() noexcept - { - return reverse_iterator(begin()); - } + [[nodiscard]] constexpr size_type size() const noexcept { return SizeV; } - [[nodiscard]] constexpr const_reverse_iterator rend() const noexcept - { - return const_reverse_iterator(begin()); - } + ValueT elements[SizeV]; +}; - [[nodiscard]] constexpr const_iterator cbegin() const noexcept - { - return begin(); - } +template(), + std::size_t Size2V = cardinality()> +using EnumArray2 = EnumArray, Size1V>; - [[nodiscard]] constexpr const_iterator cend() const noexcept - { - return end(); - } +enum struct Color : std::uint8_t { + White, + Black +}; - [[nodiscard]] constexpr const_reverse_iterator crbegin() const noexcept - { - return rbegin(); - } +template<> +struct EnumTraits { + using IdType = int; + using EnumType = Color; - [[nodiscard]] constexpr const_reverse_iterator crend() const noexcept - { - return rend(); - } + static constexpr int cardinality = 2; + static constexpr bool isNaturalIndex = true; - [[nodiscard]] constexpr size_type size() const noexcept - { - return SizeV; - } + static constexpr std::array values{Color::White, Color::Black}; - ValueT elements[SizeV]; - }; + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); + } - template (), std::size_t Size2V = cardinality()> - using EnumArray2 = EnumArray, Size1V>; + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality); - enum struct Color : std::uint8_t - { - White, - Black - }; + return static_cast(id); + } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = Color; + [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept { + return std::string_view("wb").substr(ordinal(c), 1); + } - static constexpr int cardinality = 2; - static constexpr bool isNaturalIndex = true; + [[nodiscard]] static constexpr char toChar(EnumType c) noexcept { return "wb"[ordinal(c)]; } - static constexpr std::array values{ - Color::White, - Color::Black - }; + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept { + if (c == 'w') + return Color::White; + if (c == 'b') + return Color::Black; - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + return {}; + } - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept - { - assert(id >= 0 && id < cardinality); + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept { + if (sv.size() != 1) + return {}; - return static_cast(id); - } + return fromChar(sv[0]); + } +}; - [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept - { - return std::string_view("wb").substr(ordinal(c), 1); - } +constexpr Color operator!(Color c) { return fromOrdinal(ordinal(c) ^ 1); } - [[nodiscard]] static constexpr char toChar(EnumType c) noexcept - { - return "wb"[ordinal(c)]; - } +enum struct PieceType : std::uint8_t { + Pawn, + Knight, + Bishop, + Rook, + Queen, + King, - [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept - { - if (c == 'w') return Color::White; - if (c == 'b') return Color::Black; + None +}; - return {}; - } +template<> +struct EnumTraits { + using IdType = int; + using EnumType = PieceType; - [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept - { - if (sv.size() != 1) return {}; + static constexpr int cardinality = 7; + static constexpr bool isNaturalIndex = true; - return fromChar(sv[0]); - } - }; + static constexpr std::array values{ + PieceType::Pawn, PieceType::Knight, PieceType::Bishop, PieceType::Rook, + PieceType::Queen, PieceType::King, PieceType::None}; - constexpr Color operator!(Color c) - { - return fromOrdinal(ordinal(c) ^ 1); + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); } - enum struct PieceType : std::uint8_t - { - Pawn, - Knight, - Bishop, - Rook, - Queen, - King, + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality); - None - }; + return static_cast(id); + } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = PieceType; + [[nodiscard]] static constexpr std::string_view toString(EnumType p, Color c) noexcept { + return std::string_view("PpNnBbRrQqKk ") + .substr((chess::ordinal(p) * 2 + chess::ordinal(c)), 1); + } - static constexpr int cardinality = 7; - static constexpr bool isNaturalIndex = true; + [[nodiscard]] static constexpr char toChar(EnumType p, Color c) noexcept { + return "PpNnBbRrQqKk "[chess::ordinal(p) * 2 + chess::ordinal(c)]; + } - static constexpr std::array values{ - PieceType::Pawn, - PieceType::Knight, - PieceType::Bishop, - PieceType::Rook, - PieceType::Queen, - PieceType::King, - PieceType::None - }; + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept { + auto it = std::string_view("PpNnBbRrQqKk ").find(c); + if (it == std::string::npos) + return {}; + else + return static_cast(it / 2); + } - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + [[nodiscard]] static constexpr std::optional + fromString(std::string_view sv) noexcept { + if (sv.size() != 1) + return {}; - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept - { - assert(id >= 0 && id < cardinality); + return fromChar(sv[0]); + } +}; - return static_cast(id); - } +struct Piece { + [[nodiscard]] static constexpr Piece fromId(int id) { return Piece(id); } - [[nodiscard]] static constexpr std::string_view toString(EnumType p, Color c) noexcept - { - return std::string_view("PpNnBbRrQqKk ").substr((chess::ordinal(p) * 2 + chess::ordinal(c)), 1); - } + [[nodiscard]] static constexpr Piece none() { return Piece(PieceType::None, Color::White); } - [[nodiscard]] static constexpr char toChar(EnumType p, Color c) noexcept - { - return "PpNnBbRrQqKk "[chess::ordinal(p) * 2 + chess::ordinal(c)]; - } + constexpr Piece() noexcept : + Piece(PieceType::None, Color::White) {} - [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept - { - auto it = std::string_view("PpNnBbRrQqKk ").find(c); - if (it == std::string::npos) return {}; - else return static_cast(it/2); - } + constexpr Piece(PieceType type, Color color) noexcept : + m_id((ordinal(type) << 1) | ordinal(color)) { + assert(type != PieceType::None || color == Color::White); + } - [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept - { - if (sv.size() != 1) return {}; + constexpr Piece& operator=(const Piece& other) = default; - return fromChar(sv[0]); - } - }; + [[nodiscard]] constexpr friend bool operator==(Piece lhs, Piece rhs) noexcept { + return lhs.m_id == rhs.m_id; + } - struct Piece - { - [[nodiscard]] static constexpr Piece fromId(int id) - { - return Piece(id); - } + [[nodiscard]] constexpr friend bool operator!=(Piece lhs, Piece rhs) noexcept { + return !(lhs == rhs); + } - [[nodiscard]] static constexpr Piece none() - { - return Piece(PieceType::None, Color::White); - } + [[nodiscard]] constexpr PieceType type() const { return fromOrdinal(m_id >> 1); } - constexpr Piece() noexcept : - Piece(PieceType::None, Color::White) - { + [[nodiscard]] constexpr Color color() const { return fromOrdinal(m_id & 1); } - } + [[nodiscard]] constexpr std::pair parts() const { + return std::make_pair(type(), color()); + } - constexpr Piece(PieceType type, Color color) noexcept : - m_id((ordinal(type) << 1) | ordinal(color)) - { - assert(type != PieceType::None || color == Color::White); - } + [[nodiscard]] constexpr explicit operator int() const { return static_cast(m_id); } - constexpr Piece& operator=(const Piece& other) = default; + private: + constexpr Piece(int id) : + m_id(id) {} - [[nodiscard]] constexpr friend bool operator==(Piece lhs, Piece rhs) noexcept - { - return lhs.m_id == rhs.m_id; - } + std::uint8_t m_id; // lowest bit is a color, 7 highest bits are a piece type +}; - [[nodiscard]] constexpr friend bool operator!=(Piece lhs, Piece rhs) noexcept - { - return !(lhs == rhs); - } +[[nodiscard]] constexpr Piece operator|(PieceType type, Color color) noexcept { + return Piece(type, color); +} - [[nodiscard]] constexpr PieceType type() const - { - return fromOrdinal(m_id >> 1); - } +[[nodiscard]] constexpr Piece operator|(Color color, PieceType type) noexcept { + return Piece(type, color); +} - [[nodiscard]] constexpr Color color() const - { - return fromOrdinal(m_id & 1); - } +constexpr Piece whitePawn = Piece(PieceType::Pawn, Color::White); +constexpr Piece whiteKnight = Piece(PieceType::Knight, Color::White); +constexpr Piece whiteBishop = Piece(PieceType::Bishop, Color::White); +constexpr Piece whiteRook = Piece(PieceType::Rook, Color::White); +constexpr Piece whiteQueen = Piece(PieceType::Queen, Color::White); +constexpr Piece whiteKing = Piece(PieceType::King, Color::White); - [[nodiscard]] constexpr std::pair parts() const - { - return std::make_pair(type(), color()); - } +constexpr Piece blackPawn = Piece(PieceType::Pawn, Color::Black); +constexpr Piece blackKnight = Piece(PieceType::Knight, Color::Black); +constexpr Piece blackBishop = Piece(PieceType::Bishop, Color::Black); +constexpr Piece blackRook = Piece(PieceType::Rook, Color::Black); +constexpr Piece blackQueen = Piece(PieceType::Queen, Color::Black); +constexpr Piece blackKing = Piece(PieceType::King, Color::Black); - [[nodiscard]] constexpr explicit operator int() const - { - return static_cast(m_id); - } +static_assert(Piece::none().type() == PieceType::None); - private: - constexpr Piece(int id) : - m_id(id) - { - } +template<> +struct EnumTraits { + using IdType = int; + using EnumType = Piece; - std::uint8_t m_id; // lowest bit is a color, 7 highest bits are a piece type - }; + static constexpr int cardinality = 13; + static constexpr bool isNaturalIndex = true; - [[nodiscard]] constexpr Piece operator|(PieceType type, Color color) noexcept - { - return Piece(type, color); - } + static constexpr std::array values{ + whitePawn, blackPawn, whiteKnight, blackKnight, whiteBishop, blackBishop, whiteRook, + blackRook, whiteQueen, blackQueen, whiteKing, blackKing, Piece::none()}; - [[nodiscard]] constexpr Piece operator|(Color color, PieceType type) noexcept - { - return Piece(type, color); + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); } - constexpr Piece whitePawn = Piece(PieceType::Pawn, Color::White); - constexpr Piece whiteKnight = Piece(PieceType::Knight, Color::White); - constexpr Piece whiteBishop = Piece(PieceType::Bishop, Color::White); - constexpr Piece whiteRook = Piece(PieceType::Rook, Color::White); - constexpr Piece whiteQueen = Piece(PieceType::Queen, Color::White); - constexpr Piece whiteKing = Piece(PieceType::King, Color::White); + [[nodiscard]] static constexpr EnumType fromOrdinal(int id) noexcept { + assert(id >= 0 && id < cardinality); - constexpr Piece blackPawn = Piece(PieceType::Pawn, Color::Black); - constexpr Piece blackKnight = Piece(PieceType::Knight, Color::Black); - constexpr Piece blackBishop = Piece(PieceType::Bishop, Color::Black); - constexpr Piece blackRook = Piece(PieceType::Rook, Color::Black); - constexpr Piece blackQueen = Piece(PieceType::Queen, Color::Black); - constexpr Piece blackKing = Piece(PieceType::King, Color::Black); + return Piece::fromId(id); + } - static_assert(Piece::none().type() == PieceType::None); + [[nodiscard]] static constexpr std::string_view toString(EnumType p) noexcept { + return std::string_view("PpNnBbRrQqKk ").substr(ordinal(p), 1); + } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = Piece; + [[nodiscard]] static constexpr char toChar(EnumType p) noexcept { + return "PpNnBbRrQqKk "[ordinal(p)]; + } - static constexpr int cardinality = 13; - static constexpr bool isNaturalIndex = true; + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept { + auto it = std::string_view("PpNnBbRrQqKk ").find(c); + if (it == std::string::npos) + return {}; + else + return Piece::fromId(static_cast(it)); + } - static constexpr std::array values{ - whitePawn, - blackPawn, - whiteKnight, - blackKnight, - whiteBishop, - blackBishop, - whiteRook, - blackRook, - whiteQueen, - blackQueen, - whiteKing, - blackKing, - Piece::none() - }; + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept { + if (sv.size() != 1) + return {}; - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + return fromChar(sv[0]); + } +}; - [[nodiscard]] static constexpr EnumType fromOrdinal(int id) noexcept - { - assert(id >= 0 && id < cardinality); +template +struct Coord { + constexpr Coord() noexcept : + m_i(0) {} - return Piece::fromId(id); - } + constexpr explicit Coord(int i) noexcept : + m_i(i) {} - [[nodiscard]] static constexpr std::string_view toString(EnumType p) noexcept - { - return std::string_view("PpNnBbRrQqKk ").substr(ordinal(p), 1); - } + [[nodiscard]] constexpr explicit operator int() const { return static_cast(m_i); } - [[nodiscard]] static constexpr char toChar(EnumType p) noexcept - { - return "PpNnBbRrQqKk "[ordinal(p)]; - } + constexpr friend Coord& operator++(Coord& c) { + ++c.m_i; + return c; + } - [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept - { - auto it = std::string_view("PpNnBbRrQqKk ").find(c); - if (it == std::string::npos) return {}; - else return Piece::fromId(static_cast(it)); - } + constexpr friend Coord& operator--(Coord& c) { + --c.m_i; + return c; + } - [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept - { - if (sv.size() != 1) return {}; + constexpr friend Coord& operator+=(Coord& c, int d) { + c.m_i += d; + return c; + } - return fromChar(sv[0]); - } - }; + constexpr friend Coord& operator-=(Coord& c, int d) { + c.m_i -= d; + return c; + } - template - struct Coord - { - constexpr Coord() noexcept : - m_i(0) - { - } + constexpr friend Coord operator+(const Coord& c, int d) { + Coord cpy(c); + cpy += d; + return cpy; + } - constexpr explicit Coord(int i) noexcept : - m_i(i) - { - } + constexpr friend Coord operator-(const Coord& c, int d) { + Coord cpy(c); + cpy -= d; + return cpy; + } - [[nodiscard]] constexpr explicit operator int() const - { - return static_cast(m_i); - } + constexpr friend int operator-(const Coord& c1, const Coord& c2) { return c1.m_i - c2.m_i; } - constexpr friend Coord& operator++(Coord& c) - { - ++c.m_i; - return c; - } + [[nodiscard]] constexpr friend bool operator==(const Coord& c1, const Coord& c2) noexcept { + return c1.m_i == c2.m_i; + } - constexpr friend Coord& operator--(Coord& c) - { - --c.m_i; - return c; - } + [[nodiscard]] constexpr friend bool operator!=(const Coord& c1, const Coord& c2) noexcept { + return c1.m_i != c2.m_i; + } - constexpr friend Coord& operator+=(Coord& c, int d) - { - c.m_i += d; - return c; - } + [[nodiscard]] constexpr friend bool operator<(const Coord& c1, const Coord& c2) noexcept { + return c1.m_i < c2.m_i; + } - constexpr friend Coord& operator-=(Coord& c, int d) - { - c.m_i -= d; - return c; - } + [[nodiscard]] constexpr friend bool operator<=(const Coord& c1, const Coord& c2) noexcept { + return c1.m_i <= c2.m_i; + } - constexpr friend Coord operator+(const Coord& c, int d) - { - Coord cpy(c); - cpy += d; - return cpy; - } + [[nodiscard]] constexpr friend bool operator>(const Coord& c1, const Coord& c2) noexcept { + return c1.m_i > c2.m_i; + } - constexpr friend Coord operator-(const Coord& c, int d) - { - Coord cpy(c); - cpy -= d; - return cpy; - } + [[nodiscard]] constexpr friend bool operator>=(const Coord& c1, const Coord& c2) noexcept { + return c1.m_i >= c2.m_i; + } - constexpr friend int operator-(const Coord& c1, const Coord& c2) - { - return c1.m_i - c2.m_i; - } + private: + std::int8_t m_i; +}; + +struct FileTag; +struct RankTag; +using File = Coord; +using Rank = Coord; + +constexpr File fileA = File(0); +constexpr File fileB = File(1); +constexpr File fileC = File(2); +constexpr File fileD = File(3); +constexpr File fileE = File(4); +constexpr File fileF = File(5); +constexpr File fileG = File(6); +constexpr File fileH = File(7); + +constexpr Rank rank1 = Rank(0); +constexpr Rank rank2 = Rank(1); +constexpr Rank rank3 = Rank(2); +constexpr Rank rank4 = Rank(3); +constexpr Rank rank5 = Rank(4); +constexpr Rank rank6 = Rank(5); +constexpr Rank rank7 = Rank(6); +constexpr Rank rank8 = Rank(7); + +template<> +struct EnumTraits { + using IdType = int; + using EnumType = File; + + static constexpr int cardinality = 8; + static constexpr bool isNaturalIndex = true; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); + } - [[nodiscard]] constexpr friend bool operator==(const Coord& c1, const Coord& c2) noexcept - { - return c1.m_i == c2.m_i; - } + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality); - [[nodiscard]] constexpr friend bool operator!=(const Coord& c1, const Coord& c2) noexcept - { - return c1.m_i != c2.m_i; - } + return static_cast(id); + } - [[nodiscard]] constexpr friend bool operator<(const Coord& c1, const Coord& c2) noexcept - { - return c1.m_i < c2.m_i; - } + [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept { + assert(ordinal(c) >= 0 && ordinal(c) < 8); - [[nodiscard]] constexpr friend bool operator<=(const Coord& c1, const Coord& c2) noexcept - { - return c1.m_i <= c2.m_i; - } + return std::string_view("abcdefgh").substr(ordinal(c), 1); + } - [[nodiscard]] constexpr friend bool operator>(const Coord& c1, const Coord& c2) noexcept - { - return c1.m_i > c2.m_i; - } + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept { + if (c < 'a' || c > 'h') + return {}; + return static_cast(c - 'a'); + } - [[nodiscard]] constexpr friend bool operator>=(const Coord& c1, const Coord& c2) noexcept - { - return c1.m_i >= c2.m_i; - } + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept { + if (sv.size() != 1) + return {}; - private: - std::int8_t m_i; - }; + return fromChar(sv[0]); + } +}; - struct FileTag; - struct RankTag; - using File = Coord; - using Rank = Coord; +template<> +struct EnumTraits { + using IdType = int; + using EnumType = Rank; - constexpr File fileA = File(0); - constexpr File fileB = File(1); - constexpr File fileC = File(2); - constexpr File fileD = File(3); - constexpr File fileE = File(4); - constexpr File fileF = File(5); - constexpr File fileG = File(6); - constexpr File fileH = File(7); + static constexpr int cardinality = 8; + static constexpr bool isNaturalIndex = true; - constexpr Rank rank1 = Rank(0); - constexpr Rank rank2 = Rank(1); - constexpr Rank rank3 = Rank(2); - constexpr Rank rank4 = Rank(3); - constexpr Rank rank5 = Rank(4); - constexpr Rank rank6 = Rank(5); - constexpr Rank rank7 = Rank(6); - constexpr Rank rank8 = Rank(7); + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); + } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = File; + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality); - static constexpr int cardinality = 8; - static constexpr bool isNaturalIndex = true; + return static_cast(id); + } - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept { + assert(ordinal(c) >= 0 && ordinal(c) < 8); - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept - { - assert(id >= 0 && id < cardinality); + return std::string_view("12345678").substr(ordinal(c), 1); + } - return static_cast(id); - } + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept { + if (c < '1' || c > '8') + return {}; + return static_cast(c - '1'); + } - [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept - { - assert(ordinal(c) >= 0 && ordinal(c) < 8); + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept { + if (sv.size() != 1) + return {}; - return std::string_view("abcdefgh").substr(ordinal(c), 1); - } + return fromChar(sv[0]); + } +}; - [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept - { - if (c < 'a' || c > 'h') return {}; - return static_cast(c - 'a'); - } +// files east +// ranks north +struct FlatSquareOffset { + std::int8_t value; - [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept - { - if (sv.size() != 1) return {}; + constexpr FlatSquareOffset() noexcept : + value(0) {} - return fromChar(sv[0]); - } - }; + constexpr FlatSquareOffset(int files, int ranks) noexcept : + value(files + ranks * cardinality()) { + assert(files + ranks * cardinality() >= std::numeric_limits::min()); + assert(files + ranks * cardinality() <= std::numeric_limits::max()); + } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = Rank; + constexpr FlatSquareOffset operator-() const noexcept { return FlatSquareOffset(-value); } - static constexpr int cardinality = 8; - static constexpr bool isNaturalIndex = true; + private: + constexpr FlatSquareOffset(int v) noexcept : + value(v) {} +}; - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } +struct Offset { + std::int8_t files; + std::int8_t ranks; - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept - { - assert(id >= 0 && id < cardinality); + constexpr Offset() : + files(0), + ranks(0) {} - return static_cast(id); - } + constexpr Offset(int files_, int ranks_) : + files(files_), + ranks(ranks_) {} - [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept - { - assert(ordinal(c) >= 0 && ordinal(c) < 8); + [[nodiscard]] constexpr FlatSquareOffset flat() const { return {files, ranks}; } - return std::string_view("12345678").substr(ordinal(c), 1); - } + [[nodiscard]] constexpr Offset operator-() const { return {-files, -ranks}; } +}; - [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept - { - if (c < '1' || c > '8') return {}; - return static_cast(c - '1'); - } +struct SquareCoords { + File file; + Rank rank; - [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept - { - if (sv.size() != 1) return {}; + constexpr SquareCoords() noexcept : + file{}, + rank{} {} - return fromChar(sv[0]); - } - }; + constexpr SquareCoords(File f, Rank r) noexcept : + file(f), + rank(r) {} - // files east - // ranks north - struct FlatSquareOffset - { - std::int8_t value; + constexpr friend SquareCoords& operator+=(SquareCoords& c, Offset offset) { + c.file += offset.files; + c.rank += offset.ranks; + return c; + } - constexpr FlatSquareOffset() noexcept : - value(0) - { - } + [[nodiscard]] constexpr friend SquareCoords operator+(const SquareCoords& c, Offset offset) { + SquareCoords cpy(c); + cpy.file += offset.files; + cpy.rank += offset.ranks; + return cpy; + } - constexpr FlatSquareOffset(int files, int ranks) noexcept : - value(files + ranks * cardinality()) - { - assert(files + ranks * cardinality() >= std::numeric_limits::min()); - assert(files + ranks * cardinality() <= std::numeric_limits::max()); - } + [[nodiscard]] constexpr bool isOk() const { + return file >= fileA && file <= fileH && rank >= rank1 && rank <= rank8; + } +}; - constexpr FlatSquareOffset operator-() const noexcept - { - return FlatSquareOffset(-value); - } +struct Square { + private: + static constexpr std::int8_t m_noneId = cardinality() * cardinality(); - private: - constexpr FlatSquareOffset(int v) noexcept : - value(v) - { - } - }; + static constexpr std::uint8_t fileMask = 0b111; + static constexpr std::uint8_t rankMask = 0b111000; + static constexpr std::uint8_t rankShift = 3; - struct Offset - { - std::int8_t files; - std::int8_t ranks; + public: + [[nodiscard]] static constexpr Square none() { return Square(m_noneId); } - constexpr Offset() : - files(0), - ranks(0) - { - } + constexpr Square() noexcept : + m_id(0) {} - constexpr Offset(int files_, int ranks_) : - files(files_), - ranks(ranks_) - { - } + constexpr explicit Square(int idx) noexcept : + m_id(idx) { + assert(isOk() || m_id == m_noneId); + } - [[nodiscard]] constexpr FlatSquareOffset flat() const - { - return { files, ranks }; - } + constexpr Square(File file, Rank rank) noexcept : + m_id(ordinal(file) + ordinal(rank) * cardinality()) { + assert(isOk()); + } - [[nodiscard]] constexpr Offset operator-() const - { - return { -files, -ranks }; - } - }; + constexpr explicit Square(SquareCoords coords) noexcept : + Square(coords.file, coords.rank) {} - struct SquareCoords - { - File file; - Rank rank; + [[nodiscard]] constexpr friend bool operator<(Square lhs, Square rhs) noexcept { + return lhs.m_id < rhs.m_id; + } - constexpr SquareCoords() noexcept : - file{}, - rank{} - { - } + [[nodiscard]] constexpr friend bool operator>(Square lhs, Square rhs) noexcept { + return lhs.m_id > rhs.m_id; + } - constexpr SquareCoords(File f, Rank r) noexcept : - file(f), - rank(r) - { - } + [[nodiscard]] constexpr friend bool operator<=(Square lhs, Square rhs) noexcept { + return lhs.m_id <= rhs.m_id; + } - constexpr friend SquareCoords& operator+=(SquareCoords& c, Offset offset) - { - c.file += offset.files; - c.rank += offset.ranks; - return c; - } + [[nodiscard]] constexpr friend bool operator>=(Square lhs, Square rhs) noexcept { + return lhs.m_id >= rhs.m_id; + } - [[nodiscard]] constexpr friend SquareCoords operator+(const SquareCoords& c, Offset offset) - { - SquareCoords cpy(c); - cpy.file += offset.files; - cpy.rank += offset.ranks; - return cpy; - } + [[nodiscard]] constexpr friend bool operator==(Square lhs, Square rhs) noexcept { + return lhs.m_id == rhs.m_id; + } - [[nodiscard]] constexpr bool isOk() const - { - return file >= fileA && file <= fileH && rank >= rank1 && rank <= rank8; - } - }; + [[nodiscard]] constexpr friend bool operator!=(Square lhs, Square rhs) noexcept { + return !(lhs == rhs); + } - struct Square - { - private: - static constexpr std::int8_t m_noneId = cardinality() * cardinality(); + constexpr friend Square& operator++(Square& sq) { + ++sq.m_id; + return sq; + } - static constexpr std::uint8_t fileMask = 0b111; - static constexpr std::uint8_t rankMask = 0b111000; - static constexpr std::uint8_t rankShift = 3; + constexpr friend Square& operator--(Square& sq) { + --sq.m_id; + return sq; + } - public: - [[nodiscard]] static constexpr Square none() - { - return Square(m_noneId); - } + [[nodiscard]] constexpr friend Square operator+(Square sq, FlatSquareOffset offset) { + Square sqCpy = sq; + sqCpy += offset; + return sqCpy; + } - constexpr Square() noexcept : - m_id(0) - { - } + constexpr friend Square& operator+=(Square& sq, FlatSquareOffset offset) { + assert(sq.m_id + offset.value >= 0 && sq.m_id + offset.value < Square::m_noneId); + sq.m_id += offset.value; + return sq; + } - constexpr explicit Square(int idx) noexcept : - m_id(idx) - { - assert(isOk() || m_id == m_noneId); - } + [[nodiscard]] constexpr friend Square operator+(Square sq, Offset offset) { + assert(sq.file() + offset.files >= fileA); + assert(sq.file() + offset.files <= fileH); + assert(sq.rank() + offset.ranks >= rank1); + assert(sq.rank() + offset.ranks <= rank8); + return operator+(sq, offset.flat()); + } - constexpr Square(File file, Rank rank) noexcept : - m_id(ordinal(file) + ordinal(rank) * cardinality()) - { - assert(isOk()); - } + constexpr friend Square& operator+=(Square& sq, Offset offset) { + return operator+=(sq, offset.flat()); + } - constexpr explicit Square(SquareCoords coords) noexcept : - Square(coords.file, coords.rank) - { - } + [[nodiscard]] constexpr explicit operator int() const { return m_id; } - [[nodiscard]] constexpr friend bool operator<(Square lhs, Square rhs) noexcept - { - return lhs.m_id < rhs.m_id; - } + [[nodiscard]] constexpr File file() const { + assert(isOk()); + return File(static_cast(m_id) & fileMask); + } - [[nodiscard]] constexpr friend bool operator>(Square lhs, Square rhs) noexcept - { - return lhs.m_id > rhs.m_id; - } + [[nodiscard]] constexpr Rank rank() const { + assert(isOk()); + return Rank(static_cast(m_id) >> rankShift); + } - [[nodiscard]] constexpr friend bool operator<=(Square lhs, Square rhs) noexcept - { - return lhs.m_id <= rhs.m_id; - } + [[nodiscard]] constexpr SquareCoords coords() const { return {file(), rank()}; } - [[nodiscard]] constexpr friend bool operator>=(Square lhs, Square rhs) noexcept - { - return lhs.m_id >= rhs.m_id; - } + [[nodiscard]] constexpr Color color() const { + assert(isOk()); + return !fromOrdinal((ordinal(rank()) + ordinal(file())) & 1); + } - [[nodiscard]] constexpr friend bool operator==(Square lhs, Square rhs) noexcept - { - return lhs.m_id == rhs.m_id; - } + constexpr void flipVertically() { m_id ^= rankMask; } + + constexpr void flipHorizontally() { m_id ^= fileMask; } + + constexpr Square flippedVertically() const { return Square(m_id ^ rankMask); } + + constexpr Square flippedHorizontally() const { return Square(m_id ^ fileMask); } + + [[nodiscard]] constexpr bool isOk() const { return m_id >= 0 && m_id < m_noneId; } + + private: + std::int8_t m_id; +}; + +constexpr Square a1(fileA, rank1); +constexpr Square a2(fileA, rank2); +constexpr Square a3(fileA, rank3); +constexpr Square a4(fileA, rank4); +constexpr Square a5(fileA, rank5); +constexpr Square a6(fileA, rank6); +constexpr Square a7(fileA, rank7); +constexpr Square a8(fileA, rank8); + +constexpr Square b1(fileB, rank1); +constexpr Square b2(fileB, rank2); +constexpr Square b3(fileB, rank3); +constexpr Square b4(fileB, rank4); +constexpr Square b5(fileB, rank5); +constexpr Square b6(fileB, rank6); +constexpr Square b7(fileB, rank7); +constexpr Square b8(fileB, rank8); + +constexpr Square c1(fileC, rank1); +constexpr Square c2(fileC, rank2); +constexpr Square c3(fileC, rank3); +constexpr Square c4(fileC, rank4); +constexpr Square c5(fileC, rank5); +constexpr Square c6(fileC, rank6); +constexpr Square c7(fileC, rank7); +constexpr Square c8(fileC, rank8); + +constexpr Square d1(fileD, rank1); +constexpr Square d2(fileD, rank2); +constexpr Square d3(fileD, rank3); +constexpr Square d4(fileD, rank4); +constexpr Square d5(fileD, rank5); +constexpr Square d6(fileD, rank6); +constexpr Square d7(fileD, rank7); +constexpr Square d8(fileD, rank8); + +constexpr Square e1(fileE, rank1); +constexpr Square e2(fileE, rank2); +constexpr Square e3(fileE, rank3); +constexpr Square e4(fileE, rank4); +constexpr Square e5(fileE, rank5); +constexpr Square e6(fileE, rank6); +constexpr Square e7(fileE, rank7); +constexpr Square e8(fileE, rank8); + +constexpr Square f1(fileF, rank1); +constexpr Square f2(fileF, rank2); +constexpr Square f3(fileF, rank3); +constexpr Square f4(fileF, rank4); +constexpr Square f5(fileF, rank5); +constexpr Square f6(fileF, rank6); +constexpr Square f7(fileF, rank7); +constexpr Square f8(fileF, rank8); + +constexpr Square g1(fileG, rank1); +constexpr Square g2(fileG, rank2); +constexpr Square g3(fileG, rank3); +constexpr Square g4(fileG, rank4); +constexpr Square g5(fileG, rank5); +constexpr Square g6(fileG, rank6); +constexpr Square g7(fileG, rank7); +constexpr Square g8(fileG, rank8); + +constexpr Square h1(fileH, rank1); +constexpr Square h2(fileH, rank2); +constexpr Square h3(fileH, rank3); +constexpr Square h4(fileH, rank4); +constexpr Square h5(fileH, rank5); +constexpr Square h6(fileH, rank6); +constexpr Square h7(fileH, rank7); +constexpr Square h8(fileH, rank8); + +static_assert(e1.color() == Color::Black); +static_assert(e8.color() == Color::White); + +static_assert(e1.file() == fileE); +static_assert(e1.rank() == rank1); + +static_assert(e1.flippedHorizontally() == d1); +static_assert(e1.flippedVertically() == e8); + +template<> +struct EnumTraits { + using IdType = int; + using EnumType = Square; + + static constexpr int cardinality = chess::cardinality() * chess::cardinality(); + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + a1, b1, c1, d1, e1, f1, g1, h1, a2, b2, c2, d2, e2, f2, g2, h2, a3, b3, c3, d3, e3, f3, + g3, h3, a4, b4, c4, d4, e4, f4, g4, h4, a5, b5, c5, d5, e5, f5, g5, h5, a6, b6, c6, d6, + e6, f6, g6, h6, a7, b7, c7, d7, e7, f7, g7, h7, a8, b8, c8, d8, e8, f8, g8, h8}; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); + } - [[nodiscard]] constexpr friend bool operator!=(Square lhs, Square rhs) noexcept - { - return !(lhs == rhs); - } + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality + 1); - constexpr friend Square& operator++(Square& sq) - { - ++sq.m_id; - return sq; - } + return static_cast(id); + } - constexpr friend Square& operator--(Square& sq) - { - --sq.m_id; - return sq; - } + [[nodiscard]] static constexpr std::string_view toString(Square sq) { + assert(sq.isOk()); - [[nodiscard]] constexpr friend Square operator+(Square sq, FlatSquareOffset offset) - { - Square sqCpy = sq; - sqCpy += offset; - return sqCpy; - } + return std::string_view("a1b1c1d1e1f1g1h1" + "a2b2c2d2e2f2g2h2" + "a3b3c3d3e3f3g3h3" + "a4b4c4d4e4f4g4h4" + "a5b5c5d5e5f5g5h5" + "a6b6c6d6e6f6g6h6" + "a7b7c7d7e7f7g7h7" + "a8b8c8d8e8f8g8h8") + .substr(ordinal(sq) * 2, 2); + } - constexpr friend Square& operator+=(Square& sq, FlatSquareOffset offset) - { - assert(sq.m_id + offset.value >= 0 && sq.m_id + offset.value < Square::m_noneId); - sq.m_id += offset.value; - return sq; - } + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept { + if (sv.size() != 2) + return {}; - [[nodiscard]] constexpr friend Square operator+(Square sq, Offset offset) - { - assert(sq.file() + offset.files >= fileA); - assert(sq.file() + offset.files <= fileH); - assert(sq.rank() + offset.ranks >= rank1); - assert(sq.rank() + offset.ranks <= rank8); - return operator+(sq, offset.flat()); - } + const char f = sv[0]; + const char r = sv[1]; + if (f < 'a' || f > 'h') + return {}; + if (r < '1' || r > '8') + return {}; - constexpr friend Square& operator+=(Square& sq, Offset offset) - { - return operator+=(sq, offset.flat()); - } + return Square(static_cast(f - 'a'), static_cast(r - '1')); + } +}; - [[nodiscard]] constexpr explicit operator int() const - { - return m_id; - } +static_assert(toString(d1) == std::string_view("d1")); +static_assert(values()[29] == f4); - [[nodiscard]] constexpr File file() const - { - assert(isOk()); - return File(static_cast(m_id) & fileMask); - } +enum struct MoveType : std::uint8_t { + Normal, + Promotion, + Castle, + EnPassant +}; - [[nodiscard]] constexpr Rank rank() const - { - assert(isOk()); - return Rank(static_cast(m_id) >> rankShift); - } +template<> +struct EnumTraits { + using IdType = int; + using EnumType = MoveType; - [[nodiscard]] constexpr SquareCoords coords() const - { - return { file(), rank() }; - } + static constexpr int cardinality = 4; + static constexpr bool isNaturalIndex = true; - [[nodiscard]] constexpr Color color() const - { - assert(isOk()); - return !fromOrdinal((ordinal(rank()) + ordinal(file())) & 1); - } + static constexpr std::array values{ + MoveType::Normal, MoveType::Promotion, MoveType::Castle, MoveType::EnPassant}; - constexpr void flipVertically() - { - m_id ^= rankMask; - } + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); + } - constexpr void flipHorizontally() - { - m_id ^= fileMask; - } + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality); - constexpr Square flippedVertically() const - { - return Square(m_id ^ rankMask); - } + return static_cast(id); + } +}; - constexpr Square flippedHorizontally() const - { - return Square(m_id ^ fileMask); - } +enum struct CastleType : std::uint8_t { + Short, + Long +}; - [[nodiscard]] constexpr bool isOk() const - { - return m_id >= 0 && m_id < m_noneId; - } +[[nodiscard]] constexpr CastleType operator!(CastleType ct) { + return static_cast(static_cast(ct) ^ 1); +} - private: - std::int8_t m_id; - }; +template<> +struct EnumTraits { + using IdType = int; + using EnumType = CastleType; - constexpr Square a1(fileA, rank1); - constexpr Square a2(fileA, rank2); - constexpr Square a3(fileA, rank3); - constexpr Square a4(fileA, rank4); - constexpr Square a5(fileA, rank5); - constexpr Square a6(fileA, rank6); - constexpr Square a7(fileA, rank7); - constexpr Square a8(fileA, rank8); - - constexpr Square b1(fileB, rank1); - constexpr Square b2(fileB, rank2); - constexpr Square b3(fileB, rank3); - constexpr Square b4(fileB, rank4); - constexpr Square b5(fileB, rank5); - constexpr Square b6(fileB, rank6); - constexpr Square b7(fileB, rank7); - constexpr Square b8(fileB, rank8); - - constexpr Square c1(fileC, rank1); - constexpr Square c2(fileC, rank2); - constexpr Square c3(fileC, rank3); - constexpr Square c4(fileC, rank4); - constexpr Square c5(fileC, rank5); - constexpr Square c6(fileC, rank6); - constexpr Square c7(fileC, rank7); - constexpr Square c8(fileC, rank8); - - constexpr Square d1(fileD, rank1); - constexpr Square d2(fileD, rank2); - constexpr Square d3(fileD, rank3); - constexpr Square d4(fileD, rank4); - constexpr Square d5(fileD, rank5); - constexpr Square d6(fileD, rank6); - constexpr Square d7(fileD, rank7); - constexpr Square d8(fileD, rank8); - - constexpr Square e1(fileE, rank1); - constexpr Square e2(fileE, rank2); - constexpr Square e3(fileE, rank3); - constexpr Square e4(fileE, rank4); - constexpr Square e5(fileE, rank5); - constexpr Square e6(fileE, rank6); - constexpr Square e7(fileE, rank7); - constexpr Square e8(fileE, rank8); - - constexpr Square f1(fileF, rank1); - constexpr Square f2(fileF, rank2); - constexpr Square f3(fileF, rank3); - constexpr Square f4(fileF, rank4); - constexpr Square f5(fileF, rank5); - constexpr Square f6(fileF, rank6); - constexpr Square f7(fileF, rank7); - constexpr Square f8(fileF, rank8); - - constexpr Square g1(fileG, rank1); - constexpr Square g2(fileG, rank2); - constexpr Square g3(fileG, rank3); - constexpr Square g4(fileG, rank4); - constexpr Square g5(fileG, rank5); - constexpr Square g6(fileG, rank6); - constexpr Square g7(fileG, rank7); - constexpr Square g8(fileG, rank8); - - constexpr Square h1(fileH, rank1); - constexpr Square h2(fileH, rank2); - constexpr Square h3(fileH, rank3); - constexpr Square h4(fileH, rank4); - constexpr Square h5(fileH, rank5); - constexpr Square h6(fileH, rank6); - constexpr Square h7(fileH, rank7); - constexpr Square h8(fileH, rank8); - - static_assert(e1.color() == Color::Black); - static_assert(e8.color() == Color::White); - - static_assert(e1.file() == fileE); - static_assert(e1.rank() == rank1); - - static_assert(e1.flippedHorizontally() == d1); - static_assert(e1.flippedVertically() == e8); - - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = Square; - - static constexpr int cardinality = chess::cardinality() * chess::cardinality(); - static constexpr bool isNaturalIndex = true; - - static constexpr std::array values{ - a1, b1, c1, d1, e1, f1, g1, h1, - a2, b2, c2, d2, e2, f2, g2, h2, - a3, b3, c3, d3, e3, f3, g3, h3, - a4, b4, c4, d4, e4, f4, g4, h4, - a5, b5, c5, d5, e5, f5, g5, h5, - a6, b6, c6, d6, e6, f6, g6, h6, - a7, b7, c7, d7, e7, f7, g7, h7, - a8, b8, c8, d8, e8, f8, g8, h8 - }; + static constexpr int cardinality = 2; + static constexpr bool isNaturalIndex = true; - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + static constexpr std::array values{CastleType::Short, CastleType::Long}; - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept - { - assert(id >= 0 && id < cardinality + 1); + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); + } - return static_cast(id); - } + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + assert(id >= 0 && id < cardinality); - [[nodiscard]] static constexpr std::string_view toString(Square sq) - { - assert(sq.isOk()); + return static_cast(id); + } +}; - return - std::string_view( - "a1b1c1d1e1f1g1h1" - "a2b2c2d2e2f2g2h2" - "a3b3c3d3e3f3g3h3" - "a4b4c4d4e4f4g4h4" - "a5b5c5d5e5f5g5h5" - "a6b6c6d6e6f6g6h6" - "a7b7c7d7e7f7g7h7" - "a8b8c8d8e8f8g8h8" - ).substr(ordinal(sq) * 2, 2); - } +struct CompressedMove; - [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept - { - if (sv.size() != 2) return {}; +// castling is encoded as a king capturing rook +// ep is encoded as a normal pawn capture (move.to is empty on the board) +struct Move { + Square from; + Square to; + MoveType type = MoveType::Normal; + Piece promotedPiece = Piece::none(); - const char f = sv[0]; - const char r = sv[1]; - if (f < 'a' || f > 'h') return {}; - if (r < '1' || r > '8') return {}; + [[nodiscard]] constexpr friend bool operator==(const Move& lhs, const Move& rhs) noexcept { + return lhs.from == rhs.from && lhs.to == rhs.to && lhs.type == rhs.type + && lhs.promotedPiece == rhs.promotedPiece; + } - return Square(static_cast(f - 'a'), static_cast(r - '1')); - } - }; + [[nodiscard]] constexpr friend bool operator!=(const Move& lhs, const Move& rhs) noexcept { + return !(lhs == rhs); + } - static_assert(toString(d1) == std::string_view("d1")); - static_assert(values()[29] == f4); + [[nodiscard]] constexpr CompressedMove compress() const noexcept; - enum struct MoveType : std::uint8_t - { - Normal, - Promotion, - Castle, - EnPassant - }; + [[nodiscard]] constexpr static Move null() { return Move{Square::none(), Square::none()}; } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = MoveType; + [[nodiscard]] constexpr static Move castle(CastleType ct, Color c); - static constexpr int cardinality = 4; - static constexpr bool isNaturalIndex = true; + [[nodiscard]] constexpr static Move normal(Square from, Square to) { + return Move{from, to, MoveType::Normal, Piece::none()}; + } - static constexpr std::array values{ - MoveType::Normal, - MoveType::Promotion, - MoveType::Castle, - MoveType::EnPassant - }; + [[nodiscard]] constexpr static Move enPassant(Square from, Square to) { + return Move{from, to, MoveType::EnPassant, Piece::none()}; + } - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + [[nodiscard]] constexpr static Move promotion(Square from, Square to, Piece piece) { + return Move{from, to, MoveType::Promotion, piece}; + } +}; + +namespace detail::castle { +constexpr EnumArray2 moves = { + {{{{e1, h1, MoveType::Castle}, {e8, h8, MoveType::Castle}}}, + {{{e1, a1, MoveType::Castle}, {e8, a8, MoveType::Castle}}}}}; +} + +[[nodiscard]] constexpr Move Move::castle(CastleType ct, Color c) { + return detail::castle::moves[ct][c]; +} + +static_assert(sizeof(Move) == 4); + +struct CompressedMove { + private: + // from most significant bits + // 2 bits for move type + // 6 bits for from square + // 6 bits for to square + // 2 bits for promoted piece type + // 0 if not a promotion + static constexpr std::uint16_t squareMask = 0b111111u; + static constexpr std::uint16_t promotedPieceTypeMask = 0b11u; + static constexpr std::uint16_t moveTypeMask = 0b11u; + + public: + [[nodiscard]] constexpr static CompressedMove readFromBigEndian(const unsigned char* data) { + CompressedMove move{}; + move.m_packed = (data[0] << 8) | data[1]; + return move; + } + + constexpr CompressedMove() noexcept : + m_packed(0) {} - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + // move must be either valid or a null move + constexpr CompressedMove(Move move) noexcept : + m_packed(0) { + // else null move + if (move.from != move.to) { - assert(id >= 0 && id < cardinality); + assert(move.from != Square::none()); + assert(move.to != Square::none()); - return static_cast(id); + m_packed = (static_cast(ordinal(move.type)) << (16 - 2)) + | (static_cast(ordinal(move.from)) << (16 - 2 - 6)) + | (static_cast(ordinal(move.to)) << (16 - 2 - 6 - 6)); + + if (move.type == MoveType::Promotion) + { + assert(move.promotedPiece != Piece::none()); + + m_packed |= ordinal(move.promotedPiece.type()) - ordinal(PieceType::Knight); + } + else + { + assert(move.promotedPiece == Piece::none()); + } } - }; + } - enum struct CastleType : std::uint8_t - { - Short, - Long - }; + void writeToBigEndian(unsigned char* data) const { + *data++ = m_packed >> 8; + *data++ = m_packed & 0xFF; + } - [[nodiscard]] constexpr CastleType operator!(CastleType ct) - { - return static_cast(static_cast(ct) ^ 1); + [[nodiscard]] constexpr std::uint16_t packed() const { return m_packed; } + + [[nodiscard]] constexpr MoveType type() const { + return fromOrdinal(m_packed >> (16 - 2)); } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = CastleType; + [[nodiscard]] constexpr Square from() const { + return fromOrdinal((m_packed >> (16 - 2 - 6)) & squareMask); + } - static constexpr int cardinality = 2; - static constexpr bool isNaturalIndex = true; + [[nodiscard]] constexpr Square to() const { + return fromOrdinal((m_packed >> (16 - 2 - 6 - 6)) & squareMask); + } - static constexpr std::array values{ - CastleType::Short, - CastleType::Long - }; + [[nodiscard]] constexpr Piece promotedPiece() const { + if (type() == MoveType::Promotion) + { + const Color color = (to().rank() == rank1) ? Color::Black : Color::White; - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + const PieceType pt = fromOrdinal((m_packed & promotedPieceTypeMask) + + ordinal(PieceType::Knight)); + return color | pt; + } + else { - return static_cast(c); + return Piece::none(); } + } - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + [[nodiscard]] constexpr Move decompress() const noexcept { + if (m_packed == 0) + { + return Move::null(); + } + else { - assert(id >= 0 && id < cardinality); + const MoveType type = fromOrdinal(m_packed >> (16 - 2)); + const Square from = fromOrdinal((m_packed >> (16 - 2 - 6)) & squareMask); + const Square to = fromOrdinal((m_packed >> (16 - 2 - 6 - 6)) & squareMask); + const Piece promotedPiece = [&]() { + if (type == MoveType::Promotion) + { + const Color color = (to.rank() == rank1) ? Color::Black : Color::White; + + const PieceType pt = fromOrdinal((m_packed & promotedPieceTypeMask) + + ordinal(PieceType::Knight)); + return color | pt; + } + else + { + return Piece::none(); + } + }(); - return static_cast(id); + return Move{from, to, type, promotedPiece}; } - }; + } - struct CompressedMove; + private: + std::uint16_t m_packed; +}; - // castling is encoded as a king capturing rook - // ep is encoded as a normal pawn capture (move.to is empty on the board) - struct Move - { - Square from; - Square to; - MoveType type = MoveType::Normal; - Piece promotedPiece = Piece::none(); +static_assert(sizeof(CompressedMove) == 2); - [[nodiscard]] constexpr friend bool operator==(const Move& lhs, const Move& rhs) noexcept - { - return lhs.from == rhs.from - && lhs.to == rhs.to - && lhs.type == rhs.type - && lhs.promotedPiece == rhs.promotedPiece; - } +[[nodiscard]] constexpr CompressedMove Move::compress() const noexcept { + return CompressedMove(*this); +} - [[nodiscard]] constexpr friend bool operator!=(const Move& lhs, const Move& rhs) noexcept - { - return !(lhs == rhs); - } +static_assert(a4 + Offset{0, 1} == a5); +static_assert(a4 + Offset{0, 2} == a6); +static_assert(a4 + Offset{0, -2} == a2); +static_assert(a4 + Offset{0, -1} == a3); + +static_assert(e4 + Offset{1, 0} == f4); +static_assert(e4 + Offset{2, 0} == g4); +static_assert(e4 + Offset{-1, 0} == d4); +static_assert(e4 + Offset{-2, 0} == c4); + +enum struct CastlingRights : std::uint8_t { + None = 0x0, + WhiteKingSide = 0x1, + WhiteQueenSide = 0x2, + BlackKingSide = 0x4, + BlackQueenSide = 0x8, + White = WhiteKingSide | WhiteQueenSide, + Black = BlackKingSide | BlackQueenSide, + All = WhiteKingSide | WhiteQueenSide | BlackKingSide | BlackQueenSide +}; + +[[nodiscard]] constexpr CastlingRights operator|(CastlingRights lhs, CastlingRights rhs) { + return static_cast(static_cast(lhs) + | static_cast(rhs)); +} - [[nodiscard]] constexpr CompressedMove compress() const noexcept; +[[nodiscard]] constexpr CastlingRights operator&(CastlingRights lhs, CastlingRights rhs) { + return static_cast(static_cast(lhs) + & static_cast(rhs)); +} - [[nodiscard]] constexpr static Move null() - { - return Move{ Square::none(), Square::none() }; - } +[[nodiscard]] constexpr CastlingRights operator~(CastlingRights lhs) { + return static_cast(~static_cast(lhs) + & static_cast(CastlingRights::All)); +} - [[nodiscard]] constexpr static Move castle(CastleType ct, Color c); +constexpr CastlingRights& operator|=(CastlingRights& lhs, CastlingRights rhs) { + lhs = + static_cast(static_cast(lhs) | static_cast(rhs)); + return lhs; +} - [[nodiscard]] constexpr static Move normal(Square from, Square to) - { - return Move{ from, to, MoveType::Normal, Piece::none() }; - } +constexpr CastlingRights& operator&=(CastlingRights& lhs, CastlingRights rhs) { + lhs = + static_cast(static_cast(lhs) & static_cast(rhs)); + return lhs; +} +// checks whether lhs contains rhs +[[nodiscard]] constexpr bool contains(CastlingRights lhs, CastlingRights rhs) { + return (lhs & rhs) == rhs; +} - [[nodiscard]] constexpr static Move enPassant(Square from, Square to) - { - return Move{ from, to, MoveType::EnPassant, Piece::none() }; - } +template<> +struct EnumTraits { + using IdType = int; + using EnumType = CastlingRights; - [[nodiscard]] constexpr static Move promotion(Square from, Square to, Piece piece) - { - return Move{ from, to, MoveType::Promotion, piece }; - } - }; + static constexpr int cardinality = 4; + static constexpr bool isNaturalIndex = false; - namespace detail::castle - { - constexpr EnumArray2 moves = { { - {{ { e1, h1, MoveType::Castle }, { e8, h8, MoveType::Castle } }}, - {{ { e1, a1, MoveType::Castle }, { e8, a8, MoveType::Castle } }} - } }; + static constexpr std::array values{ + CastlingRights::WhiteKingSide, CastlingRights::WhiteQueenSide, CastlingRights::BlackKingSide, + CastlingRights::BlackQueenSide}; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept { + return static_cast(c); } - [[nodiscard]] constexpr Move Move::castle(CastleType ct, Color c) - { - return detail::castle::moves[ct][c]; + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept { + return static_cast(id); + } +}; + +struct CompressedReverseMove; + +struct ReverseMove { + Move move; + Piece capturedPiece; + Square oldEpSquare; + CastlingRights oldCastlingRights; + + // We need a well defined case for the starting position. + constexpr ReverseMove() : + move(Move::null()), + capturedPiece(Piece::none()), + oldEpSquare(Square::none()), + oldCastlingRights(CastlingRights::All) {} + + constexpr ReverseMove(const Move& move_, + Piece capturedPiece_, + Square oldEpSquare_, + CastlingRights oldCastlingRights_) : + move(move_), + capturedPiece(capturedPiece_), + oldEpSquare(oldEpSquare_), + oldCastlingRights(oldCastlingRights_) {} + + constexpr bool isNull() const { return move.from == move.to; } + + [[nodiscard]] constexpr CompressedReverseMove compress() const noexcept; + + [[nodiscard]] constexpr friend bool operator==(const ReverseMove& lhs, + const ReverseMove& rhs) noexcept { + return lhs.move == rhs.move && lhs.capturedPiece == rhs.capturedPiece + && lhs.oldEpSquare == rhs.oldEpSquare && lhs.oldCastlingRights == rhs.oldCastlingRights; } - static_assert(sizeof(Move) == 4); + [[nodiscard]] constexpr friend bool operator!=(const ReverseMove& lhs, + const ReverseMove& rhs) noexcept { + return !(lhs == rhs); + } +}; - struct CompressedMove - { - private: - // from most significant bits - // 2 bits for move type - // 6 bits for from square - // 6 bits for to square - // 2 bits for promoted piece type - // 0 if not a promotion - static constexpr std::uint16_t squareMask = 0b111111u; - static constexpr std::uint16_t promotedPieceTypeMask = 0b11u; - static constexpr std::uint16_t moveTypeMask = 0b11u; +static_assert(sizeof(ReverseMove) == 7); - public: - [[nodiscard]] constexpr static CompressedMove readFromBigEndian(const unsigned char* data) - { - CompressedMove move{}; - move.m_packed = (data[0] << 8) | data[1]; - return move; - } +struct CompressedReverseMove { + private: + // we use 7 bits because it can be Square::none() + static constexpr std::uint32_t squareMask = 0b1111111u; + static constexpr std::uint32_t pieceMask = 0b1111u; + static constexpr std::uint32_t castlingRightsMask = 0b1111; - constexpr CompressedMove() noexcept : - m_packed(0) - { - } + public: + constexpr CompressedReverseMove() noexcept : + m_move{}, + m_oldState{} {} - // move must be either valid or a null move - constexpr CompressedMove(Move move) noexcept : - m_packed(0) - { - // else null move - if (move.from != move.to) - { - assert(move.from != Square::none()); - assert(move.to != Square::none()); + constexpr CompressedReverseMove(const ReverseMove& rm) noexcept : + m_move(rm.move.compress()), + m_oldState{ + static_cast(((ordinal(rm.capturedPiece) & pieceMask) << 11) + | ((ordinal(rm.oldCastlingRights) & castlingRightsMask) << 7) + | (ordinal(rm.oldEpSquare) & squareMask))} {} - m_packed = - (static_cast(ordinal(move.type)) << (16 - 2)) - | (static_cast(ordinal(move.from)) << (16 - 2 - 6)) - | (static_cast(ordinal(move.to)) << (16 - 2 - 6 - 6)); + [[nodiscard]] constexpr Move move() const { return m_move.decompress(); } - if (move.type == MoveType::Promotion) - { - assert(move.promotedPiece != Piece::none()); + [[nodiscard]] const CompressedMove& compressedMove() const { return m_move; } - m_packed |= ordinal(move.promotedPiece.type()) - ordinal(PieceType::Knight); - } - else - { - assert(move.promotedPiece == Piece::none()); - } - } - } + [[nodiscard]] constexpr Piece capturedPiece() const { + return fromOrdinal(m_oldState >> 11); + } - void writeToBigEndian(unsigned char* data) const - { - *data++ = m_packed >> 8; - *data++ = m_packed & 0xFF; - } + [[nodiscard]] constexpr CastlingRights oldCastlingRights() const { + return fromOrdinal((m_oldState >> 7) & castlingRightsMask); + } - [[nodiscard]] constexpr std::uint16_t packed() const - { - return m_packed; - } + [[nodiscard]] constexpr Square oldEpSquare() const { + return fromOrdinal(m_oldState & squareMask); + } - [[nodiscard]] constexpr MoveType type() const - { - return fromOrdinal(m_packed >> (16 - 2)); - } + [[nodiscard]] constexpr ReverseMove decompress() const noexcept { + const Piece capturedPiece = fromOrdinal(m_oldState >> 11); + const CastlingRights castlingRights = + fromOrdinal((m_oldState >> 7) & castlingRightsMask); + // We could pack the ep square more, but don't have to, because + // can't save another byte anyway. + const Square epSquare = fromOrdinal(m_oldState & squareMask); - [[nodiscard]] constexpr Square from() const - { - return fromOrdinal((m_packed >> (16 - 2 - 6)) & squareMask); - } + return ReverseMove(m_move.decompress(), capturedPiece, epSquare, castlingRights); + } + + private: + CompressedMove m_move; + std::uint16_t m_oldState; +}; - [[nodiscard]] constexpr Square to() const +static_assert(sizeof(CompressedReverseMove) == 4); + +[[nodiscard]] constexpr CompressedReverseMove ReverseMove::compress() const noexcept { + return CompressedReverseMove(*this); +} + +// This can be regarded as a perfect hash. Going back is hard. +struct PackedReverseMove { + static constexpr std::uint32_t mask = 0x7FFFFFFu; + static constexpr std::size_t numBits = 27; + + private: + static constexpr std::uint32_t squareMask = 0b111111u; + static constexpr std::uint32_t pieceMask = 0b1111u; + static constexpr std::uint32_t pieceTypeMask = 0b111u; + static constexpr std::uint32_t castlingRightsMask = 0b1111; + static constexpr std::uint32_t fileMask = 0b111; + + public: + constexpr PackedReverseMove(const std::uint32_t packed) : + m_packed(packed) {} + + constexpr PackedReverseMove(const ReverseMove& reverseMove) : + m_packed( + 0u + // The only move when square is none() is null move and + // then both squares are none(). No other move is like that + // so we don't lose any information by storing only + // the 6 bits of each square. + | ((ordinal(reverseMove.move.from) & squareMask) << 21) + | ((ordinal(reverseMove.move.to) & squareMask) << 15) + // Other masks are just for code clarity, they should + // never change the values. + | ((ordinal(reverseMove.capturedPiece) & pieceMask) << 11) + | ((ordinal(reverseMove.oldCastlingRights) & castlingRightsMask) << 7) + | ((ordinal(reverseMove.move.promotedPiece.type()) & pieceTypeMask) << 4) + | (((reverseMove.oldEpSquare != Square::none()) & 1) << 3) + // We probably could omit the squareMask here but for clarity it's left. + | (ordinal(Square(ordinal(reverseMove.oldEpSquare) & squareMask).file()) & fileMask)) {} + + constexpr std::uint32_t packed() const { return m_packed; } + + constexpr ReverseMove unpack(Color sideThatMoved) const { + ReverseMove rmove{}; + + rmove.move.from = fromOrdinal((m_packed >> 21) & squareMask); + rmove.move.to = fromOrdinal((m_packed >> 15) & squareMask); + rmove.capturedPiece = fromOrdinal((m_packed >> 11) & pieceMask); + rmove.oldCastlingRights = fromOrdinal((m_packed >> 7) & castlingRightsMask); + const PieceType promotedPieceType = fromOrdinal((m_packed >> 4) & pieceTypeMask); + if (promotedPieceType != PieceType::None) + { + rmove.move.promotedPiece = Piece(promotedPieceType, sideThatMoved); + rmove.move.type = MoveType::Promotion; + } + const bool hasEpSquare = static_cast((m_packed >> 3) & 1); + if (hasEpSquare) + { + // ep square is always where the opponent moved + const Rank rank = sideThatMoved == Color::White ? rank6 : rank3; + const File file = fromOrdinal(m_packed & fileMask); + rmove.oldEpSquare = Square(file, rank); + if (rmove.oldEpSquare == rmove.move.to) + { + rmove.move.type = MoveType::EnPassant; + } + } + else { - return fromOrdinal((m_packed >> (16 - 2 - 6 - 6)) & squareMask); + rmove.oldEpSquare = Square::none(); } - [[nodiscard]] constexpr Piece promotedPiece() const + if (rmove.move.type == MoveType::Normal && rmove.oldCastlingRights != CastlingRights::None) { - if (type() == MoveType::Promotion) + // If castling was possible then we know it was the king that moved from e1/e8. + if (rmove.move.from == e1) { - const Color color = - (to().rank() == rank1) - ? Color::Black - : Color::White; - - const PieceType pt = fromOrdinal((m_packed & promotedPieceTypeMask) + ordinal(PieceType::Knight)); - return color | pt; + if (rmove.move.to == h1 || rmove.move.to == a1) + { + rmove.move.type = MoveType::Castle; + } } - else + else if (rmove.move.from == e8) { - return Piece::none(); + if (rmove.move.to == h8 || rmove.move.to == a8) + { + rmove.move.type = MoveType::Castle; + } } } - [[nodiscard]] constexpr Move decompress() const noexcept - { - if (m_packed == 0) - { - return Move::null(); - } - else - { - const MoveType type = fromOrdinal(m_packed >> (16 - 2)); - const Square from = fromOrdinal((m_packed >> (16 - 2 - 6)) & squareMask); - const Square to = fromOrdinal((m_packed >> (16 - 2 - 6 - 6)) & squareMask); - const Piece promotedPiece = [&]() { - if (type == MoveType::Promotion) - { - const Color color = - (to.rank() == rank1) - ? Color::Black - : Color::White; + return rmove; + } - const PieceType pt = fromOrdinal((m_packed & promotedPieceTypeMask) + ordinal(PieceType::Knight)); - return color | pt; - } - else - { - return Piece::none(); - } - }(); + private: + // Uses only 27 lowest bits. + // Bit meaning from highest to lowest. + // - 6 bits from + // - 6 bits to + // - 4 bits for the captured piece + // - 4 bits for prev castling rights + // - 3 bits promoted piece type + // - 1 bit to specify if the ep square was valid (false if none()) + // - 3 bits for prev ep square file + std::uint32_t m_packed; +}; + +struct MoveCompareLess { + [[nodiscard]] bool operator()(const Move& lhs, const Move& rhs) const noexcept { + if (ordinal(lhs.from) < ordinal(rhs.from)) + return true; + if (ordinal(lhs.from) > ordinal(rhs.from)) + return false; - return Move{ from, to, type, promotedPiece }; - } - } + if (ordinal(lhs.to) < ordinal(rhs.to)) + return true; + if (ordinal(lhs.to) > ordinal(rhs.to)) + return false; - private: - std::uint16_t m_packed; - }; + if (ordinal(lhs.type) < ordinal(rhs.type)) + return true; + if (ordinal(lhs.type) > ordinal(rhs.type)) + return false; - static_assert(sizeof(CompressedMove) == 2); + if (ordinal(lhs.promotedPiece) < ordinal(rhs.promotedPiece)) + return true; - [[nodiscard]] constexpr CompressedMove Move::compress() const noexcept - { - return CompressedMove(*this); + return false; } +}; - static_assert(a4 + Offset{ 0, 1 } == a5); - static_assert(a4 + Offset{ 0, 2 } == a6); - static_assert(a4 + Offset{ 0, -2 } == a2); - static_assert(a4 + Offset{ 0, -1 } == a3); +struct ReverseMoveCompareLess { + [[nodiscard]] bool operator()(const ReverseMove& lhs, const ReverseMove& rhs) const noexcept { + if (MoveCompareLess{}(lhs.move, rhs.move)) + return true; + if (MoveCompareLess{}(rhs.move, lhs.move)) + return false; - static_assert(e4 + Offset{ 1, 0 } == f4); - static_assert(e4 + Offset{ 2, 0 } == g4); - static_assert(e4 + Offset{ -1, 0 } == d4); - static_assert(e4 + Offset{ -2, 0 } == c4); + if (ordinal(lhs.capturedPiece) < ordinal(rhs.capturedPiece)) + return true; + if (ordinal(lhs.capturedPiece) > ordinal(rhs.capturedPiece)) + return false; - enum struct CastlingRights : std::uint8_t - { - None = 0x0, - WhiteKingSide = 0x1, - WhiteQueenSide = 0x2, - BlackKingSide = 0x4, - BlackQueenSide = 0x8, - White = WhiteKingSide | WhiteQueenSide, - Black = BlackKingSide | BlackQueenSide, - All = WhiteKingSide | WhiteQueenSide | BlackKingSide | BlackQueenSide - }; + if (static_cast(lhs.oldCastlingRights) + < static_cast(rhs.oldCastlingRights)) + return true; + if (static_cast(lhs.oldCastlingRights) + > static_cast(rhs.oldCastlingRights)) + return false; - [[nodiscard]] constexpr CastlingRights operator|(CastlingRights lhs, CastlingRights rhs) - { - return static_cast(static_cast(lhs) | static_cast(rhs)); + if (ordinal(lhs.oldEpSquare) < ordinal(rhs.oldEpSquare)) + return true; + if (ordinal(lhs.oldEpSquare) > ordinal(rhs.oldEpSquare)) + return false; + + return false; } +}; - [[nodiscard]] constexpr CastlingRights operator&(CastlingRights lhs, CastlingRights rhs) - { - return static_cast(static_cast(lhs) & static_cast(rhs)); +struct BitboardIterator { + using value_type = Square; + using difference_type = std::ptrdiff_t; + using reference = Square; + using iterator_category = std::input_iterator_tag; + using pointer = const Square*; + + constexpr BitboardIterator() noexcept : + m_squares(0) {} + + constexpr BitboardIterator(std::uint64_t v) noexcept : + m_squares(v) {} + + constexpr BitboardIterator(const BitboardIterator&) = default; + constexpr BitboardIterator(BitboardIterator&&) = default; + constexpr BitboardIterator& operator=(const BitboardIterator&) = default; + constexpr BitboardIterator& operator=(BitboardIterator&&) = default; + + [[nodiscard]] constexpr bool friend operator==(BitboardIterator lhs, + BitboardIterator rhs) noexcept { + return lhs.m_squares == rhs.m_squares; } - [[nodiscard]] constexpr CastlingRights operator~(CastlingRights lhs) - { - return static_cast(~static_cast(lhs) & static_cast(CastlingRights::All)); + [[nodiscard]] constexpr bool friend operator!=(BitboardIterator lhs, + BitboardIterator rhs) noexcept { + return lhs.m_squares != rhs.m_squares; } - constexpr CastlingRights& operator|=(CastlingRights& lhs, CastlingRights rhs) - { - lhs = static_cast(static_cast(lhs) | static_cast(rhs)); - return lhs; + [[nodiscard]] inline Square operator*() const { return first(); } + + constexpr BitboardIterator& operator++() noexcept { + popFirst(); + return *this; } - constexpr CastlingRights& operator&=(CastlingRights& lhs, CastlingRights rhs) - { - lhs = static_cast(static_cast(lhs) & static_cast(rhs)); - return lhs; + private: + std::uint64_t m_squares; + + constexpr void popFirst() noexcept { m_squares &= m_squares - 1; } + + [[nodiscard]] inline Square first() const { + assert(m_squares != 0); + + return fromOrdinal(intrin::lsb(m_squares)); } - // checks whether lhs contains rhs - [[nodiscard]] constexpr bool contains(CastlingRights lhs, CastlingRights rhs) - { - return (lhs & rhs) == rhs; +}; + +struct Bitboard { + // bits counted from the LSB + // order is A1 B2 ... G8 H8 + // just like in Square + + public: + constexpr Bitboard() noexcept : + m_squares(0) {} + + private: + constexpr explicit Bitboard(Square sq) noexcept : + m_squares(static_cast(1ULL) << ordinal(sq)) { + assert(sq.isOk()); } - template <> - struct EnumTraits - { - using IdType = int; - using EnumType = CastlingRights; + constexpr explicit Bitboard(Rank r) noexcept : + m_squares(static_cast(0xFFULL) << (ordinal(r) * 8)) {} - static constexpr int cardinality = 4; - static constexpr bool isNaturalIndex = false; + constexpr explicit Bitboard(File f) noexcept : + m_squares(static_cast(0x0101010101010101ULL) << ordinal(f)) {} - static constexpr std::array values{ - CastlingRights::WhiteKingSide, - CastlingRights::WhiteQueenSide, - CastlingRights::BlackKingSide, - CastlingRights::BlackQueenSide - }; + constexpr explicit Bitboard(Color c) noexcept : + m_squares(c == Color::White ? 0xAA55AA55AA55AA55ULL : ~0xAA55AA55AA55AA55ULL) {} - [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept - { - return static_cast(c); - } + constexpr explicit Bitboard(std::uint64_t bb) noexcept : + m_squares(bb) {} - [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept - { - return static_cast(id); - } - }; + // files A..file inclusive + static constexpr EnumArray m_filesUpToBB{ + 0x0101010101010101ULL, 0x0303030303030303ULL, 0x0707070707070707ULL, 0x0F0F0F0F0F0F0F0FULL, + 0x1F1F1F1F1F1F1F1FULL, 0x3F3F3F3F3F3F3F3FULL, 0x7F7F7F7F7F7F7F7FULL, 0xFFFFFFFFFFFFFFFFULL}; - struct CompressedReverseMove; + public: + [[nodiscard]] static constexpr Bitboard none() { return Bitboard{}; } - struct ReverseMove - { - Move move; - Piece capturedPiece; - Square oldEpSquare; - CastlingRights oldCastlingRights; + [[nodiscard]] static constexpr Bitboard all() { return ~none(); } - // We need a well defined case for the starting position. - constexpr ReverseMove() : - move(Move::null()), - capturedPiece(Piece::none()), - oldEpSquare(Square::none()), - oldCastlingRights(CastlingRights::All) - { - } + [[nodiscard]] static constexpr Bitboard square(Square sq) { return Bitboard(sq); } + + [[nodiscard]] static constexpr Bitboard file(File f) { return Bitboard(f); } + + [[nodiscard]] static constexpr Bitboard rank(Rank r) { return Bitboard(r); } + + [[nodiscard]] static constexpr Bitboard color(Color c) { return Bitboard(c); } + + [[nodiscard]] static constexpr Bitboard fromBits(std::uint64_t bits) { return Bitboard(bits); } + + // inclusive + [[nodiscard]] static constexpr Bitboard betweenFiles(File left, File right) { + assert(left <= right); - constexpr ReverseMove(const Move& move_, Piece capturedPiece_, Square oldEpSquare_, CastlingRights oldCastlingRights_) : - move(move_), - capturedPiece(capturedPiece_), - oldEpSquare(oldEpSquare_), - oldCastlingRights(oldCastlingRights_) + if (left == fileA) { + return Bitboard::fromBits(m_filesUpToBB[right]); } - - constexpr bool isNull() const + else { - return move.from == move.to; + return Bitboard::fromBits(m_filesUpToBB[right] ^ m_filesUpToBB[left - 1]); } + } + + [[nodiscard]] constexpr bool isEmpty() const { return m_squares == 0; } + + [[nodiscard]] constexpr bool isSet(Square sq) const { + return !!((m_squares >> ordinal(sq)) & 1ull); + } + + constexpr void set(Square sq) { *this |= Bitboard(sq); } + + constexpr void unset(Square sq) { *this &= ~(Bitboard(sq)); } - [[nodiscard]] constexpr CompressedReverseMove compress() const noexcept; + constexpr void toggle(Square sq) { *this ^= Bitboard(sq); } + + [[nodiscard]] constexpr BitboardIterator begin() const { return BitboardIterator(m_squares); } + + [[nodiscard]] constexpr BitboardIterator end() const { return BitboardIterator{}; } + + [[nodiscard]] constexpr BitboardIterator cbegin() const { return BitboardIterator(m_squares); } + + [[nodiscard]] constexpr BitboardIterator cend() const { return BitboardIterator{}; } + + [[nodiscard]] constexpr bool friend operator==(Bitboard lhs, Bitboard rhs) noexcept { + return lhs.m_squares == rhs.m_squares; + } + + [[nodiscard]] constexpr bool friend operator!=(Bitboard lhs, Bitboard rhs) noexcept { + return lhs.m_squares != rhs.m_squares; + } - [[nodiscard]] constexpr friend bool operator==(const ReverseMove& lhs, const ReverseMove& rhs) noexcept + constexpr Bitboard shiftedVertically(int ranks) const { + if (ranks >= 0) { - return lhs.move == rhs.move - && lhs.capturedPiece == rhs.capturedPiece - && lhs.oldEpSquare == rhs.oldEpSquare - && lhs.oldCastlingRights == rhs.oldCastlingRights; + return fromBits(m_squares << 8 * ranks); } - - [[nodiscard]] constexpr friend bool operator!=(const ReverseMove& lhs, const ReverseMove& rhs) noexcept + else { - return !(lhs == rhs); + return fromBits(m_squares >> -8 * ranks); } - }; - - static_assert(sizeof(ReverseMove) == 7); + } - struct CompressedReverseMove - { - private: - // we use 7 bits because it can be Square::none() - static constexpr std::uint32_t squareMask = 0b1111111u; - static constexpr std::uint32_t pieceMask = 0b1111u; - static constexpr std::uint32_t castlingRightsMask = 0b1111; - public: + template + constexpr void shift() { + static_assert(files >= -7); + static_assert(ranks >= -7); + static_assert(files <= 7); + static_assert(ranks <= 7); - constexpr CompressedReverseMove() noexcept : - m_move{}, - m_oldState{} + if constexpr (files != 0) { + constexpr Bitboard mask = files > 0 ? Bitboard::betweenFiles(fileA, fileH - files) + : Bitboard::betweenFiles(fileA - files, fileH); + + m_squares &= mask.m_squares; } - constexpr CompressedReverseMove(const ReverseMove& rm) noexcept : - m_move(rm.move.compress()), - m_oldState{ static_cast( - ((ordinal(rm.capturedPiece) & pieceMask) << 11) - | ((ordinal(rm.oldCastlingRights) & castlingRightsMask) << 7) - | (ordinal(rm.oldEpSquare) & squareMask) - ) - } + constexpr int shift = files + ranks * 8; + if constexpr (shift == 0) { + return; } - [[nodiscard]] constexpr Move move() const + if constexpr (shift < 0) { - return m_move.decompress(); + m_squares >>= -shift; } - - [[nodiscard]] const CompressedMove& compressedMove() const + else { - return m_move; + m_squares <<= shift; } + } - [[nodiscard]] constexpr Piece capturedPiece() const + template + constexpr Bitboard shifted() const { + Bitboard bbCpy(*this); + bbCpy.shift(); + return bbCpy; + } + + constexpr void shift(Offset offset) { + assert(offset.files >= -7); + assert(offset.ranks >= -7); + assert(offset.files <= 7); + assert(offset.ranks <= 7); + + if (offset.files != 0) { - return fromOrdinal(m_oldState >> 11); + const Bitboard mask = offset.files > 0 + ? Bitboard::betweenFiles(fileA, fileH - offset.files) + : Bitboard::betweenFiles(fileA - offset.files, fileH); + + m_squares &= mask.m_squares; } - [[nodiscard]] constexpr CastlingRights oldCastlingRights() const + const int shift = offset.files + offset.ranks * 8; + if (shift < 0) { - return fromOrdinal((m_oldState >> 7) & castlingRightsMask); + m_squares >>= -shift; } - - [[nodiscard]] constexpr Square oldEpSquare() const + else { - return fromOrdinal(m_oldState & squareMask); + m_squares <<= shift; } + } - [[nodiscard]] constexpr ReverseMove decompress() const noexcept - { - const Piece capturedPiece = fromOrdinal(m_oldState >> 11); - const CastlingRights castlingRights = fromOrdinal((m_oldState >> 7) & castlingRightsMask); - // We could pack the ep square more, but don't have to, because - // can't save another byte anyway. - const Square epSquare = fromOrdinal(m_oldState & squareMask); + [[nodiscard]] constexpr Bitboard shifted(Offset offset) const { + Bitboard bbCpy(*this); + bbCpy.shift(offset); + return bbCpy; + } - return ReverseMove(m_move.decompress(), capturedPiece, epSquare, castlingRights); - } + [[nodiscard]] constexpr Bitboard operator~() const { + Bitboard bb = *this; + bb.m_squares = ~m_squares; + return bb; + } - private: - CompressedMove m_move; - std::uint16_t m_oldState; - }; + constexpr Bitboard& operator^=(Color c) { + m_squares ^= Bitboard(c).m_squares; + return *this; + } - static_assert(sizeof(CompressedReverseMove) == 4); + constexpr Bitboard& operator&=(Color c) { + m_squares &= Bitboard(c).m_squares; + return *this; + } - [[nodiscard]] constexpr CompressedReverseMove ReverseMove::compress() const noexcept - { - return CompressedReverseMove(*this); + constexpr Bitboard& operator|=(Color c) { + m_squares |= Bitboard(c).m_squares; + return *this; } - // This can be regarded as a perfect hash. Going back is hard. - struct PackedReverseMove - { - static constexpr std::uint32_t mask = 0x7FFFFFFu; - static constexpr std::size_t numBits = 27; + [[nodiscard]] constexpr Bitboard operator^(Color c) const { + Bitboard bb = *this; + bb ^= c; + return bb; + } - private: - static constexpr std::uint32_t squareMask = 0b111111u; - static constexpr std::uint32_t pieceMask = 0b1111u; - static constexpr std::uint32_t pieceTypeMask = 0b111u; - static constexpr std::uint32_t castlingRightsMask = 0b1111; - static constexpr std::uint32_t fileMask = 0b111; + [[nodiscard]] constexpr Bitboard operator&(Color c) const { + Bitboard bb = *this; + bb &= c; + return bb; + } - public: - constexpr PackedReverseMove(const std::uint32_t packed) : - m_packed(packed) - { + [[nodiscard]] constexpr Bitboard operator|(Color c) const { + Bitboard bb = *this; + bb |= c; + return bb; + } - } + constexpr Bitboard& operator^=(Square sq) { + m_squares ^= Bitboard(sq).m_squares; + return *this; + } - constexpr PackedReverseMove(const ReverseMove& reverseMove) : - m_packed( - 0u - // The only move when square is none() is null move and - // then both squares are none(). No other move is like that - // so we don't lose any information by storing only - // the 6 bits of each square. - | ((ordinal(reverseMove.move.from) & squareMask) << 21) - | ((ordinal(reverseMove.move.to) & squareMask) << 15) - // Other masks are just for code clarity, they should - // never change the values. - | ((ordinal(reverseMove.capturedPiece) & pieceMask) << 11) - | ((ordinal(reverseMove.oldCastlingRights) & castlingRightsMask) << 7) - | ((ordinal(reverseMove.move.promotedPiece.type()) & pieceTypeMask) << 4) - | (((reverseMove.oldEpSquare != Square::none()) & 1) << 3) - // We probably could omit the squareMask here but for clarity it's left. - | (ordinal(Square(ordinal(reverseMove.oldEpSquare) & squareMask).file()) & fileMask) - ) - { - } + constexpr Bitboard& operator&=(Square sq) { + m_squares &= Bitboard(sq).m_squares; + return *this; + } - constexpr std::uint32_t packed() const - { - return m_packed; - } + constexpr Bitboard& operator|=(Square sq) { + m_squares |= Bitboard(sq).m_squares; + return *this; + } - constexpr ReverseMove unpack(Color sideThatMoved) const - { - ReverseMove rmove{}; + [[nodiscard]] constexpr Bitboard operator^(Square sq) const { + Bitboard bb = *this; + bb ^= sq; + return bb; + } - rmove.move.from = fromOrdinal((m_packed >> 21) & squareMask); - rmove.move.to = fromOrdinal((m_packed >> 15) & squareMask); - rmove.capturedPiece = fromOrdinal((m_packed >> 11) & pieceMask); - rmove.oldCastlingRights = fromOrdinal((m_packed >> 7) & castlingRightsMask); - const PieceType promotedPieceType = fromOrdinal((m_packed >> 4) & pieceTypeMask); - if (promotedPieceType != PieceType::None) - { - rmove.move.promotedPiece = Piece(promotedPieceType, sideThatMoved); - rmove.move.type = MoveType::Promotion; - } - const bool hasEpSquare = static_cast((m_packed >> 3) & 1); - if (hasEpSquare) - { - // ep square is always where the opponent moved - const Rank rank = - sideThatMoved == Color::White - ? rank6 - : rank3; - const File file = fromOrdinal(m_packed & fileMask); - rmove.oldEpSquare = Square(file, rank); - if (rmove.oldEpSquare == rmove.move.to) - { - rmove.move.type = MoveType::EnPassant; - } - } - else - { - rmove.oldEpSquare = Square::none(); - } - - if (rmove.move.type == MoveType::Normal && rmove.oldCastlingRights != CastlingRights::None) - { - // If castling was possible then we know it was the king that moved from e1/e8. - if (rmove.move.from == e1) - { - if (rmove.move.to == h1 || rmove.move.to == a1) - { - rmove.move.type = MoveType::Castle; - } - } - else if (rmove.move.from == e8) - { - if (rmove.move.to == h8 || rmove.move.to == a8) - { - rmove.move.type = MoveType::Castle; - } - } - } - - return rmove; - } - - private: - // Uses only 27 lowest bits. - // Bit meaning from highest to lowest. - // - 6 bits from - // - 6 bits to - // - 4 bits for the captured piece - // - 4 bits for prev castling rights - // - 3 bits promoted piece type - // - 1 bit to specify if the ep square was valid (false if none()) - // - 3 bits for prev ep square file - std::uint32_t m_packed; - }; + [[nodiscard]] constexpr Bitboard operator&(Square sq) const { + Bitboard bb = *this; + bb &= sq; + return bb; + } - struct MoveCompareLess - { - [[nodiscard]] bool operator()(const Move& lhs, const Move& rhs) const noexcept - { - if (ordinal(lhs.from) < ordinal(rhs.from)) return true; - if (ordinal(lhs.from) > ordinal(rhs.from)) return false; + [[nodiscard]] constexpr Bitboard operator|(Square sq) const { + Bitboard bb = *this; + bb |= sq; + return bb; + } - if (ordinal(lhs.to) < ordinal(rhs.to)) return true; - if (ordinal(lhs.to) > ordinal(rhs.to)) return false; + [[nodiscard]] constexpr friend Bitboard operator^(Square sq, Bitboard bb) { return bb ^ sq; } - if (ordinal(lhs.type) < ordinal(rhs.type)) return true; - if (ordinal(lhs.type) > ordinal(rhs.type)) return false; + [[nodiscard]] constexpr friend Bitboard operator&(Square sq, Bitboard bb) { return bb & sq; } - if (ordinal(lhs.promotedPiece) < ordinal(rhs.promotedPiece)) return true; + [[nodiscard]] constexpr friend Bitboard operator|(Square sq, Bitboard bb) { return bb | sq; } - return false; - } - }; + constexpr Bitboard& operator^=(Bitboard rhs) { + m_squares ^= rhs.m_squares; + return *this; + } - struct ReverseMoveCompareLess - { - [[nodiscard]] bool operator()(const ReverseMove& lhs, const ReverseMove& rhs) const noexcept - { - if (MoveCompareLess{}(lhs.move, rhs.move)) return true; - if (MoveCompareLess{}(rhs.move, lhs.move)) return false; + constexpr Bitboard& operator&=(Bitboard rhs) { + m_squares &= rhs.m_squares; + return *this; + } - if (ordinal(lhs.capturedPiece) < ordinal(rhs.capturedPiece)) return true; - if (ordinal(lhs.capturedPiece) > ordinal(rhs.capturedPiece)) return false; + constexpr Bitboard& operator|=(Bitboard rhs) { + m_squares |= rhs.m_squares; + return *this; + } - if (static_cast(lhs.oldCastlingRights) < static_cast(rhs.oldCastlingRights)) return true; - if (static_cast(lhs.oldCastlingRights) > static_cast(rhs.oldCastlingRights)) return false; + [[nodiscard]] constexpr Bitboard operator^(Bitboard sq) const { + Bitboard bb = *this; + bb ^= sq; + return bb; + } - if (ordinal(lhs.oldEpSquare) < ordinal(rhs.oldEpSquare)) return true; - if (ordinal(lhs.oldEpSquare) > ordinal(rhs.oldEpSquare)) return false; + [[nodiscard]] constexpr Bitboard operator&(Bitboard sq) const { + Bitboard bb = *this; + bb &= sq; + return bb; + } - return false; - } - }; + [[nodiscard]] constexpr Bitboard operator|(Bitboard sq) const { + Bitboard bb = *this; + bb |= sq; + return bb; + } - struct BitboardIterator - { - using value_type = Square; - using difference_type = std::ptrdiff_t; - using reference = Square; - using iterator_category = std::input_iterator_tag; - using pointer = const Square*; + [[nodiscard]] inline int count() const { return static_cast(intrin::popcount(m_squares)); } - constexpr BitboardIterator() noexcept : - m_squares(0) - { - } + [[nodiscard]] constexpr bool moreThanOne() const { return !!(m_squares & (m_squares - 1)); } - constexpr BitboardIterator(std::uint64_t v) noexcept : - m_squares(v) - { - } + [[nodiscard]] constexpr bool exactlyOne() const { return m_squares != 0 && !moreThanOne(); } - constexpr BitboardIterator(const BitboardIterator&) = default; - constexpr BitboardIterator(BitboardIterator&&) = default; - constexpr BitboardIterator& operator=(const BitboardIterator&) = default; - constexpr BitboardIterator& operator=(BitboardIterator&&) = default; + [[nodiscard]] constexpr bool any() const { return !!m_squares; } - [[nodiscard]] constexpr bool friend operator==(BitboardIterator lhs, BitboardIterator rhs) noexcept - { - return lhs.m_squares == rhs.m_squares; - } + [[nodiscard]] inline Square first() const { + assert(m_squares != 0); - [[nodiscard]] constexpr bool friend operator!=(BitboardIterator lhs, BitboardIterator rhs) noexcept - { - return lhs.m_squares != rhs.m_squares; - } + return fromOrdinal(intrin::lsb(m_squares)); + } - [[nodiscard]] inline Square operator*() const - { - return first(); - } + [[nodiscard]] inline Square nth(int n) const { + assert(count() > n); - constexpr BitboardIterator& operator++() noexcept - { - popFirst(); - return *this; - } + Bitboard cpy = *this; + while (n--) + cpy.popFirst(); + return cpy.first(); + } - private: - std::uint64_t m_squares; + [[nodiscard]] inline Square last() const { + assert(m_squares != 0); - constexpr void popFirst() noexcept - { - m_squares &= m_squares - 1; - } + return fromOrdinal(intrin::msb(m_squares)); + } - [[nodiscard]] inline Square first() const - { - assert(m_squares != 0); + [[nodiscard]] constexpr std::uint64_t bits() const { return m_squares; } - return fromOrdinal(intrin::lsb(m_squares)); - } - }; + constexpr void popFirst() { + assert(m_squares != 0); - struct Bitboard - { - // bits counted from the LSB - // order is A1 B2 ... G8 H8 - // just like in Square + m_squares &= m_squares - 1; + } - public: - constexpr Bitboard() noexcept : - m_squares(0) - { - } + constexpr Bitboard& operator=(const Bitboard& other) = default; - private: - constexpr explicit Bitboard(Square sq) noexcept : - m_squares(static_cast(1ULL) << ordinal(sq)) - { - assert(sq.isOk()); - } + private: + std::uint64_t m_squares; +}; - constexpr explicit Bitboard(Rank r) noexcept : - m_squares(static_cast(0xFFULL) << (ordinal(r) * 8)) - { - } +[[nodiscard]] constexpr Bitboard operator^(Square sq0, Square sq1) { + return Bitboard::square(sq0) ^ sq1; +} - constexpr explicit Bitboard(File f) noexcept : - m_squares(static_cast(0x0101010101010101ULL) << ordinal(f)) - { - } +[[nodiscard]] constexpr Bitboard operator&(Square sq0, Square sq1) { + return Bitboard::square(sq0) & sq1; +} - constexpr explicit Bitboard(Color c) noexcept : - m_squares(c == Color::White ? 0xAA55AA55AA55AA55ULL : ~0xAA55AA55AA55AA55ULL) - { - } +[[nodiscard]] constexpr Bitboard operator|(Square sq0, Square sq1) { + return Bitboard::square(sq0) | sq1; +} - constexpr explicit Bitboard(std::uint64_t bb) noexcept : - m_squares(bb) - { - } +[[nodiscard]] constexpr Bitboard operator""_bb(unsigned long long bits) { + return Bitboard::fromBits(bits); +} - // files A..file inclusive - static constexpr EnumArray m_filesUpToBB{ - 0x0101010101010101ULL, - 0x0303030303030303ULL, - 0x0707070707070707ULL, - 0x0F0F0F0F0F0F0F0FULL, - 0x1F1F1F1F1F1F1F1FULL, - 0x3F3F3F3F3F3F3F3FULL, - 0x7F7F7F7F7F7F7F7FULL, - 0xFFFFFFFFFFFFFFFFULL - }; +namespace bb { +namespace fancy_magics { +// Implementation based on https://github.com/syzygy1/Cfish + +alignas(64) constexpr EnumArray g_rookMagics{ + {0x0A80004000801220ull, 0x8040004010002008ull, 0x2080200010008008ull, 0x1100100008210004ull, + 0xC200209084020008ull, 0x2100010004000208ull, 0x0400081000822421ull, 0x0200010422048844ull, + 0x0800800080400024ull, 0x0001402000401000ull, 0x3000801000802001ull, 0x4400800800100083ull, + 0x0904802402480080ull, 0x4040800400020080ull, 0x0018808042000100ull, 0x4040800080004100ull, + 0x0040048001458024ull, 0x00A0004000205000ull, 0x3100808010002000ull, 0x4825010010000820ull, + 0x5004808008000401ull, 0x2024818004000A00ull, 0x0005808002000100ull, 0x2100060004806104ull, + 0x0080400880008421ull, 0x4062220600410280ull, 0x010A004A00108022ull, 0x0000100080080080ull, + 0x0021000500080010ull, 0x0044000202001008ull, 0x0000100400080102ull, 0xC020128200040545ull, + 0x0080002000400040ull, 0x0000804000802004ull, 0x0000120022004080ull, 0x010A386103001001ull, + 0x9010080080800400ull, 0x8440020080800400ull, 0x0004228824001001ull, 0x000000490A000084ull, + 0x0080002000504000ull, 0x200020005000C000ull, 0x0012088020420010ull, 0x0010010080080800ull, + 0x0085001008010004ull, 0x0002000204008080ull, 0x0040413002040008ull, 0x0000304081020004ull, + 0x0080204000800080ull, 0x3008804000290100ull, 0x1010100080200080ull, 0x2008100208028080ull, + 0x5000850800910100ull, 0x8402019004680200ull, 0x0120911028020400ull, 0x0000008044010200ull, + 0x0020850200244012ull, 0x0020850200244012ull, 0x0000102001040841ull, 0x140900040A100021ull, + 0x000200282410A102ull, 0x000200282410A102ull, 0x000200282410A102ull, 0x4048240043802106ull}}; + +alignas(64) constexpr EnumArray g_bishopMagics{ + {0x40106000A1160020ull, 0x0020010250810120ull, 0x2010010220280081ull, 0x002806004050C040ull, + 0x0002021018000000ull, 0x2001112010000400ull, 0x0881010120218080ull, 0x1030820110010500ull, + 0x0000120222042400ull, 0x2000020404040044ull, 0x8000480094208000ull, 0x0003422A02000001ull, + 0x000A220210100040ull, 0x8004820202226000ull, 0x0018234854100800ull, 0x0100004042101040ull, + 0x0004001004082820ull, 0x0010000810010048ull, 0x1014004208081300ull, 0x2080818802044202ull, + 0x0040880C00A00100ull, 0x0080400200522010ull, 0x0001000188180B04ull, 0x0080249202020204ull, + 0x1004400004100410ull, 0x00013100A0022206ull, 0x2148500001040080ull, 0x4241080011004300ull, + 0x4020848004002000ull, 0x10101380D1004100ull, 0x0008004422020284ull, 0x01010A1041008080ull, + 0x0808080400082121ull, 0x0808080400082121ull, 0x0091128200100C00ull, 0x0202200802010104ull, + 0x8C0A020200440085ull, 0x01A0008080B10040ull, 0x0889520080122800ull, 0x100902022202010Aull, + 0x04081A0816002000ull, 0x0000681208005000ull, 0x8170840041008802ull, 0x0A00004200810805ull, + 0x0830404408210100ull, 0x2602208106006102ull, 0x1048300680802628ull, 0x2602208106006102ull, + 0x0602010120110040ull, 0x0941010801043000ull, 0x000040440A210428ull, 0x0008240020880021ull, + 0x0400002012048200ull, 0x00AC102001210220ull, 0x0220021002009900ull, 0x84440C080A013080ull, + 0x0001008044200440ull, 0x0004C04410841000ull, 0x2000500104011130ull, 0x1A0C010011C20229ull, + 0x0044800112202200ull, 0x0434804908100424ull, 0x0300404822C08200ull, 0x48081010008A2A80ull}}; + +alignas(64) static EnumArray g_rookMasks; +alignas(64) static EnumArray g_rookShifts; +alignas(64) static EnumArray g_rookAttacks; + +alignas(64) static EnumArray g_bishopMasks; +alignas(64) static EnumArray g_bishopShifts; +alignas(64) static EnumArray g_bishopAttacks; + +alignas(64) static std::array g_allRookAttacks; +alignas(64) static std::array g_allBishopAttacks; + +inline Bitboard bishopAttacks(Square s, Bitboard occupied) { + const std::size_t idx = + (occupied & fancy_magics::g_bishopMasks[s]).bits() * fancy_magics::g_bishopMagics[s] + >> fancy_magics::g_bishopShifts[s]; + + return fancy_magics::g_bishopAttacks[s][idx]; +} - public: +inline Bitboard rookAttacks(Square s, Bitboard occupied) { + const std::size_t idx = + (occupied & fancy_magics::g_rookMasks[s]).bits() * fancy_magics::g_rookMagics[s] + >> fancy_magics::g_rookShifts[s]; - [[nodiscard]] static constexpr Bitboard none() - { - return Bitboard{}; - } + return fancy_magics::g_rookAttacks[s][idx]; +} +} - [[nodiscard]] static constexpr Bitboard all() - { - return ~none(); - } +[[nodiscard]] constexpr Bitboard square(Square sq) { return Bitboard::square(sq); } - [[nodiscard]] static constexpr Bitboard square(Square sq) - { - return Bitboard(sq); - } +[[nodiscard]] constexpr Bitboard rank(Rank rank) { return Bitboard::rank(rank); } - [[nodiscard]] static constexpr Bitboard file(File f) - { - return Bitboard(f); - } +[[nodiscard]] constexpr Bitboard file(File file) { return Bitboard::file(file); } - [[nodiscard]] static constexpr Bitboard rank(Rank r) - { - return Bitboard(r); - } +[[nodiscard]] constexpr Bitboard color(Color c) { return Bitboard::color(c); } - [[nodiscard]] static constexpr Bitboard color(Color c) - { - return Bitboard(c); - } +[[nodiscard]] constexpr Bitboard before(Square sq) { + return Bitboard::fromBits(nbitmask[ordinal(sq)]); +} - [[nodiscard]] static constexpr Bitboard fromBits(std::uint64_t bits) - { - return Bitboard(bits); - } +constexpr Bitboard lightSquares = bb::color(Color::White); +constexpr Bitboard darkSquares = bb::color(Color::Black); + +constexpr Bitboard fileA = bb::file(chess::fileA); +constexpr Bitboard fileB = bb::file(chess::fileB); +constexpr Bitboard fileC = bb::file(chess::fileC); +constexpr Bitboard fileD = bb::file(chess::fileD); +constexpr Bitboard fileE = bb::file(chess::fileE); +constexpr Bitboard fileF = bb::file(chess::fileF); +constexpr Bitboard fileG = bb::file(chess::fileG); +constexpr Bitboard fileH = bb::file(chess::fileH); + +constexpr Bitboard rank1 = bb::rank(chess::rank1); +constexpr Bitboard rank2 = bb::rank(chess::rank2); +constexpr Bitboard rank3 = bb::rank(chess::rank3); +constexpr Bitboard rank4 = bb::rank(chess::rank4); +constexpr Bitboard rank5 = bb::rank(chess::rank5); +constexpr Bitboard rank6 = bb::rank(chess::rank6); +constexpr Bitboard rank7 = bb::rank(chess::rank7); +constexpr Bitboard rank8 = bb::rank(chess::rank8); + +constexpr Bitboard a1 = bb::square(chess::a1); +constexpr Bitboard a2 = bb::square(chess::a2); +constexpr Bitboard a3 = bb::square(chess::a3); +constexpr Bitboard a4 = bb::square(chess::a4); +constexpr Bitboard a5 = bb::square(chess::a5); +constexpr Bitboard a6 = bb::square(chess::a6); +constexpr Bitboard a7 = bb::square(chess::a7); +constexpr Bitboard a8 = bb::square(chess::a8); + +constexpr Bitboard b1 = bb::square(chess::b1); +constexpr Bitboard b2 = bb::square(chess::b2); +constexpr Bitboard b3 = bb::square(chess::b3); +constexpr Bitboard b4 = bb::square(chess::b4); +constexpr Bitboard b5 = bb::square(chess::b5); +constexpr Bitboard b6 = bb::square(chess::b6); +constexpr Bitboard b7 = bb::square(chess::b7); +constexpr Bitboard b8 = bb::square(chess::b8); + +constexpr Bitboard c1 = bb::square(chess::c1); +constexpr Bitboard c2 = bb::square(chess::c2); +constexpr Bitboard c3 = bb::square(chess::c3); +constexpr Bitboard c4 = bb::square(chess::c4); +constexpr Bitboard c5 = bb::square(chess::c5); +constexpr Bitboard c6 = bb::square(chess::c6); +constexpr Bitboard c7 = bb::square(chess::c7); +constexpr Bitboard c8 = bb::square(chess::c8); + +constexpr Bitboard d1 = bb::square(chess::d1); +constexpr Bitboard d2 = bb::square(chess::d2); +constexpr Bitboard d3 = bb::square(chess::d3); +constexpr Bitboard d4 = bb::square(chess::d4); +constexpr Bitboard d5 = bb::square(chess::d5); +constexpr Bitboard d6 = bb::square(chess::d6); +constexpr Bitboard d7 = bb::square(chess::d7); +constexpr Bitboard d8 = bb::square(chess::d8); + +constexpr Bitboard e1 = bb::square(chess::e1); +constexpr Bitboard e2 = bb::square(chess::e2); +constexpr Bitboard e3 = bb::square(chess::e3); +constexpr Bitboard e4 = bb::square(chess::e4); +constexpr Bitboard e5 = bb::square(chess::e5); +constexpr Bitboard e6 = bb::square(chess::e6); +constexpr Bitboard e7 = bb::square(chess::e7); +constexpr Bitboard e8 = bb::square(chess::e8); + +constexpr Bitboard f1 = bb::square(chess::f1); +constexpr Bitboard f2 = bb::square(chess::f2); +constexpr Bitboard f3 = bb::square(chess::f3); +constexpr Bitboard f4 = bb::square(chess::f4); +constexpr Bitboard f5 = bb::square(chess::f5); +constexpr Bitboard f6 = bb::square(chess::f6); +constexpr Bitboard f7 = bb::square(chess::f7); +constexpr Bitboard f8 = bb::square(chess::f8); + +constexpr Bitboard g1 = bb::square(chess::g1); +constexpr Bitboard g2 = bb::square(chess::g2); +constexpr Bitboard g3 = bb::square(chess::g3); +constexpr Bitboard g4 = bb::square(chess::g4); +constexpr Bitboard g5 = bb::square(chess::g5); +constexpr Bitboard g6 = bb::square(chess::g6); +constexpr Bitboard g7 = bb::square(chess::g7); +constexpr Bitboard g8 = bb::square(chess::g8); + +constexpr Bitboard h1 = bb::square(chess::h1); +constexpr Bitboard h2 = bb::square(chess::h2); +constexpr Bitboard h3 = bb::square(chess::h3); +constexpr Bitboard h4 = bb::square(chess::h4); +constexpr Bitboard h5 = bb::square(chess::h5); +constexpr Bitboard h6 = bb::square(chess::h6); +constexpr Bitboard h7 = bb::square(chess::h7); +constexpr Bitboard h8 = bb::square(chess::h8); + +[[nodiscard]] Bitboard between(Square s1, Square s2); + +[[nodiscard]] Bitboard line(Square s1, Square s2); + +template +[[nodiscard]] Bitboard pseudoAttacks(Square sq); + +[[nodiscard]] Bitboard pseudoAttacks(PieceType pt, Square sq); + +template +Bitboard attacks(Square sq, Bitboard occupied) { + static_assert(PieceTypeV != PieceType::None && PieceTypeV != PieceType::Pawn); + + assert(sq.isOk()); + + if constexpr (PieceTypeV == PieceType::Bishop) + { + return fancy_magics::bishopAttacks(sq, occupied); + } + else if constexpr (PieceTypeV == PieceType::Rook) + { + return fancy_magics::rookAttacks(sq, occupied); + } + else if constexpr (PieceTypeV == PieceType::Queen) + { + return fancy_magics::bishopAttacks(sq, occupied) | fancy_magics::rookAttacks(sq, occupied); + } + else + { + return pseudoAttacks(sq); + } +} - // inclusive - [[nodiscard]] static constexpr Bitboard betweenFiles(File left, File right) - { - assert(left <= right); +[[nodiscard]] inline Bitboard attacks(PieceType pt, Square sq, Bitboard occupied) { + assert(sq.isOk()); - if (left == fileA) - { - return Bitboard::fromBits(m_filesUpToBB[right]); - } - else - { - return Bitboard::fromBits(m_filesUpToBB[right] ^ m_filesUpToBB[left - 1]); - } - } + switch (pt) + { + case PieceType::Bishop : + return attacks(sq, occupied); + case PieceType::Rook : + return attacks(sq, occupied); + case PieceType::Queen : + return attacks(sq, occupied); + default : + return pseudoAttacks(pt, sq); + } +} - [[nodiscard]] constexpr bool isEmpty() const - { - return m_squares == 0; - } +[[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color); - [[nodiscard]] constexpr bool isSet(Square sq) const - { - return !!((m_squares >> ordinal(sq)) & 1ull); - } +[[nodiscard]] inline Bitboard westPawnAttacks(Bitboard pawns, Color color); - constexpr void set(Square sq) - { - *this |= Bitboard(sq); - } +[[nodiscard]] inline Bitboard eastPawnAttacks(Bitboard pawns, Color color); - constexpr void unset(Square sq) - { - *this &= ~(Bitboard(sq)); - } +[[nodiscard]] inline bool +isAttackedBySlider(Square sq, Bitboard bishops, Bitboard rooks, Bitboard queens, Bitboard occupied); - constexpr void toggle(Square sq) - { - *this ^= Bitboard(sq); - } +namespace detail { +static constexpr std::array knightOffsets{ + {{-1, -2}, {-1, 2}, {1, -2}, {1, 2}, {-2, -1}, {-2, 1}, {2, -1}, {2, 1}}}; +static constexpr std::array kingOffsets{ + {{-1, -1}, {-1, 0}, {-1, 1}, {0, -1}, {0, 1}, {1, -1}, {1, 0}, {1, 1}}}; - [[nodiscard]] constexpr BitboardIterator begin() const - { - return BitboardIterator(m_squares); - } +enum Direction { + North = 0, + NorthEast, + East, + SouthEast, + South, + SouthWest, + West, + NorthWest +}; - [[nodiscard]] constexpr BitboardIterator end() const - { - return BitboardIterator{}; - } +constexpr std::array offsets = { + {{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}, {-1, -1}, {-1, 0}, {-1, 1}}}; - [[nodiscard]] constexpr BitboardIterator cbegin() const - { - return BitboardIterator(m_squares); - } +static constexpr std::array bishopOffsets{offsets[NorthEast], offsets[SouthEast], + offsets[SouthWest], offsets[NorthWest]}; +static constexpr std::array rookOffsets{offsets[North], offsets[East], offsets[South], + offsets[West]}; - [[nodiscard]] constexpr BitboardIterator cend() const - { - return BitboardIterator{}; - } +[[nodiscard]] static EnumArray generatePseudoAttacks_Pawn() { + // pseudo attacks don't make sense for pawns + return {}; +} - [[nodiscard]] constexpr bool friend operator==(Bitboard lhs, Bitboard rhs) noexcept - { - return lhs.m_squares == rhs.m_squares; - } +[[nodiscard]] static EnumArray generatePseudoAttacks_Knight() { + EnumArray bbs{}; - [[nodiscard]] constexpr bool friend operator!=(Bitboard lhs, Bitboard rhs) noexcept - { - return lhs.m_squares != rhs.m_squares; - } + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + Bitboard bb{}; - constexpr Bitboard shiftedVertically(int ranks) const + for (auto&& offset : knightOffsets) { - if (ranks >= 0) + const SquareCoords toSq = fromSq.coords() + offset; + if (toSq.isOk()) { - return fromBits(m_squares << 8 * ranks); - } - else - { - return fromBits(m_squares >> -8 * ranks); + bb |= Square(toSq); } } - template - constexpr void shift() - { - static_assert(files >= -7); - static_assert(ranks >= -7); - static_assert(files <= 7); - static_assert(ranks <= 7); - - if constexpr (files != 0) - { - constexpr Bitboard mask = - files > 0 - ? Bitboard::betweenFiles(fileA, fileH - files) - : Bitboard::betweenFiles(fileA - files, fileH); + bbs[fromSq] = bb; + } - m_squares &= mask.m_squares; - } + return bbs; +} - constexpr int shift = files + ranks * 8; - if constexpr (shift == 0) - { - return; - } +[[nodiscard]] static Bitboard generateSliderPseudoAttacks(const std::array& offsets_, + Square fromSq) { + assert(fromSq.isOk()); - if constexpr (shift < 0) - { - m_squares >>= -shift; - } - else - { - m_squares <<= shift; - } - } + Bitboard bb{}; - template - constexpr Bitboard shifted() const - { - Bitboard bbCpy(*this); - bbCpy.shift(); - return bbCpy; - } + for (auto&& offset : offsets_) + { + SquareCoords fromSqC = fromSq.coords(); - constexpr void shift(Offset offset) + for (;;) { - assert(offset.files >= -7); - assert(offset.ranks >= -7); - assert(offset.files <= 7); - assert(offset.ranks <= 7); + fromSqC += offset; - if (offset.files != 0) + if (!fromSqC.isOk()) { - const Bitboard mask = - offset.files > 0 - ? Bitboard::betweenFiles(fileA, fileH - offset.files) - : Bitboard::betweenFiles(fileA - offset.files, fileH); - - m_squares &= mask.m_squares; + break; } - const int shift = offset.files + offset.ranks * 8; - if (shift < 0) - { - m_squares >>= -shift; - } - else - { - m_squares <<= shift; - } + bb |= Square(fromSqC); } + } - [[nodiscard]] constexpr Bitboard shifted(Offset offset) const - { - Bitboard bbCpy(*this); - bbCpy.shift(offset); - return bbCpy; - } + return bb; +} - [[nodiscard]] constexpr Bitboard operator~() const - { - Bitboard bb = *this; - bb.m_squares = ~m_squares; - return bb; - } +[[nodiscard]] static EnumArray generatePseudoAttacks_Bishop() { + EnumArray bbs{}; - constexpr Bitboard& operator^=(Color c) - { - m_squares ^= Bitboard(c).m_squares; - return *this; - } + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generateSliderPseudoAttacks(bishopOffsets, fromSq); + } - constexpr Bitboard& operator&=(Color c) - { - m_squares &= Bitboard(c).m_squares; - return *this; - } + return bbs; +} - constexpr Bitboard& operator|=(Color c) - { - m_squares |= Bitboard(c).m_squares; - return *this; - } +[[nodiscard]] static EnumArray generatePseudoAttacks_Rook() { + EnumArray bbs{}; - [[nodiscard]] constexpr Bitboard operator^(Color c) const - { - Bitboard bb = *this; - bb ^= c; - return bb; - } + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generateSliderPseudoAttacks(rookOffsets, fromSq); + } - [[nodiscard]] constexpr Bitboard operator&(Color c) const - { - Bitboard bb = *this; - bb &= c; - return bb; - } + return bbs; +} - [[nodiscard]] constexpr Bitboard operator|(Color c) const - { - Bitboard bb = *this; - bb |= c; - return bb; - } +[[nodiscard]] static EnumArray generatePseudoAttacks_Queen() { + EnumArray bbs{}; - constexpr Bitboard& operator^=(Square sq) - { - m_squares ^= Bitboard(sq).m_squares; - return *this; - } + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generateSliderPseudoAttacks(bishopOffsets, fromSq) + | generateSliderPseudoAttacks(rookOffsets, fromSq); + } - constexpr Bitboard& operator&=(Square sq) - { - m_squares &= Bitboard(sq).m_squares; - return *this; - } + return bbs; +} - constexpr Bitboard& operator|=(Square sq) - { - m_squares |= Bitboard(sq).m_squares; - return *this; - } +[[nodiscard]] static EnumArray generatePseudoAttacks_King() { + EnumArray bbs{}; - [[nodiscard]] constexpr Bitboard operator^(Square sq) const - { - Bitboard bb = *this; - bb ^= sq; - return bb; - } + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + Bitboard bb{}; - [[nodiscard]] constexpr Bitboard operator&(Square sq) const + for (auto&& offset : kingOffsets) { - Bitboard bb = *this; - bb &= sq; - return bb; + const SquareCoords toSq = fromSq.coords() + offset; + if (toSq.isOk()) + { + bb |= Square(toSq); + } } - [[nodiscard]] constexpr Bitboard operator|(Square sq) const - { - Bitboard bb = *this; - bb |= sq; - return bb; - } + bbs[fromSq] = bb; + } - [[nodiscard]] constexpr friend Bitboard operator^(Square sq, Bitboard bb) - { - return bb ^ sq; - } + return bbs; +} - [[nodiscard]] constexpr friend Bitboard operator&(Square sq, Bitboard bb) - { - return bb & sq; - } +[[nodiscard]] static EnumArray2 generatePseudoAttacks() { + return EnumArray2{ + generatePseudoAttacks_Pawn(), generatePseudoAttacks_Knight(), generatePseudoAttacks_Bishop(), + generatePseudoAttacks_Rook(), generatePseudoAttacks_Queen(), generatePseudoAttacks_King()}; +} - [[nodiscard]] constexpr friend Bitboard operator|(Square sq, Bitboard bb) - { - return bb | sq; - } +static const EnumArray2& pseudoAttacks() { + static const EnumArray2 s_pseudoAttacks = generatePseudoAttacks(); + return s_pseudoAttacks; +} - constexpr Bitboard& operator^=(Bitboard rhs) - { - m_squares ^= rhs.m_squares; - return *this; - } +[[nodiscard]] static Bitboard generatePositiveRayAttacks(Direction dir, Square fromSq) { + assert(fromSq.isOk()); - constexpr Bitboard& operator&=(Bitboard rhs) - { - m_squares &= rhs.m_squares; - return *this; - } + Bitboard bb{}; - constexpr Bitboard& operator|=(Bitboard rhs) - { - m_squares |= rhs.m_squares; - return *this; - } + const auto offset = offsets[dir]; + SquareCoords fromSqC = fromSq.coords(); + for (;;) + { + fromSqC += offset; - [[nodiscard]] constexpr Bitboard operator^(Bitboard sq) const + if (!fromSqC.isOk()) { - Bitboard bb = *this; - bb ^= sq; - return bb; + break; } - [[nodiscard]] constexpr Bitboard operator&(Bitboard sq) const - { - Bitboard bb = *this; - bb &= sq; - return bb; - } + bb |= Square(fromSqC); + } - [[nodiscard]] constexpr Bitboard operator|(Bitboard sq) const - { - Bitboard bb = *this; - bb |= sq; - return bb; - } + return bb; +} - [[nodiscard]] inline int count() const - { - return static_cast(intrin::popcount(m_squares)); - } +// classical slider move generation approach https://www.chessprogramming.org/Classical_Approach - [[nodiscard]] constexpr bool moreThanOne() const - { - return !!(m_squares & (m_squares - 1)); - } +[[nodiscard]] static EnumArray generatePositiveRayAttacks(Direction dir) { + EnumArray bbs{}; - [[nodiscard]] constexpr bool exactlyOne() const - { - return m_squares != 0 && !moreThanOne(); - } + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generatePositiveRayAttacks(dir, fromSq); + } - [[nodiscard]] constexpr bool any() const - { - return !!m_squares; - } + return bbs; +} - [[nodiscard]] inline Square first() const - { - assert(m_squares != 0); +[[nodiscard]] static std::array, 8> generatePositiveRayAttacks() { + std::array, 8> bbs{}; - return fromOrdinal(intrin::lsb(m_squares)); - } + bbs[North] = generatePositiveRayAttacks(North); + bbs[NorthEast] = generatePositiveRayAttacks(NorthEast); + bbs[East] = generatePositiveRayAttacks(East); + bbs[SouthEast] = generatePositiveRayAttacks(SouthEast); + bbs[South] = generatePositiveRayAttacks(South); + bbs[SouthWest] = generatePositiveRayAttacks(SouthWest); + bbs[West] = generatePositiveRayAttacks(West); + bbs[NorthWest] = generatePositiveRayAttacks(NorthWest); - [[nodiscard]] inline Square nth(int n) const - { - assert(count() > n); + return bbs; +} - Bitboard cpy = *this; - while (n--) cpy.popFirst(); - return cpy.first(); - } - [[nodiscard]] inline Square last() const - { - assert(m_squares != 0); +static const std::array, 8>& positiveRayAttacks() { + static const std::array, 8> s_positiveRayAttacks = + generatePositiveRayAttacks(); + return s_positiveRayAttacks; +} - return fromOrdinal(intrin::msb(m_squares)); - } +template +[[nodiscard]] static Bitboard slidingAttacks(Square sq, Bitboard occupied) { + assert(sq.isOk()); - [[nodiscard]] constexpr std::uint64_t bits() const - { - return m_squares; - } + Bitboard attacks = positiveRayAttacks()[DirV][sq]; - constexpr void popFirst() - { - assert(m_squares != 0); + if constexpr (DirV == NorthWest || DirV == North || DirV == NorthEast || DirV == East) + { + Bitboard blocker = (attacks & occupied) | h8; // set highest bit (H8) so msb never fails + return attacks ^ positiveRayAttacks()[DirV][blocker.first()]; + } + else + { + Bitboard blocker = (attacks & occupied) | a1; + return attacks ^ positiveRayAttacks()[DirV][blocker.last()]; + } +} - m_squares &= m_squares - 1; - } +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); +template Bitboard slidingAttacks(Square, Bitboard); - constexpr Bitboard& operator=(const Bitboard& other) = default; +template +[[nodiscard]] inline Bitboard pieceSlidingAttacks(Square sq, Bitboard occupied) { + static_assert(PieceTypeV == PieceType::Rook || PieceTypeV == PieceType::Bishop + || PieceTypeV == PieceType::Queen); - private: - std::uint64_t m_squares; - }; + assert(sq.isOk()); - [[nodiscard]] constexpr Bitboard operator^(Square sq0, Square sq1) - { - return Bitboard::square(sq0) ^ sq1; - } - - [[nodiscard]] constexpr Bitboard operator&(Square sq0, Square sq1) - { - return Bitboard::square(sq0) & sq1; - } - - [[nodiscard]] constexpr Bitboard operator|(Square sq0, Square sq1) - { - return Bitboard::square(sq0) | sq1; - } - - [[nodiscard]] constexpr Bitboard operator""_bb(unsigned long long bits) - { - return Bitboard::fromBits(bits); - } - - namespace bb - { - namespace fancy_magics - { - // Implementation based on https://github.com/syzygy1/Cfish - - alignas(64) constexpr EnumArray g_rookMagics{ { - 0x0A80004000801220ull, - 0x8040004010002008ull, - 0x2080200010008008ull, - 0x1100100008210004ull, - 0xC200209084020008ull, - 0x2100010004000208ull, - 0x0400081000822421ull, - 0x0200010422048844ull, - 0x0800800080400024ull, - 0x0001402000401000ull, - 0x3000801000802001ull, - 0x4400800800100083ull, - 0x0904802402480080ull, - 0x4040800400020080ull, - 0x0018808042000100ull, - 0x4040800080004100ull, - 0x0040048001458024ull, - 0x00A0004000205000ull, - 0x3100808010002000ull, - 0x4825010010000820ull, - 0x5004808008000401ull, - 0x2024818004000A00ull, - 0x0005808002000100ull, - 0x2100060004806104ull, - 0x0080400880008421ull, - 0x4062220600410280ull, - 0x010A004A00108022ull, - 0x0000100080080080ull, - 0x0021000500080010ull, - 0x0044000202001008ull, - 0x0000100400080102ull, - 0xC020128200040545ull, - 0x0080002000400040ull, - 0x0000804000802004ull, - 0x0000120022004080ull, - 0x010A386103001001ull, - 0x9010080080800400ull, - 0x8440020080800400ull, - 0x0004228824001001ull, - 0x000000490A000084ull, - 0x0080002000504000ull, - 0x200020005000C000ull, - 0x0012088020420010ull, - 0x0010010080080800ull, - 0x0085001008010004ull, - 0x0002000204008080ull, - 0x0040413002040008ull, - 0x0000304081020004ull, - 0x0080204000800080ull, - 0x3008804000290100ull, - 0x1010100080200080ull, - 0x2008100208028080ull, - 0x5000850800910100ull, - 0x8402019004680200ull, - 0x0120911028020400ull, - 0x0000008044010200ull, - 0x0020850200244012ull, - 0x0020850200244012ull, - 0x0000102001040841ull, - 0x140900040A100021ull, - 0x000200282410A102ull, - 0x000200282410A102ull, - 0x000200282410A102ull, - 0x4048240043802106ull - } }; - - alignas(64) constexpr EnumArray g_bishopMagics{ { - 0x40106000A1160020ull, - 0x0020010250810120ull, - 0x2010010220280081ull, - 0x002806004050C040ull, - 0x0002021018000000ull, - 0x2001112010000400ull, - 0x0881010120218080ull, - 0x1030820110010500ull, - 0x0000120222042400ull, - 0x2000020404040044ull, - 0x8000480094208000ull, - 0x0003422A02000001ull, - 0x000A220210100040ull, - 0x8004820202226000ull, - 0x0018234854100800ull, - 0x0100004042101040ull, - 0x0004001004082820ull, - 0x0010000810010048ull, - 0x1014004208081300ull, - 0x2080818802044202ull, - 0x0040880C00A00100ull, - 0x0080400200522010ull, - 0x0001000188180B04ull, - 0x0080249202020204ull, - 0x1004400004100410ull, - 0x00013100A0022206ull, - 0x2148500001040080ull, - 0x4241080011004300ull, - 0x4020848004002000ull, - 0x10101380D1004100ull, - 0x0008004422020284ull, - 0x01010A1041008080ull, - 0x0808080400082121ull, - 0x0808080400082121ull, - 0x0091128200100C00ull, - 0x0202200802010104ull, - 0x8C0A020200440085ull, - 0x01A0008080B10040ull, - 0x0889520080122800ull, - 0x100902022202010Aull, - 0x04081A0816002000ull, - 0x0000681208005000ull, - 0x8170840041008802ull, - 0x0A00004200810805ull, - 0x0830404408210100ull, - 0x2602208106006102ull, - 0x1048300680802628ull, - 0x2602208106006102ull, - 0x0602010120110040ull, - 0x0941010801043000ull, - 0x000040440A210428ull, - 0x0008240020880021ull, - 0x0400002012048200ull, - 0x00AC102001210220ull, - 0x0220021002009900ull, - 0x84440C080A013080ull, - 0x0001008044200440ull, - 0x0004C04410841000ull, - 0x2000500104011130ull, - 0x1A0C010011C20229ull, - 0x0044800112202200ull, - 0x0434804908100424ull, - 0x0300404822C08200ull, - 0x48081010008A2A80ull - } }; - - alignas(64) static EnumArray g_rookMasks; - alignas(64) static EnumArray g_rookShifts; - alignas(64) static EnumArray g_rookAttacks; - - alignas(64) static EnumArray g_bishopMasks; - alignas(64) static EnumArray g_bishopShifts; - alignas(64) static EnumArray g_bishopAttacks; - - alignas(64) static std::array g_allRookAttacks; - alignas(64) static std::array g_allBishopAttacks; - - inline Bitboard bishopAttacks(Square s, Bitboard occupied) - { - const std::size_t idx = - (occupied & fancy_magics::g_bishopMasks[s]).bits() - * fancy_magics::g_bishopMagics[s] - >> fancy_magics::g_bishopShifts[s]; + if constexpr (PieceTypeV == PieceType::Bishop) + { + return detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied); + } + else if constexpr (PieceTypeV == PieceType::Rook) + { + return detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied); + } + else // if constexpr (PieceTypeV == PieceType::Queen) + { + return detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied); + } +} - return fancy_magics::g_bishopAttacks[s][idx]; - } +static Bitboard generateBetween(Square s1, Square s2) { + Bitboard bb = Bitboard::none(); - inline Bitboard rookAttacks(Square s, Bitboard occupied) - { - const std::size_t idx = - (occupied & fancy_magics::g_rookMasks[s]).bits() - * fancy_magics::g_rookMagics[s] - >> fancy_magics::g_rookShifts[s]; + if (s1 == s2) + { + return bb; + } - return fancy_magics::g_rookAttacks[s][idx]; - } - } + const int fd = s2.file() - s1.file(); + const int rd = s2.rank() - s1.rank(); - [[nodiscard]] constexpr Bitboard square(Square sq) + if (fd == 0 || rd == 0 || fd == rd || fd == -rd) + { + // s1 and s2 lie on a line. + const int fileStep = (fd > 0) - (fd < 0); + const int rankStep = (rd > 0) - (rd < 0); + const auto step = FlatSquareOffset(fileStep, rankStep); + s1 += step; // omit s1 + while (s1 != s2) // omit s2 { - return Bitboard::square(sq); - } - - [[nodiscard]] constexpr Bitboard rank(Rank rank) - { - return Bitboard::rank(rank); - } - - [[nodiscard]] constexpr Bitboard file(File file) - { - return Bitboard::file(file); - } - - [[nodiscard]] constexpr Bitboard color(Color c) - { - return Bitboard::color(c); - } - - [[nodiscard]] constexpr Bitboard before(Square sq) - { - return Bitboard::fromBits(nbitmask[ordinal(sq)]); + bb |= s1; + s1 += step; } + } - constexpr Bitboard lightSquares = bb::color(Color::White); - constexpr Bitboard darkSquares = bb::color(Color::Black); + return bb; +} - constexpr Bitboard fileA = bb::file(chess::fileA); - constexpr Bitboard fileB = bb::file(chess::fileB); - constexpr Bitboard fileC = bb::file(chess::fileC); - constexpr Bitboard fileD = bb::file(chess::fileD); - constexpr Bitboard fileE = bb::file(chess::fileE); - constexpr Bitboard fileF = bb::file(chess::fileF); - constexpr Bitboard fileG = bb::file(chess::fileG); - constexpr Bitboard fileH = bb::file(chess::fileH); +static Bitboard generateLine(Square s1, Square s2) { + for (PieceType pt : {PieceType::Bishop, PieceType::Rook}) + { + const Bitboard s1Attacks = pseudoAttacks()[pt][s1]; + if (s1Attacks.isSet(s2)) + { + const Bitboard s2Attacks = pseudoAttacks()[pt][s2]; + return (s1Attacks & s2Attacks) | s1 | s2; + } + } - constexpr Bitboard rank1 = bb::rank(chess::rank1); - constexpr Bitboard rank2 = bb::rank(chess::rank2); - constexpr Bitboard rank3 = bb::rank(chess::rank3); - constexpr Bitboard rank4 = bb::rank(chess::rank4); - constexpr Bitboard rank5 = bb::rank(chess::rank5); - constexpr Bitboard rank6 = bb::rank(chess::rank6); - constexpr Bitboard rank7 = bb::rank(chess::rank7); - constexpr Bitboard rank8 = bb::rank(chess::rank8); + return Bitboard::none(); +} - constexpr Bitboard a1 = bb::square(chess::a1); - constexpr Bitboard a2 = bb::square(chess::a2); - constexpr Bitboard a3 = bb::square(chess::a3); - constexpr Bitboard a4 = bb::square(chess::a4); - constexpr Bitboard a5 = bb::square(chess::a5); - constexpr Bitboard a6 = bb::square(chess::a6); - constexpr Bitboard a7 = bb::square(chess::a7); - constexpr Bitboard a8 = bb::square(chess::a8); +static const EnumArray2 between = []() { + EnumArray2 between_; - constexpr Bitboard b1 = bb::square(chess::b1); - constexpr Bitboard b2 = bb::square(chess::b2); - constexpr Bitboard b3 = bb::square(chess::b3); - constexpr Bitboard b4 = bb::square(chess::b4); - constexpr Bitboard b5 = bb::square(chess::b5); - constexpr Bitboard b6 = bb::square(chess::b6); - constexpr Bitboard b7 = bb::square(chess::b7); - constexpr Bitboard b8 = bb::square(chess::b8); + for (Square s1 : values()) + { + for (Square s2 : values()) + { + between_[s1][s2] = generateBetween(s1, s2); + } + } - constexpr Bitboard c1 = bb::square(chess::c1); - constexpr Bitboard c2 = bb::square(chess::c2); - constexpr Bitboard c3 = bb::square(chess::c3); - constexpr Bitboard c4 = bb::square(chess::c4); - constexpr Bitboard c5 = bb::square(chess::c5); - constexpr Bitboard c6 = bb::square(chess::c6); - constexpr Bitboard c7 = bb::square(chess::c7); - constexpr Bitboard c8 = bb::square(chess::c8); + return between_; +}(); - constexpr Bitboard d1 = bb::square(chess::d1); - constexpr Bitboard d2 = bb::square(chess::d2); - constexpr Bitboard d3 = bb::square(chess::d3); - constexpr Bitboard d4 = bb::square(chess::d4); - constexpr Bitboard d5 = bb::square(chess::d5); - constexpr Bitboard d6 = bb::square(chess::d6); - constexpr Bitboard d7 = bb::square(chess::d7); - constexpr Bitboard d8 = bb::square(chess::d8); +static const EnumArray2 line = []() { + EnumArray2 line_; - constexpr Bitboard e1 = bb::square(chess::e1); - constexpr Bitboard e2 = bb::square(chess::e2); - constexpr Bitboard e3 = bb::square(chess::e3); - constexpr Bitboard e4 = bb::square(chess::e4); - constexpr Bitboard e5 = bb::square(chess::e5); - constexpr Bitboard e6 = bb::square(chess::e6); - constexpr Bitboard e7 = bb::square(chess::e7); - constexpr Bitboard e8 = bb::square(chess::e8); + for (Square s1 : values()) + { + for (Square s2 : values()) + { + line_[s1][s2] = generateLine(s1, s2); + } + } - constexpr Bitboard f1 = bb::square(chess::f1); - constexpr Bitboard f2 = bb::square(chess::f2); - constexpr Bitboard f3 = bb::square(chess::f3); - constexpr Bitboard f4 = bb::square(chess::f4); - constexpr Bitboard f5 = bb::square(chess::f5); - constexpr Bitboard f6 = bb::square(chess::f6); - constexpr Bitboard f7 = bb::square(chess::f7); - constexpr Bitboard f8 = bb::square(chess::f8); + return line_; +}(); +} - constexpr Bitboard g1 = bb::square(chess::g1); - constexpr Bitboard g2 = bb::square(chess::g2); - constexpr Bitboard g3 = bb::square(chess::g3); - constexpr Bitboard g4 = bb::square(chess::g4); - constexpr Bitboard g5 = bb::square(chess::g5); - constexpr Bitboard g6 = bb::square(chess::g6); - constexpr Bitboard g7 = bb::square(chess::g7); - constexpr Bitboard g8 = bb::square(chess::g8); - - constexpr Bitboard h1 = bb::square(chess::h1); - constexpr Bitboard h2 = bb::square(chess::h2); - constexpr Bitboard h3 = bb::square(chess::h3); - constexpr Bitboard h4 = bb::square(chess::h4); - constexpr Bitboard h5 = bb::square(chess::h5); - constexpr Bitboard h6 = bb::square(chess::h6); - constexpr Bitboard h7 = bb::square(chess::h7); - constexpr Bitboard h8 = bb::square(chess::h8); - - [[nodiscard]] Bitboard between(Square s1, Square s2); - - [[nodiscard]] Bitboard line(Square s1, Square s2); - - template - [[nodiscard]] Bitboard pseudoAttacks(Square sq); - - [[nodiscard]] Bitboard pseudoAttacks(PieceType pt, Square sq); - - template - Bitboard attacks(Square sq, Bitboard occupied) - { - static_assert(PieceTypeV != PieceType::None && PieceTypeV != PieceType::Pawn); - - assert(sq.isOk()); - - if constexpr (PieceTypeV == PieceType::Bishop) - { - return fancy_magics::bishopAttacks(sq, occupied); - } - else if constexpr (PieceTypeV == PieceType::Rook) - { - return fancy_magics::rookAttacks(sq, occupied); - } - else if constexpr (PieceTypeV == PieceType::Queen) - { - return - fancy_magics::bishopAttacks(sq, occupied) - | fancy_magics::rookAttacks(sq, occupied); - } - else - { - return pseudoAttacks(sq); - } - } +namespace fancy_magics { +enum struct MagicsType { + Rook, + Bishop +}; - [[nodiscard]] inline Bitboard attacks(PieceType pt, Square sq, Bitboard occupied) - { - assert(sq.isOk()); +template +[[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied) { + if (TypeV == MagicsType::Rook) + { + return chess::bb::detail::pieceSlidingAttacks(sq, occupied); + } - switch (pt) - { - case PieceType::Bishop: - return attacks(sq, occupied); - case PieceType::Rook: - return attacks(sq, occupied); - case PieceType::Queen: - return attacks(sq, occupied); - default: - return pseudoAttacks(pt, sq); - } - } + if (TypeV == MagicsType::Bishop) + { + return chess::bb::detail::pieceSlidingAttacks(sq, occupied); + } - [[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color); + return Bitboard::none(); +} - [[nodiscard]] inline Bitboard westPawnAttacks(Bitboard pawns, Color color); +template +[[nodiscard]] inline bool initMagics(const EnumArray& magics, + std::array& table, + EnumArray& masks, + EnumArray& shifts, + EnumArray& attacks) { + std::size_t size = 0; + for (Square sq : values()) + { + const Bitboard edges = ((bb::rank1 | bb::rank8) & ~Bitboard::rank(sq.rank())) + | ((bb::fileA | bb::fileH) & ~Bitboard::file(sq.file())); - [[nodiscard]] inline Bitboard eastPawnAttacks(Bitboard pawns, Color color); + Bitboard* currentAttacks = table.data() + size; - [[nodiscard]] inline bool isAttackedBySlider( - Square sq, - Bitboard bishops, - Bitboard rooks, - Bitboard queens, - Bitboard occupied - ); + attacks[sq] = currentAttacks; + masks[sq] = slidingAttacks(sq, Bitboard::none()) & ~edges; + shifts[sq] = 64 - masks[sq].count(); - namespace detail + Bitboard occupied = Bitboard::none(); + do { - static constexpr std::array knightOffsets{ { {-1, -2}, {-1, 2}, {1, -2}, {1, 2}, {-2, -1}, {-2, 1}, {2, -1}, {2, 1} } }; - static constexpr std::array kingOffsets{ { {-1, -1}, {-1, 0}, {-1, 1}, {0, -1}, {0, 1}, {1, -1}, {1, 0}, {1, 1} } }; + const std::size_t idx = (occupied & masks[sq]).bits() * magics[sq] >> shifts[sq]; - enum Direction - { - North = 0, - NorthEast, - East, - SouthEast, - South, - SouthWest, - West, - NorthWest - }; - - constexpr std::array offsets = { { - { 0, 1 }, - { 1, 1 }, - { 1, 0 }, - { 1, -1 }, - { 0, -1 }, - { -1, -1 }, - { -1, 0 }, - { -1, 1 } - } }; - - static constexpr std::array bishopOffsets{ - offsets[NorthEast], - offsets[SouthEast], - offsets[SouthWest], - offsets[NorthWest] - }; - static constexpr std::array rookOffsets{ - offsets[North], - offsets[East], - offsets[South], - offsets[West] - }; - - [[nodiscard]] static EnumArray generatePseudoAttacks_Pawn() - { - // pseudo attacks don't make sense for pawns - return {}; - } + currentAttacks[idx] = slidingAttacks(sq, occupied); - [[nodiscard]] static EnumArray generatePseudoAttacks_Knight() - { - EnumArray bbs{}; + ++size; + occupied = Bitboard::fromBits(occupied.bits() - masks[sq].bits()) & masks[sq]; + } while (occupied.any()); + } - for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) - { - Bitboard bb{}; + return true; +} - for (auto&& offset : knightOffsets) - { - const SquareCoords toSq = fromSq.coords() + offset; - if (toSq.isOk()) - { - bb |= Square(toSq); - } - } +static bool g_isRookMagicsInitialized = initMagics( + g_rookMagics, g_allRookAttacks, g_rookMasks, g_rookShifts, g_rookAttacks); - bbs[fromSq] = bb; - } +static bool g_isBishopMagicsInitialized = initMagics( + g_bishopMagics, g_allBishopAttacks, g_bishopMasks, g_bishopShifts, g_bishopAttacks); +} - return bbs; - } +[[nodiscard]] inline Bitboard between(Square s1, Square s2) { return detail::between[s1][s2]; } - [[nodiscard]] static Bitboard generateSliderPseudoAttacks(const std::array & offsets_, Square fromSq) - { - assert(fromSq.isOk()); +[[nodiscard]] inline Bitboard line(Square s1, Square s2) { return detail::line[s1][s2]; } - Bitboard bb{}; +template +[[nodiscard]] inline Bitboard pseudoAttacks(Square sq) { + static_assert(PieceTypeV != PieceType::None && PieceTypeV != PieceType::Pawn); - for (auto&& offset : offsets_) - { - SquareCoords fromSqC = fromSq.coords(); + assert(sq.isOk()); - for (;;) - { - fromSqC += offset; + return detail::pseudoAttacks()[PieceTypeV][sq]; +} - if (!fromSqC.isOk()) - { - break; - } +[[nodiscard]] inline Bitboard pseudoAttacks(PieceType pt, Square sq) { + assert(sq.isOk()); - bb |= Square(fromSqC); - } - } + return detail::pseudoAttacks()[pt][sq]; +} - return bb; - } +[[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color) { + if (color == Color::White) + { + return pawns.shifted<1, 1>() | pawns.shifted<-1, 1>(); + } + else + { + return pawns.shifted<1, -1>() | pawns.shifted<-1, -1>(); + } +} - [[nodiscard]] static EnumArray generatePseudoAttacks_Bishop() - { - EnumArray bbs{}; +[[nodiscard]] inline Bitboard westPawnAttacks(Bitboard pawns, Color color) { + if (color == Color::White) + { + return pawns.shifted<-1, 1>(); + } + else + { + return pawns.shifted<-1, -1>(); + } +} - for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) - { - bbs[fromSq] = generateSliderPseudoAttacks(bishopOffsets, fromSq); - } +[[nodiscard]] inline Bitboard eastPawnAttacks(Bitboard pawns, Color color) { + if (color == Color::White) + { + return pawns.shifted<1, 1>(); + } + else + { + return pawns.shifted<1, -1>(); + } +} - return bbs; - } +[[nodiscard]] inline bool isAttackedBySlider( + Square sq, Bitboard bishops, Bitboard rooks, Bitboard queens, Bitboard occupied) { + const Bitboard opponentBishopLikePieces = (bishops | queens); + const Bitboard bishopAttacks = bb::attacks(sq, occupied); + if ((bishopAttacks & opponentBishopLikePieces).any()) + { + return true; + } - [[nodiscard]] static EnumArray generatePseudoAttacks_Rook() - { - EnumArray bbs{}; + const Bitboard opponentRookLikePieces = (rooks | queens); + const Bitboard rookAttacks = bb::attacks(sq, occupied); + return (rookAttacks & opponentRookLikePieces).any(); +} +} - for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) - { - bbs[fromSq] = generateSliderPseudoAttacks(rookOffsets, fromSq); - } +struct CastlingTraits { + static constexpr EnumArray2 rookDestination = { + {{{f1, d1}}, {{f8, d8}}}}; + static constexpr EnumArray2 kingDestination = { + {{{g1, c1}}, {{g8, c8}}}}; - return bbs; - } + static constexpr EnumArray2 rookStart = {{{{h1, a1}}, {{h8, a8}}}}; - [[nodiscard]] static EnumArray generatePseudoAttacks_Queen() - { - EnumArray bbs{}; + static constexpr EnumArray kingStart = {{e1, e8}}; - for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) - { - bbs[fromSq] = - generateSliderPseudoAttacks(bishopOffsets, fromSq) - | generateSliderPseudoAttacks(rookOffsets, fromSq); - } + static constexpr EnumArray2 castlingPath = { + {{{Bitboard::square(f1) | g1, Bitboard::square(b1) | c1 | d1}}, + {{Bitboard::square(f8) | g8, Bitboard::square(b8) | c8 | d8}}}}; - return bbs; - } + static constexpr EnumArray2 squarePassedByKing = { + {{{f1, d1}}, {{f8, d8}}}}; - [[nodiscard]] static EnumArray generatePseudoAttacks_King() - { - EnumArray bbs{}; + static constexpr EnumArray2 castlingRights = { + {{{CastlingRights::WhiteKingSide, CastlingRights::WhiteQueenSide}}, + {{CastlingRights::BlackKingSide, CastlingRights::BlackQueenSide}}}}; - for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) - { - Bitboard bb{}; + // Move has to be a legal castling move. + static constexpr CastleType moveCastlingType(const Move& move) { + return (move.to.file() == fileH) ? CastleType::Short : CastleType::Long; + } - for (auto&& offset : kingOffsets) - { - const SquareCoords toSq = fromSq.coords() + offset; - if (toSq.isOk()) - { - bb |= Square(toSq); - } - } + // Move must be a legal castling move. + static constexpr CastlingRights moveCastlingRight(Move move) { + if (move.to == h1) + return CastlingRights::WhiteKingSide; + if (move.to == a1) + return CastlingRights::WhiteQueenSide; + if (move.to == h8) + return CastlingRights::WhiteKingSide; + if (move.to == a8) + return CastlingRights::WhiteQueenSide; + return CastlingRights::None; + } +}; - bbs[fromSq] = bb; - } +namespace parser_bits { +[[nodiscard]] constexpr bool isFile(char c) { return c >= 'a' && c <= 'h'; } - return bbs; - } +[[nodiscard]] constexpr bool isRank(char c) { return c >= '1' && c <= '8'; } - [[nodiscard]] static EnumArray2 generatePseudoAttacks() - { - return EnumArray2{ - generatePseudoAttacks_Pawn(), - generatePseudoAttacks_Knight(), - generatePseudoAttacks_Bishop(), - generatePseudoAttacks_Rook(), - generatePseudoAttacks_Queen(), - generatePseudoAttacks_King() - }; - } +[[nodiscard]] constexpr Rank parseRank(char c) { + assert(isRank(c)); - static const EnumArray2& pseudoAttacks() - { - static const EnumArray2 s_pseudoAttacks = generatePseudoAttacks(); - return s_pseudoAttacks; - } + return fromOrdinal(c - '1'); +} - [[nodiscard]] static Bitboard generatePositiveRayAttacks(Direction dir, Square fromSq) - { - assert(fromSq.isOk()); +[[nodiscard]] constexpr File parseFile(char c) { + assert(isFile(c)); - Bitboard bb{}; + return fromOrdinal(c - 'a'); +} - const auto offset = offsets[dir]; - SquareCoords fromSqC = fromSq.coords(); - for (;;) - { - fromSqC += offset; +[[nodiscard]] constexpr bool isSquare(const char* s) { return isFile(s[0]) && isRank(s[1]); } - if (!fromSqC.isOk()) - { - break; - } +[[nodiscard]] constexpr Square parseSquare(const char* s) { + const File file = parseFile(s[0]); + const Rank rank = parseRank(s[1]); + return Square(file, rank); +} - bb |= Square(fromSqC); - } +[[nodiscard]] constexpr std::optional tryParseSquare(std::string_view s) { + if (s.size() != 2) + return {}; + if (!isSquare(s.data())) + return {}; + return parseSquare(s.data()); +} - return bb; - } +[[nodiscard]] constexpr std::optional tryParseEpSquare(std::string_view s) { + if (s == std::string_view("-")) + return Square::none(); + return tryParseSquare(s); +} - // classical slider move generation approach https://www.chessprogramming.org/Classical_Approach +[[nodiscard]] constexpr std::optional tryParseCastlingRights(std::string_view s) { + if (s == std::string_view("-")) + return CastlingRights::None; - [[nodiscard]] static EnumArray generatePositiveRayAttacks(Direction dir) - { - EnumArray bbs{}; + CastlingRights rights = CastlingRights::None; - for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) - { - bbs[fromSq] = generatePositiveRayAttacks(dir, fromSq); - } + for (auto& c : s) + { + CastlingRights toAdd = CastlingRights::None; + switch (c) + { + case 'K' : + toAdd = CastlingRights::WhiteKingSide; + break; + case 'Q' : + toAdd = CastlingRights::WhiteQueenSide; + break; + case 'k' : + toAdd = CastlingRights::BlackKingSide; + break; + case 'q' : + toAdd = CastlingRights::BlackQueenSide; + break; + } - return bbs; - } + // If there are duplicated castling rights specification we bail. + // If there is an invalid character we bail. + // (It always contains None) + if (contains(rights, toAdd)) + return {}; + else + rights |= toAdd; + } - [[nodiscard]] static std::array, 8> generatePositiveRayAttacks() - { - std::array, 8> bbs{}; - - bbs[North] = generatePositiveRayAttacks(North); - bbs[NorthEast] = generatePositiveRayAttacks(NorthEast); - bbs[East] = generatePositiveRayAttacks(East); - bbs[SouthEast] = generatePositiveRayAttacks(SouthEast); - bbs[South] = generatePositiveRayAttacks(South); - bbs[SouthWest] = generatePositiveRayAttacks(SouthWest); - bbs[West] = generatePositiveRayAttacks(West); - bbs[NorthWest] = generatePositiveRayAttacks(NorthWest); - - return bbs; - } + return rights; +} +[[nodiscard]] constexpr CastlingRights readCastlingRights(const char*& s) { + CastlingRights rights = CastlingRights::None; - static const std::array, 8>& positiveRayAttacks() - { - static const std::array, 8> s_positiveRayAttacks = generatePositiveRayAttacks(); - return s_positiveRayAttacks; - } + while (*s != ' ') + { + switch (*s) + { + case 'K' : + rights |= CastlingRights::WhiteKingSide; + break; + case 'Q' : + rights |= CastlingRights::WhiteQueenSide; + break; + case 'k' : + rights |= CastlingRights::BlackKingSide; + break; + case 'q' : + rights |= CastlingRights::BlackQueenSide; + break; + } - template - [[nodiscard]] static Bitboard slidingAttacks(Square sq, Bitboard occupied) - { - assert(sq.isOk()); + ++s; + } - Bitboard attacks = positiveRayAttacks()[DirV][sq]; + return rights; +} - if constexpr (DirV == NorthWest || DirV == North || DirV == NorthEast || DirV == East) - { - Bitboard blocker = (attacks & occupied) | h8; // set highest bit (H8) so msb never fails - return attacks ^ positiveRayAttacks()[DirV][blocker.first()]; - } - else - { - Bitboard blocker = (attacks & occupied) | a1; - return attacks ^ positiveRayAttacks()[DirV][blocker.last()]; - } - } +FORCEINLINE inline void appendCastlingRightsToString(CastlingRights rights, std::string& str) { + if (rights == CastlingRights::None) + { + str += '-'; + } + else + { + if (contains(rights, CastlingRights::WhiteKingSide)) + str += 'K'; + if (contains(rights, CastlingRights::WhiteQueenSide)) + str += 'Q'; + if (contains(rights, CastlingRights::BlackKingSide)) + str += 'k'; + if (contains(rights, CastlingRights::BlackQueenSide)) + str += 'q'; + } +} - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - template Bitboard slidingAttacks(Square, Bitboard); - - template - [[nodiscard]] inline Bitboard pieceSlidingAttacks(Square sq, Bitboard occupied) - { - static_assert( - PieceTypeV == PieceType::Rook - || PieceTypeV == PieceType::Bishop - || PieceTypeV == PieceType::Queen); +FORCEINLINE inline void appendSquareToString(Square sq, std::string& str) { + str += static_cast('a' + ordinal(sq.file())); + str += static_cast('1' + ordinal(sq.rank())); +} - assert(sq.isOk()); +FORCEINLINE inline void appendEpSquareToString(Square sq, std::string& str) { + if (sq == Square::none()) + { + str += '-'; + } + else + { + appendSquareToString(sq, str); + } +} - if constexpr (PieceTypeV == PieceType::Bishop) - { - return - detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied); - } - else if constexpr (PieceTypeV == PieceType::Rook) - { - return - detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied); - } - else // if constexpr (PieceTypeV == PieceType::Queen) - { - return - detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied) - | detail::slidingAttacks(sq, occupied); - } - } +FORCEINLINE inline void appendRankToString(Rank r, std::string& str) { + str += static_cast('1' + ordinal(r)); +} - static Bitboard generateBetween(Square s1, Square s2) - { - Bitboard bb = Bitboard::none(); +FORCEINLINE inline void appendFileToString(File f, std::string& str) { + str += static_cast('a' + ordinal(f)); +} - if (s1 == s2) - { - return bb; - } +[[nodiscard]] FORCEINLINE inline bool isDigit(char c) { return c >= '0' && c <= '9'; } - const int fd = s2.file() - s1.file(); - const int rd = s2.rank() - s1.rank(); +[[nodiscard]] inline std::uint16_t parseUInt16(std::string_view sv) { + assert(sv.size() > 0); + assert(sv.size() <= 5); - if (fd == 0 || rd == 0 || fd == rd || fd == -rd) - { - // s1 and s2 lie on a line. - const int fileStep = (fd > 0) - (fd < 0); - const int rankStep = (rd > 0) - (rd < 0); - const auto step = FlatSquareOffset(fileStep, rankStep); - s1 += step; // omit s1 - while(s1 != s2) // omit s2 - { - bb |= s1; - s1 += step; - } - } + std::uint16_t v = 0; - return bb; - } + std::size_t idx = 0; + switch (sv.size()) + { + case 5 : + v += (sv[idx++] - '0') * 10000; + case 4 : + v += (sv[idx++] - '0') * 1000; + case 3 : + v += (sv[idx++] - '0') * 100; + case 2 : + v += (sv[idx++] - '0') * 10; + case 1 : + v += sv[idx] - '0'; + break; - static Bitboard generateLine(Square s1, Square s2) - { - for (PieceType pt : { PieceType::Bishop, PieceType::Rook }) - { - const Bitboard s1Attacks = pseudoAttacks()[pt][s1]; - if (s1Attacks.isSet(s2)) - { - const Bitboard s2Attacks = pseudoAttacks()[pt][s2]; - return (s1Attacks & s2Attacks) | s1 | s2; - } - } + default : + assert(false); + } - return Bitboard::none(); - } + return v; +} - static const EnumArray2 between = []() - { - EnumArray2 between_; +[[nodiscard]] inline std::optional tryParseUInt16(std::string_view sv) { + if (sv.size() == 0 || sv.size() > 5) + return std::nullopt; + + std::uint32_t v = 0; + + std::size_t idx = 0; + switch (sv.size()) + { + case 5 : + v += (sv[idx++] - '0') * 10000; + case 4 : + v += (sv[idx++] - '0') * 1000; + case 3 : + v += (sv[idx++] - '0') * 100; + case 2 : + v += (sv[idx++] - '0') * 10; + case 1 : + v += sv[idx] - '0'; + break; + + default : + assert(false); + } - for (Square s1 : values()) - { - for (Square s2 : values()) - { - between_[s1][s2] = generateBetween(s1, s2); - } - } + if (v > std::numeric_limits::max()) + { + return std::nullopt; + } - return between_; - }(); + return static_cast(v); +} +} - static const EnumArray2 line = []() - { - EnumArray2 line_; - for (Square s1 : values()) - { - for (Square s2 : values()) - { - line_[s1][s2] = generateLine(s1, s2); - } - } +struct Board { + constexpr Board() noexcept : + m_pieces{}, + m_pieceBB{}, + m_piecesByColorBB{}, + m_pieceCount{} { + m_pieces.fill(Piece::none()); + m_pieceBB.fill(Bitboard::none()); + m_pieceBB[Piece::none()] = Bitboard::all(); + m_piecesByColorBB.fill(Bitboard::none()); + m_pieceCount.fill(0); + m_pieceCount[Piece::none()] = 64; + } - return line_; - }(); - } + [[nodiscard]] inline bool isValid() const { + if (piecesBB(whiteKing).count() != 1) + return false; + if (piecesBB(blackKing).count() != 1) + return false; + if (((piecesBB(whitePawn) | piecesBB(blackPawn)) & (bb::rank(rank1) | bb::rank(rank8))) + .any()) + return false; + return true; + } + + [[nodiscard]] inline std::string fen() const; - namespace fancy_magics + [[nodiscard]] inline bool trySet(std::string_view boardState) { + File f = fileA; + Rank r = rank8; + bool lastWasSkip = false; + for (auto c : boardState) { - enum struct MagicsType + Piece piece = Piece::none(); + switch (c) { - Rook, - Bishop - }; + case 'r' : + piece = Piece(PieceType::Rook, Color::Black); + break; + case 'n' : + piece = Piece(PieceType::Knight, Color::Black); + break; + case 'b' : + piece = Piece(PieceType::Bishop, Color::Black); + break; + case 'q' : + piece = Piece(PieceType::Queen, Color::Black); + break; + case 'k' : + piece = Piece(PieceType::King, Color::Black); + break; + case 'p' : + piece = Piece(PieceType::Pawn, Color::Black); + break; - template - [[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied) - { - if (TypeV == MagicsType::Rook) - { - return chess::bb::detail::pieceSlidingAttacks(sq, occupied); - } + case 'R' : + piece = Piece(PieceType::Rook, Color::White); + break; + case 'N' : + piece = Piece(PieceType::Knight, Color::White); + break; + case 'B' : + piece = Piece(PieceType::Bishop, Color::White); + break; + case 'Q' : + piece = Piece(PieceType::Queen, Color::White); + break; + case 'K' : + piece = Piece(PieceType::King, Color::White); + break; + case 'P' : + piece = Piece(PieceType::Pawn, Color::White); + break; - if (TypeV == MagicsType::Bishop) - { - return chess::bb::detail::pieceSlidingAttacks(sq, occupied); - } + case '1' : + case '2' : + case '3' : + case '4' : + case '5' : + case '6' : + case '7' : + case '8' : { + if (lastWasSkip) + return false; + lastWasSkip = true; - return Bitboard::none(); + const int skip = c - '0'; + f += skip; + if (f > fileH + 1) + return false; + break; } - template - [[nodiscard]] inline bool initMagics( - const EnumArray& magics, - std::array& table, - EnumArray& masks, - EnumArray& shifts, - EnumArray& attacks - ) - { - std::size_t size = 0; - for (Square sq : values()) - { - const Bitboard edges = - ((bb::rank1 | bb::rank8) & ~Bitboard::rank(sq.rank())) - | ((bb::fileA | bb::fileH) & ~Bitboard::file(sq.file())); - - Bitboard* currentAttacks = table.data() + size; - - attacks[sq] = currentAttacks; - masks[sq] = slidingAttacks(sq, Bitboard::none()) & ~edges; - shifts[sq] = 64 - masks[sq].count(); + case '/' : + lastWasSkip = false; + if (f != fileH + 1) + return false; + f = fileA; + --r; + break; - Bitboard occupied = Bitboard::none(); - do - { - const std::size_t idx = - (occupied & masks[sq]).bits() - * magics[sq] - >> shifts[sq]; + default : + return false; + } - currentAttacks[idx] = slidingAttacks(sq, occupied); + if (piece != Piece::none()) + { + lastWasSkip = false; - ++size; - occupied = Bitboard::fromBits(occupied.bits() - masks[sq].bits()) & masks[sq]; - } while (occupied.any()); - } + const Square sq(f, r); + if (!sq.isOk()) + return false; - return true; - } - - static bool g_isRookMagicsInitialized = - initMagics(g_rookMagics, g_allRookAttacks, g_rookMasks, g_rookShifts, g_rookAttacks); - - static bool g_isBishopMagicsInitialized = - initMagics(g_bishopMagics, g_allBishopAttacks, g_bishopMasks, g_bishopShifts, g_bishopAttacks); - } - - [[nodiscard]] inline Bitboard between(Square s1, Square s2) - { - return detail::between[s1][s2]; - } - - [[nodiscard]] inline Bitboard line(Square s1, Square s2) - { - return detail::line[s1][s2]; - } - - template - [[nodiscard]] inline Bitboard pseudoAttacks(Square sq) - { - static_assert(PieceTypeV != PieceType::None && PieceTypeV != PieceType::Pawn); - - assert(sq.isOk()); - - return detail::pseudoAttacks()[PieceTypeV][sq]; - } - - [[nodiscard]] inline Bitboard pseudoAttacks(PieceType pt, Square sq) - { - assert(sq.isOk()); - - return detail::pseudoAttacks()[pt][sq]; - } - - [[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color) - { - if (color == Color::White) - { - return pawns.shifted<1, 1>() | pawns.shifted<-1, 1>(); - } - else - { - return pawns.shifted<1, -1>() | pawns.shifted<-1, -1>(); - } - } - - [[nodiscard]] inline Bitboard westPawnAttacks(Bitboard pawns, Color color) - { - if (color == Color::White) - { - return pawns.shifted<-1, 1>(); - } - else - { - return pawns.shifted<-1, -1>(); - } - } - - [[nodiscard]] inline Bitboard eastPawnAttacks(Bitboard pawns, Color color) - { - if (color == Color::White) - { - return pawns.shifted<1, 1>(); - } - else - { - return pawns.shifted<1, -1>(); + place(piece, sq); + ++f; } } - [[nodiscard]] inline bool isAttackedBySlider( - Square sq, - Bitboard bishops, - Bitboard rooks, - Bitboard queens, - Bitboard occupied - ) - { - const Bitboard opponentBishopLikePieces = (bishops | queens); - const Bitboard bishopAttacks = bb::attacks(sq, occupied); - if ((bishopAttacks & opponentBishopLikePieces).any()) - { - return true; - } + if (f != fileH + 1) + return false; + if (r != rank1) + return false; - const Bitboard opponentRookLikePieces = (rooks | queens); - const Bitboard rookAttacks = bb::attacks(sq, occupied); - return (rookAttacks & opponentRookLikePieces).any(); - } + return isValid(); } - struct CastlingTraits - { - static constexpr EnumArray2 rookDestination = { { {{ f1, d1 }}, {{ f8, d8 }} } }; - static constexpr EnumArray2 kingDestination = { { {{ g1, c1 }}, {{ g8, c8 }} } }; - - static constexpr EnumArray2 rookStart = { { {{ h1, a1 }}, {{ h8, a8 }} } }; - - static constexpr EnumArray kingStart = { { e1, e8 } }; - - static constexpr EnumArray2 castlingPath = { - { - {{ Bitboard::square(f1) | g1, Bitboard::square(b1) | c1 | d1 }}, - {{ Bitboard::square(f8) | g8, Bitboard::square(b8) | c8 | d8 }} - } - }; - - static constexpr EnumArray2 squarePassedByKing = { - { - {{ f1, d1 }}, - {{ f8, d8 }} - } - }; - - static constexpr EnumArray2 castlingRights = { - { - {{ CastlingRights::WhiteKingSide, CastlingRights::WhiteQueenSide }}, - {{ CastlingRights::BlackKingSide, CastlingRights::BlackQueenSide }} - } - }; - - // Move has to be a legal castling move. - static constexpr CastleType moveCastlingType(const Move& move) - { - return (move.to.file() == fileH) ? CastleType::Short : CastleType::Long; - } - - // Move must be a legal castling move. - static constexpr CastlingRights moveCastlingRight(Move move) - { - if (move.to == h1) return CastlingRights::WhiteKingSide; - if (move.to == a1) return CastlingRights::WhiteQueenSide; - if (move.to == h8) return CastlingRights::WhiteKingSide; - if (move.to == a8) return CastlingRights::WhiteQueenSide; - return CastlingRights::None; - } - }; - - namespace parser_bits - { - [[nodiscard]] constexpr bool isFile(char c) - { - return c >= 'a' && c <= 'h'; - } - - [[nodiscard]] constexpr bool isRank(char c) - { - return c >= '1' && c <= '8'; - } - - [[nodiscard]] constexpr Rank parseRank(char c) - { - assert(isRank(c)); - - return fromOrdinal(c - '1'); - } - - [[nodiscard]] constexpr File parseFile(char c) - { - assert(isFile(c)); - - return fromOrdinal(c - 'a'); - } - - [[nodiscard]] constexpr bool isSquare(const char* s) - { - return isFile(s[0]) && isRank(s[1]); - } - - [[nodiscard]] constexpr Square parseSquare(const char* s) - { - const File file = parseFile(s[0]); - const Rank rank = parseRank(s[1]); - return Square(file, rank); - } - - [[nodiscard]] constexpr std::optional tryParseSquare(std::string_view s) - { - if (s.size() != 2) return {}; - if (!isSquare(s.data())) return {}; - return parseSquare(s.data()); - } - - [[nodiscard]] constexpr std::optional tryParseEpSquare(std::string_view s) - { - if (s == std::string_view("-")) return Square::none(); - return tryParseSquare(s); - } - - [[nodiscard]] constexpr std::optional tryParseCastlingRights(std::string_view s) - { - if (s == std::string_view("-")) return CastlingRights::None; - - CastlingRights rights = CastlingRights::None; - - for (auto& c : s) - { - CastlingRights toAdd = CastlingRights::None; - switch (c) - { - case 'K': - toAdd = CastlingRights::WhiteKingSide; - break; - case 'Q': - toAdd = CastlingRights::WhiteQueenSide; - break; - case 'k': - toAdd = CastlingRights::BlackKingSide; - break; - case 'q': - toAdd = CastlingRights::BlackQueenSide; - break; - } - - // If there are duplicated castling rights specification we bail. - // If there is an invalid character we bail. - // (It always contains None) - if (contains(rights, toAdd)) return {}; - else rights |= toAdd; - } - - return rights; - } - - [[nodiscard]] constexpr CastlingRights readCastlingRights(const char*& s) - { - CastlingRights rights = CastlingRights::None; - - while (*s != ' ') - { - switch (*s) - { - case 'K': - rights |= CastlingRights::WhiteKingSide; - break; - case 'Q': - rights |= CastlingRights::WhiteQueenSide; - break; - case 'k': - rights |= CastlingRights::BlackKingSide; - break; - case 'q': - rights |= CastlingRights::BlackQueenSide; - break; - } - - ++s; - } - - return rights; - } + // returns side to move + [[nodiscard]] constexpr const char* set(const char* fen) { + assert(fen != nullptr); - FORCEINLINE inline void appendCastlingRightsToString(CastlingRights rights, std::string& str) + File f = fileA; + Rank r = rank8; + auto current = fen; + bool done = false; + while (*current != '\0') { - if (rights == CastlingRights::None) - { - str += '-'; - } - else + Piece piece = Piece::none(); + switch (*current) { - if (contains(rights, CastlingRights::WhiteKingSide)) str += 'K'; - if (contains(rights, CastlingRights::WhiteQueenSide)) str += 'Q'; - if (contains(rights, CastlingRights::BlackKingSide)) str += 'k'; - if (contains(rights, CastlingRights::BlackQueenSide)) str += 'q'; - } - } - - FORCEINLINE inline void appendSquareToString(Square sq, std::string& str) - { - str += static_cast('a' + ordinal(sq.file())); - str += static_cast('1' + ordinal(sq.rank())); - } - - FORCEINLINE inline void appendEpSquareToString(Square sq, std::string& str) - { - if (sq == Square::none()) - { - str += '-'; - } - else - { - appendSquareToString(sq, str); - } - } - - FORCEINLINE inline void appendRankToString(Rank r, std::string& str) - { - str += static_cast('1' + ordinal(r)); - } - - FORCEINLINE inline void appendFileToString(File f, std::string& str) - { - str += static_cast('a' + ordinal(f)); - } - - [[nodiscard]] FORCEINLINE inline bool isDigit(char c) - { - return c >= '0' && c <= '9'; - } - - [[nodiscard]] inline std::uint16_t parseUInt16(std::string_view sv) - { - assert(sv.size() > 0); - assert(sv.size() <= 5); + case 'r' : + piece = Piece(PieceType::Rook, Color::Black); + break; + case 'n' : + piece = Piece(PieceType::Knight, Color::Black); + break; + case 'b' : + piece = Piece(PieceType::Bishop, Color::Black); + break; + case 'q' : + piece = Piece(PieceType::Queen, Color::Black); + break; + case 'k' : + piece = Piece(PieceType::King, Color::Black); + break; + case 'p' : + piece = Piece(PieceType::Pawn, Color::Black); + break; - std::uint16_t v = 0; + case 'R' : + piece = Piece(PieceType::Rook, Color::White); + break; + case 'N' : + piece = Piece(PieceType::Knight, Color::White); + break; + case 'B' : + piece = Piece(PieceType::Bishop, Color::White); + break; + case 'Q' : + piece = Piece(PieceType::Queen, Color::White); + break; + case 'K' : + piece = Piece(PieceType::King, Color::White); + break; + case 'P' : + piece = Piece(PieceType::Pawn, Color::White); + break; - std::size_t idx = 0; - switch (sv.size()) - { - case 5: - v += (sv[idx++] - '0') * 10000; - case 4: - v += (sv[idx++] - '0') * 1000; - case 3: - v += (sv[idx++] - '0') * 100; - case 2: - v += (sv[idx++] - '0') * 10; - case 1: - v += sv[idx] - '0'; + case ' ' : + done = true; break; - default: - assert(false); + case '1' : + case '2' : + case '3' : + case '4' : + case '5' : + case '6' : + case '7' : + case '8' : { + const int skip = (*current) - '0'; + f += skip; + break; } - return v; - } - - [[nodiscard]] inline std::optional tryParseUInt16(std::string_view sv) - { - if (sv.size() == 0 || sv.size() > 5) return std::nullopt; - - std::uint32_t v = 0; - - std::size_t idx = 0; - switch (sv.size()) - { - case 5: - v += (sv[idx++] - '0') * 10000; - case 4: - v += (sv[idx++] - '0') * 1000; - case 3: - v += (sv[idx++] - '0') * 100; - case 2: - v += (sv[idx++] - '0') * 10; - case 1: - v += sv[idx] - '0'; + case '/' : + f = fileA; + --r; break; - default: - assert(false); + default : + break; } - if (v > std::numeric_limits::max()) + if (done) { - return std::nullopt; + break; } - return static_cast(v); - } - } - - - struct Board - { - constexpr Board() noexcept : - m_pieces{}, - m_pieceBB{}, - m_piecesByColorBB{}, - m_pieceCount{} - { - m_pieces.fill(Piece::none()); - m_pieceBB.fill(Bitboard::none()); - m_pieceBB[Piece::none()] = Bitboard::all(); - m_piecesByColorBB.fill(Bitboard::none()); - m_pieceCount.fill(0); - m_pieceCount[Piece::none()] = 64; - } - - [[nodiscard]] inline bool isValid() const - { - if (piecesBB(whiteKing).count() != 1) return false; - if (piecesBB(blackKing).count() != 1) return false; - if (((piecesBB(whitePawn) | piecesBB(blackPawn)) & (bb::rank(rank1) | bb::rank(rank8))).any()) return false; - return true; - } - - [[nodiscard]] inline std::string fen() const; - - [[nodiscard]] inline bool trySet(std::string_view boardState) - { - File f = fileA; - Rank r = rank8; - bool lastWasSkip = false; - for (auto c : boardState) + if (piece != Piece::none()) { - Piece piece = Piece::none(); - switch (c) - { - case 'r': - piece = Piece(PieceType::Rook, Color::Black); - break; - case 'n': - piece = Piece(PieceType::Knight, Color::Black); - break; - case 'b': - piece = Piece(PieceType::Bishop, Color::Black); - break; - case 'q': - piece = Piece(PieceType::Queen, Color::Black); - break; - case 'k': - piece = Piece(PieceType::King, Color::Black); - break; - case 'p': - piece = Piece(PieceType::Pawn, Color::Black); - break; - - case 'R': - piece = Piece(PieceType::Rook, Color::White); - break; - case 'N': - piece = Piece(PieceType::Knight, Color::White); - break; - case 'B': - piece = Piece(PieceType::Bishop, Color::White); - break; - case 'Q': - piece = Piece(PieceType::Queen, Color::White); - break; - case 'K': - piece = Piece(PieceType::King, Color::White); - break; - case 'P': - piece = Piece(PieceType::Pawn, Color::White); - break; - - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - { - if (lastWasSkip) return false; - lastWasSkip = true; - - const int skip = c - '0'; - f += skip; - if (f > fileH + 1) return false; - break; - } - - case '/': - lastWasSkip = false; - if (f != fileH + 1) return false; - f = fileA; - --r; - break; - - default: - return false; - } - - if (piece != Piece::none()) - { - lastWasSkip = false; - - const Square sq(f, r); - if (!sq.isOk()) return false; - - place(piece, sq); - ++f; - } + place(piece, Square(f, r)); + ++f; } - if (f != fileH + 1) return false; - if (r != rank1) return false; - - return isValid(); + ++current; } - // returns side to move - [[nodiscard]] constexpr const char* set(const char* fen) - { - assert(fen != nullptr); - - File f = fileA; - Rank r = rank8; - auto current = fen; - bool done = false; - while (*current != '\0') - { - Piece piece = Piece::none(); - switch (*current) - { - case 'r': - piece = Piece(PieceType::Rook, Color::Black); - break; - case 'n': - piece = Piece(PieceType::Knight, Color::Black); - break; - case 'b': - piece = Piece(PieceType::Bishop, Color::Black); - break; - case 'q': - piece = Piece(PieceType::Queen, Color::Black); - break; - case 'k': - piece = Piece(PieceType::King, Color::Black); - break; - case 'p': - piece = Piece(PieceType::Pawn, Color::Black); - break; - - case 'R': - piece = Piece(PieceType::Rook, Color::White); - break; - case 'N': - piece = Piece(PieceType::Knight, Color::White); - break; - case 'B': - piece = Piece(PieceType::Bishop, Color::White); - break; - case 'Q': - piece = Piece(PieceType::Queen, Color::White); - break; - case 'K': - piece = Piece(PieceType::King, Color::White); - break; - case 'P': - piece = Piece(PieceType::Pawn, Color::White); - break; - - case ' ': - done = true; - break; - - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - { - const int skip = (*current) - '0'; - f += skip; - break; - } - - case '/': - f = fileA; - --r; - break; - - default: - break; - } - - if (done) - { - break; - } - - if (piece != Piece::none()) - { - place(piece, Square(f, r)); - ++f; - } - - ++current; - } - - return current; - } - - static constexpr Board fromFen(const char* fen) - { - Board board; - (void)board.set(fen); - return board; - } - - [[nodiscard]] constexpr friend bool operator==(const Board& lhs, const Board& rhs) noexcept - { - bool equal = true; - for (Square sq = a1; sq <= h8; ++sq) - { - if (lhs.m_pieces[sq] != rhs.m_pieces[sq]) - { - equal = false; - break; - } - } - - assert(bbsEqual(lhs, rhs) == equal); + return current; + } - return equal; - } + static constexpr Board fromFen(const char* fen) { + Board board; + (void) board.set(fen); + return board; + } - constexpr void place(Piece piece, Square sq) + [[nodiscard]] constexpr friend bool operator==(const Board& lhs, const Board& rhs) noexcept { + bool equal = true; + for (Square sq = a1; sq <= h8; ++sq) { - assert(sq.isOk()); - - auto oldPiece = m_pieces[sq]; - m_pieceBB[oldPiece] ^= sq; - if (oldPiece != Piece::none()) + if (lhs.m_pieces[sq] != rhs.m_pieces[sq]) { - m_piecesByColorBB[oldPiece.color()] ^= sq; + equal = false; + break; } - m_pieces[sq] = piece; - m_pieceBB[piece] |= sq; - m_piecesByColorBB[piece.color()] |= sq; - --m_pieceCount[oldPiece]; - ++m_pieceCount[piece]; } - // returns captured piece - // doesn't check validity - inline constexpr Piece doMove(Move move) - { - if (move.type == MoveType::Normal) - { - const Piece capturedPiece = m_pieces[move.to]; - const Piece piece = m_pieces[move.from]; - - const Bitboard frombb = Bitboard::square(move.from); - const Bitboard tobb = Bitboard::square(move.to); - const Bitboard xormove = frombb ^ tobb; + assert(bbsEqual(lhs, rhs) == equal); - m_pieces[move.to] = piece; - m_pieces[move.from] = Piece::none(); - - m_pieceBB[piece] ^= xormove; - - m_piecesByColorBB[piece.color()] ^= xormove; - - if (capturedPiece == Piece::none()) - { - m_pieceBB[Piece::none()] ^= xormove; - } - else - { - m_pieceBB[capturedPiece] ^= tobb; - m_pieceBB[Piece::none()] ^= frombb; - - m_piecesByColorBB[capturedPiece.color()] ^= tobb; - - --m_pieceCount[capturedPiece]; - ++m_pieceCount[Piece::none()]; - } - - return capturedPiece; - } + return equal; + } - return doMoveColdPath(move); - } + constexpr void place(Piece piece, Square sq) { + assert(sq.isOk()); - inline constexpr Piece doMoveColdPath(Move move) + auto oldPiece = m_pieces[sq]; + m_pieceBB[oldPiece] ^= sq; + if (oldPiece != Piece::none()) { - if (move.type == MoveType::Promotion) - { - // We split it even though it's similar just because - // the normal case is much more common. - const Piece capturedPiece = m_pieces[move.to]; - const Piece fromPiece = m_pieces[move.from]; - const Piece toPiece = move.promotedPiece; - - m_pieces[move.to] = toPiece; - m_pieces[move.from] = Piece::none(); - - m_pieceBB[fromPiece] ^= move.from; - m_pieceBB[toPiece] ^= move.to; - - m_pieceBB[capturedPiece] ^= move.to; - m_pieceBB[Piece::none()] ^= move.from; - - m_piecesByColorBB[fromPiece.color()] ^= move.to; - m_piecesByColorBB[fromPiece.color()] ^= move.from; - if (capturedPiece != Piece::none()) - { - m_piecesByColorBB[capturedPiece.color()] ^= move.to; - --m_pieceCount[capturedPiece]; - ++m_pieceCount[Piece::none()]; - } - - --m_pieceCount[fromPiece]; - ++m_pieceCount[toPiece]; - - return capturedPiece; - } - else if (move.type == MoveType::EnPassant) - { - const Piece movedPiece = m_pieces[move.from]; - const Piece capturedPiece(PieceType::Pawn, !movedPiece.color()); - const Square capturedPieceSq(move.to.file(), move.from.rank()); - - // on ep move there are 3 squares involved - m_pieces[move.to] = movedPiece; - m_pieces[move.from] = Piece::none(); - m_pieces[capturedPieceSq] = Piece::none(); - - m_pieceBB[movedPiece] ^= move.from; - m_pieceBB[movedPiece] ^= move.to; - - m_pieceBB[Piece::none()] ^= move.from; - m_pieceBB[Piece::none()] ^= move.to; - - m_pieceBB[capturedPiece] ^= capturedPieceSq; - m_pieceBB[Piece::none()] ^= capturedPieceSq; - - m_piecesByColorBB[movedPiece.color()] ^= move.to; - m_piecesByColorBB[movedPiece.color()] ^= move.from; - m_piecesByColorBB[capturedPiece.color()] ^= capturedPieceSq; - - --m_pieceCount[capturedPiece]; - ++m_pieceCount[Piece::none()]; - - return capturedPiece; - } - else // if (move.type == MoveType::Castle) - { - const Square rookFromSq = move.to; - const Square kingFromSq = move.from; - - const Piece rook = m_pieces[rookFromSq]; - const Piece king = m_pieces[kingFromSq]; - const Color color = king.color(); - - const CastleType castleType = CastlingTraits::moveCastlingType(move); - const Square rookToSq = CastlingTraits::rookDestination[color][castleType]; - const Square kingToSq = CastlingTraits::kingDestination[color][castleType]; - - // 4 squares are involved - m_pieces[rookFromSq] = Piece::none(); - m_pieces[kingFromSq] = Piece::none(); - m_pieces[rookToSq] = rook; - m_pieces[kingToSq] = king; - - m_pieceBB[rook] ^= rookFromSq; - m_pieceBB[rook] ^= rookToSq; - - m_pieceBB[king] ^= kingFromSq; - m_pieceBB[king] ^= kingToSq; - - m_pieceBB[Piece::none()] ^= rookFromSq; - m_pieceBB[Piece::none()] ^= rookToSq; - - m_pieceBB[Piece::none()] ^= kingFromSq; - m_pieceBB[Piece::none()] ^= kingToSq; - - m_piecesByColorBB[color] ^= rookFromSq; - m_piecesByColorBB[color] ^= rookToSq; - m_piecesByColorBB[color] ^= kingFromSq; - m_piecesByColorBB[color] ^= kingToSq; - - return Piece::none(); - } + m_piecesByColorBB[oldPiece.color()] ^= sq; } + m_pieces[sq] = piece; + m_pieceBB[piece] |= sq; + m_piecesByColorBB[piece.color()] |= sq; + --m_pieceCount[oldPiece]; + ++m_pieceCount[piece]; + } - constexpr void undoMove(Move move, Piece capturedPiece) + // returns captured piece + // doesn't check validity + inline constexpr Piece doMove(Move move) { + if (move.type == MoveType::Normal) { - if (move.type == MoveType::Normal || move.type == MoveType::Promotion) - { - const Piece toPiece = m_pieces[move.to]; - const Piece fromPiece = move.promotedPiece == Piece::none() ? toPiece : Piece(PieceType::Pawn, toPiece.color()); + const Piece capturedPiece = m_pieces[move.to]; + const Piece piece = m_pieces[move.from]; - m_pieces[move.from] = fromPiece; - m_pieces[move.to] = capturedPiece; + const Bitboard frombb = Bitboard::square(move.from); + const Bitboard tobb = Bitboard::square(move.to); + const Bitboard xormove = frombb ^ tobb; - m_pieceBB[fromPiece] ^= move.from; - m_pieceBB[toPiece] ^= move.to; + m_pieces[move.to] = piece; + m_pieces[move.from] = Piece::none(); - m_pieceBB[capturedPiece] ^= move.to; - m_pieceBB[Piece::none()] ^= move.from; + m_pieceBB[piece] ^= xormove; - m_piecesByColorBB[fromPiece.color()] ^= move.to; - m_piecesByColorBB[fromPiece.color()] ^= move.from; - if (capturedPiece != Piece::none()) - { - m_piecesByColorBB[capturedPiece.color()] ^= move.to; - ++m_pieceCount[capturedPiece]; - --m_pieceCount[Piece::none()]; - } + m_piecesByColorBB[piece.color()] ^= xormove; - if (move.type == MoveType::Promotion) - { - --m_pieceCount[toPiece]; - ++m_pieceCount[fromPiece]; - } + if (capturedPiece == Piece::none()) + { + m_pieceBB[Piece::none()] ^= xormove; } - else if (move.type == MoveType::EnPassant) + else { - const Piece movedPiece = m_pieces[move.to]; - const Piece capturedPiece_(PieceType::Pawn, !movedPiece.color()); - const Square capturedPieceSq(move.to.file(), move.from.rank()); + m_pieceBB[capturedPiece] ^= tobb; + m_pieceBB[Piece::none()] ^= frombb; - m_pieces[move.to] = Piece::none(); - m_pieces[move.from] = movedPiece; - m_pieces[capturedPieceSq] = capturedPiece_; + m_piecesByColorBB[capturedPiece.color()] ^= tobb; - m_pieceBB[movedPiece] ^= move.from; - m_pieceBB[movedPiece] ^= move.to; + --m_pieceCount[capturedPiece]; + ++m_pieceCount[Piece::none()]; + } - m_pieceBB[Piece::none()] ^= move.from; - m_pieceBB[Piece::none()] ^= move.to; + return capturedPiece; + } - // on ep move there are 3 squares involved - m_pieceBB[capturedPiece_] ^= capturedPieceSq; - m_pieceBB[Piece::none()] ^= capturedPieceSq; + return doMoveColdPath(move); + } - m_piecesByColorBB[movedPiece.color()] ^= move.to; - m_piecesByColorBB[movedPiece.color()] ^= move.from; - m_piecesByColorBB[capturedPiece_.color()] ^= capturedPieceSq; + inline constexpr Piece doMoveColdPath(Move move) { + if (move.type == MoveType::Promotion) + { + // We split it even though it's similar just because + // the normal case is much more common. + const Piece capturedPiece = m_pieces[move.to]; + const Piece fromPiece = m_pieces[move.from]; + const Piece toPiece = move.promotedPiece; - ++m_pieceCount[capturedPiece_]; - --m_pieceCount[Piece::none()]; - } - else // if (move.type == MoveType::Castle) - { - const Square rookFromSq = move.to; - const Square kingFromSq = move.from; + m_pieces[move.to] = toPiece; + m_pieces[move.from] = Piece::none(); - const Color color = move.to.rank() == rank1 ? Color::White : Color::Black; + m_pieceBB[fromPiece] ^= move.from; + m_pieceBB[toPiece] ^= move.to; - const CastleType castleType = CastlingTraits::moveCastlingType(move); - const Square rookToSq = CastlingTraits::rookDestination[color][castleType]; - const Square kingToSq = CastlingTraits::kingDestination[color][castleType]; + m_pieceBB[capturedPiece] ^= move.to; + m_pieceBB[Piece::none()] ^= move.from; - const Piece rook = m_pieces[rookToSq]; - const Piece king = m_pieces[kingToSq]; + m_piecesByColorBB[fromPiece.color()] ^= move.to; + m_piecesByColorBB[fromPiece.color()] ^= move.from; + if (capturedPiece != Piece::none()) + { + m_piecesByColorBB[capturedPiece.color()] ^= move.to; + --m_pieceCount[capturedPiece]; + ++m_pieceCount[Piece::none()]; + } - // 4 squares are involved - m_pieces[rookFromSq] = rook; - m_pieces[kingFromSq] = king; - m_pieces[rookToSq] = Piece::none(); - m_pieces[kingToSq] = Piece::none(); + --m_pieceCount[fromPiece]; + ++m_pieceCount[toPiece]; - m_pieceBB[rook] ^= rookFromSq; - m_pieceBB[rook] ^= rookToSq; + return capturedPiece; + } + else if (move.type == MoveType::EnPassant) + { + const Piece movedPiece = m_pieces[move.from]; + const Piece capturedPiece(PieceType::Pawn, !movedPiece.color()); + const Square capturedPieceSq(move.to.file(), move.from.rank()); - m_pieceBB[king] ^= kingFromSq; - m_pieceBB[king] ^= kingToSq; + // on ep move there are 3 squares involved + m_pieces[move.to] = movedPiece; + m_pieces[move.from] = Piece::none(); + m_pieces[capturedPieceSq] = Piece::none(); - m_pieceBB[Piece::none()] ^= rookFromSq; - m_pieceBB[Piece::none()] ^= rookToSq; + m_pieceBB[movedPiece] ^= move.from; + m_pieceBB[movedPiece] ^= move.to; - m_pieceBB[Piece::none()] ^= kingFromSq; - m_pieceBB[Piece::none()] ^= kingToSq; + m_pieceBB[Piece::none()] ^= move.from; + m_pieceBB[Piece::none()] ^= move.to; - m_piecesByColorBB[color] ^= rookFromSq; - m_piecesByColorBB[color] ^= rookToSq; - m_piecesByColorBB[color] ^= kingFromSq; - m_piecesByColorBB[color] ^= kingToSq; - } - } + m_pieceBB[capturedPiece] ^= capturedPieceSq; + m_pieceBB[Piece::none()] ^= capturedPieceSq; - // Returns whether a given square is attacked by any piece - // of `attackerColor` side. - [[nodiscard]] inline bool isSquareAttacked(Square sq, Color attackerColor) const; + m_piecesByColorBB[movedPiece.color()] ^= move.to; + m_piecesByColorBB[movedPiece.color()] ^= move.from; + m_piecesByColorBB[capturedPiece.color()] ^= capturedPieceSq; - // Returns whether a given square is attacked by any piece - // of `attackerColor` side after `move` is made. - // Move must be pseudo legal. - [[nodiscard]] inline bool isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const; + --m_pieceCount[capturedPiece]; + ++m_pieceCount[Piece::none()]; - // Move must be pseudo legal. - // Must not be a king move. - [[nodiscard]] inline bool createsDiscoveredAttackOnOwnKing(Move move) const; + return capturedPiece; + } + else // if (move.type == MoveType::Castle) + { + const Square rookFromSq = move.to; + const Square kingFromSq = move.from; - // Returns whether a piece on a given square is attacked - // by any enemy piece. False if square is empty. - [[nodiscard]] inline bool isPieceAttacked(Square sq) const; + const Piece rook = m_pieces[rookFromSq]; + const Piece king = m_pieces[kingFromSq]; + const Color color = king.color(); - // Returns whether a piece on a given square is attacked - // by any enemy piece after `move` is made. False if square is empty. - // Move must be pseudo legal. - [[nodiscard]] inline bool isPieceAttackedAfterMove(Move move, Square sq) const; + const CastleType castleType = CastlingTraits::moveCastlingType(move); + const Square rookToSq = CastlingTraits::rookDestination[color][castleType]; + const Square kingToSq = CastlingTraits::kingDestination[color][castleType]; - // Returns whether the king of the moving side is attacked - // by any enemy piece after a move is made. - // Move must be pseudo legal. - [[nodiscard]] inline bool isOwnKingAttackedAfterMove(Move move) const; + // 4 squares are involved + m_pieces[rookFromSq] = Piece::none(); + m_pieces[kingFromSq] = Piece::none(); + m_pieces[rookToSq] = rook; + m_pieces[kingToSq] = king; - // Return a bitboard with all (pseudo legal) attacks by the piece on - // the given square. Empty if no piece on the square. - [[nodiscard]] inline Bitboard attacks(Square sq) const; + m_pieceBB[rook] ^= rookFromSq; + m_pieceBB[rook] ^= rookToSq; - // Returns a bitboard with all squared that have pieces - // that attack a given square (pseudo legally) - [[nodiscard]] inline Bitboard attackers(Square sq, Color attackerColor) const; + m_pieceBB[king] ^= kingFromSq; + m_pieceBB[king] ^= kingToSq; - [[nodiscard]] constexpr Piece pieceAt(Square sq) const - { - assert(sq.isOk()); + m_pieceBB[Piece::none()] ^= rookFromSq; + m_pieceBB[Piece::none()] ^= rookToSq; - return m_pieces[sq]; - } + m_pieceBB[Piece::none()] ^= kingFromSq; + m_pieceBB[Piece::none()] ^= kingToSq; - [[nodiscard]] constexpr Bitboard piecesBB(Color c) const - { - return m_piecesByColorBB[c]; - } + m_piecesByColorBB[color] ^= rookFromSq; + m_piecesByColorBB[color] ^= rookToSq; + m_piecesByColorBB[color] ^= kingFromSq; + m_piecesByColorBB[color] ^= kingToSq; - [[nodiscard]] inline Square kingSquare(Color c) const - { - return piecesBB(Piece(PieceType::King, c)).first(); + return Piece::none(); } + } - [[nodiscard]] constexpr Bitboard piecesBB(Piece pc) const + constexpr void undoMove(Move move, Piece capturedPiece) { + if (move.type == MoveType::Normal || move.type == MoveType::Promotion) { - return m_pieceBB[pc]; - } + const Piece toPiece = m_pieces[move.to]; + const Piece fromPiece = move.promotedPiece == Piece::none() + ? toPiece + : Piece(PieceType::Pawn, toPiece.color()); - [[nodiscard]] constexpr Bitboard piecesBB() const - { - Bitboard bb{}; + m_pieces[move.from] = fromPiece; + m_pieces[move.to] = capturedPiece; - // don't collect from null piece - return piecesBB(Color::White) | piecesBB(Color::Black); + m_pieceBB[fromPiece] ^= move.from; + m_pieceBB[toPiece] ^= move.to; - return bb; - } + m_pieceBB[capturedPiece] ^= move.to; + m_pieceBB[Piece::none()] ^= move.from; - [[nodiscard]] constexpr std::uint8_t pieceCount(Piece pt) const - { - return m_pieceCount[pt]; - } + m_piecesByColorBB[fromPiece.color()] ^= move.to; + m_piecesByColorBB[fromPiece.color()] ^= move.from; + if (capturedPiece != Piece::none()) + { + m_piecesByColorBB[capturedPiece.color()] ^= move.to; + ++m_pieceCount[capturedPiece]; + --m_pieceCount[Piece::none()]; + } - [[nodiscard]] constexpr bool isPromotion(Square from, Square to) const + if (move.type == MoveType::Promotion) + { + --m_pieceCount[toPiece]; + ++m_pieceCount[fromPiece]; + } + } + else if (move.type == MoveType::EnPassant) { - assert(from.isOk() && to.isOk()); + const Piece movedPiece = m_pieces[move.to]; + const Piece capturedPiece_(PieceType::Pawn, !movedPiece.color()); + const Square capturedPieceSq(move.to.file(), move.from.rank()); - return m_pieces[from].type() == PieceType::Pawn && (to.rank() == rank1 || to.rank() == rank8); - } + m_pieces[move.to] = Piece::none(); + m_pieces[move.from] = movedPiece; + m_pieces[capturedPieceSq] = capturedPiece_; - const Piece* piecesRaw() const; + m_pieceBB[movedPiece] ^= move.from; + m_pieceBB[movedPiece] ^= move.to; - private: - EnumArray m_pieces; - EnumArray m_pieceBB; - EnumArray m_piecesByColorBB; - EnumArray m_pieceCount; + m_pieceBB[Piece::none()] ^= move.from; + m_pieceBB[Piece::none()] ^= move.to; - // NOTE: currently we don't track it because it's not - // required to perform ep if we don't need to check validity - // Square m_epSquare = Square::none(); + // on ep move there are 3 squares involved + m_pieceBB[capturedPiece_] ^= capturedPieceSq; + m_pieceBB[Piece::none()] ^= capturedPieceSq; - [[nodiscard]] static constexpr bool bbsEqual(const Board& lhs, const Board& rhs) noexcept - { - for (Piece pc : values()) - { - if (lhs.m_pieceBB[pc] != rhs.m_pieceBB[pc]) - { - return false; - } - } + m_piecesByColorBB[movedPiece.color()] ^= move.to; + m_piecesByColorBB[movedPiece.color()] ^= move.from; + m_piecesByColorBB[capturedPiece_.color()] ^= capturedPieceSq; - return true; + ++m_pieceCount[capturedPiece_]; + --m_pieceCount[Piece::none()]; } - }; + else // if (move.type == MoveType::Castle) + { + const Square rookFromSq = move.to; + const Square kingFromSq = move.from; - struct Position; + const Color color = move.to.rank() == rank1 ? Color::White : Color::Black; - struct CompressedPosition; + const CastleType castleType = CastlingTraits::moveCastlingType(move); + const Square rookToSq = CastlingTraits::rookDestination[color][castleType]; + const Square kingToSq = CastlingTraits::kingDestination[color][castleType]; - struct PositionHash128 - { - std::uint64_t high; - std::uint64_t low; - }; + const Piece rook = m_pieces[rookToSq]; + const Piece king = m_pieces[kingToSq]; - struct Position; + // 4 squares are involved + m_pieces[rookFromSq] = rook; + m_pieces[kingFromSq] = king; + m_pieces[rookToSq] = Piece::none(); + m_pieces[kingToSq] = Piece::none(); - struct MoveLegalityChecker - { - MoveLegalityChecker(const Position& position); + m_pieceBB[rook] ^= rookFromSq; + m_pieceBB[rook] ^= rookToSq; - [[nodiscard]] bool isPseudoLegalMoveLegal(const Move& move) const; + m_pieceBB[king] ^= kingFromSq; + m_pieceBB[king] ^= kingToSq; - private: - const Position* m_position; - Bitboard m_checkers; - Bitboard m_ourBlockersForKing; - Bitboard m_potentialCheckRemovals; - Square m_ksq; - }; + m_pieceBB[Piece::none()] ^= rookFromSq; + m_pieceBB[Piece::none()] ^= rookToSq; - struct Position : public Board - { - using BaseType = Board; + m_pieceBB[Piece::none()] ^= kingFromSq; + m_pieceBB[Piece::none()] ^= kingToSq; - constexpr Position() noexcept : - Board(), - m_sideToMove(Color::White), - m_epSquare(Square::none()), - m_castlingRights(CastlingRights::All), - m_rule50Counter(0), - m_ply(0) - { + m_piecesByColorBB[color] ^= rookFromSq; + m_piecesByColorBB[color] ^= rookToSq; + m_piecesByColorBB[color] ^= kingFromSq; + m_piecesByColorBB[color] ^= kingToSq; } + } - constexpr Position(const Board& board, Color sideToMove, Square epSquare, CastlingRights castlingRights) : - Board(board), - m_sideToMove(sideToMove), - m_epSquare(epSquare), - m_castlingRights(castlingRights), - m_rule50Counter(0), - m_ply(0) - { - } + // Returns whether a given square is attacked by any piece + // of `attackerColor` side. + [[nodiscard]] inline bool isSquareAttacked(Square sq, Color attackerColor) const; - inline void set(std::string_view fen); + // Returns whether a given square is attacked by any piece + // of `attackerColor` side after `move` is made. + // Move must be pseudo legal. + [[nodiscard]] inline bool + isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const; - // Returns false if the fen was not valid - // If the returned value was false the position - // is in unspecified state. - [[nodiscard]] inline bool trySet(std::string_view fen); + // Move must be pseudo legal. + // Must not be a king move. + [[nodiscard]] inline bool createsDiscoveredAttackOnOwnKing(Move move) const; - [[nodiscard]] static inline Position fromFen(std::string_view fen); + // Returns whether a piece on a given square is attacked + // by any enemy piece. False if square is empty. + [[nodiscard]] inline bool isPieceAttacked(Square sq) const; - [[nodiscard]] static inline std::optional tryFromFen(std::string_view fen); + // Returns whether a piece on a given square is attacked + // by any enemy piece after `move` is made. False if square is empty. + // Move must be pseudo legal. + [[nodiscard]] inline bool isPieceAttackedAfterMove(Move move, Square sq) const; - [[nodiscard]] static inline Position startPosition(); + // Returns whether the king of the moving side is attacked + // by any enemy piece after a move is made. + // Move must be pseudo legal. + [[nodiscard]] inline bool isOwnKingAttackedAfterMove(Move move) const; - [[nodiscard]] inline std::string fen() const; + // Return a bitboard with all (pseudo legal) attacks by the piece on + // the given square. Empty if no piece on the square. + [[nodiscard]] inline Bitboard attacks(Square sq) const; - [[nodiscard]] MoveLegalityChecker moveLegalityChecker() const - { - return { *this }; - } + // Returns a bitboard with all squared that have pieces + // that attack a given square (pseudo legally) + [[nodiscard]] inline Bitboard attackers(Square sq, Color attackerColor) const; - constexpr void setEpSquareUnchecked(Square sq) - { - m_epSquare = sq; - } + [[nodiscard]] constexpr Piece pieceAt(Square sq) const { + assert(sq.isOk()); - void setEpSquare(Square sq) - { - m_epSquare = sq; - nullifyEpSquareIfNotPossible(); - } + return m_pieces[sq]; + } - constexpr void setSideToMove(Color color) - { - m_sideToMove = color; - } + [[nodiscard]] constexpr Bitboard piecesBB(Color c) const { return m_piecesByColorBB[c]; } - constexpr void addCastlingRights(CastlingRights rights) - { - m_castlingRights |= rights; - } + [[nodiscard]] inline Square kingSquare(Color c) const { + return piecesBB(Piece(PieceType::King, c)).first(); + } - constexpr void setCastlingRights(CastlingRights rights) - { - m_castlingRights = rights; - } + [[nodiscard]] constexpr Bitboard piecesBB(Piece pc) const { return m_pieceBB[pc]; } - constexpr void setRule50Counter(std::uint8_t v) - { - m_rule50Counter = v; - } + [[nodiscard]] constexpr Bitboard piecesBB() const { + Bitboard bb{}; - constexpr void setPly(std::uint16_t ply) - { - m_ply = ply; - } + // don't collect from null piece + return piecesBB(Color::White) | piecesBB(Color::Black); - inline ReverseMove doMove(const Move& move); + return bb; + } - constexpr void undoMove(const ReverseMove& reverseMove) - { - const Move& move = reverseMove.move; - BaseType::undoMove(move, reverseMove.capturedPiece); + [[nodiscard]] constexpr std::uint8_t pieceCount(Piece pt) const { return m_pieceCount[pt]; } - m_epSquare = reverseMove.oldEpSquare; - m_castlingRights = reverseMove.oldCastlingRights; + [[nodiscard]] constexpr bool isPromotion(Square from, Square to) const { + assert(from.isOk() && to.isOk()); - m_sideToMove = !m_sideToMove; + return m_pieces[from].type() == PieceType::Pawn + && (to.rank() == rank1 || to.rank() == rank8); + } - --m_ply; - if (m_rule50Counter > 0) - { - m_rule50Counter -= 1; - } - } + const Piece* piecesRaw() const; - [[nodiscard]] constexpr Color sideToMove() const - { - return m_sideToMove; - } + private: + EnumArray m_pieces; + EnumArray m_pieceBB; + EnumArray m_piecesByColorBB; + EnumArray m_pieceCount; - [[nodiscard]] inline std::uint8_t rule50Counter() const - { - return m_rule50Counter; - } + // NOTE: currently we don't track it because it's not + // required to perform ep if we don't need to check validity + // Square m_epSquare = Square::none(); - [[nodiscard]] inline std::uint16_t ply() const + [[nodiscard]] static constexpr bool bbsEqual(const Board& lhs, const Board& rhs) noexcept { + for (Piece pc : values()) { - return m_ply; + if (lhs.m_pieceBB[pc] != rhs.m_pieceBB[pc]) + { + return false; + } } - [[nodiscard]] inline std::uint16_t fullMove() const - { - return (m_ply + 1) / 2; - } + return true; + } +}; - inline void setFullMove(std::uint16_t hm) - { - m_ply = 2 * hm - 1 + (m_sideToMove == Color::Black); - } +struct Position; - [[nodiscard]] inline bool isCheck() const; +struct CompressedPosition; - [[nodiscard]] inline Bitboard checkers() const; +struct PositionHash128 { + std::uint64_t high; + std::uint64_t low; +}; - [[nodiscard]] inline bool isCheckAfterMove(Move move) const; +struct Position; - [[nodiscard]] inline bool isMoveLegal(Move move) const; +struct MoveLegalityChecker { + MoveLegalityChecker(const Position& position); - [[nodiscard]] inline bool isPseudoLegalMoveLegal(Move move) const; + [[nodiscard]] bool isPseudoLegalMoveLegal(const Move& move) const; - [[nodiscard]] inline bool isMovePseudoLegal(Move move) const; + private: + const Position* m_position; + Bitboard m_checkers; + Bitboard m_ourBlockersForKing; + Bitboard m_potentialCheckRemovals; + Square m_ksq; +}; - // Returns all pieces that block a slider - // from attacking our king. When two or more - // pieces block a single slider then none - // of these pieces are included. - [[nodiscard]] inline Bitboard blockersForKing(Color color) const; +struct Position: public Board { + using BaseType = Board; - [[nodiscard]] constexpr Square epSquare() const - { - return m_epSquare; - } + constexpr Position() noexcept : + Board(), + m_sideToMove(Color::White), + m_epSquare(Square::none()), + m_castlingRights(CastlingRights::All), + m_rule50Counter(0), + m_ply(0) {} - [[nodiscard]] constexpr CastlingRights castlingRights() const - { - return m_castlingRights; - } + constexpr Position(const Board& board, + Color sideToMove, + Square epSquare, + CastlingRights castlingRights) : + Board(board), + m_sideToMove(sideToMove), + m_epSquare(epSquare), + m_castlingRights(castlingRights), + m_rule50Counter(0), + m_ply(0) {} - [[nodiscard]] constexpr bool friend operator==(const Position& lhs, const Position& rhs) noexcept - { - return - lhs.m_sideToMove == rhs.m_sideToMove - && lhs.m_epSquare == rhs.m_epSquare - && lhs.m_castlingRights == rhs.m_castlingRights - && static_cast(lhs) == static_cast(rhs); - } + inline void set(std::string_view fen); - [[nodiscard]] constexpr bool friend operator!=(const Position& lhs, const Position& rhs) noexcept - { - return !(lhs == rhs); - } + // Returns false if the fen was not valid + // If the returned value was false the position + // is in unspecified state. + [[nodiscard]] inline bool trySet(std::string_view fen); - // these are supposed to be used only for testing - // that's why there's this assert in afterMove + [[nodiscard]] static inline Position fromFen(std::string_view fen); - [[nodiscard]] constexpr Position beforeMove(const ReverseMove& reverseMove) const - { - Position cpy(*this); - cpy.undoMove(reverseMove); - return cpy; - } + [[nodiscard]] static inline std::optional tryFromFen(std::string_view fen); - [[nodiscard]] inline Position afterMove(Move move) const; + [[nodiscard]] static inline Position startPosition(); - [[nodiscard]] constexpr bool isEpPossible() const - { - return m_epSquare != Square::none(); - } + [[nodiscard]] inline std::string fen() const; - [[nodiscard]] inline CompressedPosition compress() const; + [[nodiscard]] MoveLegalityChecker moveLegalityChecker() const { return {*this}; } - protected: - Color m_sideToMove; - Square m_epSquare; - CastlingRights m_castlingRights; - std::uint8_t m_rule50Counter; - std::uint16_t m_ply; + constexpr void setEpSquareUnchecked(Square sq) { m_epSquare = sq; } - static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) == 4); + void setEpSquare(Square sq) { + m_epSquare = sq; + nullifyEpSquareIfNotPossible(); + } - [[nodiscard]] inline bool isEpPossible(Square epSquare, Color sideToMove) const; + constexpr void setSideToMove(Color color) { m_sideToMove = color; } - [[nodiscard]] inline bool isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const; + constexpr void addCastlingRights(CastlingRights rights) { m_castlingRights |= rights; } - inline void nullifyEpSquareIfNotPossible(); - }; + constexpr void setCastlingRights(CastlingRights rights) { m_castlingRights = rights; } - struct CompressedPosition - { - friend struct Position; + constexpr void setRule50Counter(std::uint8_t v) { m_rule50Counter = v; } - // Occupied bitboard has bits set for - // each square with a piece on it. - // Each packedState byte holds 2 values (nibbles). - // First one at low bits, second one at high bits. - // Values correspond to consecutive squares - // in bitboard iteration order. - // Nibble values: - // these are the same as for Piece - // knights, bishops, queens can just be copied - // 0 : white pawn - // 1 : black pawn - // 2 : white knight - // 3 : black knight - // 4 : white bishop - // 5 : black bishop - // 6 : white rook - // 7 : black rook - // 8 : white queen - // 9 : black queen - // 10 : white king - // 11 : black king - // - // these are special - // 12 : pawn with ep square behind (white or black, depending on rank) - // 13 : white rook with coresponding castling rights - // 14 : black rook with coresponding castling rights - // 15 : black king and black is side to move - // - // Let N be the number of bits set in occupied bitboard. - // Only N nibbles are present. (N+1)/2 bytes are initialized. + constexpr void setPly(std::uint16_t ply) { m_ply = ply; } - static CompressedPosition readFromBigEndian(const unsigned char* data) - { - CompressedPosition pos{}; - pos.m_occupied = Bitboard::fromBits( - (std::uint64_t)data[0] << 56 - | (std::uint64_t)data[1] << 48 - | (std::uint64_t)data[2] << 40 - | (std::uint64_t)data[3] << 32 - | (std::uint64_t)data[4] << 24 - | (std::uint64_t)data[5] << 16 - | (std::uint64_t)data[6] << 8 - | (std::uint64_t)data[7] - ); - std::memcpy(pos.m_packedState, data + 8, 16); - return pos; - } + inline ReverseMove doMove(const Move& move); - constexpr CompressedPosition() : - m_occupied{}, - m_packedState{} - { - } + constexpr void undoMove(const ReverseMove& reverseMove) { + const Move& move = reverseMove.move; + BaseType::undoMove(move, reverseMove.capturedPiece); - [[nodiscard]] friend bool operator<(const CompressedPosition& lhs, const CompressedPosition& rhs) - { - if (lhs.m_occupied.bits() < rhs.m_occupied.bits()) return true; - if (lhs.m_occupied.bits() > rhs.m_occupied.bits()) return false; + m_epSquare = reverseMove.oldEpSquare; + m_castlingRights = reverseMove.oldCastlingRights; - return std::strcmp(reinterpret_cast(lhs.m_packedState), reinterpret_cast(rhs.m_packedState)) < 0; - } + m_sideToMove = !m_sideToMove; - [[nodiscard]] friend bool operator==(const CompressedPosition& lhs, const CompressedPosition& rhs) + --m_ply; + if (m_rule50Counter > 0) { - return lhs.m_occupied == rhs.m_occupied - && std::strcmp(reinterpret_cast(lhs.m_packedState), reinterpret_cast(rhs.m_packedState)) == 0; + m_rule50Counter -= 1; } + } - [[nodiscard]] inline Position decompress() const; + [[nodiscard]] constexpr Color sideToMove() const { return m_sideToMove; } - [[nodiscard]] constexpr Bitboard pieceBB() const - { - return m_occupied; - } + [[nodiscard]] inline std::uint8_t rule50Counter() const { return m_rule50Counter; } - void writeToBigEndian(unsigned char* data) - { - const auto occupied = m_occupied.bits(); - *data++ = occupied >> 56; - *data++ = (occupied >> 48) & 0xFF; - *data++ = (occupied >> 40) & 0xFF; - *data++ = (occupied >> 32) & 0xFF; - *data++ = (occupied >> 24) & 0xFF; - *data++ = (occupied >> 16) & 0xFF; - *data++ = (occupied >> 8) & 0xFF; - *data++ = occupied & 0xFF; - std::memcpy(data, m_packedState, 16); - } + [[nodiscard]] inline std::uint16_t ply() const { return m_ply; } - private: - Bitboard m_occupied; - std::uint8_t m_packedState[16]; - }; + [[nodiscard]] inline std::uint16_t fullMove() const { return (m_ply + 1) / 2; } - namespace movegen - { - // For a pseudo-legal move the following are true: - // - the moving piece has the pos.sideToMove() color - // - the destination square is either empty or has a piece of the opposite color - // - if it is a pawn move it is valid (but may be illegal due to discovered checks) - // - if it is not a pawn move then the destination square is contained in attacks() - // - if it is a castling it is legal - // - a move other than castling may create a discovered attack on the king - // - a king may walk into a check + inline void setFullMove(std::uint16_t hm) { + m_ply = 2 * hm - 1 + (m_sideToMove == Color::Black); + } - template - inline void forEachPseudoLegalPawnMove(const Position& pos, Square from, FuncT&& f) - { - const Color sideToMove = pos.sideToMove(); - const Square epSquare = pos.epSquare(); - const Bitboard ourPieces = pos.piecesBB(sideToMove); - const Bitboard theirPieces = pos.piecesBB(!sideToMove); - const Bitboard occupied = ourPieces | theirPieces; + [[nodiscard]] inline bool isCheck() const; - Bitboard attackTargets = theirPieces; - if (epSquare != Square::none()) - { - attackTargets |= epSquare; - } + [[nodiscard]] inline Bitboard checkers() const; - const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), sideToMove) & attackTargets; + [[nodiscard]] inline bool isCheckAfterMove(Move move) const; - const Rank secondToLastRank = sideToMove == Color::White ? rank7 : rank2; - const auto forward = sideToMove == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); + [[nodiscard]] inline bool isMoveLegal(Move move) const; - // promotions - if (from.rank() == secondToLastRank) - { - // capture promotions - for (Square toSq : attacks) - { - for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) - { - Move move{ from, toSq, MoveType::Promotion, Piece(pt, sideToMove) }; - f(move); - } - } + [[nodiscard]] inline bool isPseudoLegalMoveLegal(Move move) const; - // push promotions - const Square toSq = from + forward; - if (!occupied.isSet(toSq)) - { - for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) - { - Move move{ from, toSq, MoveType::Promotion, Piece(pt, sideToMove) }; - f(move); - } - } - } - else - { - // captures - for (Square toSq : attacks) - { - Move move{ from, toSq, (toSq == epSquare) ? MoveType::EnPassant : MoveType::Normal }; - f(move); - } + [[nodiscard]] inline bool isMovePseudoLegal(Move move) const; - const Square toSq = from + forward; + // Returns all pieces that block a slider + // from attacking our king. When two or more + // pieces block a single slider then none + // of these pieces are included. + [[nodiscard]] inline Bitboard blockersForKing(Color color) const; - // single push - if (!occupied.isSet(toSq)) - { - const Rank startRank = sideToMove == Color::White ? rank2 : rank7; - if (from.rank() == startRank) - { - // double push - const Square toSq2 = toSq + forward; - if (!occupied.isSet(toSq2)) - { - Move move{ from, toSq2 }; - f(move); - } - } + [[nodiscard]] constexpr Square epSquare() const { return m_epSquare; } - Move move{ from, toSq }; - f(move); - } - } - } + [[nodiscard]] constexpr CastlingRights castlingRights() const { return m_castlingRights; } - template - inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) - { - const Square epSquare = pos.epSquare(); - const Bitboard ourPieces = pos.piecesBB(SideToMoveV); - const Bitboard theirPieces = pos.piecesBB(!SideToMoveV); - const Bitboard occupied = ourPieces | theirPieces; - const Bitboard pawns = pos.piecesBB(Piece(PieceType::Pawn, SideToMoveV)); + [[nodiscard]] constexpr bool friend operator==(const Position& lhs, + const Position& rhs) noexcept { + return lhs.m_sideToMove == rhs.m_sideToMove && lhs.m_epSquare == rhs.m_epSquare + && lhs.m_castlingRights == rhs.m_castlingRights + && static_cast(lhs) == static_cast(rhs); + } - const Bitboard secondToLastRank = SideToMoveV == Color::White ? bb::rank7 : bb::rank2; - const Bitboard secondRank = SideToMoveV == Color::White ? bb::rank2 : bb::rank7; + [[nodiscard]] constexpr bool friend operator!=(const Position& lhs, + const Position& rhs) noexcept { + return !(lhs == rhs); + } - const auto singlePawnMoveDestinationOffset = SideToMoveV == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); - const auto doublePawnMoveDestinationOffset = SideToMoveV == Color::White ? FlatSquareOffset(0, 2) : FlatSquareOffset(0, -2); + // these are supposed to be used only for testing + // that's why there's this assert in afterMove - { - const int backward = SideToMoveV == Color::White ? -1 : 1; - const int backward2 = backward * 2; + [[nodiscard]] constexpr Position beforeMove(const ReverseMove& reverseMove) const { + Position cpy(*this); + cpy.undoMove(reverseMove); + return cpy; + } - const Bitboard doublePawnMoveStarts = - pawns - & secondRank - & ~(occupied.shiftedVertically(backward) | occupied.shiftedVertically(backward2)); + [[nodiscard]] inline Position afterMove(Move move) const; + + [[nodiscard]] constexpr bool isEpPossible() const { return m_epSquare != Square::none(); } + + [[nodiscard]] inline CompressedPosition compress() const; + + protected: + Color m_sideToMove; + Square m_epSquare; + CastlingRights m_castlingRights; + std::uint8_t m_rule50Counter; + std::uint16_t m_ply; + + static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) + == 4); + + [[nodiscard]] inline bool isEpPossible(Square epSquare, Color sideToMove) const; + + [[nodiscard]] inline bool + isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const; + + inline void nullifyEpSquareIfNotPossible(); +}; + +struct CompressedPosition { + friend struct Position; + + // Occupied bitboard has bits set for + // each square with a piece on it. + // Each packedState byte holds 2 values (nibbles). + // First one at low bits, second one at high bits. + // Values correspond to consecutive squares + // in bitboard iteration order. + // Nibble values: + // these are the same as for Piece + // knights, bishops, queens can just be copied + // 0 : white pawn + // 1 : black pawn + // 2 : white knight + // 3 : black knight + // 4 : white bishop + // 5 : black bishop + // 6 : white rook + // 7 : black rook + // 8 : white queen + // 9 : black queen + // 10 : white king + // 11 : black king + // + // these are special + // 12 : pawn with ep square behind (white or black, depending on rank) + // 13 : white rook with coresponding castling rights + // 14 : black rook with coresponding castling rights + // 15 : black king and black is side to move + // + // Let N be the number of bits set in occupied bitboard. + // Only N nibbles are present. (N+1)/2 bytes are initialized. + + static CompressedPosition readFromBigEndian(const unsigned char* data) { + CompressedPosition pos{}; + pos.m_occupied = + Bitboard::fromBits((std::uint64_t) data[0] << 56 | (std::uint64_t) data[1] << 48 + | (std::uint64_t) data[2] << 40 | (std::uint64_t) data[3] << 32 + | (std::uint64_t) data[4] << 24 | (std::uint64_t) data[5] << 16 + | (std::uint64_t) data[6] << 8 | (std::uint64_t) data[7]); + std::memcpy(pos.m_packedState, data + 8, 16); + return pos; + } - const Bitboard singlePawnMoveStarts = - pawns - & ~secondToLastRank - & ~occupied.shiftedVertically(backward); + constexpr CompressedPosition() : + m_occupied{}, + m_packedState{} {} - for (Square from : doublePawnMoveStarts) - { - const Square to = from + doublePawnMoveDestinationOffset; - f(Move::normal(from, to)); - } + [[nodiscard]] friend bool operator<(const CompressedPosition& lhs, + const CompressedPosition& rhs) { + if (lhs.m_occupied.bits() < rhs.m_occupied.bits()) + return true; + if (lhs.m_occupied.bits() > rhs.m_occupied.bits()) + return false; - for (Square from : singlePawnMoveStarts) - { - const Square to = from + singlePawnMoveDestinationOffset; - f(Move::normal(from, to)); - } - } + return std::strcmp(reinterpret_cast(lhs.m_packedState), + reinterpret_cast(rhs.m_packedState)) + < 0; + } - { - const Bitboard lastRank = SideToMoveV == Color::White ? bb::rank8 : bb::rank1; - const FlatSquareOffset westCaptureOffset = SideToMoveV == Color::White ? FlatSquareOffset(-1, 1) : FlatSquareOffset(-1, -1); - const FlatSquareOffset eastCaptureOffset = SideToMoveV == Color::White ? FlatSquareOffset(1, 1) : FlatSquareOffset(1, -1); + [[nodiscard]] friend bool operator==(const CompressedPosition& lhs, + const CompressedPosition& rhs) { + return lhs.m_occupied == rhs.m_occupied + && std::strcmp(reinterpret_cast(lhs.m_packedState), + reinterpret_cast(rhs.m_packedState)) + == 0; + } - const Bitboard pawnsWithWestCapture = bb::eastPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; - const Bitboard pawnsWithEastCapture = bb::westPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; + [[nodiscard]] inline Position decompress() const; + + [[nodiscard]] constexpr Bitboard pieceBB() const { return m_occupied; } + + void writeToBigEndian(unsigned char* data) { + const auto occupied = m_occupied.bits(); + *data++ = occupied >> 56; + *data++ = (occupied >> 48) & 0xFF; + *data++ = (occupied >> 40) & 0xFF; + *data++ = (occupied >> 32) & 0xFF; + *data++ = (occupied >> 24) & 0xFF; + *data++ = (occupied >> 16) & 0xFF; + *data++ = (occupied >> 8) & 0xFF; + *data++ = occupied & 0xFF; + std::memcpy(data, m_packedState, 16); + } - for (Square from : pawnsWithWestCapture) - { - f(Move::normal(from, from + westCaptureOffset)); - } + private: + Bitboard m_occupied; + std::uint8_t m_packedState[16]; +}; + +namespace movegen { +// For a pseudo-legal move the following are true: +// - the moving piece has the pos.sideToMove() color +// - the destination square is either empty or has a piece of the opposite color +// - if it is a pawn move it is valid (but may be illegal due to discovered checks) +// - if it is not a pawn move then the destination square is contained in attacks() +// - if it is a castling it is legal +// - a move other than castling may create a discovered attack on the king +// - a king may walk into a check + +template +inline void forEachPseudoLegalPawnMove(const Position& pos, Square from, FuncT&& f) { + const Color sideToMove = pos.sideToMove(); + const Square epSquare = pos.epSquare(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + + Bitboard attackTargets = theirPieces; + if (epSquare != Square::none()) + { + attackTargets |= epSquare; + } - for (Square from : pawnsWithEastCapture) - { - f(Move::normal(from, from + eastCaptureOffset)); - } - } + const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), sideToMove) & attackTargets; - if (epSquare != Square::none()) - { - const Bitboard pawnsThatCanCapture = bb::pawnAttacks(Bitboard::square(epSquare), !SideToMoveV) & pawns; - for (Square from : pawnsThatCanCapture) - { - f(Move::enPassant(from, epSquare)); - } - } + const Rank secondToLastRank = sideToMove == Color::White ? rank7 : rank2; + const auto forward = + sideToMove == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); - for (Square from : pawns & secondToLastRank) + // promotions + if (from.rank() == secondToLastRank) + { + // capture promotions + for (Square toSq : attacks) + { + for (PieceType pt : + {PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen}) { - const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), SideToMoveV) & theirPieces; - - // capture promotions - for (Square to : attacks) - { - for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) - { - Move move{ from, to, MoveType::Promotion, Piece(pt, SideToMoveV) }; - f(move); - } - } - - // push promotions - const Square to = from + singlePawnMoveDestinationOffset; - if (!occupied.isSet(to)) - { - for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) - { - Move move{ from, to, MoveType::Promotion, Piece(pt, SideToMoveV) }; - f(move); - } - } + Move move{from, toSq, MoveType::Promotion, Piece(pt, sideToMove)}; + f(move); } } - template - inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) + // push promotions + const Square toSq = from + forward; + if (!occupied.isSet(toSq)) { - if (pos.sideToMove() == Color::White) - { - forEachPseudoLegalPawnMove(pos, std::forward(f)); - } - else + for (PieceType pt : + {PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen}) { - forEachPseudoLegalPawnMove(pos, std::forward(f)); + Move move{from, toSq, MoveType::Promotion, Piece(pt, sideToMove)}; + f(move); } } - - template - inline void forEachPseudoLegalPieceMove(const Position& pos, Square from, FuncT&& f) + } + else + { + // captures + for (Square toSq : attacks) { - static_assert(PieceTypeV != PieceType::None); + Move move{from, toSq, (toSq == epSquare) ? MoveType::EnPassant : MoveType::Normal}; + f(move); + } - if constexpr (PieceTypeV == PieceType::Pawn) - { - forEachPseudoLegalPawnMove(pos, from, f); - } - else - { - const Color sideToMove = pos.sideToMove(); - const Bitboard ourPieces = pos.piecesBB(sideToMove); - const Bitboard theirPieces = pos.piecesBB(!sideToMove); - const Bitboard occupied = ourPieces | theirPieces; - const Bitboard attacks = bb::attacks(from, occupied) & ~ourPieces; + const Square toSq = from + forward; - for (Square toSq : attacks) + // single push + if (!occupied.isSet(toSq)) + { + const Rank startRank = sideToMove == Color::White ? rank2 : rank7; + if (from.rank() == startRank) + { + // double push + const Square toSq2 = toSq + forward; + if (!occupied.isSet(toSq2)) { - Move move{ from, toSq }; + Move move{from, toSq2}; f(move); } } + + Move move{from, toSq}; + f(move); } + } +} - template - inline void forEachPseudoLegalPieceMove(const Position& pos, FuncT&& f) - { - static_assert(PieceTypeV != PieceType::None); +template +inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) { + const Square epSquare = pos.epSquare(); + const Bitboard ourPieces = pos.piecesBB(SideToMoveV); + const Bitboard theirPieces = pos.piecesBB(!SideToMoveV); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard pawns = pos.piecesBB(Piece(PieceType::Pawn, SideToMoveV)); - if constexpr (PieceTypeV == PieceType::Pawn) - { - forEachPseudoLegalPawnMove(pos, f); - } - else - { - const Color sideToMove = pos.sideToMove(); - const Bitboard ourPieces = pos.piecesBB(sideToMove); - const Bitboard theirPieces = pos.piecesBB(!sideToMove); - const Bitboard occupied = ourPieces | theirPieces; - const Bitboard pieces = pos.piecesBB(Piece(PieceTypeV, sideToMove)); - for (Square fromSq : pieces) - { - const Bitboard attacks = bb::attacks(fromSq, occupied) & ~ourPieces; - for (Square toSq : attacks) - { - Move move{ fromSq, toSq }; - f(move); - } - } - } - } + const Bitboard secondToLastRank = SideToMoveV == Color::White ? bb::rank7 : bb::rank2; + const Bitboard secondRank = SideToMoveV == Color::White ? bb::rank2 : bb::rank7; - template - inline void forEachCastlingMove(const Position& pos, FuncT&& f) - { - CastlingRights rights = pos.castlingRights(); - if (rights == CastlingRights::None) - { - return; - } + const auto singlePawnMoveDestinationOffset = + SideToMoveV == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); + const auto doublePawnMoveDestinationOffset = + SideToMoveV == Color::White ? FlatSquareOffset(0, 2) : FlatSquareOffset(0, -2); + + { + const int backward = SideToMoveV == Color::White ? -1 : 1; + const int backward2 = backward * 2; - const Color sideToMove = pos.sideToMove(); - const Bitboard ourPieces = pos.piecesBB(sideToMove); - const Bitboard theirPieces = pos.piecesBB(!sideToMove); - const Bitboard occupied = ourPieces | theirPieces; + const Bitboard doublePawnMoveStarts = + pawns & secondRank + & ~(occupied.shiftedVertically(backward) | occupied.shiftedVertically(backward2)); - // we first reduce the set of legal castlings by checking the paths for pieces - if (sideToMove == Color::White) - { - if ((CastlingTraits::castlingPath[Color::White][CastleType::Short] & occupied).any()) rights &= ~CastlingRights::WhiteKingSide; - if ((CastlingTraits::castlingPath[Color::White][CastleType::Long] & occupied).any()) rights &= ~CastlingRights::WhiteQueenSide; - rights &= ~CastlingRights::Black; - } - else - { - if ((CastlingTraits::castlingPath[Color::Black][CastleType::Short] & occupied).any()) rights &= ~CastlingRights::BlackKingSide; - if ((CastlingTraits::castlingPath[Color::Black][CastleType::Long] & occupied).any()) rights &= ~CastlingRights::BlackQueenSide; - rights &= ~CastlingRights::White; - } + const Bitboard singlePawnMoveStarts = + pawns & ~secondToLastRank & ~occupied.shiftedVertically(backward); - if (rights == CastlingRights::None) - { - return; - } + for (Square from : doublePawnMoveStarts) + { + const Square to = from + doublePawnMoveDestinationOffset; + f(Move::normal(from, to)); + } + + for (Square from : singlePawnMoveStarts) + { + const Square to = from + singlePawnMoveDestinationOffset; + f(Move::normal(from, to)); + } + } + + { + const Bitboard lastRank = SideToMoveV == Color::White ? bb::rank8 : bb::rank1; + const FlatSquareOffset westCaptureOffset = + SideToMoveV == Color::White ? FlatSquareOffset(-1, 1) : FlatSquareOffset(-1, -1); + const FlatSquareOffset eastCaptureOffset = + SideToMoveV == Color::White ? FlatSquareOffset(1, 1) : FlatSquareOffset(1, -1); - // King must not be in check. Done here because it is quite expensive. - const Square ksq = pos.kingSquare(sideToMove); - if (pos.isSquareAttacked(ksq, !sideToMove)) - { - return; - } + const Bitboard pawnsWithWestCapture = + bb::eastPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; + const Bitboard pawnsWithEastCapture = + bb::westPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; - // Loop through all possible castlings. - for (CastleType castlingType : values()) - { - const CastlingRights right = CastlingTraits::castlingRights[sideToMove][castlingType]; + for (Square from : pawnsWithWestCapture) + { + f(Move::normal(from, from + westCaptureOffset)); + } - if (!contains(rights, right)) - { - continue; - } + for (Square from : pawnsWithEastCapture) + { + f(Move::normal(from, from + eastCaptureOffset)); + } + } - // If we have this castling right - // we check whether the king passes an attacked square. - const Square passedSquare = CastlingTraits::squarePassedByKing[sideToMove][castlingType]; - if (pos.isSquareAttacked(passedSquare, !sideToMove)) - { - continue; - } + if (epSquare != Square::none()) + { + const Bitboard pawnsThatCanCapture = + bb::pawnAttacks(Bitboard::square(epSquare), !SideToMoveV) & pawns; + for (Square from : pawnsThatCanCapture) + { + f(Move::enPassant(from, epSquare)); + } + } - // If it's a castling move then the change in square occupation - // cannot have an effect because otherwise there would be - // a slider attacker attacking the castling king. - if (pos.isSquareAttacked(CastlingTraits::kingDestination[sideToMove][castlingType], !sideToMove)) - { - continue; - } + for (Square from : pawns& secondToLastRank) + { + const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), SideToMoveV) & theirPieces; - // If not we can castle. - Move move = Move::castle(castlingType, sideToMove); + // capture promotions + for (Square to : attacks) + { + for (PieceType pt : + {PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen}) + { + Move move{from, to, MoveType::Promotion, Piece(pt, SideToMoveV)}; f(move); } } - // Calls a given function for all pseudo legal moves for the position. - // `pos` must be a legal chess position - template - inline void forEachPseudoLegalMove(const Position& pos, FuncT&& func) + // push promotions + const Square to = from + singlePawnMoveDestinationOffset; + if (!occupied.isSet(to)) { - forEachPseudoLegalPieceMove(pos, func); - forEachPseudoLegalPieceMove(pos, func); - forEachPseudoLegalPieceMove(pos, func); - forEachPseudoLegalPieceMove(pos, func); - forEachPseudoLegalPieceMove(pos, func); - forEachPseudoLegalPieceMove(pos, func); - forEachCastlingMove(pos, func); + for (PieceType pt : + {PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen}) + { + Move move{from, to, MoveType::Promotion, Piece(pt, SideToMoveV)}; + f(move); + } } + } +} - // Calls a given function for all legal moves for the position. - // `pos` must be a legal chess position - template - inline void forEachLegalMove(const Position& pos, FuncT&& func) - { - auto funcIfLegal = [&func, checker = pos.moveLegalityChecker()](Move move) { - if (checker.isPseudoLegalMoveLegal(move)) - { - func(move); - } - }; +template +inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) { + if (pos.sideToMove() == Color::White) + { + forEachPseudoLegalPawnMove(pos, std::forward(f)); + } + else + { + forEachPseudoLegalPawnMove(pos, std::forward(f)); + } +} - forEachPseudoLegalPieceMove(pos, funcIfLegal); - forEachPseudoLegalPieceMove(pos, funcIfLegal); - forEachPseudoLegalPieceMove(pos, funcIfLegal); - forEachPseudoLegalPieceMove(pos, funcIfLegal); - forEachPseudoLegalPieceMove(pos, funcIfLegal); - forEachPseudoLegalPieceMove(pos, funcIfLegal); - forEachCastlingMove(pos, func); - } +template +inline void forEachPseudoLegalPieceMove(const Position& pos, Square from, FuncT&& f) { + static_assert(PieceTypeV != PieceType::None); - // Generates all pseudo legal moves for the position. - // `pos` must be a legal chess position - [[nodiscard]] std::vector generatePseudoLegalMoves(const Position& pos); + if constexpr (PieceTypeV == PieceType::Pawn) + { + forEachPseudoLegalPawnMove(pos, from, f); + } + else + { + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard attacks = bb::attacks(from, occupied) & ~ourPieces; - // Generates all legal moves for the position. - // `pos` must be a legal chess position - [[nodiscard]] std::vector generateLegalMoves(const Position& pos); + for (Square toSq : attacks) + { + Move move{from, toSq}; + f(move); + } } +} - [[nodiscard]] inline bool Position::isCheck() const +template +inline void forEachPseudoLegalPieceMove(const Position& pos, FuncT&& f) { + static_assert(PieceTypeV != PieceType::None); + + if constexpr (PieceTypeV == PieceType::Pawn) + { + forEachPseudoLegalPawnMove(pos, f); + } + else { - return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard pieces = pos.piecesBB(Piece(PieceTypeV, sideToMove)); + for (Square fromSq : pieces) + { + const Bitboard attacks = bb::attacks(fromSq, occupied) & ~ourPieces; + for (Square toSq : attacks) + { + Move move{fromSq, toSq}; + f(move); + } + } } +} - [[nodiscard]] inline Bitboard Position::checkers() const +template +inline void forEachCastlingMove(const Position& pos, FuncT&& f) { + CastlingRights rights = pos.castlingRights(); + if (rights == CastlingRights::None) { - return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); + return; } - [[nodiscard]] inline bool Position::isCheckAfterMove(Move move) const + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + + // we first reduce the set of legal castlings by checking the paths for pieces + if (sideToMove == Color::White) { - return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); + if ((CastlingTraits::castlingPath[Color::White][CastleType::Short] & occupied).any()) + rights &= ~CastlingRights::WhiteKingSide; + if ((CastlingTraits::castlingPath[Color::White][CastleType::Long] & occupied).any()) + rights &= ~CastlingRights::WhiteQueenSide; + rights &= ~CastlingRights::Black; + } + else + { + if ((CastlingTraits::castlingPath[Color::Black][CastleType::Short] & occupied).any()) + rights &= ~CastlingRights::BlackKingSide; + if ((CastlingTraits::castlingPath[Color::Black][CastleType::Long] & occupied).any()) + rights &= ~CastlingRights::BlackQueenSide; + rights &= ~CastlingRights::White; } - [[nodiscard]] inline bool Position::isMoveLegal(Move move) const + if (rights == CastlingRights::None) { - return - isMovePseudoLegal(move) - && isPseudoLegalMoveLegal(move); + return; } - [[nodiscard]] inline bool Position::isPseudoLegalMoveLegal(Move move) const + // King must not be in check. Done here because it is quite expensive. + const Square ksq = pos.kingSquare(sideToMove); + if (pos.isSquareAttacked(ksq, !sideToMove)) { - return - (move.type == MoveType::Castle) - || !isOwnKingAttackedAfterMove(move); + return; } - [[nodiscard]] inline bool Position::isMovePseudoLegal(Move move) const + // Loop through all possible castlings. + for (CastleType castlingType : values()) { - if (!move.from.isOk() || !move.to.isOk()) - { - return false; - } + const CastlingRights right = CastlingTraits::castlingRights[sideToMove][castlingType]; - if (move.from == move.to) + if (!contains(rights, right)) { - return false; + continue; } - if (move.type != MoveType::Promotion && move.promotedPiece != Piece::none()) + // If we have this castling right + // we check whether the king passes an attacked square. + const Square passedSquare = CastlingTraits::squarePassedByKing[sideToMove][castlingType]; + if (pos.isSquareAttacked(passedSquare, !sideToMove)) { - return false; + continue; } - const Piece movedPiece = pieceAt(move.from); - if (movedPiece == Piece::none()) + // If it's a castling move then the change in square occupation + // cannot have an effect because otherwise there would be + // a slider attacker attacking the castling king. + if (pos.isSquareAttacked(CastlingTraits::kingDestination[sideToMove][castlingType], + !sideToMove)) { - return false; + continue; } - if (movedPiece.color() != m_sideToMove) + // If not we can castle. + Move move = Move::castle(castlingType, sideToMove); + f(move); + } +} + +// Calls a given function for all pseudo legal moves for the position. +// `pos` must be a legal chess position +template +inline void forEachPseudoLegalMove(const Position& pos, FuncT&& func) { + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachCastlingMove(pos, func); +} + +// Calls a given function for all legal moves for the position. +// `pos` must be a legal chess position +template +inline void forEachLegalMove(const Position& pos, FuncT&& func) { + auto funcIfLegal = [&func, checker = pos.moveLegalityChecker()](Move move) { + if (checker.isPseudoLegalMoveLegal(move)) { - return false; + func(move); } + }; - const Bitboard occupied = piecesBB(); - const Bitboard ourPieces = piecesBB(m_sideToMove); - const bool isNormal = move.type == MoveType::Normal; + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachCastlingMove(pos, func); +} - switch (movedPiece.type()) - { - case PieceType::Pawn: +// Generates all pseudo legal moves for the position. +// `pos` must be a legal chess position +[[nodiscard]] std::vector generatePseudoLegalMoves(const Position& pos); + +// Generates all legal moves for the position. +// `pos` must be a legal chess position +[[nodiscard]] std::vector generateLegalMoves(const Position& pos); +} + +[[nodiscard]] inline bool Position::isCheck() const { + return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); +} + +[[nodiscard]] inline Bitboard Position::checkers() const { + return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); +} + +[[nodiscard]] inline bool Position::isCheckAfterMove(Move move) const { + return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); +} + +[[nodiscard]] inline bool Position::isMoveLegal(Move move) const { + return isMovePseudoLegal(move) && isPseudoLegalMoveLegal(move); +} + +[[nodiscard]] inline bool Position::isPseudoLegalMoveLegal(Move move) const { + return (move.type == MoveType::Castle) || !isOwnKingAttackedAfterMove(move); +} + +[[nodiscard]] inline bool Position::isMovePseudoLegal(Move move) const { + if (!move.from.isOk() || !move.to.isOk()) + { + return false; + } + + if (move.from == move.to) + { + return false; + } + + if (move.type != MoveType::Promotion && move.promotedPiece != Piece::none()) + { + return false; + } + + const Piece movedPiece = pieceAt(move.from); + if (movedPiece == Piece::none()) + { + return false; + } + + if (movedPiece.color() != m_sideToMove) + { + return false; + } + + const Bitboard occupied = piecesBB(); + const Bitboard ourPieces = piecesBB(m_sideToMove); + const bool isNormal = move.type == MoveType::Normal; + + switch (movedPiece.type()) + { + case PieceType::Pawn : { + bool isValid = false; + // TODO: use iterators so we don't loop over all moves + // when we can avoid it. + movegen::forEachPseudoLegalPawnMove(*this, move.from, + [&isValid, &move](const Move& genMove) { + if (move == genMove) + { + isValid = true; + } + }); + return isValid; + } + + case PieceType::Bishop : + return isNormal + && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::Knight : + return isNormal + && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); + + case PieceType::Rook : + return isNormal + && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::Queen : + return isNormal + && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::King : { + if (move.type == MoveType::Castle) { bool isValid = false; - // TODO: use iterators so we don't loop over all moves - // when we can avoid it. - movegen::forEachPseudoLegalPawnMove(*this, move.from, [&isValid, &move](const Move& genMove) { + movegen::forEachCastlingMove(*this, [&isValid, &move](const Move& genMove) { if (move == genMove) { isValid = true; } - }); + }); return isValid; } - - case PieceType::Bishop: - return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); - - case PieceType::Knight: - return isNormal && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); - - case PieceType::Rook: - return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); - - case PieceType::Queen: - return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); - - case PieceType::King: + else { - if (move.type == MoveType::Castle) - { - bool isValid = false; - movegen::forEachCastlingMove(*this, [&isValid, &move](const Move& genMove) { - if (move == genMove) - { - isValid = true; - } - }); - return isValid; - } - else - { - return isNormal && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); - } + return isNormal + && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); } + } - default: - return false; - } + default : + return false; } +} - [[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const - { - const Color attackerColor = !color; +[[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const { + const Color attackerColor = !color; - const Bitboard occupied = piecesBB(); + const Bitboard occupied = piecesBB(); - const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); - const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); - const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - const Square ksq = kingSquare(color); + const Square ksq = kingSquare(color); - const Bitboard opponentBishopLikePieces = (bishops | queens); - const Bitboard bishopPseudoAttacks = bb::pseudoAttacks(ksq); + const Bitboard opponentBishopLikePieces = (bishops | queens); + const Bitboard bishopPseudoAttacks = bb::pseudoAttacks(ksq); - const Bitboard opponentRookLikePieces = (rooks | queens); - const Bitboard rookPseudoAttacks = bb::pseudoAttacks(ksq); + const Bitboard opponentRookLikePieces = (rooks | queens); + const Bitboard rookPseudoAttacks = bb::pseudoAttacks(ksq); - const Bitboard xrayers = - (bishopPseudoAttacks & opponentBishopLikePieces) - | (rookPseudoAttacks & opponentRookLikePieces); + const Bitboard xrayers = (bishopPseudoAttacks & opponentBishopLikePieces) + | (rookPseudoAttacks & opponentRookLikePieces); - Bitboard allBlockers = Bitboard::none(); + Bitboard allBlockers = Bitboard::none(); - for (Square xrayer : xrayers) + for (Square xrayer : xrayers) + { + const Bitboard blockers = bb::between(xrayer, ksq) & occupied; + if (blockers.exactlyOne()) { - const Bitboard blockers = bb::between(xrayer, ksq) & occupied; - if (blockers.exactlyOne()) - { - allBlockers |= blockers; - } + allBlockers |= blockers; } - - return allBlockers; } - inline MoveLegalityChecker::MoveLegalityChecker(const Position& position) : - m_position(&position), - m_checkers(position.checkers()), - m_ourBlockersForKing( - position.blockersForKing(position.sideToMove()) - & position.piecesBB(position.sideToMove()) - ), - m_ksq(position.kingSquare(position.sideToMove())) + return allBlockers; +} + +inline MoveLegalityChecker::MoveLegalityChecker(const Position& position) : + m_position(&position), + m_checkers(position.checkers()), + m_ourBlockersForKing(position.blockersForKing(position.sideToMove()) + & position.piecesBB(position.sideToMove())), + m_ksq(position.kingSquare(position.sideToMove())) { + if (m_checkers.exactlyOne()) { - if (m_checkers.exactlyOne()) + const Bitboard knightCheckers = m_checkers & bb::pseudoAttacks(m_ksq); + if (knightCheckers.any()) { - const Bitboard knightCheckers = m_checkers & bb::pseudoAttacks(m_ksq); - if (knightCheckers.any()) - { - // We're checked by a knight, we have to remove it or move the king. - m_potentialCheckRemovals = knightCheckers; - } - else - { - // If we're not checked by a knight we can block it. - m_potentialCheckRemovals = bb::between(m_ksq, m_checkers.first()) | m_checkers; - } + // We're checked by a knight, we have to remove it or move the king. + m_potentialCheckRemovals = knightCheckers; } else { - // Double check, king has to move. - m_potentialCheckRemovals = Bitboard::none(); + // If we're not checked by a knight we can block it. + m_potentialCheckRemovals = bb::between(m_ksq, m_checkers.first()) | m_checkers; } } + else + { + // Double check, king has to move. + m_potentialCheckRemovals = Bitboard::none(); + } +} - [[nodiscard]] inline bool MoveLegalityChecker::isPseudoLegalMoveLegal(const Move& move) const +[[nodiscard]] inline bool MoveLegalityChecker::isPseudoLegalMoveLegal(const Move& move) const { + if (m_checkers.any()) { - if (m_checkers.any()) + if (move.from == m_ksq || move.type == MoveType::EnPassant) { - if (move.from == m_ksq || move.type == MoveType::EnPassant) - { - return m_position->isPseudoLegalMoveLegal(move); - } - else - { - // This means there's only one check and we either - // blocked it or removed the piece that attacked - // our king. So the only threat is if it's a discovered check. - return - m_potentialCheckRemovals.isSet(move.to) - && !m_ourBlockersForKing.isSet(move.from); - } + return m_position->isPseudoLegalMoveLegal(move); } else { - if (move.from == m_ksq) - { - return m_position->isPseudoLegalMoveLegal(move); - } - else if (move.type == MoveType::EnPassant) - { - return !m_position->createsDiscoveredAttackOnOwnKing(move); - } - else if (m_ourBlockersForKing.isSet(move.from)) - { - // If it was a blocker it may have only moved in line with our king. - // Otherwise it's a discovered check. - return bb::line(m_ksq, move.from).isSet(move.to); - } - else - { - return true; - } + // This means there's only one check and we either + // blocked it or removed the piece that attacked + // our king. So the only threat is if it's a discovered check. + return m_potentialCheckRemovals.isSet(move.to) + && !m_ourBlockersForKing.isSet(move.from); } } - - static_assert(sizeof(CompressedPosition) == 24); - static_assert(std::is_trivially_copyable_v); - - namespace detail + else { - [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressOrdinaryPiece(const Position&, Square, Piece piece) + if (move.from == m_ksq) { - return static_cast(ordinal(piece)); + return m_position->isPseudoLegalMoveLegal(move); } - - [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressPawn(const Position& position, Square sq, Piece piece) + else if (move.type == MoveType::EnPassant) { - const Square epSquare = position.epSquare(); - if (epSquare == Square::none()) - { - return static_cast(ordinal(piece)); - } - else - { - const Color sideToMove = position.sideToMove(); - const Rank rank = sq.rank(); - const File file = sq.file(); - // use bitwise operators, there is a lot of unpredictable branches but in - // total the result is quite predictable - if ( - (file == epSquare.file()) - && ( - ((rank == rank4) & (sideToMove == Color::Black)) - | ((rank == rank5) & (sideToMove == Color::White)) - ) - ) - { - return 12; - } - else - { - return static_cast(ordinal(piece)); - } - } + return !m_position->createsDiscoveredAttackOnOwnKing(move); } - - [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressRook(const Position& position, Square sq, Piece piece) + else if (m_ourBlockersForKing.isSet(move.from)) { - const CastlingRights castlingRights = position.castlingRights(); - const Color color = piece.color(); - - if (color == Color::White - && ( - (sq == a1 && contains(castlingRights, CastlingRights::WhiteQueenSide)) - || (sq == h1 && contains(castlingRights, CastlingRights::WhiteKingSide)) - ) - ) - { - return 13; - } - else if ( - color == Color::Black - && ( - (sq == a8 && contains(castlingRights, CastlingRights::BlackQueenSide)) - || (sq == h8 && contains(castlingRights, CastlingRights::BlackKingSide)) - ) - ) - { - return 14; - } - else - { - return static_cast(ordinal(piece)); - } + // If it was a blocker it may have only moved in line with our king. + // Otherwise it's a discovered check. + return bb::line(m_ksq, move.from).isSet(move.to); } - - [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressKing(const Position& position, Square /* sq */, Piece piece) + else { - const Color color = piece.color(); - const Color sideToMove = position.sideToMove(); - - if (color == Color::White) - { - return 10; - } - else if (sideToMove == Color::White) - { - return 11; - } - else - { - return 15; - } + return true; } } +} - namespace detail::lookup - { - static constexpr EnumArray pieceCompressorFunc = []() { - EnumArray pieceCompressorFunc_{}; - - pieceCompressorFunc_[PieceType::Knight] = detail::compressOrdinaryPiece; - pieceCompressorFunc_[PieceType::Bishop] = detail::compressOrdinaryPiece; - pieceCompressorFunc_[PieceType::Queen] = detail::compressOrdinaryPiece; - - pieceCompressorFunc_[PieceType::Pawn] = detail::compressPawn; - pieceCompressorFunc_[PieceType::Rook] = detail::compressRook; - pieceCompressorFunc_[PieceType::King] = detail::compressKing; +static_assert(sizeof(CompressedPosition) == 24); +static_assert(std::is_trivially_copyable_v); - pieceCompressorFunc_[PieceType::None] = [](const Position&, Square, Piece) -> std::uint8_t { /* should never happen */ return 0; }; +namespace detail { +[[nodiscard]] FORCEINLINE constexpr std::uint8_t +compressOrdinaryPiece(const Position&, Square, Piece piece) { + return static_cast(ordinal(piece)); +} - return pieceCompressorFunc_; - }(); +[[nodiscard]] FORCEINLINE constexpr std::uint8_t +compressPawn(const Position& position, Square sq, Piece piece) { + const Square epSquare = position.epSquare(); + if (epSquare == Square::none()) + { + return static_cast(ordinal(piece)); } - - [[nodiscard]] inline CompressedPosition Position::compress() const + else { - auto compressPiece = [this](Square sq, Piece piece) -> std::uint8_t { - if (piece.type() == PieceType::Pawn) // it's likely to be a pawn - { - return detail::compressPawn(*this, sq, piece); - } - else - { - return detail::lookup::pieceCompressorFunc[piece.type()](*this, sq, piece); - } - }; - - const Bitboard occ = piecesBB(); - - CompressedPosition compressed; - compressed.m_occupied = occ; - - auto it = occ.begin(); - auto end = occ.end(); - for (int i = 0;; ++i) + const Color sideToMove = position.sideToMove(); + const Rank rank = sq.rank(); + const File file = sq.file(); + // use bitwise operators, there is a lot of unpredictable branches but in + // total the result is quite predictable + if ((file == epSquare.file()) + && (((rank == rank4) & (sideToMove == Color::Black)) + | ((rank == rank5) & (sideToMove == Color::White)))) { - if (it == end) break; - compressed.m_packedState[i] = compressPiece(*it, pieceAt(*it)); - ++it; - - if (it == end) break; - compressed.m_packedState[i] |= compressPiece(*it, pieceAt(*it)) << 4; - ++it; + return 12; + } + else + { + return static_cast(ordinal(piece)); } - - return compressed; } +} - [[nodiscard]] inline Position CompressedPosition::decompress() const - { - Position pos; - pos.setCastlingRights(CastlingRights::None); - - auto decompressPiece = [&pos](Square sq, std::uint8_t nibble) { - switch (nibble) - { - case 0: - case 1: - case 2: - case 3: - case 4: - case 5: - case 6: - case 7: - case 8: - case 9: - case 10: - case 11: - { - pos.place(fromOrdinal(nibble), sq); - return; - } - - case 12: - { - const Rank rank = sq.rank(); - if (rank == rank4) - { - pos.place(whitePawn, sq); - pos.setEpSquareUnchecked(sq + Offset{ 0, -1 }); - } - else // (rank == rank5) - { - pos.place(blackPawn, sq); - pos.setEpSquareUnchecked(sq + Offset{ 0, 1 }); - } - return; - } - - case 13: - { - pos.place(whiteRook, sq); - if (sq == a1) - { - pos.addCastlingRights(CastlingRights::WhiteQueenSide); - } - else // (sq == H1) - { - pos.addCastlingRights(CastlingRights::WhiteKingSide); - } - return; - } - - case 14: - { - pos.place(blackRook, sq); - if (sq == a8) - { - pos.addCastlingRights(CastlingRights::BlackQueenSide); - } - else // (sq == H8) - { - pos.addCastlingRights(CastlingRights::BlackKingSide); - } - return; - } - - case 15: - { - pos.place(blackKing, sq); - pos.setSideToMove(Color::Black); - return; - } - - } +[[nodiscard]] FORCEINLINE constexpr std::uint8_t +compressRook(const Position& position, Square sq, Piece piece) { + const CastlingRights castlingRights = position.castlingRights(); + const Color color = piece.color(); - return; - }; + if (color == Color::White + && ((sq == a1 && contains(castlingRights, CastlingRights::WhiteQueenSide)) + || (sq == h1 && contains(castlingRights, CastlingRights::WhiteKingSide)))) + { + return 13; + } + else if (color == Color::Black + && ((sq == a8 && contains(castlingRights, CastlingRights::BlackQueenSide)) + || (sq == h8 && contains(castlingRights, CastlingRights::BlackKingSide)))) + { + return 14; + } + else + { + return static_cast(ordinal(piece)); + } +} - const Bitboard occ = m_occupied; +[[nodiscard]] FORCEINLINE constexpr std::uint8_t +compressKing(const Position& position, Square /* sq */, Piece piece) { + const Color color = piece.color(); + const Color sideToMove = position.sideToMove(); - auto it = occ.begin(); - auto end = occ.end(); - for (int i = 0;; ++i) - { - if (it == end) break; - decompressPiece(*it, m_packedState[i] & 0xF); - ++it; + if (color == Color::White) + { + return 10; + } + else if (sideToMove == Color::White) + { + return 11; + } + else + { + return 15; + } +} +} - if (it == end) break; - decompressPiece(*it, m_packedState[i] >> 4); - ++it; - } +namespace detail::lookup { +static constexpr EnumArray + pieceCompressorFunc = []() { + EnumArray pieceCompressorFunc_{}; - return pos; - } + pieceCompressorFunc_[PieceType::Knight] = detail::compressOrdinaryPiece; + pieceCompressorFunc_[PieceType::Bishop] = detail::compressOrdinaryPiece; + pieceCompressorFunc_[PieceType::Queen] = detail::compressOrdinaryPiece; + pieceCompressorFunc_[PieceType::Pawn] = detail::compressPawn; + pieceCompressorFunc_[PieceType::Rook] = detail::compressRook; + pieceCompressorFunc_[PieceType::King] = detail::compressKing; - [[nodiscard]] bool Board::isSquareAttacked(Square sq, Color attackerColor) const - { - assert(sq.isOk()); + pieceCompressorFunc_[PieceType::None] = [](const Position&, Square, + Piece) -> std::uint8_t { /* should never happen */ + return 0; + }; - const Bitboard occupied = piecesBB(); - const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); - const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); - const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - - const Bitboard allSliders = (bishops | rooks | queens); - if ((bb::pseudoAttacks(sq) & allSliders).any()) - { - if (bb::isAttackedBySlider( - sq, - bishops, - rooks, - queens, - occupied - )) - { - return true; - } - } + return pieceCompressorFunc_; + }(); +} - const Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); - if ((bb::pseudoAttacks(sq) & king).any()) +[[nodiscard]] inline CompressedPosition Position::compress() const { + auto compressPiece = [this](Square sq, Piece piece) -> std::uint8_t { + if (piece.type() == PieceType::Pawn) // it's likely to be a pawn { - return true; + return detail::compressPawn(*this, sq, piece); } - - const Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); - if ((bb::pseudoAttacks(sq) & knights).any()) + else { - return true; + return detail::lookup::pieceCompressorFunc[piece.type()](*this, sq, piece); } + }; - const Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); - const Bitboard pawnAttacks = bb::pawnAttacks(pawns, attackerColor); + const Bitboard occ = piecesBB(); - return pawnAttacks.isSet(sq); - } + CompressedPosition compressed; + compressed.m_occupied = occ; - [[nodiscard]] bool Board::isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const + auto it = occ.begin(); + auto end = occ.end(); + for (int i = 0;; ++i) { - const Bitboard occupiedChange = Bitboard::square(move.from) | move.to; + if (it == end) + break; + compressed.m_packedState[i] = compressPiece(*it, pieceAt(*it)); + ++it; - Bitboard occupied = (piecesBB() ^ move.from) | move.to; + if (it == end) + break; + compressed.m_packedState[i] |= compressPiece(*it, pieceAt(*it)) << 4; + ++it; + } - Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); - Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); - Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); - Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); - Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); + return compressed; +} - if (move.type == MoveType::EnPassant) - { - const Square capturedPawnSq(move.to.file(), move.from.rank()); - occupied ^= capturedPawnSq; - pawns ^= capturedPawnSq; - } - else if (pieceAt(move.to) != Piece::none()) - { - const Bitboard notCaptured = ~Bitboard::square(move.to); - bishops &= notCaptured; - rooks &= notCaptured; - queens &= notCaptured; - knights &= notCaptured; - pawns &= notCaptured; +[[nodiscard]] inline Position CompressedPosition::decompress() const { + Position pos; + pos.setCastlingRights(CastlingRights::None); + + auto decompressPiece = [&pos](Square sq, std::uint8_t nibble) { + switch (nibble) + { + case 0 : + case 1 : + case 2 : + case 3 : + case 4 : + case 5 : + case 6 : + case 7 : + case 8 : + case 9 : + case 10 : + case 11 : { + pos.place(fromOrdinal(nibble), sq); + return; } - // Potential attackers may have moved. - const Piece movedPiece = pieceAt(move.from); - if (movedPiece.color() == attackerColor) - { - switch (movedPiece.type()) + case 12 : { + const Rank rank = sq.rank(); + if (rank == rank4) { - case PieceType::Pawn: - pawns ^= occupiedChange; - break; - case PieceType::Knight: - knights ^= occupiedChange; - break; - case PieceType::Bishop: - bishops ^= occupiedChange; - break; - case PieceType::Rook: - rooks ^= occupiedChange; - break; - case PieceType::Queen: - queens ^= occupiedChange; - break; - case PieceType::King: + pos.place(whitePawn, sq); + pos.setEpSquareUnchecked(sq + Offset{0, -1}); + } + else // (rank == rank5) { - if (move.type == MoveType::Castle) - { - const CastleType castleType = CastlingTraits::moveCastlingType(move); + pos.place(blackPawn, sq); + pos.setEpSquareUnchecked(sq + Offset{0, 1}); + } + return; + } - king ^= move.from; - king ^= CastlingTraits::kingDestination[attackerColor][castleType]; - rooks ^= move.to; - rooks ^= CastlingTraits::rookDestination[attackerColor][castleType]; - } - else - { - king ^= occupiedChange; - } - break; + case 13 : { + pos.place(whiteRook, sq); + if (sq == a1) + { + pos.addCastlingRights(CastlingRights::WhiteQueenSide); } - case PieceType::None: - assert(false); + else // (sq == H1) + { + pos.addCastlingRights(CastlingRights::WhiteKingSide); } + return; } - // If it's a castling move then the change in square occupation - // cannot have an effect because otherwise there would be - // a slider attacker attacking the castling king. - // (It could have an effect in chess960 if the slider - // attacker was behind the rook involved in castling, - // but we don't care about chess960.) - - const Bitboard allSliders = (bishops | rooks | queens); - if ((bb::pseudoAttacks(sq) & allSliders).any()) - { - if (bb::isAttackedBySlider( - sq, - bishops, - rooks, - queens, - occupied - )) + case 14 : { + pos.place(blackRook, sq); + if (sq == a8) + { + pos.addCastlingRights(CastlingRights::BlackQueenSide); + } + else // (sq == H8) { - return true; + pos.addCastlingRights(CastlingRights::BlackKingSide); } + return; } - if ((bb::pseudoAttacks(sq) & king).any()) - { - return true; + case 15 : { + pos.place(blackKing, sq); + pos.setSideToMove(Color::Black); + return; } - - if ((bb::pseudoAttacks(sq) & knights).any()) - { - return true; } - const Bitboard pawnAttacks = bb::pawnAttacks(pawns, attackerColor); + return; + }; - return pawnAttacks.isSet(sq); - } + const Bitboard occ = m_occupied; - [[nodiscard]] bool Board::createsDiscoveredAttackOnOwnKing(Move move) const + auto it = occ.begin(); + auto end = occ.end(); + for (int i = 0;; ++i) { - Bitboard occupied = (piecesBB() ^ move.from) | move.to; + if (it == end) + break; + decompressPiece(*it, m_packedState[i] & 0xF); + ++it; - const Piece movedPiece = pieceAt(move.from); - const Color kingColor = movedPiece.color(); - const Color attackerColor = !kingColor; - const Square ksq = kingSquare(kingColor); + if (it == end) + break; + decompressPiece(*it, m_packedState[i] >> 4); + ++it; + } - Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); - Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); - Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + return pos; +} - if (move.type == MoveType::EnPassant) - { - const Square capturedPawnSq(move.to.file(), move.from.rank()); - occupied ^= capturedPawnSq; - } - else if (pieceAt(move.to) != Piece::none()) - { - const Bitboard notCaptured = ~Bitboard::square(move.to); - bishops &= notCaptured; - rooks &= notCaptured; - queens &= notCaptured; - } - const Bitboard allSliders = (bishops | rooks | queens); - if ((bb::pseudoAttacks(ksq) & allSliders).any()) - { - if (bb::isAttackedBySlider( - ksq, - bishops, - rooks, - queens, - occupied - )) - { - return true; - } - } +[[nodiscard]] bool Board::isSquareAttacked(Square sq, Color attackerColor) const { + assert(sq.isOk()); - return false; - } + const Bitboard occupied = piecesBB(); + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - [[nodiscard]] bool Board::isPieceAttacked(Square sq) const + const Bitboard allSliders = (bishops | rooks | queens); + if ((bb::pseudoAttacks(sq) & allSliders).any()) { - const Piece piece = pieceAt(sq); - - if (piece == Piece::none()) + if (bb::isAttackedBySlider(sq, bishops, rooks, queens, occupied)) { - return false; + return true; } - - return isSquareAttacked(sq, !piece.color()); } - [[nodiscard]] bool Board::isPieceAttackedAfterMove(Move move, Square sq) const + const Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); + if ((bb::pseudoAttacks(sq) & king).any()) { - const Piece piece = pieceAt(sq); + return true; + } - if (piece == Piece::none()) - { - return false; - } + const Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); + if ((bb::pseudoAttacks(sq) & knights).any()) + { + return true; + } - if (sq == move.from) - { - // We moved the piece we're interested in. - // For every move the piece ends up on the move.to except - // for the case of castling moves. - // But we know pseudo legal castling moves - // are already legal, so the king cannot be in check after. - if (move.type == MoveType::Castle) - { - return false; - } + const Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); + const Bitboard pawnAttacks = bb::pawnAttacks(pawns, attackerColor); - // So update the square we're interested in. - sq = move.to; - } + return pawnAttacks.isSet(sq); +} - return isSquareAttackedAfterMove(move, sq, !piece.color()); - } +[[nodiscard]] bool +Board::isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const { + const Bitboard occupiedChange = Bitboard::square(move.from) | move.to; - [[nodiscard]] bool Board::isOwnKingAttackedAfterMove(Move move) const - { - if (move.type == MoveType::Castle) - { - // Pseudo legal castling moves are already legal. - // This is ensured by the move generator. - return false; - } + Bitboard occupied = (piecesBB() ^ move.from) | move.to; - const Piece movedPiece = pieceAt(move.from); + Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); + Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); + Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); - return isPieceAttackedAfterMove(move, kingSquare(movedPiece.color())); + if (move.type == MoveType::EnPassant) + { + const Square capturedPawnSq(move.to.file(), move.from.rank()); + occupied ^= capturedPawnSq; + pawns ^= capturedPawnSq; + } + else if (pieceAt(move.to) != Piece::none()) + { + const Bitboard notCaptured = ~Bitboard::square(move.to); + bishops &= notCaptured; + rooks &= notCaptured; + queens &= notCaptured; + knights &= notCaptured; + pawns &= notCaptured; } - [[nodiscard]] Bitboard Board::attacks(Square sq) const + // Potential attackers may have moved. + const Piece movedPiece = pieceAt(move.from); + if (movedPiece.color() == attackerColor) { - const Piece piece = pieceAt(sq); - if (piece == Piece::none()) + switch (movedPiece.type()) { - return Bitboard::none(); - } + case PieceType::Pawn : + pawns ^= occupiedChange; + break; + case PieceType::Knight : + knights ^= occupiedChange; + break; + case PieceType::Bishop : + bishops ^= occupiedChange; + break; + case PieceType::Rook : + rooks ^= occupiedChange; + break; + case PieceType::Queen : + queens ^= occupiedChange; + break; + case PieceType::King : { + if (move.type == MoveType::Castle) + { + const CastleType castleType = CastlingTraits::moveCastlingType(move); - if (piece.type() == PieceType::Pawn) - { - return bb::pawnAttacks(Bitboard::square(sq), piece.color()); + king ^= move.from; + king ^= CastlingTraits::kingDestination[attackerColor][castleType]; + rooks ^= move.to; + rooks ^= CastlingTraits::rookDestination[attackerColor][castleType]; + } + else + { + king ^= occupiedChange; + } + break; } - else - { - return bb::attacks(piece.type(), sq, piecesBB()); + case PieceType::None : + assert(false); } } - [[nodiscard]] Bitboard Board::attackers(Square sq, Color attackerColor) const - { - // En-passant square is not included. - - Bitboard allAttackers = Bitboard::none(); + // If it's a castling move then the change in square occupation + // cannot have an effect because otherwise there would be + // a slider attacker attacking the castling king. + // (It could have an effect in chess960 if the slider + // attacker was behind the rook involved in castling, + // but we don't care about chess960.) - const Bitboard occupied = piecesBB(); + const Bitboard allSliders = (bishops | rooks | queens); + if ((bb::pseudoAttacks(sq) & allSliders).any()) + { + if (bb::isAttackedBySlider(sq, bishops, rooks, queens, occupied)) + { + return true; + } + } - const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); - const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); - const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + if ((bb::pseudoAttacks(sq) & king).any()) + { + return true; + } - const Bitboard bishopLikePieces = (bishops | queens); - const Bitboard bishopAttacks = bb::attacks(sq, occupied); - allAttackers |= bishopAttacks & bishopLikePieces; + if ((bb::pseudoAttacks(sq) & knights).any()) + { + return true; + } - const Bitboard rookLikePieces = (rooks | queens); - const Bitboard rookAttacks = bb::attacks(sq, occupied); - allAttackers |= rookAttacks & rookLikePieces; + const Bitboard pawnAttacks = bb::pawnAttacks(pawns, attackerColor); - const Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); - allAttackers |= bb::pseudoAttacks(sq) & king; + return pawnAttacks.isSet(sq); +} - const Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); - allAttackers |= bb::pseudoAttacks(sq) & knights; +[[nodiscard]] bool Board::createsDiscoveredAttackOnOwnKing(Move move) const { + Bitboard occupied = (piecesBB() ^ move.from) | move.to; - const Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); - allAttackers |= bb::pawnAttacks(Bitboard::square(sq), !attackerColor) & pawns; + const Piece movedPiece = pieceAt(move.from); + const Color kingColor = movedPiece.color(); + const Color attackerColor = !kingColor; + const Square ksq = kingSquare(kingColor); - return allAttackers; - } + Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - inline const Piece* Board::piecesRaw() const + if (move.type == MoveType::EnPassant) { - return m_pieces.data(); + const Square capturedPawnSq(move.to.file(), move.from.rank()); + occupied ^= capturedPawnSq; } - - namespace detail::lookup + else if (pieceAt(move.to) != Piece::none()) { - static constexpr EnumArray fenPiece = []() { - EnumArray fenPiece_{}; - - fenPiece_[whitePawn] = 'P'; - fenPiece_[blackPawn] = 'p'; - fenPiece_[whiteKnight] = 'N'; - fenPiece_[blackKnight] = 'n'; - fenPiece_[whiteBishop] = 'B'; - fenPiece_[blackBishop] = 'b'; - fenPiece_[whiteRook] = 'R'; - fenPiece_[blackRook] = 'r'; - fenPiece_[whiteQueen] = 'Q'; - fenPiece_[blackQueen] = 'q'; - fenPiece_[whiteKing] = 'K'; - fenPiece_[blackKing] = 'k'; - fenPiece_[Piece::none()] = 'X'; - - return fenPiece_; - }(); + const Bitboard notCaptured = ~Bitboard::square(move.to); + bishops &= notCaptured; + rooks &= notCaptured; + queens &= notCaptured; } - [[nodiscard]] inline std::string Board::fen() const + const Bitboard allSliders = (bishops | rooks | queens); + if ((bb::pseudoAttacks(ksq) & allSliders).any()) { - std::string fen; - fen.reserve(96); // longest fen is probably in range of around 88 - - Rank rank = rank8; - File file = fileA; - std::uint8_t emptyCounter = 0; - - for (;;) + if (bb::isAttackedBySlider(ksq, bishops, rooks, queens, occupied)) { - const Square sq(file, rank); - const Piece piece = m_pieces[sq]; - - if (piece == Piece::none()) - { - ++emptyCounter; - } - else - { - if (emptyCounter != 0) - { - fen.push_back(static_cast(emptyCounter) + '0'); - emptyCounter = 0; - } + return true; + } + } - fen.push_back(detail::lookup::fenPiece[piece]); - } + return false; +} - ++file; - if (file > fileH) - { - file = fileA; - --rank; +[[nodiscard]] bool Board::isPieceAttacked(Square sq) const { + const Piece piece = pieceAt(sq); - if (emptyCounter != 0) - { - fen.push_back(static_cast(emptyCounter) + '0'); - emptyCounter = 0; - } + if (piece == Piece::none()) + { + return false; + } - if (rank < rank1) - { - break; - } - fen.push_back('/'); - } - } + return isSquareAttacked(sq, !piece.color()); +} - return fen; - } +[[nodiscard]] bool Board::isPieceAttackedAfterMove(Move move, Square sq) const { + const Piece piece = pieceAt(sq); - void Position::set(std::string_view fen) + if (piece == Piece::none()) { - (void)trySet(fen); + return false; } - // Returns false if the fen was not valid - // If the returned value was false the position - // is in unspecified state. - [[nodiscard]] bool Position::trySet(std::string_view fen) + if (sq == move.from) { - // Lazily splits by ' '. Returns empty string views if at the end. - auto nextPart = [fen, start = std::size_t{ 0 }]() mutable { - std::size_t end = fen.find(' ', start); - if (end == std::string::npos) - { - std::string_view substr = fen.substr(start); - start = fen.size(); - return substr; - } - else - { - std::string_view substr = fen.substr(start, end - start); - start = end + 1; // to skip whitespace - return substr; - } - }; - - if (!BaseType::trySet(nextPart())) return false; - - { - const auto side = nextPart(); - if (side == std::string_view("w")) m_sideToMove = Color::White; - else if (side == std::string_view("b")) m_sideToMove = Color::Black; - else return false; - - if (isSquareAttacked(kingSquare(!m_sideToMove), m_sideToMove)) return false; - } - + // We moved the piece we're interested in. + // For every move the piece ends up on the move.to except + // for the case of castling moves. + // But we know pseudo legal castling moves + // are already legal, so the king cannot be in check after. + if (move.type == MoveType::Castle) { - const auto castlingRights = nextPart(); - auto castlingRightsOpt = parser_bits::tryParseCastlingRights(castlingRights); - if (!castlingRightsOpt.has_value()) - { - return false; - } - else - { - m_castlingRights = *castlingRightsOpt; - } + return false; } - { - const auto epSquare = nextPart(); - auto epSquareOpt = parser_bits::tryParseEpSquare(epSquare); - if (!epSquareOpt.has_value()) - { - return false; - } - else - { - m_epSquare = *epSquareOpt; - } - } + // So update the square we're interested in. + sq = move.to; + } - { - const auto rule50 = nextPart(); - if (!rule50.empty()) - { - m_rule50Counter = std::stoi(rule50.data()); - } - else - { - m_rule50Counter = 0; - } - } + return isSquareAttackedAfterMove(move, sq, !piece.color()); +} - { - const auto fullMove = nextPart(); - if (!fullMove.empty()) - { - m_ply = std::stoi(fullMove.data()) * 2 - (m_sideToMove == Color::White); - } - else - { - m_ply = 0; - } - } +[[nodiscard]] bool Board::isOwnKingAttackedAfterMove(Move move) const { + if (move.type == MoveType::Castle) + { + // Pseudo legal castling moves are already legal. + // This is ensured by the move generator. + return false; + } - nullifyEpSquareIfNotPossible(); + const Piece movedPiece = pieceAt(move.from); - return true; - } + return isPieceAttackedAfterMove(move, kingSquare(movedPiece.color())); +} - [[nodiscard]] Position Position::fromFen(std::string_view fen) +[[nodiscard]] Bitboard Board::attacks(Square sq) const { + const Piece piece = pieceAt(sq); + if (piece == Piece::none()) { - Position pos{}; - pos.set(fen); - return pos; + return Bitboard::none(); } - [[nodiscard]] std::optional Position::tryFromFen(std::string_view fen) + if (piece.type() == PieceType::Pawn) { - Position pos{}; - if (pos.trySet(fen)) return pos; - else return {}; + return bb::pawnAttacks(Bitboard::square(sq), piece.color()); } - - [[nodiscard]] Position Position::startPosition() + else { - static const Position pos = fromFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"); - return pos; + return bb::attacks(piece.type(), sq, piecesBB()); } +} - [[nodiscard]] std::string Position::fen() const - { - std::string fen = Board::fen(); +[[nodiscard]] Bitboard Board::attackers(Square sq, Color attackerColor) const { + // En-passant square is not included. - fen += ' '; - fen += m_sideToMove == Color::White ? 'w' : 'b'; + Bitboard allAttackers = Bitboard::none(); - fen += ' '; - parser_bits::appendCastlingRightsToString(m_castlingRights, fen); + const Bitboard occupied = piecesBB(); - fen += ' '; - parser_bits::appendEpSquareToString(m_epSquare, fen); + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); - fen += ' '; - fen += std::to_string(m_rule50Counter); + const Bitboard bishopLikePieces = (bishops | queens); + const Bitboard bishopAttacks = bb::attacks(sq, occupied); + allAttackers |= bishopAttacks & bishopLikePieces; - fen += ' '; - fen += std::to_string(fullMove()); + const Bitboard rookLikePieces = (rooks | queens); + const Bitboard rookAttacks = bb::attacks(sq, occupied); + allAttackers |= rookAttacks & rookLikePieces; - return fen; - } + const Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); + allAttackers |= bb::pseudoAttacks(sq) & king; + + const Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); + allAttackers |= bb::pseudoAttacks(sq) & knights; + + const Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); + allAttackers |= bb::pawnAttacks(Bitboard::square(sq), !attackerColor) & pawns; + + return allAttackers; +} + +inline const Piece* Board::piecesRaw() const { return m_pieces.data(); } + +namespace detail::lookup { +static constexpr EnumArray fenPiece = []() { + EnumArray fenPiece_{}; + + fenPiece_[whitePawn] = 'P'; + fenPiece_[blackPawn] = 'p'; + fenPiece_[whiteKnight] = 'N'; + fenPiece_[blackKnight] = 'n'; + fenPiece_[whiteBishop] = 'B'; + fenPiece_[blackBishop] = 'b'; + fenPiece_[whiteRook] = 'R'; + fenPiece_[blackRook] = 'r'; + fenPiece_[whiteQueen] = 'Q'; + fenPiece_[blackQueen] = 'q'; + fenPiece_[whiteKing] = 'K'; + fenPiece_[blackKing] = 'k'; + fenPiece_[Piece::none()] = 'X'; + + return fenPiece_; +}(); +} + +[[nodiscard]] inline std::string Board::fen() const { + std::string fen; + fen.reserve(96); // longest fen is probably in range of around 88 - namespace detail::lookup + Rank rank = rank8; + File file = fileA; + std::uint8_t emptyCounter = 0; + + for (;;) { - static constexpr EnumArray preservedCastlingRights = []() { - EnumArray preservedCastlingRights_{}; - for (CastlingRights& rights : preservedCastlingRights_) + const Square sq(file, rank); + const Piece piece = m_pieces[sq]; + + if (piece == Piece::none()) + { + ++emptyCounter; + } + else + { + if (emptyCounter != 0) { - rights = ~CastlingRights::None; + fen.push_back(static_cast(emptyCounter) + '0'); + emptyCounter = 0; } - preservedCastlingRights_[e1] = ~CastlingRights::White; - preservedCastlingRights_[e8] = ~CastlingRights::Black; + fen.push_back(detail::lookup::fenPiece[piece]); + } + + ++file; + if (file > fileH) + { + file = fileA; + --rank; - preservedCastlingRights_[h1] = ~CastlingRights::WhiteKingSide; - preservedCastlingRights_[a1] = ~CastlingRights::WhiteQueenSide; - preservedCastlingRights_[h8] = ~CastlingRights::BlackKingSide; - preservedCastlingRights_[a8] = ~CastlingRights::BlackQueenSide; + if (emptyCounter != 0) + { + fen.push_back(static_cast(emptyCounter) + '0'); + emptyCounter = 0; + } - return preservedCastlingRights_; - }(); + if (rank < rank1) + { + break; + } + fen.push_back('/'); + } } - inline ReverseMove Position::doMove(const Move& move) - { - assert(move.from.isOk() && move.to.isOk()); - - const PieceType movedPiece = pieceAt(move.from).type(); + return fen; +} - m_ply += 1; - m_rule50Counter += 1; +void Position::set(std::string_view fen) { (void) trySet(fen); } - if (move.type != MoveType::Castle && (movedPiece == PieceType::Pawn || pieceAt(move.to) != Piece::none())) +// Returns false if the fen was not valid +// If the returned value was false the position +// is in unspecified state. +[[nodiscard]] bool Position::trySet(std::string_view fen) { + // Lazily splits by ' '. Returns empty string views if at the end. + auto nextPart = [fen, start = std::size_t{0}]() mutable { + std::size_t end = fen.find(' ', start); + if (end == std::string::npos) { - m_rule50Counter = 0; + std::string_view substr = fen.substr(start); + start = fen.size(); + return substr; } - - const Square oldEpSquare = m_epSquare; - const CastlingRights oldCastlingRights = m_castlingRights; - m_castlingRights &= detail::lookup::preservedCastlingRights[move.from]; - m_castlingRights &= detail::lookup::preservedCastlingRights[move.to]; - - m_epSquare = Square::none(); - // for double pushes move index differs by 16 or -16; - if((movedPiece == PieceType::Pawn) & ((ordinal(move.to) ^ ordinal(move.from)) == 16)) + else { - m_epSquare = fromOrdinal((ordinal(move.to) + ordinal(move.from)) >> 1); + std::string_view substr = fen.substr(start, end - start); + start = end + 1; // to skip whitespace + return substr; } + }; - const Piece captured = BaseType::doMove(move); - m_sideToMove = !m_sideToMove; - - nullifyEpSquareIfNotPossible(); - - return { move, captured, oldEpSquare, oldCastlingRights }; - } + if (!BaseType::trySet(nextPart())) + return false; - [[nodiscard]] inline Position Position::afterMove(Move move) const { - Position cpy(*this); - auto pc = cpy.doMove(move); - - (void)pc; - //assert(cpy.beforeMove(move, pc) == *this); // this assert would result in infinite recursion + const auto side = nextPart(); + if (side == std::string_view("w")) + m_sideToMove = Color::White; + else if (side == std::string_view("b")) + m_sideToMove = Color::Black; + else + return false; - return cpy; + if (isSquareAttacked(kingSquare(!m_sideToMove), m_sideToMove)) + return false; } - [[nodiscard]] inline bool Position::isEpPossible(Square epSquare, Color sideToMove) const { - const Bitboard pawnsAttackingEpSquare = - bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove) - & piecesBB(Piece(PieceType::Pawn, sideToMove)); - - if (!pawnsAttackingEpSquare.any()) + const auto castlingRights = nextPart(); + auto castlingRightsOpt = parser_bits::tryParseCastlingRights(castlingRights); + if (!castlingRightsOpt.has_value()) { return false; } - - return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove); + else + { + m_castlingRights = *castlingRightsOpt; + } } - [[nodiscard]] inline bool Position::isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const { - if (pieceAt(epSquare) != Piece::none()) + const auto epSquare = nextPart(); + auto epSquareOpt = parser_bits::tryParseEpSquare(epSquare); + if (!epSquareOpt.has_value()) { return false; } - - const auto forward = - sideToMove == chess::Color::White - ? FlatSquareOffset(0, 1) - : FlatSquareOffset(0, -1); - - if (pieceAt(epSquare + forward) != Piece::none()) + else { - return false; + m_epSquare = *epSquareOpt; } + } - if (pieceAt(epSquare + -forward) != Piece(PieceType::Pawn, !sideToMove)) + { + const auto rule50 = nextPart(); + if (!rule50.empty()) { - return false; + m_rule50Counter = std::stoi(rule50.data()); } - - // only set m_epSquare when it matters, ie. when - // the opposite side can actually capture - for (Square sq : pawnsAttackingEpSquare) + else { - // If we're here the previous move by other side - // was a double pawn move so our king is either not in check - // or is attacked only by the moved pawn - in which - // case it can be captured by our pawn if it doesn't - // create a discovered check on our king. - // So overall we only have to check whether our king - // ends up being uncovered to a slider attack. - - const Square ksq = kingSquare(sideToMove); - - const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, !sideToMove)); - const Bitboard rooks = piecesBB(Piece(PieceType::Rook, !sideToMove)); - const Bitboard queens = piecesBB(Piece(PieceType::Queen, !sideToMove)); - - const Bitboard relevantAttackers = bishops | rooks | queens; - const Bitboard pseudoSliderAttacksFromKing = bb::pseudoAttacks(ksq); - if ((relevantAttackers & pseudoSliderAttacksFromKing).isEmpty()) - { - // It's enough that one pawn can capture. - return true; - } - - const Square capturedPawnSq(epSquare.file(), sq.rank()); - const Bitboard occupied = ((piecesBB() ^ sq) | epSquare) ^ capturedPawnSq; - - if (!bb::isAttackedBySlider( - ksq, - bishops, - rooks, - queens, - occupied - )) - { - // It's enough that one pawn can capture. - return true; - } + m_rule50Counter = 0; } - - return false; } - inline void Position::nullifyEpSquareIfNotPossible() { - if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove)) + const auto fullMove = nextPart(); + if (!fullMove.empty()) + { + m_ply = std::stoi(fullMove.data()) * 2 - (m_sideToMove == Color::White); + } + else { - m_epSquare = Square::none(); + m_ply = 0; } } - namespace uci - { - [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move); - [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv); - - [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move) - { - std::string s; + nullifyEpSquareIfNotPossible(); - parser_bits::appendSquareToString(move.from, s); + return true; +} - if (move.type == MoveType::Castle) - { - const CastleType castleType = CastlingTraits::moveCastlingType(move); +[[nodiscard]] Position Position::fromFen(std::string_view fen) { + Position pos{}; + pos.set(fen); + return pos; +} - const Square kingDestination = CastlingTraits::kingDestination[pos.sideToMove()][castleType]; - parser_bits::appendSquareToString(kingDestination, s); - } - else - { - parser_bits::appendSquareToString(move.to, s); +[[nodiscard]] std::optional Position::tryFromFen(std::string_view fen) { + Position pos{}; + if (pos.trySet(fen)) + return pos; + else + return {}; +} - if (move.type == MoveType::Promotion) - { - // lowercase piece symbol - s += EnumTraits::toChar(move.promotedPiece.type(), Color::Black); - } - } +[[nodiscard]] Position Position::startPosition() { + static const Position pos = fromFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"); + return pos; +} - return s; - } +[[nodiscard]] std::string Position::fen() const { + std::string fen = Board::fen(); - [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv) - { - const Square from = parser_bits::parseSquare(sv.data()); - const Square to = parser_bits::parseSquare(sv.data() + 2); + fen += ' '; + fen += m_sideToMove == Color::White ? 'w' : 'b'; - if (sv.size() == 5) - { - const PieceType promotedPieceType = *fromChar(sv[4]); - return Move::promotion(from, to, Piece(promotedPieceType, pos.sideToMove())); - } - else - { - if ( - pos.pieceAt(from).type() == PieceType::King - && std::abs(from.file() - to.file()) > 1 - ) - { - // uci king destinations are on files C or G. - const CastleType castleType = - (to.file() == fileG) - ? CastleType::Short - : CastleType::Long; + fen += ' '; + parser_bits::appendCastlingRightsToString(m_castlingRights, fen); - return Move::castle(castleType, pos.sideToMove()); - } - else if (pos.pieceAt(from).type() == PieceType::Pawn && pos.epSquare() == to) - { - return Move::enPassant(from, to); - } - else - { - return Move::normal(from, to); - } - } - } - } -} + fen += ' '; + parser_bits::appendEpSquareToString(m_epSquare, fen); -namespace binpack -{ - constexpr std::size_t KiB = 1024; - constexpr std::size_t MiB = (1024*KiB); - constexpr std::size_t GiB = (1024*MiB); + fen += ' '; + fen += std::to_string(m_rule50Counter); - constexpr std::size_t suggestedChunkSize = MiB; - constexpr std::size_t maxMovelistSize = 10*KiB; // a safe upper bound - constexpr std::size_t maxChunkSize = 100*MiB; // to prevent malformed files from causing huge allocations + fen += ' '; + fen += std::to_string(fullMove()); - using namespace std::literals; + return fen; +} - namespace nodchip +namespace detail::lookup { +static constexpr EnumArray preservedCastlingRights = []() { + EnumArray preservedCastlingRights_{}; + for (CastlingRights& rights : preservedCastlingRights_) { - // This namespace contains modified code from https://github.com/nodchip/Stockfish - // which is released under GPL v3 license https://www.gnu.org/licenses/gpl-3.0.html + rights = ~CastlingRights::None; + } - using namespace std; + preservedCastlingRights_[e1] = ~CastlingRights::White; + preservedCastlingRights_[e8] = ~CastlingRights::Black; - struct StockfishMove - { - [[nodiscard]] static StockfishMove fromMove(chess::Move move) - { - StockfishMove sfm; + preservedCastlingRights_[h1] = ~CastlingRights::WhiteKingSide; + preservedCastlingRights_[a1] = ~CastlingRights::WhiteQueenSide; + preservedCastlingRights_[h8] = ~CastlingRights::BlackKingSide; + preservedCastlingRights_[a8] = ~CastlingRights::BlackQueenSide; - sfm.m_raw = 0; + return preservedCastlingRights_; +}(); +} - unsigned moveFlag = 0; - if (move.type == chess::MoveType::Promotion) moveFlag = 1; - else if (move.type == chess::MoveType::EnPassant) moveFlag = 2; - else if (move.type == chess::MoveType::Castle) moveFlag = 3; +inline ReverseMove Position::doMove(const Move& move) { + assert(move.from.isOk() && move.to.isOk()); - unsigned promotionIndex = 0; - if (move.type == chess::MoveType::Promotion) - { - promotionIndex = static_cast(move.promotedPiece.type()) - static_cast(chess::PieceType::Knight); - } + const PieceType movedPiece = pieceAt(move.from).type(); - sfm.m_raw |= static_cast(moveFlag); - sfm.m_raw <<= 2; - sfm.m_raw |= static_cast(promotionIndex); - sfm.m_raw <<= 6; - sfm.m_raw |= static_cast(move.from); - sfm.m_raw <<= 6; - sfm.m_raw |= static_cast(move.to); + m_ply += 1; + m_rule50Counter += 1; - return sfm; - } + if (move.type != MoveType::Castle + && (movedPiece == PieceType::Pawn || pieceAt(move.to) != Piece::none())) + { + m_rule50Counter = 0; + } - [[nodiscard]] chess::Move toMove() const - { - const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); - const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); + const Square oldEpSquare = m_epSquare; + const CastlingRights oldCastlingRights = m_castlingRights; + m_castlingRights &= detail::lookup::preservedCastlingRights[move.from]; + m_castlingRights &= detail::lookup::preservedCastlingRights[move.to]; - const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; - const chess::PieceType promotionType = static_cast(static_cast(chess::PieceType::Knight) + promotionIndex); + m_epSquare = Square::none(); + // for double pushes move index differs by 16 or -16; + if ((movedPiece == PieceType::Pawn) & ((ordinal(move.to) ^ ordinal(move.from)) == 16)) + { + m_epSquare = fromOrdinal((ordinal(move.to) + ordinal(move.from)) >> 1); + } - const unsigned moveFlag = (m_raw & (0b11 << 14)) >> 14; - chess::MoveType type = chess::MoveType::Normal; - if (moveFlag == 1) type = chess::MoveType::Promotion; - else if (moveFlag == 2) type = chess::MoveType::EnPassant; - else if (moveFlag == 3) type = chess::MoveType::Castle; + const Piece captured = BaseType::doMove(move); + m_sideToMove = !m_sideToMove; - if (type == chess::MoveType::Promotion) - { - const chess::Color stm = to.rank() == chess::rank8 ? chess::Color::White : chess::Color::Black; - return chess::Move{from, to, type, chess::Piece(promotionType, stm)}; - } + nullifyEpSquareIfNotPossible(); - return chess::Move{from, to, type}; - } + return {move, captured, oldEpSquare, oldCastlingRights}; +} - [[nodiscard]] std::string toString() const - { - const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); - const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); +[[nodiscard]] inline Position Position::afterMove(Move move) const { + Position cpy(*this); + auto pc = cpy.doMove(move); - const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; - const chess::PieceType promotionType = static_cast(static_cast(chess::PieceType::Knight) + promotionIndex); + (void) pc; + //assert(cpy.beforeMove(move, pc) == *this); // this assert would result in infinite recursion - std::string r; - chess::parser_bits::appendSquareToString(from, r); - chess::parser_bits::appendSquareToString(to, r); - if (promotionType != chess::PieceType::None) - { - r += chess::EnumTraits::toChar(promotionType, chess::Color::Black); - } + return cpy; +} - return r; - } +[[nodiscard]] inline bool Position::isEpPossible(Square epSquare, Color sideToMove) const { + const Bitboard pawnsAttackingEpSquare = bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove) + & piecesBB(Piece(PieceType::Pawn, sideToMove)); - private: - std::uint16_t m_raw; - }; - static_assert(sizeof(StockfishMove) == sizeof(std::uint16_t)); + if (!pawnsAttackingEpSquare.any()) + { + return false; + } - struct PackedSfen - { - uint8_t data[32]; - }; + return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove); +} - struct PackedSfenValue - { - // phase - PackedSfen sfen; +[[nodiscard]] inline bool Position::isEpPossibleColdPath(Square epSquare, + Bitboard pawnsAttackingEpSquare, + Color sideToMove) const { + if (pieceAt(epSquare) != Piece::none()) + { + return false; + } - // Evaluation value returned from Learner::search() - int16_t score; + const auto forward = + sideToMove == chess::Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); - // PV first move - // Used when finding the match rate with the teacher - StockfishMove move; + if (pieceAt(epSquare + forward) != Piece::none()) + { + return false; + } - // Trouble of the phase from the initial phase. - uint16_t gamePly; + if (pieceAt(epSquare + -forward) != Piece(PieceType::Pawn, !sideToMove)) + { + return false; + } - // 1 if the player on this side ultimately wins the game. -1 if you are losing. - // 0 if a draw is reached. - // The draw is in the teacher position generation command gensfen, - // Only write if LEARN_GENSFEN_DRAW_RESULT is enabled. - int8_t game_result; + // only set m_epSquare when it matters, ie. when + // the opposite side can actually capture + for (Square sq : pawnsAttackingEpSquare) + { + // If we're here the previous move by other side + // was a double pawn move so our king is either not in check + // or is attacked only by the moved pawn - in which + // case it can be captured by our pawn if it doesn't + // create a discovered check on our king. + // So overall we only have to check whether our king + // ends up being uncovered to a slider attack. - // When exchanging the file that wrote the teacher aspect with other people - //Because this structure size is not fixed, pad it so that it is 40 bytes in any environment. - uint8_t padding; + const Square ksq = kingSquare(sideToMove); - // 32 + 2 + 2 + 2 + 1 + 1 = 40bytes - }; - static_assert(sizeof(PackedSfenValue) == 40); - // Class that handles bitstream + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, !sideToMove)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, !sideToMove)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, !sideToMove)); - // useful when doing aspect encoding - struct BitStream + const Bitboard relevantAttackers = bishops | rooks | queens; + const Bitboard pseudoSliderAttacksFromKing = bb::pseudoAttacks(ksq); + if ((relevantAttackers & pseudoSliderAttacksFromKing).isEmpty()) { - // Set the memory to store the data in advance. - // Assume that memory is cleared to 0. - void set_data(uint8_t* data_) { data = data_; reset(); } - - // Get the pointer passed in set_data(). - uint8_t* get_data() const { return data; } - - // Get the cursor. - int get_cursor() const { return bit_cursor; } + // It's enough that one pawn can capture. + return true; + } - // reset the cursor - void reset() { bit_cursor = 0; } + const Square capturedPawnSq(epSquare.file(), sq.rank()); + const Bitboard occupied = ((piecesBB() ^ sq) | epSquare) ^ capturedPawnSq; - // Write 1bit to the stream. - // If b is non-zero, write out 1. If 0, write 0. - void write_one_bit(int b) - { - if (b) - data[bit_cursor / 8] |= 1 << (bit_cursor & 7); + if (!bb::isAttackedBySlider(ksq, bishops, rooks, queens, occupied)) + { + // It's enough that one pawn can capture. + return true; + } + } - ++bit_cursor; - } + return false; +} - // Get 1 bit from the stream. - int read_one_bit() - { - int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1; - ++bit_cursor; +inline void Position::nullifyEpSquareIfNotPossible() { + if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove)) + { + m_epSquare = Square::none(); + } +} - return b; - } +namespace uci { +[[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move); +[[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv); - // write n bits of data - // Data shall be written out from the lower order of d. - void write_n_bit(int d, int n) - { - for (int i = 0; i ::toChar(move.promotedPiece.type(), Color::Black); + } + } + return s; +} - // Huffman coding - // * is simplified from mini encoding to make conversion easier. - // - // Huffman Encoding - // - // Empty xxxxxxx0 - // Pawn xxxxx001 + 1 bit (Color) - // Knight xxxxx011 + 1 bit (Color) - // Bishop xxxxx101 + 1 bit (Color) - // Rook xxxxx111 + 1 bit (Color) - // Queen xxxx1001 + 1 bit (Color) - // - // Worst case: - // - 32 empty squares 32 bits - // - 30 pieces 150 bits - // - 2 kings 12 bits - // - castling rights 4 bits - // - ep square 7 bits - // - rule50 7 bits - // - game ply 16 bits - // - TOTAL 228 bits < 256 bits - - struct HuffmanedPiece - { - int code; // how it will be coded - int bits; // How many bits do you have - }; +[[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv) { + const Square from = parser_bits::parseSquare(sv.data()); + const Square to = parser_bits::parseSquare(sv.data() + 2); - // NOTE: Order adjusted for this library because originally NO_PIECE had index 0 - constexpr HuffmanedPiece huffman_table[] = + if (sv.size() == 5) + { + const PieceType promotedPieceType = *fromChar(sv[4]); + return Move::promotion(from, to, Piece(promotedPieceType, pos.sideToMove())); + } + else + { + if (pos.pieceAt(from).type() == PieceType::King && std::abs(from.file() - to.file()) > 1) { - {0b0001,4}, // PAWN 1 - {0b0011,4}, // KNIGHT 3 - {0b0101,4}, // BISHOP 5 - {0b0111,4}, // ROOK 7 - {0b1001,4}, // QUEEN 9 - {-1,-1}, // KING - unused - {0b0000,1}, // NO_PIECE 0 - }; + // uci king destinations are on files C or G. + const CastleType castleType = + (to.file() == fileG) ? CastleType::Short : CastleType::Long; - // Class for compressing/decompressing sfen - // sfen can be packed to 256bit (32bytes) by Huffman coding. - // This is proven by mini. The above is Huffman coding. - // - // Internal format = 1-bit turn + 7-bit king position *2 + piece on board (Huffman coding) + hand piece (Huffman coding) - // Side to move (White = 0, Black = 1) (1bit) - // White King Position (6 bits) - // Black King Position (6 bits) - // Huffman Encoding of the board - // Castling availability (1 bit x 4) - // En passant square (1 or 1 + 6 bits) - // Rule 50 (6 bits) - // Game play (8 bits) - // - // TODO(someone): Rename SFEN to FEN. - // - struct SfenPacker - { - // Pack sfen and store in data[32]. - void pack(const chess::Position& pos) - { - memset(data, 0, 32 /* 256bit */); - stream.set_data(data); + return Move::castle(castleType, pos.sideToMove()); + } + else if (pos.pieceAt(from).type() == PieceType::Pawn && pos.epSquare() == to) + { + return Move::enPassant(from, to); + } + else + { + return Move::normal(from, to); + } + } +} +} +} - // turn - // Side to move. - stream.write_one_bit((int)(pos.sideToMove())); +namespace binpack { +constexpr std::size_t KiB = 1024; +constexpr std::size_t MiB = (1024 * KiB); +constexpr std::size_t GiB = (1024 * MiB); - // 7-bit positions for leading and trailing balls - // White king and black king, 6 bits for each. - stream.write_n_bit(static_cast(pos.kingSquare(chess::Color::White)), 6); - stream.write_n_bit(static_cast(pos.kingSquare(chess::Color::Black)), 6); +constexpr std::size_t suggestedChunkSize = MiB; +constexpr std::size_t maxMovelistSize = 10 * KiB; // a safe upper bound +constexpr std::size_t maxChunkSize = + 100 * MiB; // to prevent malformed files from causing huge allocations - // Write the pieces on the board other than the kings. - for (chess::Rank r = chess::rank8; r >= chess::rank1; --r) - { - for (chess::File f = chess::fileA; f <= chess::fileH; ++f) - { - chess::Piece pc = pos.pieceAt(chess::Square(f, r)); - if (pc.type() == chess::PieceType::King) - continue; - write_board_piece_to_stream(pc); - } - } +using namespace std::literals; - // TODO(someone): Support chess960. - auto cr = pos.castlingRights(); - stream.write_one_bit(contains(cr, chess::CastlingRights::WhiteKingSide)); - stream.write_one_bit(contains(cr, chess::CastlingRights::WhiteQueenSide)); - stream.write_one_bit(contains(cr, chess::CastlingRights::BlackKingSide)); - stream.write_one_bit(contains(cr, chess::CastlingRights::BlackQueenSide)); +namespace nodchip { +// This namespace contains modified code from https://github.com/nodchip/Stockfish +// which is released under GPL v3 license https://www.gnu.org/licenses/gpl-3.0.html - if (pos.epSquare() == chess::Square::none()) { - stream.write_one_bit(0); - } - else { - stream.write_one_bit(1); - stream.write_n_bit(static_cast(pos.epSquare()), 6); - } +using namespace std; - stream.write_n_bit(pos.rule50Counter(), 6); +struct StockfishMove { + [[nodiscard]] static StockfishMove fromMove(chess::Move move) { + StockfishMove sfm; - stream.write_n_bit(pos.fullMove(), 8); + sfm.m_raw = 0; - // Write high bits of half move. This is a fix for the - // limited range of half move counter. - // This is backwards compatibile. - stream.write_n_bit(pos.fullMove() >> 8, 8); + unsigned moveFlag = 0; + if (move.type == chess::MoveType::Promotion) + moveFlag = 1; + else if (move.type == chess::MoveType::EnPassant) + moveFlag = 2; + else if (move.type == chess::MoveType::Castle) + moveFlag = 3; - // Write the highest bit of rule50 at the end. This is a backwards - // compatibile fix for rule50 having only 6 bits stored. - // This bit is just ignored by the old parsers. - stream.write_n_bit(pos.rule50Counter() >> 6, 1); + unsigned promotionIndex = 0; + if (move.type == chess::MoveType::Promotion) + { + promotionIndex = static_cast(move.promotedPiece.type()) + - static_cast(chess::PieceType::Knight); + } - assert(stream.get_cursor() <= 256); - } + sfm.m_raw |= static_cast(moveFlag); + sfm.m_raw <<= 2; + sfm.m_raw |= static_cast(promotionIndex); + sfm.m_raw <<= 6; + sfm.m_raw |= static_cast(move.from); + sfm.m_raw <<= 6; + sfm.m_raw |= static_cast(move.to); - // sfen packed by pack() (256bit = 32bytes) - // Or sfen to decode with unpack() - uint8_t *data; // uint8_t[32]; + return sfm; + } - BitStream stream; + [[nodiscard]] chess::Move toMove() const { + const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); + const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); - // Output the board pieces to stream. - void write_board_piece_to_stream(chess::Piece pc) - { - // piece type - chess::PieceType pr = pc.type(); - auto c = huffman_table[static_cast(pr)]; - stream.write_n_bit(c.code, c.bits); + const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; + const chess::PieceType promotionType = static_cast( + static_cast(chess::PieceType::Knight) + promotionIndex); - if (pc == chess::Piece::none()) - return; + const unsigned moveFlag = (m_raw & (0b11 << 14)) >> 14; + chess::MoveType type = chess::MoveType::Normal; + if (moveFlag == 1) + type = chess::MoveType::Promotion; + else if (moveFlag == 2) + type = chess::MoveType::EnPassant; + else if (moveFlag == 3) + type = chess::MoveType::Castle; - // first and second flag - stream.write_one_bit(static_cast(pc.color())); - } + if (type == chess::MoveType::Promotion) + { + const chess::Color stm = + to.rank() == chess::rank8 ? chess::Color::White : chess::Color::Black; + return chess::Move{from, to, type, chess::Piece(promotionType, stm)}; + } - // Read one board piece from stream - [[nodiscard]] chess::Piece read_board_piece_from_stream() - { - int pr = static_cast(chess::PieceType::None); - int code = 0, bits = 0; - while (true) - { - code |= stream.read_one_bit() << bits; - ++bits; + return chess::Move{from, to, type}; + } - assert(bits <= 6); + [[nodiscard]] std::string toString() const { + const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); + const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); - for (pr = static_cast(chess::PieceType::Pawn); pr <= static_cast(chess::PieceType::None); ++pr) - if (huffman_table[pr].code == code - && huffman_table[pr].bits == bits) - goto Found; - } - Found:; - if (pr == static_cast(chess::PieceType::None)) - return chess::Piece::none(); + const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; + const chess::PieceType promotionType = static_cast( + static_cast(chess::PieceType::Knight) + promotionIndex); - // first and second flag - chess::Color c = (chess::Color)stream.read_one_bit(); + std::string r; + chess::parser_bits::appendSquareToString(from, r); + chess::parser_bits::appendSquareToString(to, r); + if (promotionType != chess::PieceType::None) + { + r += chess::EnumTraits::toChar(promotionType, chess::Color::Black); + } - return chess::Piece(static_cast(pr), c); - } - }; + return r; + } + private: + std::uint16_t m_raw; +}; +static_assert(sizeof(StockfishMove) == sizeof(std::uint16_t)); + +struct PackedSfen { + uint8_t data[32]; +}; + +struct PackedSfenValue { + // phase + PackedSfen sfen; + + // Evaluation value returned from Learner::search() + int16_t score; + + // PV first move + // Used when finding the match rate with the teacher + StockfishMove move; + + // Trouble of the phase from the initial phase. + uint16_t gamePly; + + // 1 if the player on this side ultimately wins the game. -1 if you are losing. + // 0 if a draw is reached. + // The draw is in the teacher position generation command gensfen, + // Only write if LEARN_GENSFEN_DRAW_RESULT is enabled. + int8_t game_result; + + // When exchanging the file that wrote the teacher aspect with other people + //Because this structure size is not fixed, pad it so that it is 40 bytes in any environment. + uint8_t padding; + + // 32 + 2 + 2 + 2 + 1 + 1 = 40bytes +}; +static_assert(sizeof(PackedSfenValue) == 40); +// Class that handles bitstream + +// useful when doing aspect encoding +struct BitStream { + // Set the memory to store the data in advance. + // Assume that memory is cleared to 0. + void set_data(uint8_t* data_) { + data = data_; + reset(); + } - [[nodiscard]] inline chess::Position pos_from_packed_sfen(const PackedSfen& sfen) - { - SfenPacker packer; - auto& stream = packer.stream; - stream.set_data(const_cast(reinterpret_cast(&sfen))); + // Get the pointer passed in set_data(). + uint8_t* get_data() const { return data; } - chess::Position pos{}; + // Get the cursor. + int get_cursor() const { return bit_cursor; } - // Active color - pos.setSideToMove((chess::Color)stream.read_one_bit()); + // reset the cursor + void reset() { bit_cursor = 0; } - // First the position of the ball - pos.place(chess::Piece(chess::PieceType::King, chess::Color::White), static_cast(stream.read_n_bit(6))); - pos.place(chess::Piece(chess::PieceType::King, chess::Color::Black), static_cast(stream.read_n_bit(6))); + // Write 1bit to the stream. + // If b is non-zero, write out 1. If 0, write 0. + void write_one_bit(int b) { + if (b) + data[bit_cursor / 8] |= 1 << (bit_cursor & 7); - // Piece placement - for (chess::Rank r = chess::rank8; r >= chess::rank1; --r) - { - for (chess::File f = chess::fileA; f <= chess::fileH; ++f) - { - auto sq = chess::Square(f, r); + ++bit_cursor; + } - // it seems there are already balls - chess::Piece pc; - if (pos.pieceAt(sq).type() != chess::PieceType::King) - { - assert(pos.pieceAt(sq) == chess::Piece::none()); - pc = packer.read_board_piece_from_stream(); - } - else - { - pc = pos.pieceAt(sq); - } + // Get 1 bit from the stream. + int read_one_bit() { + int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1; + ++bit_cursor; - // There may be no pieces, so skip in that case. - if (pc == chess::Piece::none()) - continue; + return b; + } - if (pc.type() != chess::PieceType::King) - { - pos.place(pc, sq); - } + // write n bits of data + // Data shall be written out from the lower order of d. + void write_n_bit(int d, int n) { + for (int i = 0; i < n; ++i) + write_one_bit(d & (1 << i)); + } - assert(stream.get_cursor() <= 256); - } - } + // read n bits of data + // Reverse conversion of write_n_bit(). + int read_n_bit(int n) { + int result = 0; + for (int i = 0; i < n; ++i) + result |= read_one_bit() ? (1 << i) : 0; - // Castling availability. - chess::CastlingRights cr = chess::CastlingRights::None; - if (stream.read_one_bit()) { - cr |= chess::CastlingRights::WhiteKingSide; - } - if (stream.read_one_bit()) { - cr |= chess::CastlingRights::WhiteQueenSide; - } - if (stream.read_one_bit()) { - cr |= chess::CastlingRights::BlackKingSide; - } - if (stream.read_one_bit()) { - cr |= chess::CastlingRights::BlackQueenSide; - } - pos.setCastlingRights(cr); + return result; + } - // En passant square. Ignore if no pawn capture is possible - if (stream.read_one_bit()) { - chess::Square ep_square = static_cast(stream.read_n_bit(6)); - pos.setEpSquare(ep_square); + private: + // Next bit position to read/write. + int bit_cursor; + + // data entity + uint8_t* data; +}; + + +// Huffman coding +// * is simplified from mini encoding to make conversion easier. +// +// Huffman Encoding +// +// Empty xxxxxxx0 +// Pawn xxxxx001 + 1 bit (Color) +// Knight xxxxx011 + 1 bit (Color) +// Bishop xxxxx101 + 1 bit (Color) +// Rook xxxxx111 + 1 bit (Color) +// Queen xxxx1001 + 1 bit (Color) +// +// Worst case: +// - 32 empty squares 32 bits +// - 30 pieces 150 bits +// - 2 kings 12 bits +// - castling rights 4 bits +// - ep square 7 bits +// - rule50 7 bits +// - game ply 16 bits +// - TOTAL 228 bits < 256 bits + +struct HuffmanedPiece { + int code; // how it will be coded + int bits; // How many bits do you have +}; + +// NOTE: Order adjusted for this library because originally NO_PIECE had index 0 +constexpr HuffmanedPiece huffman_table[] = { + {0b0001, 4}, // PAWN 1 + {0b0011, 4}, // KNIGHT 3 + {0b0101, 4}, // BISHOP 5 + {0b0111, 4}, // ROOK 7 + {0b1001, 4}, // QUEEN 9 + {-1, -1}, // KING - unused + {0b0000, 1}, // NO_PIECE 0 +}; + +// Class for compressing/decompressing sfen +// sfen can be packed to 256bit (32bytes) by Huffman coding. +// This is proven by mini. The above is Huffman coding. +// +// Internal format = 1-bit turn + 7-bit king position *2 + piece on board (Huffman coding) + hand piece (Huffman coding) +// Side to move (White = 0, Black = 1) (1bit) +// White King Position (6 bits) +// Black King Position (6 bits) +// Huffman Encoding of the board +// Castling availability (1 bit x 4) +// En passant square (1 or 1 + 6 bits) +// Rule 50 (6 bits) +// Game play (8 bits) +// +// TODO(someone): Rename SFEN to FEN. +// +struct SfenPacker { + // Pack sfen and store in data[32]. + void pack(const chess::Position& pos) { + memset(data, 0, 32 /* 256bit */); + stream.set_data(data); + + // turn + // Side to move. + stream.write_one_bit((int) (pos.sideToMove())); + + // 7-bit positions for leading and trailing balls + // White king and black king, 6 bits for each. + stream.write_n_bit(static_cast(pos.kingSquare(chess::Color::White)), 6); + stream.write_n_bit(static_cast(pos.kingSquare(chess::Color::Black)), 6); + + // Write the pieces on the board other than the kings. + for (chess::Rank r = chess::rank8; r >= chess::rank1; --r) + { + for (chess::File f = chess::fileA; f <= chess::fileH; ++f) + { + chess::Piece pc = pos.pieceAt(chess::Square(f, r)); + if (pc.type() == chess::PieceType::King) + continue; + write_board_piece_to_stream(pc); } + } - // Halfmove clock - std::uint8_t rule50 = stream.read_n_bit(6); + // TODO(someone): Support chess960. + auto cr = pos.castlingRights(); + stream.write_one_bit(contains(cr, chess::CastlingRights::WhiteKingSide)); + stream.write_one_bit(contains(cr, chess::CastlingRights::WhiteQueenSide)); + stream.write_one_bit(contains(cr, chess::CastlingRights::BlackKingSide)); + stream.write_one_bit(contains(cr, chess::CastlingRights::BlackQueenSide)); - // Fullmove number - std::uint16_t fullmove = stream.read_n_bit(8); + if (pos.epSquare() == chess::Square::none()) + { + stream.write_one_bit(0); + } + else + { + stream.write_one_bit(1); + stream.write_n_bit(static_cast(pos.epSquare()), 6); + } - // Fullmove number, high bits - // This was added as a fix for fullmove clock - // overflowing at 256. This change is backwards compatibile. - fullmove |= stream.read_n_bit(8) << 8; + stream.write_n_bit(pos.rule50Counter(), 6); - // Read the highest bit of rule50. This was added as a fix for rule50 - // counter having only 6 bits stored. - // In older entries this will just be a zero bit. - rule50 |= stream.read_n_bit(1) << 6; + stream.write_n_bit(pos.fullMove(), 8); - pos.setFullMove(fullmove); - pos.setRule50Counter(rule50); + // Write high bits of half move. This is a fix for the + // limited range of half move counter. + // This is backwards compatibile. + stream.write_n_bit(pos.fullMove() >> 8, 8); - assert(stream.get_cursor() <= 256); + // Write the highest bit of rule50 at the end. This is a backwards + // compatibile fix for rule50 having only 6 bits stored. + // This bit is just ignored by the old parsers. + stream.write_n_bit(pos.rule50Counter() >> 6, 1); - return pos; - } + assert(stream.get_cursor() <= 256); } - inline std::ifstream::pos_type filesize(const char* filename) - { - std::ifstream in(filename, std::ifstream::ate | std::ifstream::binary); - return in.tellg(); + // sfen packed by pack() (256bit = 32bytes) + // Or sfen to decode with unpack() + uint8_t* data; // uint8_t[32]; + + BitStream stream; + + // Output the board pieces to stream. + void write_board_piece_to_stream(chess::Piece pc) { + // piece type + chess::PieceType pr = pc.type(); + auto c = huffman_table[static_cast(pr)]; + stream.write_n_bit(c.code, c.bits); + + if (pc == chess::Piece::none()) + return; + + // first and second flag + stream.write_one_bit(static_cast(pc.color())); } - struct CompressedTrainingDataFile - { - struct Header + // Read one board piece from stream + [[nodiscard]] chess::Piece read_board_piece_from_stream() { + int pr = static_cast(chess::PieceType::None); + int code = 0, bits = 0; + while (true) { - std::uint32_t chunkSize; - }; + code |= stream.read_one_bit() << bits; + ++bits; - CompressedTrainingDataFile(std::string path, std::ios_base::openmode om = std::ios_base::app) : - m_path(std::move(path)), - m_file(m_path, std::ios_base::binary | std::ios_base::in | std::ios_base::out | om) - { - // Racey but who cares - m_sizeBytes = filesize(m_path.c_str()); - } + assert(bits <= 6); - void append(const char* data, std::uint32_t size) - { - writeChunkHeader({size}); - m_file.write(data, size); - m_sizeBytes += size + 8; + for (pr = static_cast(chess::PieceType::Pawn); + pr <= static_cast(chess::PieceType::None); ++pr) + if (huffman_table[pr].code == code && huffman_table[pr].bits == bits) + goto Found; } +Found:; + if (pr == static_cast(chess::PieceType::None)) + return chess::Piece::none(); - [[nodiscard]] bool hasNextChunk() - { - if (!m_file) - { - return false; - } + // first and second flag + chess::Color c = (chess::Color) stream.read_one_bit(); - m_file.peek(); - return !m_file.eof(); - } + return chess::Piece(static_cast(pr), c); + } +}; - void seek_to_start() - { - m_file.seekg(0); - } - [[nodiscard]] std::vector readNextChunk() - { - auto size = readChunkHeader().chunkSize; - std::vector data(size); - m_file.read(reinterpret_cast(data.data()), size); - return data; - } +[[nodiscard]] inline chess::Position pos_from_packed_sfen(const PackedSfen& sfen) { + SfenPacker packer; + auto& stream = packer.stream; + stream.set_data(const_cast(reinterpret_cast(&sfen))); - [[nodiscard]] std::size_t sizeBytes() const - { - return m_sizeBytes; - } + chess::Position pos{}; - private: - std::string m_path; - std::fstream m_file; - std::size_t m_sizeBytes; + // Active color + pos.setSideToMove((chess::Color) stream.read_one_bit()); - void writeChunkHeader(Header h) - { - unsigned char header[8]; - header[0] = 'B'; - header[1] = 'I'; - header[2] = 'N'; - header[3] = 'P'; - header[4] = h.chunkSize; - header[5] = h.chunkSize >> 8; - header[6] = h.chunkSize >> 16; - header[7] = h.chunkSize >> 24; - m_file.write(reinterpret_cast(header), 8); - } + // First the position of the ball + pos.place(chess::Piece(chess::PieceType::King, chess::Color::White), + static_cast(stream.read_n_bit(6))); + pos.place(chess::Piece(chess::PieceType::King, chess::Color::Black), + static_cast(stream.read_n_bit(6))); - [[nodiscard]] Header readChunkHeader() + // Piece placement + for (chess::Rank r = chess::rank8; r >= chess::rank1; --r) + { + for (chess::File f = chess::fileA; f <= chess::fileH; ++f) { - unsigned char header[8]; - m_file.read(reinterpret_cast(header), 8); - if (header[0] != 'B' || header[1] != 'I' || header[2] != 'N' || header[3] != 'P') + auto sq = chess::Square(f, r); + + // it seems there are already balls + chess::Piece pc; + if (pos.pieceAt(sq).type() != chess::PieceType::King) + { + assert(pos.pieceAt(sq) == chess::Piece::none()); + pc = packer.read_board_piece_from_stream(); + } + else { - assert(false); - // throw std::runtime_error("Invalid binpack file or chunk."); + pc = pos.pieceAt(sq); } - const std::uint32_t size = - header[4] - | (header[5] << 8) - | (header[6] << 16) - | (header[7] << 24); + // There may be no pieces, so skip in that case. + if (pc == chess::Piece::none()) + continue; - if (size > maxChunkSize) + if (pc.type() != chess::PieceType::King) { - assert(false); - // throw std::runtime_error("Chunks size larger than supported. Malformed file?"); + pos.place(pc, sq); } - return { size }; + assert(stream.get_cursor() <= 256); } - }; + } - [[nodiscard]] inline std::uint16_t signedToUnsigned(std::int16_t a) + // Castling availability. + chess::CastlingRights cr = chess::CastlingRights::None; + if (stream.read_one_bit()) { - std::uint16_t r; - std::memcpy(&r, &a, sizeof(std::uint16_t)); - if (r & 0x8000) - { - r ^= 0x7FFF; - } - r = (r << 1) | (r >> 15); - return r; + cr |= chess::CastlingRights::WhiteKingSide; } - - [[nodiscard]] inline std::int16_t unsignedToSigned(std::uint16_t r) + if (stream.read_one_bit()) { - std::int16_t a; - r = (r << 15) | (r >> 1); - if (r & 0x8000) - { - r ^= 0x7FFF; - } - std::memcpy(&a, &r, sizeof(std::uint16_t)); - return a; + cr |= chess::CastlingRights::WhiteQueenSide; + } + if (stream.read_one_bit()) + { + cr |= chess::CastlingRights::BlackKingSide; } + if (stream.read_one_bit()) + { + cr |= chess::CastlingRights::BlackQueenSide; + } + pos.setCastlingRights(cr); - struct TrainingDataEntry + // En passant square. Ignore if no pawn capture is possible + if (stream.read_one_bit()) { - chess::Position pos; - chess::Move move; - std::int16_t score; - std::uint16_t ply; - std::int16_t result; + chess::Square ep_square = static_cast(stream.read_n_bit(6)); + pos.setEpSquare(ep_square); + } - [[nodiscard]] bool isValid() const - { - return pos.isMoveLegal(move); - } + // Halfmove clock + std::uint8_t rule50 = stream.read_n_bit(6); - [[nodiscard]] bool isCapturingMove() const - { - return pos.pieceAt(move.to) != chess::Piece::none() && - pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling - } + // Fullmove number + std::uint16_t fullmove = stream.read_n_bit(8); - // The win rate model returns the probability (per mille) of winning given an eval - // and a game-ply. The model fits rather accurately the LTC fishtest statistics. - std::tuple win_rate_model() const { + // Fullmove number, high bits + // This was added as a fix for fullmove clock + // overflowing at 256. This change is backwards compatibile. + fullmove |= stream.read_n_bit(8) << 8; - // The model captures only up to 240 plies, so limit input (and rescale) - double m = std::min(240, int(ply)) / 64.0; + // Read the highest bit of rule50. This was added as a fix for rule50 + // counter having only 6 bits stored. + // In older entries this will just be a zero bit. + rule50 |= stream.read_n_bit(1) << 6; - // Coefficients of a 3rd order polynomial fit based on fishtest data - // for two parameters needed to transform eval to the argument of a - // logistic function. - double as[] = {-3.68389304, 30.07065921, -60.52878723, 149.53378557}; - double bs[] = {-2.0181857, 15.85685038, -29.83452023, 47.59078827}; - double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; - double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; + pos.setFullMove(fullmove); + pos.setRule50Counter(rule50); - // tweak wdl model, deviating from fishtest results, - // but yielding improved training results - b *= 1.5; + assert(stream.get_cursor() <= 256); - // Transform eval to centipawns with limited range - double x = std::clamp(double(100 * score) / 208, -2000.0, 2000.0); - double w = 1.0 / (1 + std::exp((a - x) / b)); - double l = 1.0 / (1 + std::exp((a + x) / b)); - double d = 1.0 - w - l; + return pos; +} +} - // Return win, loss, draw rate in per mille (rounded to nearest) - return std::make_tuple(w, l, d); - } +inline std::ifstream::pos_type filesize(const char* filename) { + std::ifstream in(filename, std::ifstream::ate | std::ifstream::binary); + return in.tellg(); +} - // how likely is end-game result with the current score? - double score_result_prob() const { - auto [w, l, d] = win_rate_model(); - if (result > 0) - return w; - if (result < 0) - return l; - return d; - } +struct CompressedTrainingDataFile { + struct Header { + std::uint32_t chunkSize; + }; + + CompressedTrainingDataFile(std::string path, std::ios_base::openmode om = std::ios_base::app) : + m_path(std::move(path)), + m_file(m_path, std::ios_base::binary | std::ios_base::in | std::ios_base::out | om) { + // Racey but who cares + m_sizeBytes = filesize(m_path.c_str()); + } - [[nodiscard]] bool isInCheck() const + void append(const char* data, std::uint32_t size) { + writeChunkHeader({size}); + m_file.write(data, size); + m_sizeBytes += size + 8; + } + + [[nodiscard]] bool hasNextChunk() { + if (!m_file) { - return pos.isCheck(); + return false; } - }; - [[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv) - { - TrainingDataEntry ret; + m_file.peek(); + return !m_file.eof(); + } - ret.pos = nodchip::pos_from_packed_sfen(psv.sfen); - ret.move = psv.move.toMove(); - ret.score = psv.score; - ret.ply = psv.gamePly; - ret.result = psv.game_result; + void seek_to_start() { m_file.seekg(0); } - return ret; + [[nodiscard]] std::vector readNextChunk() { + auto size = readChunkHeader().chunkSize; + std::vector data(size); + m_file.read(reinterpret_cast(data.data()), size); + return data; } - [[nodiscard]] inline nodchip::PackedSfenValue trainingDataEntryToPackedSfenValue(const TrainingDataEntry& plain) - { - nodchip::PackedSfenValue ret; + [[nodiscard]] std::size_t sizeBytes() const { return m_sizeBytes; } + + private: + std::string m_path; + std::fstream m_file; + std::size_t m_sizeBytes; + + void writeChunkHeader(Header h) { + unsigned char header[8]; + header[0] = 'B'; + header[1] = 'I'; + header[2] = 'N'; + header[3] = 'P'; + header[4] = h.chunkSize; + header[5] = h.chunkSize >> 8; + header[6] = h.chunkSize >> 16; + header[7] = h.chunkSize >> 24; + m_file.write(reinterpret_cast(header), 8); + } - nodchip::SfenPacker sp; - sp.data = reinterpret_cast(&ret.sfen); - sp.pack(plain.pos); + [[nodiscard]] Header readChunkHeader() { + unsigned char header[8]; + m_file.read(reinterpret_cast(header), 8); + if (header[0] != 'B' || header[1] != 'I' || header[2] != 'N' || header[3] != 'P') + { + assert(false); + // throw std::runtime_error("Invalid binpack file or chunk."); + } - ret.score = plain.score; - ret.move = nodchip::StockfishMove::fromMove(plain.move); - ret.gamePly = plain.ply; - ret.game_result = plain.result; - ret.padding = 0xff; // for consistency with the .bin format. + const std::uint32_t size = + header[4] | (header[5] << 8) | (header[6] << 16) | (header[7] << 24); + + if (size > maxChunkSize) + { + assert(false); + // throw std::runtime_error("Chunks size larger than supported. Malformed file?"); + } - return ret; + return {size}; } +}; - [[nodiscard]] inline bool isContinuation(const TrainingDataEntry& lhs, const TrainingDataEntry& rhs) +[[nodiscard]] inline std::uint16_t signedToUnsigned(std::int16_t a) { + std::uint16_t r; + std::memcpy(&r, &a, sizeof(std::uint16_t)); + if (r & 0x8000) { - return - lhs.result == -rhs.result - && lhs.ply + 1 == rhs.ply - && lhs.pos.afterMove(lhs.move) == rhs.pos; + r ^= 0x7FFF; } + r = (r << 1) | (r >> 15); + return r; +} - struct PackedTrainingDataEntry +[[nodiscard]] inline std::int16_t unsignedToSigned(std::uint16_t r) { + std::int16_t a; + r = (r << 15) | (r >> 1); + if (r & 0x8000) { - unsigned char bytes[32]; - }; + r ^= 0x7FFF; + } + std::memcpy(&a, &r, sizeof(std::uint16_t)); + return a; +} - [[nodiscard]] inline std::size_t usedBitsSafe(std::size_t value) - { - if (value == 0) return 0; - return chess::util::usedBits(value - 1); +struct TrainingDataEntry { + chess::Position pos; + chess::Move move; + std::int16_t score; + std::uint16_t ply; + std::int16_t result; + + [[nodiscard]] bool isValid() const { return pos.isMoveLegal(move); } + + [[nodiscard]] bool isCapturingMove() const { + return pos.pieceAt(move.to) != chess::Piece::none() + && pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling } - static constexpr std::size_t scoreVleBlockSize = 4; + // The win rate model returns the probability (per mille) of winning given an eval + // and a game-ply. The model fits rather accurately the LTC fishtest statistics. + std::tuple win_rate_model() const { + + // The model captures only up to 240 plies, so limit input (and rescale) + double m = std::min(240, int(ply)) / 64.0; + + // Coefficients of a 3rd order polynomial fit based on fishtest data + // for two parameters needed to transform eval to the argument of a + // logistic function. + double as[] = {-3.68389304, 30.07065921, -60.52878723, 149.53378557}; + double bs[] = {-2.0181857, 15.85685038, -29.83452023, 47.59078827}; + double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; + double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; + + // tweak wdl model, deviating from fishtest results, + // but yielding improved training results + b *= 1.5; + + // Transform eval to centipawns with limited range + double x = std::clamp(double(100 * score) / 208, -2000.0, 2000.0); + double w = 1.0 / (1 + std::exp((a - x) / b)); + double l = 1.0 / (1 + std::exp((a + x) / b)); + double d = 1.0 - w - l; + + // Return win, loss, draw rate in per mille (rounded to nearest) + return std::make_tuple(w, l, d); + } - struct PackedMoveScoreListReader - { - TrainingDataEntry entry; - std::uint16_t numPlies; - unsigned char* movetext; + // how likely is end-game result with the current score? + double score_result_prob() const { + auto [w, l, d] = win_rate_model(); + if (result > 0) + return w; + if (result < 0) + return l; + return d; + } - PackedMoveScoreListReader(const TrainingDataEntry& entry_, unsigned char* movetext_, std::uint16_t numPlies_) : - entry(entry_), - numPlies(numPlies_), - movetext(movetext_), - m_lastScore(-entry_.score) - { + [[nodiscard]] bool isInCheck() const { return pos.isCheck(); } +}; - } +[[nodiscard]] inline TrainingDataEntry +packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv) { + TrainingDataEntry ret; - [[nodiscard]] std::uint8_t extractBitsLE8(std::size_t count) - { - if (count == 0) return 0; + ret.pos = nodchip::pos_from_packed_sfen(psv.sfen); + ret.move = psv.move.toMove(); + ret.score = psv.score; + ret.ply = psv.gamePly; + ret.result = psv.game_result; - if (m_readBitsLeft == 0) - { - m_readOffset += 1; - m_readBitsLeft = 8; - } + return ret; +} - const std::uint8_t byte = movetext[m_readOffset] << (8 - m_readBitsLeft); - std::uint8_t bits = byte >> (8 - count); +[[nodiscard]] inline nodchip::PackedSfenValue +trainingDataEntryToPackedSfenValue(const TrainingDataEntry& plain) { + nodchip::PackedSfenValue ret; - if (count > m_readBitsLeft) - { - const auto spillCount = count - m_readBitsLeft; - bits |= movetext[m_readOffset + 1] >> (8 - spillCount); + nodchip::SfenPacker sp; + sp.data = reinterpret_cast(&ret.sfen); + sp.pack(plain.pos); - m_readBitsLeft += 8; - m_readOffset += 1; - } + ret.score = plain.score; + ret.move = nodchip::StockfishMove::fromMove(plain.move); + ret.gamePly = plain.ply; + ret.game_result = plain.result; + ret.padding = 0xff; // for consistency with the .bin format. - m_readBitsLeft -= count; + return ret; +} - return bits; - } +[[nodiscard]] inline bool isContinuation(const TrainingDataEntry& lhs, + const TrainingDataEntry& rhs) { + return lhs.result == -rhs.result && lhs.ply + 1 == rhs.ply + && lhs.pos.afterMove(lhs.move) == rhs.pos; +} - [[nodiscard]] std::uint16_t extractVle16(std::size_t blockSize) - { - auto mask = (1 << blockSize) - 1; - std::uint16_t v = 0; - std::size_t offset = 0; - for(;;) - { - std::uint16_t block = extractBitsLE8(blockSize + 1); - v |= ((block & mask) << offset); - if (!(block >> blockSize)) - { - break; - } +struct PackedTrainingDataEntry { + unsigned char bytes[32]; +}; - offset += blockSize; - } - return v; - } +[[nodiscard]] inline std::size_t usedBitsSafe(std::size_t value) { + if (value == 0) + return 0; + return chess::util::usedBits(value - 1); +} + +static constexpr std::size_t scoreVleBlockSize = 4; + +struct PackedMoveScoreListReader { + TrainingDataEntry entry; + std::uint16_t numPlies; + unsigned char* movetext; - [[nodiscard]] TrainingDataEntry nextEntry() + PackedMoveScoreListReader(const TrainingDataEntry& entry_, + unsigned char* movetext_, + std::uint16_t numPlies_) : + entry(entry_), + numPlies(numPlies_), + movetext(movetext_), + m_lastScore(-entry_.score) {} + + [[nodiscard]] std::uint8_t extractBitsLE8(std::size_t count) { + if (count == 0) + return 0; + + if (m_readBitsLeft == 0) { - entry.pos.doMove(entry.move); - auto [move, score] = nextMoveScore(entry.pos); - entry.move = move; - entry.score = score; - entry.ply += 1; - entry.result = -entry.result; - return entry; + m_readOffset += 1; + m_readBitsLeft = 8; } - [[nodiscard]] bool hasNext() const + const std::uint8_t byte = movetext[m_readOffset] << (8 - m_readBitsLeft); + std::uint8_t bits = byte >> (8 - count); + + if (count > m_readBitsLeft) { - return m_numReadPlies < numPlies; + const auto spillCount = count - m_readBitsLeft; + bits |= movetext[m_readOffset + 1] >> (8 - spillCount); + + m_readBitsLeft += 8; + m_readOffset += 1; } - [[nodiscard]] std::pair nextMoveScore(const chess::Position& pos) + m_readBitsLeft -= count; + + return bits; + } + + [[nodiscard]] std::uint16_t extractVle16(std::size_t blockSize) { + auto mask = (1 << blockSize) - 1; + std::uint16_t v = 0; + std::size_t offset = 0; + for (;;) { - chess::Move move; - std::int16_t score; + std::uint16_t block = extractBitsLE8(blockSize + 1); + v |= ((block & mask) << offset); + if (!(block >> blockSize)) + { + break; + } - const chess::Color sideToMove = pos.sideToMove(); - const chess::Bitboard ourPieces = pos.piecesBB(sideToMove); - const chess::Bitboard theirPieces = pos.piecesBB(!sideToMove); - const chess::Bitboard occupied = ourPieces | theirPieces; + offset += blockSize; + } + return v; + } - const auto pieceId = extractBitsLE8(usedBitsSafe(ourPieces.count())); - const auto from = chess::Square(chess::nthSetBitIndex(ourPieces.bits(), pieceId)); + [[nodiscard]] TrainingDataEntry nextEntry() { + entry.pos.doMove(entry.move); + auto [move, score] = nextMoveScore(entry.pos); + entry.move = move; + entry.score = score; + entry.ply += 1; + entry.result = -entry.result; + return entry; + } - const auto pt = pos.pieceAt(from).type(); - switch (pt) - { - case chess::PieceType::Pawn: - { - const chess::Rank promotionRank = pos.sideToMove() == chess::Color::White ? chess::rank7 : chess::rank2; - const chess::Rank startRank = pos.sideToMove() == chess::Color::White ? chess::rank2 : chess::rank7; - const auto forward = sideToMove == chess::Color::White ? chess::FlatSquareOffset(0, 1) : chess::FlatSquareOffset(0, -1); + [[nodiscard]] bool hasNext() const { return m_numReadPlies < numPlies; } - const chess::Square epSquare = pos.epSquare(); + [[nodiscard]] std::pair nextMoveScore(const chess::Position& pos) { + chess::Move move; + std::int16_t score; - chess::Bitboard attackTargets = theirPieces; - if (epSquare != chess::Square::none()) - { - attackTargets |= epSquare; - } + const chess::Color sideToMove = pos.sideToMove(); + const chess::Bitboard ourPieces = pos.piecesBB(sideToMove); + const chess::Bitboard theirPieces = pos.piecesBB(!sideToMove); + const chess::Bitboard occupied = ourPieces | theirPieces; + + const auto pieceId = extractBitsLE8(usedBitsSafe(ourPieces.count())); + const auto from = chess::Square(chess::nthSetBitIndex(ourPieces.bits(), pieceId)); + + const auto pt = pos.pieceAt(from).type(); + switch (pt) + { + case chess::PieceType::Pawn : { + const chess::Rank promotionRank = + pos.sideToMove() == chess::Color::White ? chess::rank7 : chess::rank2; + const chess::Rank startRank = + pos.sideToMove() == chess::Color::White ? chess::rank2 : chess::rank7; + const auto forward = sideToMove == chess::Color::White ? chess::FlatSquareOffset(0, 1) + : chess::FlatSquareOffset(0, -1); + + const chess::Square epSquare = pos.epSquare(); + + chess::Bitboard attackTargets = theirPieces; + if (epSquare != chess::Square::none()) + { + attackTargets |= epSquare; + } - chess::Bitboard destinations = chess::bb::pawnAttacks(chess::Bitboard::square(from), sideToMove) & attackTargets; + chess::Bitboard destinations = + chess::bb::pawnAttacks(chess::Bitboard::square(from), sideToMove) & attackTargets; - const chess::Square sqForward = from + forward; - if (!occupied.isSet(sqForward)) + const chess::Square sqForward = from + forward; + if (!occupied.isSet(sqForward)) + { + destinations |= sqForward; + if (from.rank() == startRank && !occupied.isSet(sqForward + forward)) { - destinations |= sqForward; - if ( - from.rank() == startRank - && !occupied.isSet(sqForward + forward) - ) - { - destinations |= sqForward + forward; - } + destinations |= sqForward + forward; } + } + + const auto destinationsCount = destinations.count(); + if (from.rank() == promotionRank) + { + const auto moveId = extractBitsLE8(usedBitsSafe(destinationsCount * 4ull)); + const chess::Piece promotedPiece = + chess::Piece(chess::fromOrdinal( + ordinal(chess::PieceType::Knight) + (moveId % 4ull)), + sideToMove); + const auto to = + chess::Square(chess::nthSetBitIndex(destinations.bits(), moveId / 4ull)); - const auto destinationsCount = destinations.count(); - if (from.rank() == promotionRank) + move = chess::Move::promotion(from, to, promotedPiece); + break; + } + else + { + auto moveId = extractBitsLE8(usedBitsSafe(destinationsCount)); + const auto to = chess::Square(chess::nthSetBitIndex(destinations.bits(), moveId)); + if (to == epSquare) { - const auto moveId = extractBitsLE8(usedBitsSafe(destinationsCount * 4ull)); - const chess::Piece promotedPiece = chess::Piece( - chess::fromOrdinal(ordinal(chess::PieceType::Knight) + (moveId % 4ull)), - sideToMove - ); - const auto to = chess::Square(chess::nthSetBitIndex(destinations.bits(), moveId / 4ull)); - - move = chess::Move::promotion(from, to, promotedPiece); + move = chess::Move::enPassant(from, to); break; } else { - auto moveId = extractBitsLE8(usedBitsSafe(destinationsCount)); - const auto to = chess::Square(chess::nthSetBitIndex(destinations.bits(), moveId)); - if (to == epSquare) - { - move = chess::Move::enPassant(from, to); - break; - } - else - { - move = chess::Move::normal(from, to); - break; - } + move = chess::Move::normal(from, to); + break; } } - case chess::PieceType::King: - { - const chess::CastlingRights ourCastlingRightsMask = - sideToMove == chess::Color::White - ? chess::CastlingRights::White - : chess::CastlingRights::Black; + } + case chess::PieceType::King : { + const chess::CastlingRights ourCastlingRightsMask = sideToMove == chess::Color::White + ? chess::CastlingRights::White + : chess::CastlingRights::Black; - const chess::CastlingRights castlingRights = pos.castlingRights(); + const chess::CastlingRights castlingRights = pos.castlingRights(); - const chess::Bitboard attacks = chess::bb::pseudoAttacks(from) & ~ourPieces; - const std::size_t attacksSize = attacks.count(); - const std::size_t numCastlings = chess::intrin::popcount(ordinal(castlingRights & ourCastlingRightsMask)); + const chess::Bitboard attacks = + chess::bb::pseudoAttacks(from) & ~ourPieces; + const std::size_t attacksSize = attacks.count(); + const std::size_t numCastlings = + chess::intrin::popcount(ordinal(castlingRights & ourCastlingRightsMask)); - const auto moveId = extractBitsLE8(usedBitsSafe(attacksSize + numCastlings)); + const auto moveId = extractBitsLE8(usedBitsSafe(attacksSize + numCastlings)); - if (moveId >= attacksSize) - { - const std::size_t idx = moveId - attacksSize; + if (moveId >= attacksSize) + { + const std::size_t idx = moveId - attacksSize; - const chess::CastleType castleType = - idx == 0 - && chess::contains(castlingRights, chess::CastlingTraits::castlingRights[sideToMove][chess::CastleType::Long]) - ? chess::CastleType::Long - : chess::CastleType::Short; + const chess::CastleType castleType = + idx == 0 + && chess::contains( + castlingRights, + chess::CastlingTraits::castlingRights[sideToMove][chess::CastleType::Long]) + ? chess::CastleType::Long + : chess::CastleType::Short; - move = chess::Move::castle(castleType, sideToMove); - break; - } - else - { - auto to = chess::Square(chess::nthSetBitIndex(attacks.bits(), moveId)); - move = chess::Move::normal(from, to); - break; - } + move = chess::Move::castle(castleType, sideToMove); break; } - default: + else { - const chess::Bitboard attacks = chess::bb::attacks(pt, from, occupied) & ~ourPieces; - const auto moveId = extractBitsLE8(usedBitsSafe(attacks.count())); auto to = chess::Square(chess::nthSetBitIndex(attacks.bits(), moveId)); - move = chess::Move::normal(from, to); + move = chess::Move::normal(from, to); break; } - } + break; + } + default : { + const chess::Bitboard attacks = chess::bb::attacks(pt, from, occupied) & ~ourPieces; + const auto moveId = extractBitsLE8(usedBitsSafe(attacks.count())); + auto to = chess::Square(chess::nthSetBitIndex(attacks.bits(), moveId)); + move = chess::Move::normal(from, to); + break; + } + } - score = m_lastScore + unsignedToSigned(extractVle16(scoreVleBlockSize)); - m_lastScore = -score; + score = m_lastScore + unsignedToSigned(extractVle16(scoreVleBlockSize)); + m_lastScore = -score; - ++m_numReadPlies; + ++m_numReadPlies; - return {move, score}; - } + return {move, score}; + } - [[nodiscard]] std::size_t numReadBytes() - { - return m_readOffset + (m_readBitsLeft != 8); - } + [[nodiscard]] std::size_t numReadBytes() { return m_readOffset + (m_readBitsLeft != 8); } - private: - std::size_t m_readBitsLeft = 8; - std::size_t m_readOffset = 0; - std::int16_t m_lastScore = 0; - std::uint16_t m_numReadPlies = 0; - }; + private: + std::size_t m_readBitsLeft = 8; + std::size_t m_readOffset = 0; + std::int16_t m_lastScore = 0; + std::uint16_t m_numReadPlies = 0; +}; - struct PackedMoveScoreList - { - std::uint16_t numPlies = 0; - std::vector movetext; +struct PackedMoveScoreList { + std::uint16_t numPlies = 0; + std::vector movetext; + + void clear(const TrainingDataEntry& e) { + numPlies = 0; + movetext.clear(); + m_bitsLeft = 0; + m_lastScore = -e.score; + } + + void addBitsLE8(std::uint8_t bits, std::size_t count) { + if (count == 0) + return; - void clear(const TrainingDataEntry& e) + if (m_bitsLeft == 0) { - numPlies = 0; - movetext.clear(); - m_bitsLeft = 0; - m_lastScore = -e.score; + movetext.emplace_back(bits << (8 - count)); + m_bitsLeft = 8; } - - void addBitsLE8(std::uint8_t bits, std::size_t count) + else if (count <= m_bitsLeft) { - if (count == 0) return; - - if (m_bitsLeft == 0) - { - movetext.emplace_back(bits << (8 - count)); - m_bitsLeft = 8; - } - else if (count <= m_bitsLeft) - { - movetext.back() |= bits << (m_bitsLeft - count); - } - else - { - const auto spillCount = count - m_bitsLeft; - movetext.back() |= bits >> spillCount; - movetext.emplace_back(bits << (8 - spillCount)); - m_bitsLeft += 8; - } - - m_bitsLeft -= count; + movetext.back() |= bits << (m_bitsLeft - count); } - - void addBitsVle16(std::uint16_t v, std::size_t blockSize) + else { - auto mask = (1 << blockSize) - 1; - for(;;) - { - std::uint8_t block = (v & mask) | ((v > mask) << blockSize); - addBitsLE8(block, blockSize + 1); - v >>= blockSize; - if (v == 0) break; - } + const auto spillCount = count - m_bitsLeft; + movetext.back() |= bits >> spillCount; + movetext.emplace_back(bits << (8 - spillCount)); + m_bitsLeft += 8; } + m_bitsLeft -= count; + } - void addMoveScore(const chess::Position& pos, chess::Move move, std::int16_t score) + void addBitsVle16(std::uint16_t v, std::size_t blockSize) { + auto mask = (1 << blockSize) - 1; + for (;;) { - const chess::Color sideToMove = pos.sideToMove(); - const chess::Bitboard ourPieces = pos.piecesBB(sideToMove); - const chess::Bitboard theirPieces = pos.piecesBB(!sideToMove); - const chess::Bitboard occupied = ourPieces | theirPieces; + std::uint8_t block = (v & mask) | ((v > mask) << blockSize); + addBitsLE8(block, blockSize + 1); + v >>= blockSize; + if (v == 0) + break; + } + } - const std::uint8_t pieceId = (pos.piecesBB(sideToMove) & chess::bb::before(move.from)).count(); - std::size_t numMoves = 0; - int moveId = 0; - const auto pt = pos.pieceAt(move.from).type(); - switch (pt) - { - case chess::PieceType::Pawn: - { - const chess::Rank secondToLastRank = pos.sideToMove() == chess::Color::White ? chess::rank7 : chess::rank2; - const chess::Rank startRank = pos.sideToMove() == chess::Color::White ? chess::rank2 : chess::rank7; - const auto forward = sideToMove == chess::Color::White ? chess::FlatSquareOffset(0, 1) : chess::FlatSquareOffset(0, -1); - const chess::Square epSquare = pos.epSquare(); + void addMoveScore(const chess::Position& pos, chess::Move move, std::int16_t score) { + const chess::Color sideToMove = pos.sideToMove(); + const chess::Bitboard ourPieces = pos.piecesBB(sideToMove); + const chess::Bitboard theirPieces = pos.piecesBB(!sideToMove); + const chess::Bitboard occupied = ourPieces | theirPieces; - chess::Bitboard attackTargets = theirPieces; - if (epSquare != chess::Square::none()) - { - attackTargets |= epSquare; - } + const std::uint8_t pieceId = + (pos.piecesBB(sideToMove) & chess::bb::before(move.from)).count(); + std::size_t numMoves = 0; + int moveId = 0; + const auto pt = pos.pieceAt(move.from).type(); + switch (pt) + { + case chess::PieceType::Pawn : { + const chess::Rank secondToLastRank = + pos.sideToMove() == chess::Color::White ? chess::rank7 : chess::rank2; + const chess::Rank startRank = + pos.sideToMove() == chess::Color::White ? chess::rank2 : chess::rank7; + const auto forward = sideToMove == chess::Color::White ? chess::FlatSquareOffset(0, 1) + : chess::FlatSquareOffset(0, -1); - chess::Bitboard destinations = chess::bb::pawnAttacks(chess::Bitboard::square(move.from), sideToMove) & attackTargets; + const chess::Square epSquare = pos.epSquare(); - const chess::Square sqForward = move.from + forward; - if (!occupied.isSet(sqForward)) - { - destinations |= sqForward; + chess::Bitboard attackTargets = theirPieces; + if (epSquare != chess::Square::none()) + { + attackTargets |= epSquare; + } - if ( - move.from.rank() == startRank - && !occupied.isSet(sqForward + forward) - ) - { - destinations |= sqForward + forward; - } - } + chess::Bitboard destinations = + chess::bb::pawnAttacks(chess::Bitboard::square(move.from), sideToMove) + & attackTargets; + + const chess::Square sqForward = move.from + forward; + if (!occupied.isSet(sqForward)) + { + destinations |= sqForward; - moveId = (destinations & chess::bb::before(move.to)).count(); - numMoves = destinations.count(); - if (move.from.rank() == secondToLastRank) + if (move.from.rank() == startRank && !occupied.isSet(sqForward + forward)) { - const auto promotionIndex = (ordinal(move.promotedPiece.type()) - ordinal(chess::PieceType::Knight)); - moveId = moveId * 4 + promotionIndex; - numMoves *= 4; + destinations |= sqForward + forward; } - - break; } - case chess::PieceType::King: + + moveId = (destinations & chess::bb::before(move.to)).count(); + numMoves = destinations.count(); + if (move.from.rank() == secondToLastRank) { - const chess::CastlingRights ourCastlingRightsMask = - sideToMove == chess::Color::White - ? chess::CastlingRights::White - : chess::CastlingRights::Black; + const auto promotionIndex = + (ordinal(move.promotedPiece.type()) - ordinal(chess::PieceType::Knight)); + moveId = moveId * 4 + promotionIndex; + numMoves *= 4; + } - const chess::CastlingRights castlingRights = pos.castlingRights(); + break; + } + case chess::PieceType::King : { + const chess::CastlingRights ourCastlingRightsMask = sideToMove == chess::Color::White + ? chess::CastlingRights::White + : chess::CastlingRights::Black; - const chess::Bitboard attacks = chess::bb::pseudoAttacks(move.from) & ~ourPieces; - const auto attacksSize = attacks.count(); - const auto numCastlingRights = chess::intrin::popcount(ordinal(castlingRights & ourCastlingRightsMask)); + const chess::CastlingRights castlingRights = pos.castlingRights(); - numMoves += attacksSize; - numMoves += numCastlingRights; + const chess::Bitboard attacks = + chess::bb::pseudoAttacks(move.from) & ~ourPieces; + const auto attacksSize = attacks.count(); + const auto numCastlingRights = + chess::intrin::popcount(ordinal(castlingRights & ourCastlingRightsMask)); - if (move.type == chess::MoveType::Castle) - { - const auto longCastlingRights = chess::CastlingTraits::castlingRights[sideToMove][chess::CastleType::Long]; + numMoves += attacksSize; + numMoves += numCastlingRights; - moveId = attacksSize - 1; + if (move.type == chess::MoveType::Castle) + { + const auto longCastlingRights = + chess::CastlingTraits::castlingRights[sideToMove][chess::CastleType::Long]; - if (chess::contains(castlingRights, longCastlingRights)) - { - // We have to add one no matter if it's the used one or not. - moveId += 1; - } + moveId = attacksSize - 1; - if (chess::CastlingTraits::moveCastlingType(move) == chess::CastleType::Short) - { - moveId += 1; - } + if (chess::contains(castlingRights, longCastlingRights)) + { + // We have to add one no matter if it's the used one or not. + moveId += 1; } - else + + if (chess::CastlingTraits::moveCastlingType(move) == chess::CastleType::Short) { - moveId = (attacks & chess::bb::before(move.to)).count(); + moveId += 1; } - break; } - default: + else { - const chess::Bitboard attacks = chess::bb::attacks(pt, move.from, occupied) & ~ourPieces; - moveId = (attacks & chess::bb::before(move.to)).count(); - numMoves = attacks.count(); - } } + break; + } + default : { + const chess::Bitboard attacks = + chess::bb::attacks(pt, move.from, occupied) & ~ourPieces; - const std::size_t numPieces = ourPieces.count(); - addBitsLE8(pieceId, usedBitsSafe(numPieces)); - addBitsLE8(moveId, usedBitsSafe(numMoves)); + moveId = (attacks & chess::bb::before(move.to)).count(); + numMoves = attacks.count(); + } + } - std::uint16_t scoreDelta = signedToUnsigned(score - m_lastScore); - addBitsVle16(scoreDelta, scoreVleBlockSize); - m_lastScore = -score; + const std::size_t numPieces = ourPieces.count(); + addBitsLE8(pieceId, usedBitsSafe(numPieces)); + addBitsLE8(moveId, usedBitsSafe(numMoves)); - ++numPlies; - } + std::uint16_t scoreDelta = signedToUnsigned(score - m_lastScore); + addBitsVle16(scoreDelta, scoreVleBlockSize); + m_lastScore = -score; - private: - std::size_t m_bitsLeft = 0; - std::int16_t m_lastScore = 0; - }; + ++numPlies; + } + private: + std::size_t m_bitsLeft = 0; + std::int16_t m_lastScore = 0; +}; - [[nodiscard]] inline PackedTrainingDataEntry packEntry(const TrainingDataEntry& plain) - { - PackedTrainingDataEntry packed; - auto compressedPos = plain.pos.compress(); - auto compressedMove = plain.move.compress(); +[[nodiscard]] inline PackedTrainingDataEntry packEntry(const TrainingDataEntry& plain) { + PackedTrainingDataEntry packed; - static_assert(sizeof(compressedPos) + sizeof(compressedMove) + 6 == sizeof(PackedTrainingDataEntry)); + auto compressedPos = plain.pos.compress(); + auto compressedMove = plain.move.compress(); - std::size_t offset = 0; - compressedPos.writeToBigEndian(packed.bytes); - offset += sizeof(compressedPos); - compressedMove.writeToBigEndian(packed.bytes + offset); - offset += sizeof(compressedMove); - std::uint16_t pr = plain.ply | (signedToUnsigned(plain.result) << 14); - packed.bytes[offset++] = signedToUnsigned(plain.score) >> 8; - packed.bytes[offset++] = signedToUnsigned(plain.score); - packed.bytes[offset++] = pr >> 8; - packed.bytes[offset++] = pr; - packed.bytes[offset++] = plain.pos.rule50Counter() >> 8; - packed.bytes[offset++] = plain.pos.rule50Counter(); + static_assert(sizeof(compressedPos) + sizeof(compressedMove) + 6 + == sizeof(PackedTrainingDataEntry)); - return packed; - } + std::size_t offset = 0; + compressedPos.writeToBigEndian(packed.bytes); + offset += sizeof(compressedPos); + compressedMove.writeToBigEndian(packed.bytes + offset); + offset += sizeof(compressedMove); + std::uint16_t pr = plain.ply | (signedToUnsigned(plain.result) << 14); + packed.bytes[offset++] = signedToUnsigned(plain.score) >> 8; + packed.bytes[offset++] = signedToUnsigned(plain.score); + packed.bytes[offset++] = pr >> 8; + packed.bytes[offset++] = pr; + packed.bytes[offset++] = plain.pos.rule50Counter() >> 8; + packed.bytes[offset++] = plain.pos.rule50Counter(); - [[nodiscard]] inline TrainingDataEntry unpackEntry(const PackedTrainingDataEntry& packed) - { - TrainingDataEntry plain; + return packed; +} - std::size_t offset = 0; - auto compressedPos = chess::CompressedPosition::readFromBigEndian(packed.bytes); - plain.pos = compressedPos.decompress(); - offset += sizeof(compressedPos); - auto compressedMove = chess::CompressedMove::readFromBigEndian(packed.bytes + offset); - plain.move = compressedMove.decompress(); - offset += sizeof(compressedMove); - plain.score = unsignedToSigned((packed.bytes[offset] << 8) | packed.bytes[offset+1]); - offset += 2; - std::uint16_t pr = (packed.bytes[offset] << 8) | packed.bytes[offset+1]; - plain.ply = pr & 0x3FFF; - plain.pos.setPly(plain.ply); - plain.result = unsignedToSigned(pr >> 14); - offset += 2; - plain.pos.setRule50Counter((packed.bytes[offset] << 8) | packed.bytes[offset+1]); +[[nodiscard]] inline TrainingDataEntry unpackEntry(const PackedTrainingDataEntry& packed) { + TrainingDataEntry plain; + + std::size_t offset = 0; + auto compressedPos = chess::CompressedPosition::readFromBigEndian(packed.bytes); + plain.pos = compressedPos.decompress(); + offset += sizeof(compressedPos); + auto compressedMove = chess::CompressedMove::readFromBigEndian(packed.bytes + offset); + plain.move = compressedMove.decompress(); + offset += sizeof(compressedMove); + plain.score = unsignedToSigned((packed.bytes[offset] << 8) | packed.bytes[offset + 1]); + offset += 2; + std::uint16_t pr = (packed.bytes[offset] << 8) | packed.bytes[offset + 1]; + plain.ply = pr & 0x3FFF; + plain.pos.setPly(plain.ply); + plain.result = unsignedToSigned(pr >> 14); + offset += 2; + plain.pos.setRule50Counter((packed.bytes[offset] << 8) | packed.bytes[offset + 1]); + + return plain; +} - return plain; +struct CompressedTrainingDataEntryWriter { + static constexpr std::size_t chunkSize = suggestedChunkSize; + + CompressedTrainingDataEntryWriter(std::string path, + std::ios_base::openmode om = std::ios_base::app) : + m_outputFile(path, om), + m_lastEntry{}, + m_movelist{}, + m_packedSize(0), + m_packedEntries(chunkSize + maxMovelistSize), + m_isFirst(true) { + m_lastEntry.ply = 0xFFFF; // so it's never a continuation + m_lastEntry.result = 0x7FFF; } - struct CompressedTrainingDataEntryWriter - { - static constexpr std::size_t chunkSize = suggestedChunkSize; - - CompressedTrainingDataEntryWriter(std::string path, std::ios_base::openmode om = std::ios_base::app) : - m_outputFile(path, om), - m_lastEntry{}, - m_movelist{}, - m_packedSize(0), - m_packedEntries(chunkSize + maxMovelistSize), - m_isFirst(true) + void addTrainingDataEntry(const TrainingDataEntry& e) { + bool isCont = isContinuation(m_lastEntry, e); + if (isCont) { - m_lastEntry.ply = 0xFFFF; // so it's never a continuation - m_lastEntry.result = 0x7FFF; + // add to movelist + m_movelist.addMoveScore(e.pos, e.move, e.score); } - - void addTrainingDataEntry(const TrainingDataEntry& e) + else { - bool isCont = isContinuation(m_lastEntry, e); - if (isCont) + if (!m_isFirst) { - // add to movelist - m_movelist.addMoveScore(e.pos, e.move, e.score); + writeMovelist(); } - else - { - if (!m_isFirst) - { - writeMovelist(); - } - - if (m_packedSize >= chunkSize) - { - m_outputFile.append(m_packedEntries.data(), m_packedSize); - m_packedSize = 0; - } - auto packed = packEntry(e); - std::memcpy(m_packedEntries.data() + m_packedSize, &packed, sizeof(PackedTrainingDataEntry)); - m_packedSize += sizeof(PackedTrainingDataEntry); + if (m_packedSize >= chunkSize) + { + m_outputFile.append(m_packedEntries.data(), m_packedSize); + m_packedSize = 0; + } - m_movelist.clear(e); + auto packed = packEntry(e); + std::memcpy(m_packedEntries.data() + m_packedSize, &packed, + sizeof(PackedTrainingDataEntry)); + m_packedSize += sizeof(PackedTrainingDataEntry); - m_isFirst = false; - } + m_movelist.clear(e); - m_lastEntry = e; + m_isFirst = false; } - ~CompressedTrainingDataEntryWriter() + m_lastEntry = e; + } + + ~CompressedTrainingDataEntryWriter() { + if (m_packedSize > 0) { - if (m_packedSize > 0) + if (!m_isFirst) { - if (!m_isFirst) - { - writeMovelist(); - } - - m_outputFile.append(m_packedEntries.data(), m_packedSize); - m_packedSize = 0; + writeMovelist(); } + + m_outputFile.append(m_packedEntries.data(), m_packedSize); + m_packedSize = 0; } + } - private: - CompressedTrainingDataFile m_outputFile; - TrainingDataEntry m_lastEntry; - PackedMoveScoreList m_movelist; - std::size_t m_packedSize; - std::vector m_packedEntries; - bool m_isFirst; + private: + CompressedTrainingDataFile m_outputFile; + TrainingDataEntry m_lastEntry; + PackedMoveScoreList m_movelist; + std::size_t m_packedSize; + std::vector m_packedEntries; + bool m_isFirst; - void writeMovelist() + void writeMovelist() { + m_packedEntries[m_packedSize++] = m_movelist.numPlies >> 8; + m_packedEntries[m_packedSize++] = m_movelist.numPlies; + if (m_movelist.numPlies > 0) { - m_packedEntries[m_packedSize++] = m_movelist.numPlies >> 8; - m_packedEntries[m_packedSize++] = m_movelist.numPlies; - if (m_movelist.numPlies > 0) - { - std::memcpy(m_packedEntries.data() + m_packedSize, m_movelist.movetext.data(), m_movelist.movetext.size()); - m_packedSize += m_movelist.movetext.size(); - } - }; + std::memcpy(m_packedEntries.data() + m_packedSize, m_movelist.movetext.data(), + m_movelist.movetext.size()); + m_packedSize += m_movelist.movetext.size(); + } }; +}; - struct CompressedTrainingDataEntryReader - { - static constexpr std::size_t chunkSize = suggestedChunkSize; +struct CompressedTrainingDataEntryReader { + static constexpr std::size_t chunkSize = suggestedChunkSize; - CompressedTrainingDataEntryReader(std::string path, std::ios_base::openmode om = std::ios_base::app) : - m_inputFile(path, om), - m_chunk(), - m_movelistReader(std::nullopt), - m_offset(0), - m_isEnd(false) + CompressedTrainingDataEntryReader(std::string path, + std::ios_base::openmode om = std::ios_base::app) : + m_inputFile(path, om), + m_chunk(), + m_movelistReader(std::nullopt), + m_offset(0), + m_isEnd(false) { + if (!m_inputFile.hasNextChunk()) { - if (!m_inputFile.hasNextChunk()) - { - m_isEnd = true; - } - else - { - m_chunk = m_inputFile.readNextChunk(); - } + m_isEnd = true; } - - [[nodiscard]] bool hasNext() + else { - return !m_isEnd; + m_chunk = m_inputFile.readNextChunk(); } + } - [[nodiscard]] TrainingDataEntry next() + [[nodiscard]] bool hasNext() { return !m_isEnd; } + + [[nodiscard]] TrainingDataEntry next() { + if (m_movelistReader.has_value()) { - if (m_movelistReader.has_value()) + const auto e = m_movelistReader->nextEntry(); + + if (!m_movelistReader->hasNext()) { - const auto e = m_movelistReader->nextEntry(); + m_offset += m_movelistReader->numReadBytes(); + m_movelistReader.reset(); - if (!m_movelistReader->hasNext()) - { - m_offset += m_movelistReader->numReadBytes(); - m_movelistReader.reset(); + fetchNextChunkIfNeeded(); + } + + return e; + } + + PackedTrainingDataEntry packed; + std::memcpy(&packed, m_chunk.data() + m_offset, sizeof(PackedTrainingDataEntry)); + m_offset += sizeof(PackedTrainingDataEntry); - fetchNextChunkIfNeeded(); - } + const std::uint16_t numPlies = (m_chunk[m_offset] << 8) | m_chunk[m_offset + 1]; + m_offset += 2; - return e; - } + const auto e = unpackEntry(packed); - PackedTrainingDataEntry packed; - std::memcpy(&packed, m_chunk.data() + m_offset, sizeof(PackedTrainingDataEntry)); - m_offset += sizeof(PackedTrainingDataEntry); + if (numPlies > 0) + { + m_movelistReader.emplace(e, reinterpret_cast(m_chunk.data()) + m_offset, + numPlies); + } + else + { + fetchNextChunkIfNeeded(); + } - const std::uint16_t numPlies = (m_chunk[m_offset] << 8) | m_chunk[m_offset + 1]; - m_offset += 2; + return e; + } - const auto e = unpackEntry(packed); + private: + CompressedTrainingDataFile m_inputFile; + std::vector m_chunk; + std::optional m_movelistReader; + std::size_t m_offset; + bool m_isEnd; - if (numPlies > 0) + void fetchNextChunkIfNeeded() { + if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size()) + { + if (m_inputFile.hasNextChunk()) { - m_movelistReader.emplace(e, reinterpret_cast(m_chunk.data()) + m_offset, numPlies); + m_chunk = m_inputFile.readNextChunk(); + m_offset = 0; } else { - fetchNextChunkIfNeeded(); + m_isEnd = true; } - - return e; } + } +}; - private: - CompressedTrainingDataFile m_inputFile; - std::vector m_chunk; - std::optional m_movelistReader; - std::size_t m_offset; - bool m_isEnd; +struct CompressedTrainingDataEntryParallelReader { + static constexpr std::size_t chunkSize = suggestedChunkSize; - void fetchNextChunkIfNeeded() + CompressedTrainingDataEntryParallelReader( + int concurrency, + std::vector paths, + std::ios_base::openmode om = std::ios_base::app, + bool cyclic = false, + std::function skipPredicate = nullptr) : + m_concurrency(concurrency), + m_bufferOffset(0), + m_cyclic(cyclic), + m_skipPredicate(std::move(skipPredicate)) { + m_numRunningWorkers.store(0); + std::vector sizes; // discrete distribution wants double weights + for (const auto& path : paths) { - if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size()) + auto& file = m_inputFiles.emplace_back(path, om); + + if (!file.hasNextChunk()) { - if (m_inputFile.hasNextChunk()) - { - m_chunk = m_inputFile.readNextChunk(); - m_offset = 0; - } - else - { - m_isEnd = true; - } + return; } - } - }; - struct CompressedTrainingDataEntryParallelReader - { - static constexpr std::size_t chunkSize = suggestedChunkSize; - - CompressedTrainingDataEntryParallelReader( - int concurrency, - std::vector paths, - std::ios_base::openmode om = std::ios_base::app, - bool cyclic = false, - std::function skipPredicate = nullptr - ) : - m_concurrency(concurrency), - m_bufferOffset(0), - m_cyclic(cyclic), - m_skipPredicate(std::move(skipPredicate)) - { - m_numRunningWorkers.store(0); - std::vector sizes; // discrete distribution wants double weights - for (const auto& path : paths) - { - auto& file = m_inputFiles.emplace_back(path, om); + sizes.emplace_back(static_cast(file.sizeBytes())); + } - if (!file.hasNextChunk()) - { - return; - } + m_inputFileDistribution = std::discrete_distribution<>(sizes.begin(), sizes.end()); - sizes.emplace_back(static_cast(file.sizeBytes())); - } + m_stopFlag.store(false); - m_inputFileDistribution = std::discrete_distribution<>(sizes.begin(), sizes.end()); + auto worker = [this]() { + std::vector m_chunk{}; + std::optional m_movelistReader(std::nullopt); + std::size_t m_offset(0); + std::vector m_localBuffer; + m_localBuffer.reserve(threadBufferSize); - m_stopFlag.store(false); + bool isEnd = fetchNextChunkIfNeeded(m_offset, m_chunk); - auto worker = [this]() + while (!isEnd && !m_stopFlag.load()) { - std::vector m_chunk{}; - std::optional m_movelistReader(std::nullopt); - std::size_t m_offset(0); - std::vector m_localBuffer; - m_localBuffer.reserve(threadBufferSize); - - bool isEnd = fetchNextChunkIfNeeded(m_offset, m_chunk); - - while(!isEnd && !m_stopFlag.load()) + while (m_localBuffer.size() < threadBufferSize) { - while (m_localBuffer.size() < threadBufferSize) + if (m_movelistReader.has_value()) { - if (m_movelistReader.has_value()) + const auto e = m_movelistReader->nextEntry(); + + if (!m_movelistReader->hasNext()) { - const auto e = m_movelistReader->nextEntry(); + m_offset += m_movelistReader->numReadBytes(); + m_movelistReader.reset(); - if (!m_movelistReader->hasNext()) - { - m_offset += m_movelistReader->numReadBytes(); - m_movelistReader.reset(); + isEnd = fetchNextChunkIfNeeded(m_offset, m_chunk); + } + + if (!m_skipPredicate || !m_skipPredicate(e)) + m_localBuffer.emplace_back(e); + } + else + { + PackedTrainingDataEntry packed; + std::memcpy(&packed, m_chunk.data() + m_offset, + sizeof(PackedTrainingDataEntry)); + m_offset += sizeof(PackedTrainingDataEntry); + + const std::uint16_t numPlies = + (m_chunk[m_offset] << 8) | m_chunk[m_offset + 1]; + m_offset += 2; - isEnd = fetchNextChunkIfNeeded(m_offset, m_chunk); - } + const auto e = unpackEntry(packed); - if (!m_skipPredicate || !m_skipPredicate(e)) - m_localBuffer.emplace_back(e); + if (numPlies > 0) + { + m_movelistReader.emplace( + e, reinterpret_cast(m_chunk.data()) + m_offset, + numPlies); } else { - PackedTrainingDataEntry packed; - std::memcpy(&packed, m_chunk.data() + m_offset, sizeof(PackedTrainingDataEntry)); - m_offset += sizeof(PackedTrainingDataEntry); - - const std::uint16_t numPlies = (m_chunk[m_offset] << 8) | m_chunk[m_offset + 1]; - m_offset += 2; - - const auto e = unpackEntry(packed); - - if (numPlies > 0) - { - m_movelistReader.emplace(e, reinterpret_cast(m_chunk.data()) + m_offset, numPlies); - } - else - { - isEnd = fetchNextChunkIfNeeded(m_offset, m_chunk); - } - - if (!m_skipPredicate || !m_skipPredicate(e)) - m_localBuffer.emplace_back(e); + isEnd = fetchNextChunkIfNeeded(m_offset, m_chunk); } - if (isEnd || m_stopFlag.load()) - { - break; - } + if (!m_skipPredicate || !m_skipPredicate(e)) + m_localBuffer.emplace_back(e); } - if (!m_localBuffer.empty()) + if (isEnd || m_stopFlag.load()) { - // now shuffle the local buffer - auto& prng = rng::get_thread_local_rng(); - std::shuffle(m_localBuffer.begin(), m_localBuffer.end(), prng); - - std::unique_lock lock(m_waitingBufferMutex); - m_waitingBufferEmpty.wait(lock, [this]() { return m_waitingBuffer.empty() || m_stopFlag.load(); }); - m_waitingBuffer.swap(m_localBuffer); - - lock.unlock(); - m_waitingBufferFull.notify_one(); - - m_localBuffer.clear(); + break; } } - m_numRunningWorkers.fetch_sub(1); + if (!m_localBuffer.empty()) + { + // now shuffle the local buffer + auto& prng = rng::get_thread_local_rng(); + std::shuffle(m_localBuffer.begin(), m_localBuffer.end(), prng); - m_waitingBufferFull.notify_one(); - }; + std::unique_lock lock(m_waitingBufferMutex); + m_waitingBufferEmpty.wait( + lock, [this]() { return m_waitingBuffer.empty() || m_stopFlag.load(); }); + m_waitingBuffer.swap(m_localBuffer); - for (int i = 0; i < concurrency; ++i) - { - m_workers.emplace_back(worker); + lock.unlock(); + m_waitingBufferFull.notify_one(); - // This cannot be done in the thread worker. We need - // to have a guarantee that this is incremented, but if - // we did it in the worker there's no guarantee - // that it executed. - m_numRunningWorkers.fetch_add(1); + m_localBuffer.clear(); + } } - } - - [[nodiscard]] std::optional next() - { - if (m_bufferOffset >= m_buffer.size()) - { - m_buffer.clear(); - std::unique_lock lock(m_waitingBufferMutex); - m_waitingBufferFull.wait(lock, [this]() { return !m_waitingBuffer.empty() || !m_numRunningWorkers.load(); }); - if (m_waitingBuffer.empty()) - { - return std::nullopt; - } + m_numRunningWorkers.fetch_sub(1); - m_waitingBuffer.swap(m_buffer); - m_bufferOffset = 0; + m_waitingBufferFull.notify_one(); + }; - lock.unlock(); - m_waitingBufferEmpty.notify_one(); - } + for (int i = 0; i < concurrency; ++i) + { + m_workers.emplace_back(worker); - return m_buffer[m_bufferOffset++]; + // This cannot be done in the thread worker. We need + // to have a guarantee that this is incremented, but if + // we did it in the worker there's no guarantee + // that it executed. + m_numRunningWorkers.fetch_add(1); } + } - int fill(std::vector& vec, std::size_t n) + [[nodiscard]] std::optional next() { + if (m_bufferOffset >= m_buffer.size()) { - if (m_bufferOffset >= m_buffer.size()) - { - m_buffer.clear(); + m_buffer.clear(); - std::unique_lock lock(m_waitingBufferMutex); - m_waitingBufferFull.wait(lock, [this]() { return !m_waitingBuffer.empty() || !m_numRunningWorkers.load(); }); - if (m_waitingBuffer.empty()) - { - return 0; - } + std::unique_lock lock(m_waitingBufferMutex); + m_waitingBufferFull.wait( + lock, [this]() { return !m_waitingBuffer.empty() || !m_numRunningWorkers.load(); }); + if (m_waitingBuffer.empty()) + { + return std::nullopt; + } - m_waitingBuffer.swap(m_buffer); - m_bufferOffset = 0; + m_waitingBuffer.swap(m_buffer); + m_bufferOffset = 0; - lock.unlock(); - m_waitingBufferEmpty.notify_one(); - } + lock.unlock(); + m_waitingBufferEmpty.notify_one(); + } - const std::size_t m = std::min(n, m_buffer.size() - m_bufferOffset); - vec.insert(vec.end(), m_buffer.begin() + m_bufferOffset, m_buffer.begin() + m_bufferOffset + m); + return m_buffer[m_bufferOffset++]; + } - m_bufferOffset += m; + int fill(std::vector& vec, std::size_t n) { + if (m_bufferOffset >= m_buffer.size()) + { + m_buffer.clear(); - if (m != n) + std::unique_lock lock(m_waitingBufferMutex); + m_waitingBufferFull.wait( + lock, [this]() { return !m_waitingBuffer.empty() || !m_numRunningWorkers.load(); }); + if (m_waitingBuffer.empty()) { - return m + fill(vec, n - m); - } - else - { - return m; + return 0; } + + m_waitingBuffer.swap(m_buffer); + m_bufferOffset = 0; + + lock.unlock(); + m_waitingBufferEmpty.notify_one(); } - ~CompressedTrainingDataEntryParallelReader() + const std::size_t m = std::min(n, m_buffer.size() - m_bufferOffset); + vec.insert(vec.end(), m_buffer.begin() + m_bufferOffset, + m_buffer.begin() + m_bufferOffset + m); + + m_bufferOffset += m; + + if (m != n) + { + return m + fill(vec, n - m); + } + else { - m_stopFlag.store(true); - m_waitingBufferEmpty.notify_all(); + return m; + } + } - for (auto& worker : m_workers) + ~CompressedTrainingDataEntryParallelReader() { + m_stopFlag.store(true); + m_waitingBufferEmpty.notify_all(); + + for (auto& worker : m_workers) + { + if (worker.joinable()) { - if (worker.joinable()) - { - worker.join(); - } + worker.join(); } } + } - private: - int m_concurrency; - std::vector m_inputFiles; - std::discrete_distribution<> m_inputFileDistribution; - std::atomic_int m_numRunningWorkers; - bool m_cyclic; + private: + int m_concurrency; + std::vector m_inputFiles; + std::discrete_distribution<> m_inputFileDistribution; + std::atomic_int m_numRunningWorkers; + bool m_cyclic; - static constexpr int threadBufferSize = 256 * 256 * 16; + static constexpr int threadBufferSize = 256 * 256 * 16; - std::atomic_bool m_stopFlag; - std::vector m_waitingBuffer; - std::vector m_buffer; - std::size_t m_bufferOffset; - std::mutex m_waitingBufferMutex; - std::mutex m_fileMutex; - std::condition_variable m_waitingBufferEmpty; - std::condition_variable m_waitingBufferFull; - std::function m_skipPredicate; + std::atomic_bool m_stopFlag; + std::vector m_waitingBuffer; + std::vector m_buffer; + std::size_t m_bufferOffset; + std::mutex m_waitingBufferMutex; + std::mutex m_fileMutex; + std::condition_variable m_waitingBufferEmpty; + std::condition_variable m_waitingBufferFull; + std::function m_skipPredicate; - std::vector m_workers; + std::vector m_workers; - bool fetchNextChunkIfNeeded(std::size_t& m_offset, std::vector& m_chunk) + bool fetchNextChunkIfNeeded(std::size_t& m_offset, std::vector& m_chunk) { + if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size()) { - if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size()) - { - auto& prng = rng::get_thread_local_rng(); - const std::size_t fileId = m_inputFileDistribution(prng); - auto& inputFile = m_inputFiles[fileId]; + auto& prng = rng::get_thread_local_rng(); + const std::size_t fileId = m_inputFileDistribution(prng); + auto& inputFile = m_inputFiles[fileId]; - std::unique_lock lock(m_fileMutex); + std::unique_lock lock(m_fileMutex); - if (!inputFile.hasNextChunk()) + if (!inputFile.hasNextChunk()) + { + if (m_cyclic) { - if (m_cyclic) - { - inputFile.seek_to_start(); - } - else - return true; + inputFile.seek_to_start(); } - - m_chunk = inputFile.readNextChunk(); - m_offset = 0; + else + return true; } - return false; + m_chunk = inputFile.readNextChunk(); + m_offset = 0; } - }; - inline void emitPlainEntry(std::string& buffer, const TrainingDataEntry& plain) - { - buffer += "fen "; - buffer += plain.pos.fen(); - buffer += '\n'; + return false; + } +}; - buffer += "move "; - buffer += chess::uci::moveToUci(plain.pos, plain.move); - buffer += '\n'; +inline void emitPlainEntry(std::string& buffer, const TrainingDataEntry& plain) { + buffer += "fen "; + buffer += plain.pos.fen(); + buffer += '\n'; - buffer += "score "; - buffer += std::to_string(plain.score); - buffer += '\n'; + buffer += "move "; + buffer += chess::uci::moveToUci(plain.pos, plain.move); + buffer += '\n'; - buffer += "ply "; - buffer += std::to_string(plain.ply); - buffer += '\n'; + buffer += "score "; + buffer += std::to_string(plain.score); + buffer += '\n'; - buffer += "result "; - buffer += std::to_string(plain.result); - buffer += "\ne\n"; - } + buffer += "ply "; + buffer += std::to_string(plain.ply); + buffer += '\n'; - inline void emitBinEntry(std::vector& buffer, const TrainingDataEntry& plain) - { - auto psv = trainingDataEntryToPackedSfenValue(plain); - const char* data = reinterpret_cast(&psv); - buffer.insert(buffer.end(), data, data+sizeof(psv)); - } + buffer += "result "; + buffer += std::to_string(plain.result); + buffer += "\ne\n"; +} - inline void convertPlainToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) - { - constexpr std::size_t reportEveryNPositions = 100'000; +inline void emitBinEntry(std::vector& buffer, const TrainingDataEntry& plain) { + auto psv = trainingDataEntryToPackedSfenValue(plain); + const char* data = reinterpret_cast(&psv); + buffer.insert(buffer.end(), data, data + sizeof(psv)); +} - std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; +inline void convertPlainToBinpack(std::string inputPath, + std::string outputPath, + std::ios_base::openmode om, + bool validate) { + constexpr std::size_t reportEveryNPositions = 100'000; - CompressedTrainingDataEntryWriter writer(outputPath, om); - TrainingDataEntry e; + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - std::string key; - std::string value; - std::string move; + CompressedTrainingDataEntryWriter writer(outputPath, om); + TrainingDataEntry e; - std::ifstream inputFile(inputPath); - const auto base = inputFile.tellg(); - std::size_t numProcessedPositions = 0; + std::string key; + std::string value; + std::string move; - for(;;) + std::ifstream inputFile(inputPath); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + + for (;;) + { + inputFile >> key; + if (!inputFile) { - inputFile >> key; - if (!inputFile) - { - break; - } + break; + } - if (key == "e"sv) + if (key == "e"sv) + { + e.move = chess::uci::uciToMove(e.pos, move); + if (validate && !e.isValid()) { - e.move = chess::uci::uciToMove(e.pos, move); - if (validate && !e.isValid()) - { - std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; - return; - } - - writer.addTrainingDataEntry(e); + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) + << " for position " << e.pos.fen() << '\n'; + return; + } - ++numProcessedPositions; - const auto cur = inputFile.tellg(); - if (numProcessedPositions % reportEveryNPositions == 0) - { - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } + writer.addTrainingDataEntry(e); - continue; + ++numProcessedPositions; + const auto cur = inputFile.tellg(); + if (numProcessedPositions % reportEveryNPositions == 0) + { + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; } - inputFile >> std::ws; - std::getline(inputFile, value, '\n'); - - if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str()); - if (key == "move"sv) move = value; - if (key == "score"sv) e.score = std::stoi(value); - if (key == "ply"sv) e.ply = std::stoi(value); - if (key == "result"sv) e.result = std::stoi(value); + continue; } - std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; - } + inputFile >> std::ws; + std::getline(inputFile, value, '\n'); - inline void convertBinpackToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) - { - constexpr std::size_t bufferSize = MiB; + if (key == "fen"sv) + e.pos = chess::Position::fromFen(value.c_str()); + if (key == "move"sv) + move = value; + if (key == "score"sv) + e.score = std::stoi(value); + if (key == "ply"sv) + e.ply = std::stoi(value); + if (key == "result"sv) + e.result = std::stoi(value); + } - std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; +} - CompressedTrainingDataEntryReader reader(inputPath); - std::ofstream outputFile(outputPath, om); - const auto base = outputFile.tellp(); - std::size_t numProcessedPositions = 0; - std::string buffer; - buffer.reserve(bufferSize * 2); +inline void convertBinpackToPlain(std::string inputPath, + std::string outputPath, + std::ios_base::openmode om, + bool validate) { + constexpr std::size_t bufferSize = MiB; - while(reader.hasNext()) - { - auto e = reader.next(); - if (validate && !e.isValid()) - { - std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; - return; - } + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - emitPlainEntry(buffer, e); + CompressedTrainingDataEntryReader reader(inputPath); + std::ofstream outputFile(outputPath, om); + const auto base = outputFile.tellp(); + std::size_t numProcessedPositions = 0; + std::string buffer; + buffer.reserve(bufferSize * 2); - ++numProcessedPositions; + while (reader.hasNext()) + { + auto e = reader.next(); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " + << e.pos.fen() << '\n'; + return; + } - if (buffer.size() > bufferSize) - { - outputFile << buffer; - buffer.clear(); + emitPlainEntry(buffer, e); - const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } - } + ++numProcessedPositions; - if (!buffer.empty()) + if (buffer.size() > bufferSize) { outputFile << buffer; + buffer.clear(); const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; } - - std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - - inline void convertBinToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + if (!buffer.empty()) { - constexpr std::size_t reportEveryNPositions = 100'000; + outputFile << buffer; - std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - - CompressedTrainingDataEntryWriter writer(outputPath, om); + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; + } - std::ifstream inputFile(inputPath, std::ios_base::binary); - const auto base = inputFile.tellg(); - std::size_t numProcessedPositions = 0; + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; +} - nodchip::PackedSfenValue psv; - for(;;) - { - inputFile.read(reinterpret_cast(&psv), sizeof(psv)); - if (inputFile.gcount() != 40) - { - break; - } - auto e = packedSfenValueToTrainingDataEntry(psv); - if (validate && !e.isValid()) - { - std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; - std::cerr << static_cast(e.move.type) << '\n'; - return; - } +inline void convertBinToBinpack(std::string inputPath, + std::string outputPath, + std::ios_base::openmode om, + bool validate) { + constexpr std::size_t reportEveryNPositions = 100'000; - writer.addTrainingDataEntry(e); + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - ++numProcessedPositions; - const auto cur = inputFile.tellg(); - if (numProcessedPositions % reportEveryNPositions == 0) - { - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } - } + CompressedTrainingDataEntryWriter writer(outputPath, om); - std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; - } + std::ifstream inputFile(inputPath, std::ios_base::binary); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; - inline void convertBinpackToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + nodchip::PackedSfenValue psv; + for (;;) { - constexpr std::size_t bufferSize = MiB; + inputFile.read(reinterpret_cast(&psv), sizeof(psv)); + if (inputFile.gcount() != 40) + { + break; + } - std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + auto e = packedSfenValueToTrainingDataEntry(psv); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " + << e.pos.fen() << '\n'; + std::cerr << static_cast(e.move.type) << '\n'; + return; + } - CompressedTrainingDataEntryReader reader(inputPath); - std::ofstream outputFile(outputPath, std::ios_base::binary | om); - const auto base = outputFile.tellp(); - std::size_t numProcessedPositions = 0; - std::vector buffer; - buffer.reserve(bufferSize * 2); + writer.addTrainingDataEntry(e); - while(reader.hasNext()) + ++numProcessedPositions; + const auto cur = inputFile.tellg(); + if (numProcessedPositions % reportEveryNPositions == 0) { - auto e = reader.next(); - if (validate && !e.isValid()) - { - std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; - return; - } + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; + } + } - emitBinEntry(buffer, e); + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; +} - ++numProcessedPositions; +inline void convertBinpackToBin(std::string inputPath, + std::string outputPath, + std::ios_base::openmode om, + bool validate) { + constexpr std::size_t bufferSize = MiB; - if (buffer.size() > bufferSize) - { - outputFile.write(buffer.data(), buffer.size()); - buffer.clear(); + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } + CompressedTrainingDataEntryReader reader(inputPath); + std::ofstream outputFile(outputPath, std::ios_base::binary | om); + const auto base = outputFile.tellp(); + std::size_t numProcessedPositions = 0; + std::vector buffer; + buffer.reserve(bufferSize * 2); + + while (reader.hasNext()) + { + auto e = reader.next(); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " + << e.pos.fen() << '\n'; + return; } - if (!buffer.empty()) + emitBinEntry(buffer, e); + + ++numProcessedPositions; + + if (buffer.size() > bufferSize) { outputFile.write(buffer.data(), buffer.size()); + buffer.clear(); const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; } - - std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - inline void convertBinToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + if (!buffer.empty()) { - constexpr std::size_t bufferSize = MiB; + outputFile.write(buffer.data(), buffer.size()); - std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - - std::ifstream inputFile(inputPath, std::ios_base::binary); - const auto base = inputFile.tellg(); - std::size_t numProcessedPositions = 0; + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; + } - std::ofstream outputFile(outputPath, om); - std::string buffer; - buffer.reserve(bufferSize * 2); + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; +} - nodchip::PackedSfenValue psv; - for(;;) - { - inputFile.read(reinterpret_cast(&psv), sizeof(psv)); - if (inputFile.gcount() != 40) - { - break; - } +inline void convertBinToPlain(std::string inputPath, + std::string outputPath, + std::ios_base::openmode om, + bool validate) { + constexpr std::size_t bufferSize = MiB; - auto e = packedSfenValueToTrainingDataEntry(psv); - if (validate && !e.isValid()) - { - std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; - return; - } + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - emitPlainEntry(buffer, e); + std::ifstream inputFile(inputPath, std::ios_base::binary); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; - ++numProcessedPositions; + std::ofstream outputFile(outputPath, om); + std::string buffer; + buffer.reserve(bufferSize * 2); - if (buffer.size() > bufferSize) - { - outputFile << buffer; - buffer.clear(); + nodchip::PackedSfenValue psv; + for (;;) + { + inputFile.read(reinterpret_cast(&psv), sizeof(psv)); + if (inputFile.gcount() != 40) + { + break; + } - const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } + auto e = packedSfenValueToTrainingDataEntry(psv); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " + << e.pos.fen() << '\n'; + return; } - if (!buffer.empty()) + emitPlainEntry(buffer, e); + + ++numProcessedPositions; + + if (buffer.size() > bufferSize) { outputFile << buffer; + buffer.clear(); const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; } - - std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } - inline void convertPlainToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + if (!buffer.empty()) { - constexpr std::size_t bufferSize = MiB; + outputFile << buffer; + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; +} + +inline void convertPlainToBin(std::string inputPath, + std::string outputPath, + std::ios_base::openmode om, + bool validate) { + constexpr std::size_t bufferSize = MiB; - std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; - std::ofstream outputFile(outputPath, std::ios_base::binary | om); - std::vector buffer; - buffer.reserve(bufferSize * 2); + std::ofstream outputFile(outputPath, std::ios_base::binary | om); + std::vector buffer; + buffer.reserve(bufferSize * 2); - TrainingDataEntry e; + TrainingDataEntry e; - std::string key; - std::string value; - std::string move; + std::string key; + std::string value; + std::string move; - std::ifstream inputFile(inputPath); - const auto base = inputFile.tellg(); - std::size_t numProcessedPositions = 0; + std::ifstream inputFile(inputPath); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; - for(;;) + for (;;) + { + inputFile >> key; + if (!inputFile) { - inputFile >> key; - if (!inputFile) - { - break; - } + break; + } - if (key == "e"sv) + if (key == "e"sv) + { + e.move = chess::uci::uciToMove(e.pos, move); + if (validate && !e.isValid()) { - e.move = chess::uci::uciToMove(e.pos, move); - if (validate && !e.isValid()) - { - std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; - return; - } - - emitBinEntry(buffer, e); + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) + << " for position " << e.pos.fen() << '\n'; + return; + } - ++numProcessedPositions; + emitBinEntry(buffer, e); - if (buffer.size() > bufferSize) - { - outputFile.write(buffer.data(), buffer.size()); - buffer.clear(); + ++numProcessedPositions; - const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } + if (buffer.size() > bufferSize) + { + outputFile.write(buffer.data(), buffer.size()); + buffer.clear(); - continue; + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; } - inputFile >> std::ws; - std::getline(inputFile, value, '\n'); - - if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str()); - if (key == "move"sv) move = value; - if (key == "score"sv) e.score = std::stoi(value); - if (key == "ply"sv) e.ply = std::stoi(value); - if (key == "result"sv) e.result = std::stoi(value); + continue; } - if (!buffer.empty()) - { - outputFile.write(buffer.data(), buffer.size()); + inputFile >> std::ws; + std::getline(inputFile, value, '\n'); - const auto cur = outputFile.tellp(); - std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; - } + if (key == "fen"sv) + e.pos = chess::Position::fromFen(value.c_str()); + if (key == "move"sv) + move = value; + if (key == "score"sv) + e.score = std::stoi(value); + if (key == "ply"sv) + e.ply = std::stoi(value); + if (key == "result"sv) + e.result = std::stoi(value); + } - std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + if (!buffer.empty()) + { + outputFile.write(buffer.data(), buffer.size()); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions + << " positions.\n"; } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; +} } diff --git a/lib/nnue_training_data_stream.h b/lib/nnue_training_data_stream.h index e41a4121..2005f52f 100644 --- a/lib/nnue_training_data_stream.h +++ b/lib/nnue_training_data_stream.h @@ -10,247 +10,236 @@ namespace training_data { - using namespace binpack; +using namespace binpack; - static bool ends_with(const std::string& lhs, const std::string& end) - { - if (end.size() > lhs.size()) return false; +static bool ends_with(const std::string& lhs, const std::string& end) { + if (end.size() > lhs.size()) + return false; - return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); - } + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); +} + +static bool has_extension(const std::string& filename, const std::string& extension) { + return ends_with(filename, "." + extension); +} - static bool has_extension(const std::string& filename, const std::string& extension) +static std::string filename_with_extension(const std::string& filename, const std::string& ext) { + if (ends_with(filename, ext)) { - return ends_with(filename, "." + extension); + return filename; } - - static std::string filename_with_extension(const std::string& filename, const std::string& ext) + else { - if (ends_with(filename, ext)) - { - return filename; - } - else - { - return filename + "." + ext; - } + return filename + "." + ext; } +} - struct BasicSfenInputStream - { - virtual std::optional next() = 0; - virtual void fill(std::vector& vec, std::size_t n) +struct BasicSfenInputStream { + virtual std::optional next() = 0; + virtual void fill(std::vector& vec, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) { - for (std::size_t i = 0; i < n; ++i) + auto v = this->next(); + if (!v.has_value()) { - auto v = this->next(); - if (!v.has_value()) - { - break; - } - vec.emplace_back(*v); + break; } + vec.emplace_back(*v); } + } - virtual bool eof() const = 0; - virtual ~BasicSfenInputStream() {} - }; - - struct BinSfenInputStream : BasicSfenInputStream - { - static constexpr auto openmode = std::ios::in | std::ios::binary; - static inline const std::string extension = "bin"; - - BinSfenInputStream(std::string filename, bool cyclic, std::function skipPredicate) : - m_stream(filename, openmode), - m_filename(filename), - m_eof(!m_stream), - m_cyclic(cyclic), - m_skipPredicate(std::move(skipPredicate)) - { - } - - std::optional next() override + virtual bool eof() const = 0; + virtual ~BasicSfenInputStream() {} +}; + +struct BinSfenInputStream: BasicSfenInputStream { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "bin"; + + BinSfenInputStream(std::string filename, + bool cyclic, + std::function skipPredicate) : + m_stream(filename, openmode), + m_filename(filename), + m_eof(!m_stream), + m_cyclic(cyclic), + m_skipPredicate(std::move(skipPredicate)) {} + + std::optional next() override { + nodchip::PackedSfenValue e; + bool reopenedFileOnce = false; + for (;;) { - nodchip::PackedSfenValue e; - bool reopenedFileOnce = false; - for(;;) + if (m_stream.read(reinterpret_cast(&e), sizeof(nodchip::PackedSfenValue))) { - if(m_stream.read(reinterpret_cast(&e), sizeof(nodchip::PackedSfenValue))) - { - auto entry = packedSfenValueToTrainingDataEntry(e); - if (!m_skipPredicate || !m_skipPredicate(entry)) - return entry; - } - else - { - if (m_cyclic) - { - if (reopenedFileOnce) - return std::nullopt; - - m_stream = std::fstream(m_filename, openmode); - reopenedFileOnce = true; - if (!m_stream) - return std::nullopt; - - continue; - } - - m_eof = true; - return std::nullopt; - } + auto entry = packedSfenValueToTrainingDataEntry(e); + if (!m_skipPredicate || !m_skipPredicate(entry)) + return entry; } - } - - bool eof() const override - { - return m_eof; - } - - ~BinSfenInputStream() override {} - - private: - std::fstream m_stream; - std::string m_filename; - bool m_eof; - bool m_cyclic; - std::function m_skipPredicate; - }; - - struct BinpackSfenInputStream : BasicSfenInputStream - { - static constexpr auto openmode = std::ios::in | std::ios::binary; - static inline const std::string extension = "binpack"; - - BinpackSfenInputStream(std::string filename, bool cyclic, std::function skipPredicate) : - m_stream(std::make_unique(filename, openmode)), - m_filename(filename), - m_eof(!m_stream->hasNext()), - m_cyclic(cyclic), - m_skipPredicate(std::move(skipPredicate)) - { - } - - std::optional next() override - { - bool reopenedFileOnce = false; - for(;;) + else { - if (!m_stream->hasNext()) + if (m_cyclic) { - if (m_cyclic) - { - if (reopenedFileOnce) - return std::nullopt; + if (reopenedFileOnce) + return std::nullopt; - m_stream = std::make_unique(m_filename, openmode); - reopenedFileOnce = true; + m_stream = std::fstream(m_filename, openmode); + reopenedFileOnce = true; + if (!m_stream) + return std::nullopt; - if (!m_stream->hasNext()) - return std::nullopt; - - continue; - } - - m_eof = true; - return std::nullopt; + continue; } - auto e = m_stream->next(); - if (!m_skipPredicate || !m_skipPredicate(e)) - return e; + m_eof = true; + return std::nullopt; } } + } - bool eof() const override + bool eof() const override { return m_eof; } + + ~BinSfenInputStream() override {} + + private: + std::fstream m_stream; + std::string m_filename; + bool m_eof; + bool m_cyclic; + std::function m_skipPredicate; +}; + +struct BinpackSfenInputStream: BasicSfenInputStream { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "binpack"; + + BinpackSfenInputStream(std::string filename, + bool cyclic, + std::function skipPredicate) : + m_stream(std::make_unique(filename, openmode)), + m_filename(filename), + m_eof(!m_stream->hasNext()), + m_cyclic(cyclic), + m_skipPredicate(std::move(skipPredicate)) {} + + std::optional next() override { + bool reopenedFileOnce = false; + for (;;) { - return m_eof; - } + if (!m_stream->hasNext()) + { + if (m_cyclic) + { + if (reopenedFileOnce) + return std::nullopt; - ~BinpackSfenInputStream() override {} + m_stream = std::make_unique( + m_filename, openmode); + reopenedFileOnce = true; - private: - std::unique_ptr m_stream; - std::string m_filename; - bool m_eof; - bool m_cyclic; - std::function m_skipPredicate; - }; + if (!m_stream->hasNext()) + return std::nullopt; - struct BinpackSfenInputParallelStream : BasicSfenInputStream - { - static constexpr auto openmode = std::ios::in | std::ios::binary; - static inline const std::string extension = "binpack"; - - BinpackSfenInputParallelStream(int concurrency, const std::vector& filenames, bool cyclic, std::function skipPredicate) : - m_stream(std::make_unique(concurrency, filenames, openmode, cyclic, skipPredicate)), - m_filenames(filenames), - m_concurrency(concurrency), - m_eof(false), - m_cyclic(cyclic), - m_skipPredicate(skipPredicate) - { - } + continue; + } - std::optional next() override - { - // filtering is done a layer deeper. - auto v = m_stream->next(); - if (!v.has_value()) - { m_eof = true; return std::nullopt; } - return v; + auto e = m_stream->next(); + if (!m_skipPredicate || !m_skipPredicate(e)) + return e; } + } - void fill(std::vector& v, std::size_t n) override + bool eof() const override { return m_eof; } + + ~BinpackSfenInputStream() override {} + + private: + std::unique_ptr m_stream; + std::string m_filename; + bool m_eof; + bool m_cyclic; + std::function m_skipPredicate; +}; + +struct BinpackSfenInputParallelStream: BasicSfenInputStream { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "binpack"; + + BinpackSfenInputParallelStream(int concurrency, + const std::vector& filenames, + bool cyclic, + std::function skipPredicate) : + m_stream(std::make_unique( + concurrency, filenames, openmode, cyclic, skipPredicate)), + m_filenames(filenames), + m_concurrency(concurrency), + m_eof(false), + m_cyclic(cyclic), + m_skipPredicate(skipPredicate) {} + + std::optional next() override { + // filtering is done a layer deeper. + auto v = m_stream->next(); + if (!v.has_value()) { - auto k = m_stream->fill(v, n); - if (n != k) - { - m_eof = true; - } + m_eof = true; + return std::nullopt; } - bool eof() const override + return v; + } + + void fill(std::vector& v, std::size_t n) override { + auto k = m_stream->fill(v, n); + if (n != k) { - return m_eof; + m_eof = true; } - - ~BinpackSfenInputParallelStream() override {} - - private: - std::unique_ptr m_stream; - std::vector m_filenames; - int m_concurrency; - bool m_eof; - bool m_cyclic; - std::function m_skipPredicate; - }; - - inline std::unique_ptr open_sfen_input_file(const std::string& filename, bool cyclic, std::function skipPredicate = nullptr) - { - if (has_extension(filename, BinSfenInputStream::extension)) - return std::make_unique(filename, cyclic, std::move(skipPredicate)); - else if (has_extension(filename, BinpackSfenInputStream::extension)) - return std::make_unique(filename, cyclic, std::move(skipPredicate)); - - return nullptr; } - inline std::unique_ptr open_sfen_input_file_parallel(int concurrency, const std::vector& filenames, bool cyclic, std::function skipPredicate = nullptr) - { - // TODO (low priority): optimize and parallelize .bin reading. - if (has_extension(filenames[0], BinSfenInputStream::extension)) - return std::make_unique(filenames[0], cyclic, std::move(skipPredicate)); - else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension)) - return std::make_unique(concurrency, filenames, cyclic, std::move(skipPredicate)); + bool eof() const override { return m_eof; } + + ~BinpackSfenInputParallelStream() override {} + + private: + std::unique_ptr m_stream; + std::vector m_filenames; + int m_concurrency; + bool m_eof; + bool m_cyclic; + std::function m_skipPredicate; +}; + +inline std::unique_ptr +open_sfen_input_file(const std::string& filename, + bool cyclic, + std::function skipPredicate = nullptr) { + if (has_extension(filename, BinSfenInputStream::extension)) + return std::make_unique(filename, cyclic, std::move(skipPredicate)); + else if (has_extension(filename, BinpackSfenInputStream::extension)) + return std::make_unique(filename, cyclic, std::move(skipPredicate)); + + return nullptr; +} - return nullptr; - } +inline std::unique_ptr open_sfen_input_file_parallel( + int concurrency, + const std::vector& filenames, + bool cyclic, + std::function skipPredicate = nullptr) { + // TODO (low priority): optimize and parallelize .bin reading. + if (has_extension(filenames[0], BinSfenInputStream::extension)) + return std::make_unique(filenames[0], cyclic, std::move(skipPredicate)); + else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension)) + return std::make_unique(concurrency, filenames, cyclic, + std::move(skipPredicate)); + + return nullptr; +} } #endif diff --git a/lib/rng.h b/lib/rng.h index cca0439d..8e7fe7cb 100644 --- a/lib/rng.h +++ b/lib/rng.h @@ -2,11 +2,9 @@ #include -namespace rng -{ - inline auto& get_thread_local_rng() - { - static thread_local std::mt19937_64 s_rng(std::random_device{}()); - return s_rng; - } +namespace rng { +inline auto& get_thread_local_rng() { + static thread_local std::mt19937_64 s_rng(std::random_device{}()); + return s_rng; +} } diff --git a/model.py b/model.py index acaa98c0..34279e37 100644 --- a/model.py +++ b/model.py @@ -11,337 +11,445 @@ L2 = 15 L3 = 32 + def coalesce_ft_weights(model, layer): - weight = layer.weight.data - indices = model.feature_set.get_virtual_to_real_features_gather_indices() - weight_coalesced = weight.new_zeros((model.feature_set.num_real_features, weight.shape[1])) - for i_real, is_virtual in enumerate(indices): - weight_coalesced[i_real, :] = sum(weight[i_virtual, :] for i_virtual in is_virtual) - return weight_coalesced + weight = layer.weight.data + indices = model.feature_set.get_virtual_to_real_features_gather_indices() + weight_coalesced = weight.new_zeros( + (model.feature_set.num_real_features, weight.shape[1]) + ) + for i_real, is_virtual in enumerate(indices): + weight_coalesced[i_real, :] = sum( + weight[i_virtual, :] for i_virtual in is_virtual + ) + return weight_coalesced + def get_parameters(layers): - return [p for layer in layers for p in layer.parameters()] + return [p for layer in layers for p in layer.parameters()] + class LayerStacks(nn.Module): - def __init__(self, count): - super(LayerStacks, self).__init__() - - self.count = count - self.l1 = nn.Linear(2 * L1 // 2, (L2 + 1) * count) - # Factorizer only for the first layer because later - # there's a non-linearity and factorization breaks. - # This is by design. The weights in the further layers should be - # able to diverge a lot. - self.l1_fact = nn.Linear(2 * L1 // 2, L2 + 1, bias=True) - self.l2 = nn.Linear(L2*2, L3 * count) - self.output = nn.Linear(L3, 1 * count) - - # Cached helper tensor for choosing outputs by bucket indices. - # Initialized lazily in forward. - self.idx_offset = None - - self._init_layers() - - def _init_layers(self): - l1_weight = self.l1.weight - l1_bias = self.l1.bias - l1_fact_weight = self.l1_fact.weight - l1_fact_bias = self.l1_fact.bias - l2_weight = self.l2.weight - l2_bias = self.l2.bias - output_weight = self.output.weight - output_bias = self.output.bias - with torch.no_grad(): - l1_fact_weight.fill_(0.0) - l1_fact_bias.fill_(0.0) - output_bias.fill_(0.0) - - for i in range(1, self.count): - # Force all layer stacks to be initialized in the same way. - l1_weight[i*(L2+1):(i+1)*(L2+1), :] = l1_weight[0:(L2+1), :] - l1_bias[i*(L2+1):(i+1)*(L2+1)] = l1_bias[0:(L2+1)] - l2_weight[i*L3:(i+1)*L3, :] = l2_weight[0:L3, :] - l2_bias[i*L3:(i+1)*L3] = l2_bias[0:L3] - output_weight[i:i+1, :] = output_weight[0:1, :] - - self.l1.weight = nn.Parameter(l1_weight) - self.l1.bias = nn.Parameter(l1_bias) - self.l1_fact.weight = nn.Parameter(l1_fact_weight) - self.l1_fact.bias = nn.Parameter(l1_fact_bias) - self.l2.weight = nn.Parameter(l2_weight) - self.l2.bias = nn.Parameter(l2_bias) - self.output.weight = nn.Parameter(output_weight) - self.output.bias = nn.Parameter(output_bias) - - def forward(self, x, ls_indices): - # Precompute and cache the offset for gathers - if self.idx_offset == None or self.idx_offset.shape[0] != x.shape[0]: - self.idx_offset = torch.arange(0,x.shape[0]*self.count,self.count, device=ls_indices.device) - - indices = ls_indices.flatten() + self.idx_offset - - l1s_ = self.l1(x).reshape((-1, self.count, L2 + 1)) - l1f_ = self.l1_fact(x) - # https://stackoverflow.com/questions/55881002/pytorch-tensor-indexing-how-to-gather-rows-by-tensor-containing-indices - # basically we present it as a list of individual results and pick not only based on - # the ls index but also based on batch (they are combined into one index) - l1c_ = l1s_.view(-1, L2 + 1)[indices] - l1c_, l1c_out = l1c_.split(L2, dim=1) - l1f_, l1f_out = l1f_.split(L2, dim=1) - l1x_ = l1c_ + l1f_ - # multiply sqr crelu result by (127/128) to match quantized version - l1x_ = torch.clamp(torch.cat([torch.pow(l1x_, 2.0) * (127/128), l1x_], dim=1), 0.0, 1.0) - - l2s_ = self.l2(l1x_).reshape((-1, self.count, L3)) - l2c_ = l2s_.view(-1, L3)[indices] - l2x_ = torch.clamp(l2c_, 0.0, 1.0) - - l3s_ = self.output(l2x_).reshape((-1, self.count, 1)) - l3c_ = l3s_.view(-1, 1)[indices] - l3x_ = l3c_ + l1f_out + l1c_out - - return l3x_ - - def get_coalesced_layer_stacks(self): - # During training the buckets are represented by a single, wider, layer. - # This representation needs to be transformed into individual layers - # for the serializer, because the buckets are interpreted as separate layers. - for i in range(self.count): - with torch.no_grad(): - l1 = nn.Linear(2*L1 // 2, L2+1) - l2 = nn.Linear(L2*2, L3) - output = nn.Linear(L3, 1) - l1.weight.data = self.l1.weight[i*(L2+1):(i+1)*(L2+1), :] + self.l1_fact.weight.data - l1.bias.data = self.l1.bias[i*(L2+1):(i+1)*(L2+1)] + self.l1_fact.bias.data - l2.weight.data = self.l2.weight[i*L3:(i+1)*L3, :] - l2.bias.data = self.l2.bias[i*L3:(i+1)*L3] - output.weight.data = self.output.weight[i:(i+1), :] - output.bias.data = self.output.bias[i:(i+1)] - yield l1, l2, output + def __init__(self, count): + super(LayerStacks, self).__init__() + + self.count = count + self.l1 = nn.Linear(2 * L1 // 2, (L2 + 1) * count) + # Factorizer only for the first layer because later + # there's a non-linearity and factorization breaks. + # This is by design. The weights in the further layers should be + # able to diverge a lot. + self.l1_fact = nn.Linear(2 * L1 // 2, L2 + 1, bias=True) + self.l2 = nn.Linear(L2 * 2, L3 * count) + self.output = nn.Linear(L3, 1 * count) + + # Cached helper tensor for choosing outputs by bucket indices. + # Initialized lazily in forward. + self.idx_offset = None + + self._init_layers() + + def _init_layers(self): + l1_weight = self.l1.weight + l1_bias = self.l1.bias + l1_fact_weight = self.l1_fact.weight + l1_fact_bias = self.l1_fact.bias + l2_weight = self.l2.weight + l2_bias = self.l2.bias + output_weight = self.output.weight + output_bias = self.output.bias + with torch.no_grad(): + l1_fact_weight.fill_(0.0) + l1_fact_bias.fill_(0.0) + output_bias.fill_(0.0) + + for i in range(1, self.count): + # Force all layer stacks to be initialized in the same way. + l1_weight[i * (L2 + 1) : (i + 1) * (L2 + 1), :] = l1_weight[ + 0 : (L2 + 1), : + ] + l1_bias[i * (L2 + 1) : (i + 1) * (L2 + 1)] = l1_bias[0 : (L2 + 1)] + l2_weight[i * L3 : (i + 1) * L3, :] = l2_weight[0:L3, :] + l2_bias[i * L3 : (i + 1) * L3] = l2_bias[0:L3] + output_weight[i : i + 1, :] = output_weight[0:1, :] + + self.l1.weight = nn.Parameter(l1_weight) + self.l1.bias = nn.Parameter(l1_bias) + self.l1_fact.weight = nn.Parameter(l1_fact_weight) + self.l1_fact.bias = nn.Parameter(l1_fact_bias) + self.l2.weight = nn.Parameter(l2_weight) + self.l2.bias = nn.Parameter(l2_bias) + self.output.weight = nn.Parameter(output_weight) + self.output.bias = nn.Parameter(output_bias) + + def forward(self, x, ls_indices): + # Precompute and cache the offset for gathers + if self.idx_offset == None or self.idx_offset.shape[0] != x.shape[0]: + self.idx_offset = torch.arange( + 0, x.shape[0] * self.count, self.count, device=ls_indices.device + ) + + indices = ls_indices.flatten() + self.idx_offset + + l1s_ = self.l1(x).reshape((-1, self.count, L2 + 1)) + l1f_ = self.l1_fact(x) + # https://stackoverflow.com/questions/55881002/pytorch-tensor-indexing-how-to-gather-rows-by-tensor-containing-indices + # basically we present it as a list of individual results and pick not only based on + # the ls index but also based on batch (they are combined into one index) + l1c_ = l1s_.view(-1, L2 + 1)[indices] + l1c_, l1c_out = l1c_.split(L2, dim=1) + l1f_, l1f_out = l1f_.split(L2, dim=1) + l1x_ = l1c_ + l1f_ + # multiply sqr crelu result by (127/128) to match quantized version + l1x_ = torch.clamp( + torch.cat([torch.pow(l1x_, 2.0) * (127 / 128), l1x_], dim=1), 0.0, 1.0 + ) + + l2s_ = self.l2(l1x_).reshape((-1, self.count, L3)) + l2c_ = l2s_.view(-1, L3)[indices] + l2x_ = torch.clamp(l2c_, 0.0, 1.0) + + l3s_ = self.output(l2x_).reshape((-1, self.count, 1)) + l3c_ = l3s_.view(-1, 1)[indices] + l3x_ = l3c_ + l1f_out + l1c_out + + return l3x_ + + def get_coalesced_layer_stacks(self): + # During training the buckets are represented by a single, wider, layer. + # This representation needs to be transformed into individual layers + # for the serializer, because the buckets are interpreted as separate layers. + for i in range(self.count): + with torch.no_grad(): + l1 = nn.Linear(2 * L1 // 2, L2 + 1) + l2 = nn.Linear(L2 * 2, L3) + output = nn.Linear(L3, 1) + l1.weight.data = ( + self.l1.weight[i * (L2 + 1) : (i + 1) * (L2 + 1), :] + + self.l1_fact.weight.data + ) + l1.bias.data = ( + self.l1.bias[i * (L2 + 1) : (i + 1) * (L2 + 1)] + + self.l1_fact.bias.data + ) + l2.weight.data = self.l2.weight[i * L3 : (i + 1) * L3, :] + l2.bias.data = self.l2.bias[i * L3 : (i + 1) * L3] + output.weight.data = self.output.weight[i : (i + 1), :] + output.bias.data = self.output.bias[i : (i + 1)] + yield l1, l2, output class NNUE(pl.LightningModule): - """ - feature_set - an instance of FeatureSet defining the input features - - lambda_ = 0.0 - purely based on game results - 0.0 < lambda_ < 1.0 - interpolated score and result - lambda_ = 1.0 - purely based on search scores - - gamma - the multiplicative factor applied to the learning rate after each epoch - - lr - the initial learning rate - """ - def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800, gamma=0.992, lr=8.75e-4, param_index=0, num_psqt_buckets=8, num_ls_buckets=8): - super(NNUE, self).__init__() - self.num_psqt_buckets = num_psqt_buckets - self.num_ls_buckets = num_ls_buckets - self.input = DoubleFeatureTransformerSlice(feature_set.num_features, L1 + self.num_psqt_buckets) - self.feature_set = feature_set - self.layer_stacks = LayerStacks(self.num_ls_buckets) - self.start_lambda = start_lambda - self.end_lambda = end_lambda - self.max_epoch = max_epoch - self.gamma = gamma - self.lr = lr - self.param_index = param_index - - self.nnue2score = 600.0 - self.weight_scale_hidden = 64.0 - self.weight_scale_out = 16.0 - self.quantized_one = 127.0 - - max_hidden_weight = self.quantized_one / self.weight_scale_hidden - max_out_weight = (self.quantized_one * self.quantized_one) / (self.nnue2score * self.weight_scale_out) - self.weight_clipping = [ - {'params' : [self.layer_stacks.l1.weight], 'min_weight' : -max_hidden_weight, 'max_weight' : max_hidden_weight, 'virtual_params' : self.layer_stacks.l1_fact.weight }, - {'params' : [self.layer_stacks.l2.weight], 'min_weight' : -max_hidden_weight, 'max_weight' : max_hidden_weight }, - {'params' : [self.layer_stacks.output.weight], 'min_weight' : -max_out_weight, 'max_weight' : max_out_weight }, - ] - - self._init_layers() - - ''' + """ + feature_set - an instance of FeatureSet defining the input features + + lambda_ = 0.0 - purely based on game results + 0.0 < lambda_ < 1.0 - interpolated score and result + lambda_ = 1.0 - purely based on search scores + + gamma - the multiplicative factor applied to the learning rate after each epoch + + lr - the initial learning rate + """ + + def __init__( + self, + feature_set, + start_lambda=1.0, + end_lambda=1.0, + max_epoch=800, + gamma=0.992, + lr=8.75e-4, + param_index=0, + num_psqt_buckets=8, + num_ls_buckets=8, + ): + super(NNUE, self).__init__() + self.num_psqt_buckets = num_psqt_buckets + self.num_ls_buckets = num_ls_buckets + self.input = DoubleFeatureTransformerSlice( + feature_set.num_features, L1 + self.num_psqt_buckets + ) + self.feature_set = feature_set + self.layer_stacks = LayerStacks(self.num_ls_buckets) + self.start_lambda = start_lambda + self.end_lambda = end_lambda + self.max_epoch = max_epoch + self.gamma = gamma + self.lr = lr + self.param_index = param_index + + self.nnue2score = 600.0 + self.weight_scale_hidden = 64.0 + self.weight_scale_out = 16.0 + self.quantized_one = 127.0 + + max_hidden_weight = self.quantized_one / self.weight_scale_hidden + max_out_weight = (self.quantized_one * self.quantized_one) / ( + self.nnue2score * self.weight_scale_out + ) + self.weight_clipping = [ + { + "params": [self.layer_stacks.l1.weight], + "min_weight": -max_hidden_weight, + "max_weight": max_hidden_weight, + "virtual_params": self.layer_stacks.l1_fact.weight, + }, + { + "params": [self.layer_stacks.l2.weight], + "min_weight": -max_hidden_weight, + "max_weight": max_hidden_weight, + }, + { + "params": [self.layer_stacks.output.weight], + "min_weight": -max_out_weight, + "max_weight": max_out_weight, + }, + ] + + self._init_layers() + + """ We zero all virtual feature weights because there's not need for them to be initialized; they only aid the training of correlated features. - ''' - def _zero_virtual_feature_weights(self): - weights = self.input.weight - with torch.no_grad(): - for a, b in self.feature_set.get_virtual_feature_ranges(): - weights[a:b, :] = 0.0 - self.input.weight = nn.Parameter(weights) - - def _init_layers(self): - self._zero_virtual_feature_weights() - self._init_psqt() - - def _init_psqt(self): - input_weights = self.input.weight - input_bias = self.input.bias - # 1.0 / kPonanzaConstant - scale = 1 / self.nnue2score - with torch.no_grad(): - initial_values = self.feature_set.get_initial_psqt_features() - assert len(initial_values) == self.feature_set.num_features - for i in range(self.num_psqt_buckets): - input_weights[:, L1 + i] = torch.FloatTensor(initial_values) * scale - # Bias doesn't matter because it cancels out during - # inference during perspective averaging. We set it to 0 - # just for the sake of it. It might still diverge away from 0 - # due to gradient imprecision but it won't change anything. - input_bias[L1 + i] = 0.0 - self.input.weight = nn.Parameter(input_weights) - self.input.bias = nn.Parameter(input_bias) - - ''' + """ + + def _zero_virtual_feature_weights(self): + weights = self.input.weight + with torch.no_grad(): + for a, b in self.feature_set.get_virtual_feature_ranges(): + weights[a:b, :] = 0.0 + self.input.weight = nn.Parameter(weights) + + def _init_layers(self): + self._zero_virtual_feature_weights() + self._init_psqt() + + def _init_psqt(self): + input_weights = self.input.weight + input_bias = self.input.bias + # 1.0 / kPonanzaConstant + scale = 1 / self.nnue2score + with torch.no_grad(): + initial_values = self.feature_set.get_initial_psqt_features() + assert len(initial_values) == self.feature_set.num_features + for i in range(self.num_psqt_buckets): + input_weights[:, L1 + i] = torch.FloatTensor(initial_values) * scale + # Bias doesn't matter because it cancels out during + # inference during perspective averaging. We set it to 0 + # just for the sake of it. It might still diverge away from 0 + # due to gradient imprecision but it won't change anything. + input_bias[L1 + i] = 0.0 + self.input.weight = nn.Parameter(input_weights) + self.input.bias = nn.Parameter(input_bias) + + """ Clips the weights of the model based on the min/max values allowed by the quantization scheme. - ''' - def _clip_weights(self): - for group in self.weight_clipping: - for p in group['params']: - if 'min_weight' in group or 'max_weight' in group: - p_data_fp32 = p.data - min_weight = group['min_weight'] - max_weight = group['max_weight'] - if 'virtual_params' in group: - virtual_params = group['virtual_params'] - xs = p_data_fp32.shape[0] // virtual_params.shape[0] - ys = p_data_fp32.shape[1] // virtual_params.shape[1] - expanded_virtual_layer = virtual_params.repeat(xs, ys) - if min_weight is not None: - min_weight_t = p_data_fp32.new_full(p_data_fp32.shape, min_weight) - expanded_virtual_layer - p_data_fp32 = torch.max(p_data_fp32, min_weight_t) - if max_weight is not None: - max_weight_t = p_data_fp32.new_full(p_data_fp32.shape, max_weight) - expanded_virtual_layer - p_data_fp32 = torch.min(p_data_fp32, max_weight_t) - else: - if min_weight is not None and max_weight is not None: - p_data_fp32.clamp_(min_weight, max_weight) - else: - raise Exception('Not supported.') - p.data.copy_(p_data_fp32) - - ''' + """ + + def _clip_weights(self): + for group in self.weight_clipping: + for p in group["params"]: + if "min_weight" in group or "max_weight" in group: + p_data_fp32 = p.data + min_weight = group["min_weight"] + max_weight = group["max_weight"] + if "virtual_params" in group: + virtual_params = group["virtual_params"] + xs = p_data_fp32.shape[0] // virtual_params.shape[0] + ys = p_data_fp32.shape[1] // virtual_params.shape[1] + expanded_virtual_layer = virtual_params.repeat(xs, ys) + if min_weight is not None: + min_weight_t = ( + p_data_fp32.new_full(p_data_fp32.shape, min_weight) + - expanded_virtual_layer + ) + p_data_fp32 = torch.max(p_data_fp32, min_weight_t) + if max_weight is not None: + max_weight_t = ( + p_data_fp32.new_full(p_data_fp32.shape, max_weight) + - expanded_virtual_layer + ) + p_data_fp32 = torch.min(p_data_fp32, max_weight_t) + else: + if min_weight is not None and max_weight is not None: + p_data_fp32.clamp_(min_weight, max_weight) + else: + raise Exception("Not supported.") + p.data.copy_(p_data_fp32) + + """ This method attempts to convert the model from using the self.feature_set to new_feature_set. Currently only works for adding virtual features. - ''' - def set_feature_set(self, new_feature_set): - if self.feature_set.name == new_feature_set.name: - return - - # TODO: Implement this for more complicated conversions. - # Currently we support only a single feature block. - if len(self.feature_set.features) > 1: - raise Exception('Cannot change feature set from {} to {}.'.format(self.feature_set.name, new_feature_set.name)) - - # Currently we only support conversion for feature sets with - # one feature block each so we'll dig the feature blocks directly - # and forget about the set. - old_feature_block = self.feature_set.features[0] - new_feature_block = new_feature_set.features[0] - - # next(iter(new_feature_block.factors)) is the way to get the - # first item in a OrderedDict. (the ordered dict being str : int - # mapping of the factor name to its size). - # It is our new_feature_factor_name. - # For example old_feature_block.name == "HalfKP" - # and new_feature_factor_name == "HalfKP^" - # We assume here that the "^" denotes factorized feature block - # and we would like feature block implementers to follow this convention. - # So if our current feature_set matches the first factor in the new_feature_set - # we only have to add the virtual feature on top of the already existing real ones. - if old_feature_block.name == next(iter(new_feature_block.factors)): - # We can just extend with zeros since it's unfactorized -> factorized - weights = self.input.weight - padding = weights.new_zeros((new_feature_block.num_virtual_features, weights.shape[1])) - weights = torch.cat([weights, padding], dim=0) - self.input.weight = nn.Parameter(weights) - self.feature_set = new_feature_set - else: - raise Exception('Cannot change feature set from {} to {}.'.format(self.feature_set.name, new_feature_set.name)) - - def forward(self, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices): - wp, bp = self.input(white_indices, white_values, black_indices, black_values) - w, wpsqt = torch.split(wp, L1, dim=1) - b, bpsqt = torch.split(bp, L1, dim=1) - l0_ = (us * torch.cat([w, b], dim=1)) + (them * torch.cat([b, w], dim=1)) - l0_ = torch.clamp(l0_, 0.0, 1.0) - - l0_s = torch.split(l0_, L1 // 2, dim=1) - l0_s1 = [l0_s[0] * l0_s[1], l0_s[2] * l0_s[3]] - # We multiply by 127/128 because in the quantized network 1.0 is represented by 127 - # and it's more efficient to divide by 128 instead. - l0_ = torch.cat(l0_s1, dim=1) * (127/128) - - psqt_indices_unsq = psqt_indices.unsqueeze(dim=1) - wpsqt = wpsqt.gather(1, psqt_indices_unsq) - bpsqt = bpsqt.gather(1, psqt_indices_unsq) - # The PSQT values are averaged over perspectives. "Their" perspective - # has a negative influence (us-0.5 is 0.5 for white and -0.5 for black, - # which does both the averaging and sign flip for black to move) - x = self.layer_stacks(l0_, layer_stack_indices) + (wpsqt - bpsqt) * (us - 0.5) - - return x - - def step_(self, batch, batch_idx, loss_type): - # We clip weights at the start of each step. This means that after - # the last step the weights might be outside of the desired range. - # They should be also clipped accordingly in the serializer. - self._clip_weights() - - us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch - - # convert the network and search scores to an estimate match result - # based on the win_rate_model, with scalings and offsets optimized - in_scaling = 340 - out_scaling = 380 - offset = 270 - - scorenet = self(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) * self.nnue2score - q = ( scorenet - offset) / in_scaling # used to compute the chance of a win - qm = (-scorenet - offset) / in_scaling # used to compute the chance of a loss - qf = 0.5 * (1.0 + q.sigmoid() - qm.sigmoid()) # estimated match result (using win, loss and draw probs). - - p = ( score - offset) / out_scaling - pm = (-score - offset) / out_scaling - pf = 0.5 * (1.0 + p.sigmoid() - pm.sigmoid()) - - t = outcome - actual_lambda = self.start_lambda + (self.end_lambda - self.start_lambda) * (self.current_epoch / self.max_epoch) - pt = pf * actual_lambda + t * (1.0 - actual_lambda) - - loss = torch.pow(torch.abs(pt - qf), 2.5).mean() - - self.log(loss_type, loss) - - return loss - - def training_step(self, batch, batch_idx): - return self.step_(batch, batch_idx, 'train_loss') - - def validation_step(self, batch, batch_idx): - self.step_(batch, batch_idx, 'val_loss') - - def test_step(self, batch, batch_idx): - self.step_(batch, batch_idx, 'test_loss') - - def configure_optimizers(self): - LR = self.lr - train_params = [ - {'params' : get_parameters([self.input]), 'lr' : LR, 'gc_dim' : 0 }, - {'params' : [self.layer_stacks.l1_fact.weight], 'lr' : LR }, - {'params' : [self.layer_stacks.l1_fact.bias], 'lr' : LR }, - {'params' : [self.layer_stacks.l1.weight], 'lr' : LR }, - {'params' : [self.layer_stacks.l1.bias], 'lr' : LR }, - {'params' : [self.layer_stacks.l2.weight], 'lr' : LR }, - {'params' : [self.layer_stacks.l2.bias], 'lr' : LR }, - {'params' : [self.layer_stacks.output.weight], 'lr' : LR }, - {'params' : [self.layer_stacks.output.bias], 'lr' : LR }, - ] - # Increasing the eps leads to less saturated nets with a few dead neurons. - # Gradient localisation appears slightly harmful. - optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.gamma) - return [optimizer], [scheduler] + """ + + def set_feature_set(self, new_feature_set): + if self.feature_set.name == new_feature_set.name: + return + + # TODO: Implement this for more complicated conversions. + # Currently we support only a single feature block. + if len(self.feature_set.features) > 1: + raise Exception( + "Cannot change feature set from {} to {}.".format( + self.feature_set.name, new_feature_set.name + ) + ) + + # Currently we only support conversion for feature sets with + # one feature block each so we'll dig the feature blocks directly + # and forget about the set. + old_feature_block = self.feature_set.features[0] + new_feature_block = new_feature_set.features[0] + + # next(iter(new_feature_block.factors)) is the way to get the + # first item in a OrderedDict. (the ordered dict being str : int + # mapping of the factor name to its size). + # It is our new_feature_factor_name. + # For example old_feature_block.name == "HalfKP" + # and new_feature_factor_name == "HalfKP^" + # We assume here that the "^" denotes factorized feature block + # and we would like feature block implementers to follow this convention. + # So if our current feature_set matches the first factor in the new_feature_set + # we only have to add the virtual feature on top of the already existing real ones. + if old_feature_block.name == next(iter(new_feature_block.factors)): + # We can just extend with zeros since it's unfactorized -> factorized + weights = self.input.weight + padding = weights.new_zeros( + (new_feature_block.num_virtual_features, weights.shape[1]) + ) + weights = torch.cat([weights, padding], dim=0) + self.input.weight = nn.Parameter(weights) + self.feature_set = new_feature_set + else: + raise Exception( + "Cannot change feature set from {} to {}.".format( + self.feature_set.name, new_feature_set.name + ) + ) + + def forward( + self, + us, + them, + white_indices, + white_values, + black_indices, + black_values, + psqt_indices, + layer_stack_indices, + ): + wp, bp = self.input(white_indices, white_values, black_indices, black_values) + w, wpsqt = torch.split(wp, L1, dim=1) + b, bpsqt = torch.split(bp, L1, dim=1) + l0_ = (us * torch.cat([w, b], dim=1)) + (them * torch.cat([b, w], dim=1)) + l0_ = torch.clamp(l0_, 0.0, 1.0) + + l0_s = torch.split(l0_, L1 // 2, dim=1) + l0_s1 = [l0_s[0] * l0_s[1], l0_s[2] * l0_s[3]] + # We multiply by 127/128 because in the quantized network 1.0 is represented by 127 + # and it's more efficient to divide by 128 instead. + l0_ = torch.cat(l0_s1, dim=1) * (127 / 128) + + psqt_indices_unsq = psqt_indices.unsqueeze(dim=1) + wpsqt = wpsqt.gather(1, psqt_indices_unsq) + bpsqt = bpsqt.gather(1, psqt_indices_unsq) + # The PSQT values are averaged over perspectives. "Their" perspective + # has a negative influence (us-0.5 is 0.5 for white and -0.5 for black, + # which does both the averaging and sign flip for black to move) + x = self.layer_stacks(l0_, layer_stack_indices) + (wpsqt - bpsqt) * (us - 0.5) + + return x + + def step_(self, batch, batch_idx, loss_type): + # We clip weights at the start of each step. This means that after + # the last step the weights might be outside of the desired range. + # They should be also clipped accordingly in the serializer. + self._clip_weights() + + ( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + outcome, + score, + psqt_indices, + layer_stack_indices, + ) = batch + + # convert the network and search scores to an estimate match result + # based on the win_rate_model, with scalings and offsets optimized + in_scaling = 340 + out_scaling = 380 + offset = 270 + + scorenet = ( + self( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + psqt_indices, + layer_stack_indices, + ) + * self.nnue2score + ) + q = (scorenet - offset) / in_scaling # used to compute the chance of a win + qm = (-scorenet - offset) / in_scaling # used to compute the chance of a loss + qf = 0.5 * ( + 1.0 + q.sigmoid() - qm.sigmoid() + ) # estimated match result (using win, loss and draw probs). + + p = (score - offset) / out_scaling + pm = (-score - offset) / out_scaling + pf = 0.5 * (1.0 + p.sigmoid() - pm.sigmoid()) + + t = outcome + actual_lambda = self.start_lambda + (self.end_lambda - self.start_lambda) * ( + self.current_epoch / self.max_epoch + ) + pt = pf * actual_lambda + t * (1.0 - actual_lambda) + + loss = torch.pow(torch.abs(pt - qf), 2.5).mean() + + self.log(loss_type, loss) + + return loss + + def training_step(self, batch, batch_idx): + return self.step_(batch, batch_idx, "train_loss") + + def validation_step(self, batch, batch_idx): + self.step_(batch, batch_idx, "val_loss") + + def test_step(self, batch, batch_idx): + self.step_(batch, batch_idx, "test_loss") + + def configure_optimizers(self): + LR = self.lr + train_params = [ + {"params": get_parameters([self.input]), "lr": LR, "gc_dim": 0}, + {"params": [self.layer_stacks.l1_fact.weight], "lr": LR}, + {"params": [self.layer_stacks.l1_fact.bias], "lr": LR}, + {"params": [self.layer_stacks.l1.weight], "lr": LR}, + {"params": [self.layer_stacks.l1.bias], "lr": LR}, + {"params": [self.layer_stacks.l2.weight], "lr": LR}, + {"params": [self.layer_stacks.l2.bias], "lr": LR}, + {"params": [self.layer_stacks.output.weight], "lr": LR}, + {"params": [self.layer_stacks.output.bias], "lr": LR}, + ] + # Increasing the eps leads to less saturated nets with a few dead neurons. + # Gradient localisation appears slightly harmful. + optimizer = ranger.Ranger( + train_params, betas=(0.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False + ) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=1, gamma=self.gamma + ) + return [optimizer], [scheduler] diff --git a/nnue_dataset.py b/nnue_dataset.py index 9ce73a43..13315055 100644 --- a/nnue_dataset.py +++ b/nnue_dataset.py @@ -6,84 +6,192 @@ import glob from torch.utils.data import Dataset -local_dllpath = [n for n in glob.glob('./*training_data_loader.*') if n.endswith('.so') or n.endswith('.dll') or n.endswith('.dylib')] +local_dllpath = [ + n + for n in glob.glob("./*training_data_loader.*") + if n.endswith(".so") or n.endswith(".dll") or n.endswith(".dylib") +] if not local_dllpath: - print('Cannot find data_loader shared library.') + print("Cannot find data_loader shared library.") sys.exit(1) dllpath = os.path.abspath(local_dllpath[0]) dll = ctypes.cdll.LoadLibrary(dllpath) + class SparseBatch(ctypes.Structure): _fields_ = [ - ('num_inputs', ctypes.c_int), - ('size', ctypes.c_int), - ('is_white', ctypes.POINTER(ctypes.c_float)), - ('outcome', ctypes.POINTER(ctypes.c_float)), - ('score', ctypes.POINTER(ctypes.c_float)), - ('num_active_white_features', ctypes.c_int), - ('num_active_black_features', ctypes.c_int), - ('max_active_features', ctypes.c_int), - ('white', ctypes.POINTER(ctypes.c_int)), - ('black', ctypes.POINTER(ctypes.c_int)), - ('white_values', ctypes.POINTER(ctypes.c_float)), - ('black_values', ctypes.POINTER(ctypes.c_float)), - ('psqt_indices', ctypes.POINTER(ctypes.c_int)), - ('layer_stack_indices', ctypes.POINTER(ctypes.c_int)), + ("num_inputs", ctypes.c_int), + ("size", ctypes.c_int), + ("is_white", ctypes.POINTER(ctypes.c_float)), + ("outcome", ctypes.POINTER(ctypes.c_float)), + ("score", ctypes.POINTER(ctypes.c_float)), + ("num_active_white_features", ctypes.c_int), + ("num_active_black_features", ctypes.c_int), + ("max_active_features", ctypes.c_int), + ("white", ctypes.POINTER(ctypes.c_int)), + ("black", ctypes.POINTER(ctypes.c_int)), + ("white_values", ctypes.POINTER(ctypes.c_float)), + ("black_values", ctypes.POINTER(ctypes.c_float)), + ("psqt_indices", ctypes.POINTER(ctypes.c_int)), + ("layer_stack_indices", ctypes.POINTER(ctypes.c_int)), ] def get_tensors(self, device): - white_values = torch.from_numpy(np.ctypeslib.as_array(self.white_values, shape=(self.size, self.max_active_features))).pin_memory().to(device=device, non_blocking=True) - black_values = torch.from_numpy(np.ctypeslib.as_array(self.black_values, shape=(self.size, self.max_active_features))).pin_memory().to(device=device, non_blocking=True) - white_indices = torch.from_numpy(np.ctypeslib.as_array(self.white, shape=(self.size, self.max_active_features))).pin_memory().to(device=device, non_blocking=True) - black_indices = torch.from_numpy(np.ctypeslib.as_array(self.black, shape=(self.size, self.max_active_features))).pin_memory().to(device=device, non_blocking=True) - us = torch.from_numpy(np.ctypeslib.as_array(self.is_white, shape=(self.size, 1))).pin_memory().to(device=device, non_blocking=True) + white_values = ( + torch.from_numpy( + np.ctypeslib.as_array( + self.white_values, shape=(self.size, self.max_active_features) + ) + ) + .pin_memory() + .to(device=device, non_blocking=True) + ) + black_values = ( + torch.from_numpy( + np.ctypeslib.as_array( + self.black_values, shape=(self.size, self.max_active_features) + ) + ) + .pin_memory() + .to(device=device, non_blocking=True) + ) + white_indices = ( + torch.from_numpy( + np.ctypeslib.as_array( + self.white, shape=(self.size, self.max_active_features) + ) + ) + .pin_memory() + .to(device=device, non_blocking=True) + ) + black_indices = ( + torch.from_numpy( + np.ctypeslib.as_array( + self.black, shape=(self.size, self.max_active_features) + ) + ) + .pin_memory() + .to(device=device, non_blocking=True) + ) + us = ( + torch.from_numpy(np.ctypeslib.as_array(self.is_white, shape=(self.size, 1))) + .pin_memory() + .to(device=device, non_blocking=True) + ) them = 1.0 - us - outcome = torch.from_numpy(np.ctypeslib.as_array(self.outcome, shape=(self.size, 1))).pin_memory().to(device=device, non_blocking=True) - score = torch.from_numpy(np.ctypeslib.as_array(self.score, shape=(self.size, 1))).pin_memory().to(device=device, non_blocking=True) - psqt_indices = torch.from_numpy(np.ctypeslib.as_array(self.psqt_indices, shape=(self.size,))).long().pin_memory().to(device=device, non_blocking=True) - layer_stack_indices = torch.from_numpy(np.ctypeslib.as_array(self.layer_stack_indices, shape=(self.size,))).long().pin_memory().to(device=device, non_blocking=True) - return us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices + outcome = ( + torch.from_numpy(np.ctypeslib.as_array(self.outcome, shape=(self.size, 1))) + .pin_memory() + .to(device=device, non_blocking=True) + ) + score = ( + torch.from_numpy(np.ctypeslib.as_array(self.score, shape=(self.size, 1))) + .pin_memory() + .to(device=device, non_blocking=True) + ) + psqt_indices = ( + torch.from_numpy( + np.ctypeslib.as_array(self.psqt_indices, shape=(self.size,)) + ) + .long() + .pin_memory() + .to(device=device, non_blocking=True) + ) + layer_stack_indices = ( + torch.from_numpy( + np.ctypeslib.as_array(self.layer_stack_indices, shape=(self.size,)) + ) + .long() + .pin_memory() + .to(device=device, non_blocking=True) + ) + return ( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + outcome, + score, + psqt_indices, + layer_stack_indices, + ) + SparseBatchPtr = ctypes.POINTER(SparseBatch) + class Fen(ctypes.Structure): - _fields_ = [ - ('size', ctypes.c_int), - ('fen', ctypes.c_char_p) - ] + _fields_ = [("size", ctypes.c_int), ("fen", ctypes.c_char_p)] + FenPtr = ctypes.POINTER(Fen) + class FenBatch(ctypes.Structure): - _fields_ = [ - ('size', ctypes.c_int), - ('fens', FenPtr) - ] + _fields_ = [("size", ctypes.c_int), ("fens", FenPtr)] def get_fens(self): strings = [] for i in range(self.size): - strings.append(self.fens[i].fen.decode('utf-8')) + strings.append(self.fens[i].fen.decode("utf-8")) return strings + FenBatchPtr = ctypes.POINTER(FenBatch) # EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index) create_fen_batch_stream = dll.create_fen_batch_stream create_fen_batch_stream.restype = ctypes.c_void_p -create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int] +create_fen_batch_stream.argtypes = [ + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ctypes.c_char_p), + ctypes.c_int, + ctypes.c_bool, + ctypes.c_bool, + ctypes.c_int, + ctypes.c_bool, + ctypes.c_int, + ctypes.c_int, +] destroy_fen_batch_stream = dll.destroy_fen_batch_stream destroy_fen_batch_stream.argtypes = [ctypes.c_void_p] -def make_fen_batch_stream(concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index): + +def make_fen_batch_stream( + concurrency, + filenames, + batch_size, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, +): filenames_ = (ctypes.c_char_p * len(filenames))() - filenames_[:] = [filename.encode('utf-8') for filename in filenames] - return create_fen_batch_stream(concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index) + filenames_[:] = [filename.encode("utf-8") for filename in filenames] + return create_fen_batch_stream( + concurrency, + len(filenames), + filenames_, + batch_size, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, + ) + fetch_next_fen_batch = dll.fetch_next_fen_batch fetch_next_fen_batch.restype = FenBatchPtr fetch_next_fen_batch.argtypes = [ctypes.c_void_p] destroy_fen_batch = dll.destroy_fen_batch + class FenBatchProvider: def __init__( self, @@ -95,7 +203,8 @@ def __init__( random_fen_skipping=0, early_fen_skipping=-1, wld_filtered=False, - param_index=0): + param_index=0, + ): self.filename = filename self.cyclic = cyclic @@ -108,9 +217,28 @@ def __init__( self.param_index = param_index if batch_size: - self.stream = make_fen_batch_stream(self.num_workers, [self.filename], batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index) + self.stream = make_fen_batch_stream( + self.num_workers, + [self.filename], + batch_size, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, + ) else: - self.stream = make_fen_batch_stream(self.num_workers, [self.filename], cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index) + self.stream = make_fen_batch_stream( + self.num_workers, + [self.filename], + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, + ) def __iter__(self): return self @@ -128,6 +256,7 @@ def __next__(self): def __del__(self): destroy_fen_batch_stream(self.stream) + class TrainingDataProvider: def __init__( self, @@ -145,9 +274,10 @@ def __init__( wld_filtered=False, early_fen_skipping=-1, param_index=0, - device='cpu'): + device="cpu", + ): - self.feature_set = feature_set.encode('utf-8') + self.feature_set = feature_set.encode("utf-8") self.create_stream = create_stream self.destroy_stream = destroy_stream self.fetch_next = fetch_next @@ -163,9 +293,30 @@ def __init__( self.device = device if batch_size: - self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index) + self.stream = self.create_stream( + self.feature_set, + self.num_workers, + self.filenames, + batch_size, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, + ) else: - self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index) + self.stream = self.create_stream( + self.feature_set, + self.num_workers, + self.filenames, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, + ) def __iter__(self): return self @@ -183,18 +334,56 @@ def __next__(self): def __del__(self): self.destroy_stream(self.stream) + # EXPORT Stream* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, # bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index) create_sparse_batch_stream = dll.create_sparse_batch_stream create_sparse_batch_stream.restype = ctypes.c_void_p -create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int] +create_sparse_batch_stream.argtypes = [ + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ctypes.c_char_p), + ctypes.c_int, + ctypes.c_bool, + ctypes.c_bool, + ctypes.c_int, + ctypes.c_bool, + ctypes.c_int, + ctypes.c_int, +] destroy_sparse_batch_stream = dll.destroy_sparse_batch_stream destroy_sparse_batch_stream.argtypes = [ctypes.c_void_p] -def make_sparse_batch_stream(feature_set, concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index): + +def make_sparse_batch_stream( + feature_set, + concurrency, + filenames, + batch_size, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, +): filenames_ = (ctypes.c_char_p * len(filenames))() - filenames_[:] = [filename.encode('utf-8') for filename in filenames] - return create_sparse_batch_stream(feature_set, concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index) + filenames_[:] = [filename.encode("utf-8") for filename in filenames] + return create_sparse_batch_stream( + feature_set, + concurrency, + len(filenames), + filenames_, + batch_size, + cyclic, + filtered, + random_fen_skipping, + wld_filtered, + early_fen_skipping, + param_index, + ) + fetch_next_sparse_batch = dll.fetch_next_sparse_batch fetch_next_sparse_batch.restype = SparseBatchPtr @@ -203,25 +392,49 @@ def make_sparse_batch_stream(feature_set, concurrency, filenames, batch_size, cy get_sparse_batch_from_fens = dll.get_sparse_batch_from_fens get_sparse_batch_from_fens.restype = SparseBatchPtr -get_sparse_batch_from_fens.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)] +get_sparse_batch_from_fens.argtypes = [ + ctypes.c_char_p, + ctypes.c_int, + ctypes.POINTER(ctypes.c_char_p), + ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int), +] + def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results): - results_ = (ctypes.c_int*len(scores))() - scores_ = (ctypes.c_int*len(plies))() - plies_ = (ctypes.c_int*len(results))() + results_ = (ctypes.c_int * len(scores))() + scores_ = (ctypes.c_int * len(plies))() + plies_ = (ctypes.c_int * len(results))() fens_ = (ctypes.c_char_p * len(fens))() - fens_[:] = [fen.encode('utf-8') for fen in fens] + fens_[:] = [fen.encode("utf-8") for fen in fens] for i, v in enumerate(scores): scores_[i] = v for i, v in enumerate(plies): plies_[i] = v for i, v in enumerate(results): results_[i] = v - b = get_sparse_batch_from_fens(feature_set.name.encode('utf-8'), len(fens), fens_, scores_, plies_, results_) + b = get_sparse_batch_from_fens( + feature_set.name.encode("utf-8"), len(fens), fens_, scores_, plies_, results_ + ) return b + class SparseBatchProvider(TrainingDataProvider): - def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'): + def __init__( + self, + feature_set, + filenames, + batch_size, + cyclic=True, + num_workers=1, + filtered=False, + random_fen_skipping=0, + wld_filtered=False, + early_fen_skipping=-1, + param_index=0, + device="cpu", + ): super(SparseBatchProvider, self).__init__( feature_set, make_sparse_batch_stream, @@ -237,36 +450,63 @@ def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers= wld_filtered, early_fen_skipping, param_index, - device) + device, + ) + class SparseBatchDataset(torch.utils.data.IterableDataset): - def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'): - super(SparseBatchDataset).__init__() - self.feature_set = feature_set - self.filenames = filenames - self.batch_size = batch_size - self.cyclic = cyclic - self.num_workers = num_workers - self.filtered = filtered - self.random_fen_skipping = random_fen_skipping - self.wld_filtered = wld_filtered - self.early_fen_skipping = early_fen_skipping - self.param_index = param_index - self.device = device - - def __iter__(self): - return SparseBatchProvider(self.feature_set, self.filenames, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers, - filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, wld_filtered=self.wld_filtered, early_fen_skipping = self.early_fen_skipping, param_index=self.param_index, device=self.device) + def __init__( + self, + feature_set, + filenames, + batch_size, + cyclic=True, + num_workers=1, + filtered=False, + random_fen_skipping=0, + wld_filtered=False, + early_fen_skipping=-1, + param_index=0, + device="cpu", + ): + super(SparseBatchDataset).__init__() + self.feature_set = feature_set + self.filenames = filenames + self.batch_size = batch_size + self.cyclic = cyclic + self.num_workers = num_workers + self.filtered = filtered + self.random_fen_skipping = random_fen_skipping + self.wld_filtered = wld_filtered + self.early_fen_skipping = early_fen_skipping + self.param_index = param_index + self.device = device + + def __iter__(self): + return SparseBatchProvider( + self.feature_set, + self.filenames, + self.batch_size, + cyclic=self.cyclic, + num_workers=self.num_workers, + filtered=self.filtered, + random_fen_skipping=self.random_fen_skipping, + wld_filtered=self.wld_filtered, + early_fen_skipping=self.early_fen_skipping, + param_index=self.param_index, + device=self.device, + ) + class FixedNumBatchesDataset(Dataset): - def __init__(self, dataset, num_batches): - super(FixedNumBatchesDataset, self).__init__() - self.dataset = dataset; - self.iter = iter(self.dataset) - self.num_batches = num_batches + def __init__(self, dataset, num_batches): + super(FixedNumBatchesDataset, self).__init__() + self.dataset = dataset + self.iter = iter(self.dataset) + self.num_batches = num_batches - def __len__(self): - return self.num_batches + def __len__(self): + return self.num_batches - def __getitem__(self, idx): - return next(self.iter) + def __getitem__(self, idx): + return next(self.iter) diff --git a/perf_sigmoid_fitter.py b/perf_sigmoid_fitter.py index 40dffcd5..71fec616 100644 --- a/perf_sigmoid_fitter.py +++ b/perf_sigmoid_fitter.py @@ -8,16 +8,19 @@ import sys import random + def sigmoid(x, k): - y = 1 / (1 + np.exp(-k*x)) - return (y) + y = 1 / (1 + np.exp(-k * x)) + return y + def fit_data(x, y, sigma): # 1/361 is the initial guess. It's good enough to find the solution - p0 = [1/361] - popt, pcov = curve_fit(sigmoid, x, y, p0, sigma, method='dogbox') + p0 = [1 / 361] + popt, pcov = curve_fit(sigmoid, x, y, p0, sigma, method="dogbox") return popt[0] + def do_plot(data, filename): # plot of the eval distribution fig, axs = plt.subplots(2) @@ -25,11 +28,11 @@ def do_plot(data, filename): fig.suptitle(filename) x = list(data.keys()) y = [data[k][1] for k in x] - x, y = zip(*list(sorted(zip(x, y), key=lambda x:x[0]))) + x, y = zip(*list(sorted(zip(x, y), key=lambda x: x[0]))) axs[0].plot(x, y) - axs[0].set_ylabel('density') - axs[0].set_xlabel('eval') - axs[0].set_xscale('symlog') + axs[0].set_ylabel("density") + axs[0].set_xlabel("eval") + axs[0].set_xscale("symlog") # plot of the perf% by eval and the fitted sigmoid x = list(data.keys()) @@ -38,32 +41,44 @@ def do_plot(data, filename): # The inverted counts are good enough. sigma = [1 / data[k][1] for k in x] k = fit_data(x, y, sigma) - print('k: ', k) - print('inv k: ', 1/k) - axs[1].scatter(x, y, label='perf') + print("k: ", k) + print("inv k: ", 1 / k) + axs[1].scatter(x, y, label="perf") y = [sigmoid(xx, k) for xx in x] - axs[1].scatter(x, y, label='sigmoid(x/{})'.format(1.0/k)) + axs[1].scatter(x, y, label="sigmoid(x/{})".format(1.0 / k)) axs[1].legend(loc="upper left") - axs[1].set_ylabel('perf') - axs[1].set_xlabel('eval') + axs[1].set_ylabel("perf") + axs[1].set_xlabel("eval") # save to a .png file - plot_filename = '.'.join(filename.split('.')[:-1]) + '.png' + plot_filename = ".".join(filename.split(".")[:-1]) + ".png" plt.savefig(plot_filename) - print('plot saved at {}'.format(plot_filename)) + print("plot saved at {}".format(plot_filename)) + def gather_statistics_from_batches(batches, bucket_size): - ''' + """ This function takes an iterable of training batches and a bucket_size. It goes through all batches and collects evals and the outcomes. The evals are bucketed by bucket_size. Perf% is computed based on the evals and corresponding game outcomes. The result is a dictionary of the form { eval : (perf%, count) } - ''' + """ data = dict() i = 0 for batch in batches: - us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch + ( + us, + them, + white_indices, + white_values, + black_indices, + black_values, + outcome, + score, + psqt_indices, + layer_stack_indices, + ) = batch batch_size = len(us) bucket = torch.round(score / bucket_size) * bucket_size perf = outcome @@ -76,34 +91,41 @@ def gather_statistics_from_batches(batches, bucket_size): else: data[bucket_id] = (pp, 1) i += batch_size - print('Loaded {} positions...'.format(i)) + print("Loaded {} positions...".format(i)) return data + def gather_statistics_from_data(filename, count, bucket_size): - ''' + """ Takes a .bin or .binpack file and produces perf% statistics The result is a dictionary of the form { eval : (perf%, count) } - ''' + """ batch_size = 8192 cyclic = True smart_fen_skipping = True # we pass whatever feature set because we have to pass something # it doesn't actually matter, all we care about are the scores and outcomes # this is just the easiest way to do it - dataset = nnue_dataset.SparseBatchDataset('HalfKP', filename, batch_size, cyclic, smart_fen_skipping) + dataset = nnue_dataset.SparseBatchDataset( + "HalfKP", filename, batch_size, cyclic, smart_fen_skipping + ) batches = iter(dataset) num_batches = (count + batch_size - 1) // batch_size - data = gather_statistics_from_batches((next(batches) for i in range(num_batches)), bucket_size) + data = gather_statistics_from_batches( + (next(batches) for i in range(num_batches)), bucket_size + ) return data + def show_help(): - print('Usage: python perf_sigmoid_fitter.py filename [count] [bucket_size]') - print('count is the number of positions. Default: 1000000') - print('bucket_size determines how the evals are bucketed. Default: 16') - print('') - print('This file can be used as a module') - print('The function `gather_statistics_from_batches` can be used to determine') - print('the sigmoid scaling factor for each batch during training') + print("Usage: python perf_sigmoid_fitter.py filename [count] [bucket_size]") + print("count is the number of positions. Default: 1000000") + print("bucket_size determines how the evals are bucketed. Default: 16") + print("") + print("This file can be used as a module") + print("The function `gather_statistics_from_batches` can be used to determine") + print("the sigmoid scaling factor for each batch during training") + def main(): filename = sys.argv[1] @@ -112,7 +134,8 @@ def main(): data = gather_statistics_from_data(filename, count, bucket_size) do_plot(data, filename) -if __name__ == '__main__': + +if __name__ == "__main__": if len(sys.argv) <= 1: show_help() else: diff --git a/ranger.py b/ranger.py index 997dc7d6..b64f9bfb 100644 --- a/ranger.py +++ b/ranger.py @@ -29,32 +29,42 @@ # If dim is None it will be chosen automatically def centralized_gradient(x, use_gc=True, gc_conv_only=False, dim=None): - '''credit - https://github.com/Yonghongwei/Gradient-Centralization ''' + """credit - https://github.com/Yonghongwei/Gradient-Centralization""" if use_gc: dim_threshold = 3 if gc_conv_only else 1 if len(list(x.size())) > dim_threshold: - x.add_(-x.mean(dim=(dim or tuple(range(1, len(list(x.size()))))), keepdim=True)) + x.add_( + -x.mean(dim=(dim or tuple(range(1, len(list(x.size()))))), keepdim=True) + ) return x class Ranger(Optimizer): - - def __init__(self, params, lr=1e-3, # lr - alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options - betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options - # Gradient centralization on or off, applied to conv layers only or conv + fc layers - use_gc=True, gc_conv_only=False, gc_loc=True - ): + def __init__( + self, + params, + lr=1e-3, # lr + alpha=0.5, + k=6, + N_sma_threshhold=5, # Ranger options + betas=(0.95, 0.999), + eps=1e-5, + weight_decay=0, # Adam options + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + use_gc=True, + gc_conv_only=False, + gc_loc=True, + ): # parameter checks if not 0.0 <= alpha <= 1.0: - raise ValueError(f'Invalid slow update rate: {alpha}') + raise ValueError(f"Invalid slow update rate: {alpha}") if not 1 <= k: - raise ValueError(f'Invalid lookahead steps: {k}') + raise ValueError(f"Invalid lookahead steps: {k}") if not lr > 0: - raise ValueError(f'Invalid Learning Rate: {lr}') + raise ValueError(f"Invalid Learning Rate: {lr}") if not eps > 0: - raise ValueError(f'Invalid eps: {eps}') + raise ValueError(f"Invalid eps: {eps}") # parameter comments: # beta1 (momentum) of .95 seems to work better than .90... @@ -62,9 +72,17 @@ def __init__(self, params, lr=1e-3, # lr # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. # prep defaults and init torch.optim base - defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, - N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay, - gc_dim=None) + defaults = dict( + lr=lr, + alpha=alpha, + k=k, + step_counter=0, + betas=betas, + N_sma_threshhold=N_sma_threshhold, + eps=eps, + weight_decay=weight_decay, + gc_dim=None, + ) super().__init__(params, defaults) # adjustable threshold @@ -83,13 +101,14 @@ def __init__(self, params, lr=1e-3, # lr self.use_gc = use_gc self.gc_conv_only = gc_conv_only # level of gradient centralization - #self.gc_gradient_threshold = 3 if gc_conv_only else 1 + # self.gc_gradient_threshold = 3 if gc_conv_only else 1 print( - f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") - if (self.use_gc and self.gc_conv_only == False): + f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}" + ) + if self.use_gc and self.gc_conv_only == False: print(f"GC applied to both conv and fc layers") - elif (self.use_gc and self.gc_conv_only == True): + elif self.use_gc and self.gc_conv_only == True: print(f"GC applied to conv layers only") def __setstate__(self, state): @@ -105,47 +124,54 @@ def step(self, closure=None): # Evaluate averages and grad, update param tensors for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError( - 'Ranger optimizer does not support sparse gradients') + "Ranger optimizer does not support sparse gradients" + ) p_data_fp32 = p.data.float() state = self.state[p] # get state dict for this param - if len(state) == 0: # if first time to run...init dictionary with our desired entries + if ( + len(state) == 0 + ): # if first time to run...init dictionary with our desired entries # if self.first_run_check==0: # self.first_run_check=1 - #print("Initializing slow buffer...should not see this at load from saved model!") - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + # print("Initializing slow buffer...should not see this at load from saved model!") + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) # look ahead weight storage now in state dict - state['slow_buffer'] = torch.empty_like(p.data) - state['slow_buffer'].copy_(p.data) + state["slow_buffer"] = torch.empty_like(p.data) + state["slow_buffer"].copy_(p.data) else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as( - p_data_fp32) + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) # begin computations - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] # GC operation for Conv layers and FC layers # if grad.dim() > self.gc_gradient_threshold: # grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) if self.gc_loc: - grad = centralized_gradient(grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only, dim=group['gc_dim']) + grad = centralized_gradient( + grad, + use_gc=self.use_gc, + gc_conv_only=self.gc_conv_only, + dim=group["gc_dim"], + ) - state['step'] += 1 + state["step"] += 1 # compute variance mov avg exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -153,22 +179,28 @@ def step(self, closure=None): # compute mean moving avg exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - buffered = self.radam_buffer[int(state['step'] % 10)] + buffered = self.radam_buffer[int(state["step"] % 10)] - if state['step'] == buffered[0]: + if state["step"] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: - buffered[0] = state['step'] - beta2_t = beta2 ** state['step'] + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * \ - state['step'] * beta2_t / (1 - beta2_t) + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) buffered[1] = N_sma if N_sma > self.N_sma_threshhold: - step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( - N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) else: - step_size = 1.0 / (1 - beta1 ** state['step']) + step_size = 1.0 / (1 - beta1 ** state["step"]) buffered[2] = step_size # if group['weight_decay'] != 0: @@ -177,26 +209,31 @@ def step(self, closure=None): # apply lr if N_sma > self.N_sma_threshhold: - denom = exp_avg_sq.sqrt().add_(group['eps']) + denom = exp_avg_sq.sqrt().add_(group["eps"]) G_grad = exp_avg / denom else: G_grad = exp_avg - if group['weight_decay'] != 0: - G_grad.add_(p_data_fp32, alpha=group['weight_decay']) + if group["weight_decay"] != 0: + G_grad.add_(p_data_fp32, alpha=group["weight_decay"]) # GC operation if self.gc_loc == False: - G_grad = centralized_gradient(G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only, dim=group['gc_dim']) + G_grad = centralized_gradient( + G_grad, + use_gc=self.use_gc, + gc_conv_only=self.gc_conv_only, + dim=group["gc_dim"], + ) - p_data_fp32.add_(G_grad, alpha=-step_size * group['lr']) + p_data_fp32.add_(G_grad, alpha=-step_size * group["lr"]) p.data.copy_(p_data_fp32) # integrated look ahead... # we do it at the param level instead of group level - if state['step'] % group['k'] == 0: + if state["step"] % group["k"] == 0: # get access to slow param tensor - slow_p = state['slow_buffer'] + slow_p = state["slow_buffer"] # (fast weights - slow weights) * alpha slow_p.add_(p.data - slow_p, alpha=self.alpha) # copy interpolated weights to RAdam param tensor diff --git a/run_games.py b/run_games.py index 87fa4fd2..b4df8692 100644 --- a/run_games.py +++ b/run_games.py @@ -12,6 +12,8 @@ from pathlib import Path, PurePath GLOBAL_LOCK = threading.Lock() + + def print_atomic(*args, **kwargs): GLOBAL_LOCK.acquire() try: @@ -19,8 +21,17 @@ def print_atomic(*args, **kwargs): finally: GLOBAL_LOCK.release() + class GameParams: - def __init__(self, hash, threads, games_per_round, time_per_game=None, time_increment_per_move=None, nodes_per_move=None): + def __init__( + self, + hash, + threads, + games_per_round, + time_per_game=None, + time_increment_per_move=None, + nodes_per_move=None, + ): self.hash = hash self.threads = threads self.games_per_round = games_per_round @@ -29,35 +40,35 @@ def __init__(self, hash, threads, games_per_round, time_per_game=None, time_incr self.nodes_per_move = nodes_per_move if not time_per_game and not time_increment_per_move and not nodes_per_move: - raise Exception('Invalid TC specification.') + raise Exception("Invalid TC specification.") def get_all_params(self): params = [] - params += ['-each'] + params += ["-each"] params += [ - f'option.Hash={self.hash}', - f'option.Threads={self.threads}', - f'timeout=20' + f"option.Hash={self.hash}", + f"option.Threads={self.threads}", + f"timeout=20", ] if self.nodes_per_move: params += [ - f'tc=10000+10000', - f'nodes={self.nodes_per_move}', + f"tc=10000+10000", + f"nodes={self.nodes_per_move}", ] else: inc = self.time_increment_per_move or 0 - params += [f'tc={self.time_per_game}+{inc}'] + params += [f"tc={self.time_per_game}+{inc}"] - params += ['-games', f'{self.games_per_round}'] + params += ["-games", f"{self.games_per_round}"] return params -def convert_ckpt(root_dir,features): - """ Find the list of checkpoints that are available, and convert those that have no matching .nnue """ +def convert_ckpt(root_dir, features): + """Find the list of checkpoints that are available, and convert those that have no matching .nnue""" # run96/run0/default/version_0/checkpoints/epoch=3.ckpt, or epoch=3-step=321151.ckpt p = re.compile("epoch.*\.ckpt") @@ -66,22 +77,38 @@ def convert_ckpt(root_dir,features): # lets move the .nnue files a bit up in the tree, and get rid of the = sign. # run96/run0/default/version_0/checkpoints/epoch=3.ckpt -> run96/run0/nn-epoch3.nnue for ckpt in ckpts: - nnue_file_name = re.sub("default[/\\\\]version_[0-9]+[/\\\\]checkpoints[/\\\\]", "", ckpt) # for older pytorch lightning - nnue_file_name = re.sub("lightning_logs[/\\\\]version_[0-9]+[/\\\\]checkpoints[/\\\\]", "", nnue_file_name) # for newer pytorch lightning - nnue_file_name = re.sub(r"epoch\=([0-9]+).*\.ckpt", r"nn-epoch\1.nnue", nnue_file_name) + nnue_file_name = re.sub( + "default[/\\\\]version_[0-9]+[/\\\\]checkpoints[/\\\\]", "", ckpt + ) # for older pytorch lightning + nnue_file_name = re.sub( + "lightning_logs[/\\\\]version_[0-9]+[/\\\\]checkpoints[/\\\\]", + "", + nnue_file_name, + ) # for newer pytorch lightning + nnue_file_name = re.sub( + r"epoch\=([0-9]+).*\.ckpt", r"nn-epoch\1.nnue", nnue_file_name + ) if not os.path.exists(nnue_file_name) and os.path.exists(ckpt): - with subprocess.Popen([sys.executable, 'serialize.py', ckpt, nnue_file_name, f'--features={features}']) as process: + with subprocess.Popen( + [ + sys.executable, + "serialize.py", + ckpt, + nnue_file_name, + f"--features={features}", + ] + ) as process: if process.wait(): print_atomic("Error serializing!") def find_nnue(root_dir): - """ Find the set of nnue nets that are available for testing, going through the full subtree """ + """Find the set of nnue nets that are available for testing, going through the full subtree""" return [str(file) for file in Path(root_dir).rglob("nn-epoch*.nnue")] def parse_ordo(root_dir, nnues): - """ Parse an ordo output file for rating and error """ + """Parse an ordo output file for rating and error""" ordo_file_name = os.path.join(root_dir, "ordo.out") ordo_scores = {} for name in nnues: @@ -98,67 +125,100 @@ def parse_ordo(root_dir, nnues): error = float(fields[4]) for name in nnues: if net in name: - ordo_scores[name] = (rating, error) + ordo_scores[name] = (rating, error) return ordo_scores -def run_match(best, root_dir, c_chess_exe, concurrency, book_file_name, stockfish_base, stockfish_test, game_params, tries=10): - """ Run a match using c-chess-cli adding pgns to a file to be analysed with ordo """ +def run_match( + best, + root_dir, + c_chess_exe, + concurrency, + book_file_name, + stockfish_base, + stockfish_test, + game_params, + tries=10, +): + """Run a match using c-chess-cli adding pgns to a file to be analysed with ordo""" pgn_file_name = os.path.join(root_dir, "out_temp.pgn") command = [] if sys.platform != "win32": - command += ['stdbuf', '-o0'] + command += ["stdbuf", "-o0"] command += [ c_chess_exe, - '-gauntlet', '-rounds', '1', - '-concurrency', f'{concurrency}' + "-gauntlet", + "-rounds", + "1", + "-concurrency", + f"{concurrency}", ] command += game_params.get_all_params() command += [ - '-openings', f'file={book_file_name}', 'order=random', f'srand={random.randint(0,100000000)}', '-repeat', - '-resign', 'count=3', 'score=700', - '-draw', 'count=8', 'score=10', - '-pgn', f'{pgn_file_name}', '0' + "-openings", + f"file={book_file_name}", + "order=random", + f"srand={random.randint(0,100000000)}", + "-repeat", + "-resign", + "count=3", + "score=700", + "-draw", + "count=8", + "score=10", + "-pgn", + f"{pgn_file_name}", + "0", ] - command += ['-engine', f'cmd={stockfish_base}', 'name=master'] + command += ["-engine", f"cmd={stockfish_base}", "name=master"] for net in best: evalfile = os.path.join(os.getcwd(), net) netname = PurePath(*PurePath(evalfile).parts[-2:]) - command += ['-engine', f'cmd={stockfish_test}', f'name={netname}', f'option.EvalFile={evalfile}'] + command += [ + "-engine", + f"cmd={stockfish_test}", + f"name={netname}", + f"option.EvalFile={evalfile}", + ] # Attempt to run the match multiple times in case of unforseen # errors like engine hanging or c-chess-cli having an error... for i in range(tries): print_atomic(" ".join(command)) - print_atomic("Running match with c-chess-cli ... {}".format(pgn_file_name), flush=True) - c_chess_out = open(os.path.join(root_dir, "c_chess.out"), 'w') - process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + print_atomic( + "Running match with c-chess-cli ... {}".format(pgn_file_name), flush=True + ) + c_chess_out = open(os.path.join(root_dir, "c_chess.out"), "w") + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) seen = {} for line in process.stdout: - line = line.decode('utf-8') + line = line.decode("utf-8") c_chess_out.write(line) - if 'Score' in line: - epoch_num = re.search(r'epoch(\d+)', line) + if "Score" in line: + epoch_num = re.search(r"epoch(\d+)", line) if epoch_num.group(1) not in seen: - sys.stdout.write('\n') + sys.stdout.write("\n") seen[epoch_num.group(1)] = True - sys.stdout.write('\r' + line.rstrip()) + sys.stdout.write("\r" + line.rstrip()) sys.stdout.flush() - sys.stdout.write('\n') + sys.stdout.write("\n") c_chess_out.close() if process.wait() != 0: - if i == tries-1: + if i == tries - 1: print_atomic("Error running match!") else: - print_atomic(f'Retrying running match ({i}/{tries}) in 10s ...') + print_atomic(f"Retrying running match ({i}/{tries}) in 10s ...") time.sleep(10) else: break print_atomic("Finished running match.") + class EngineResults: def __init__(self, name): self._name = name @@ -217,8 +277,9 @@ def elo(self): def elo_error_95(self): return 400 / math.sqrt(self.total_games) + def run_approximate_ordo(root_dir): - """ run an approximate ordo-like calculation on an existing pgn file """ + """run an approximate ordo-like calculation on an existing pgn file""" """ it takes advantege of the fact that all matches are ran against master """ pgn_file_name = os.path.join(root_dir, "out.pgn") ordo_file_name = os.path.join(root_dir, "ordo.out") @@ -228,73 +289,88 @@ def run_approximate_ordo(root_dir): white = None black = None try: - with open(pgn_file_name, 'r', encoding='utf-8') as pgn_file: + with open(pgn_file_name, "r", encoding="utf-8") as pgn_file: for line in pgn_file: line = line.strip() - if line.startswith('[White'): + if line.startswith("[White"): white = line[8:-2] - elif line.startswith('[Black'): + elif line.startswith("[Black"): black = line[8:-2] - elif line.startswith('[Result') and white is not None and black is not None: + elif ( + line.startswith("[Result") + and white is not None + and black is not None + ): result_str = line[9:-2] if white not in entries: entries[white] = EngineResults(white) if black not in entries: entries[black] = EngineResults(black) - if result_str == '1-0': + if result_str == "1-0": entries[white].add_wins(1) entries[black].add_losses(1) - elif result_str == '0-1': + elif result_str == "0-1": entries[white].add_losses(1) entries[black].add_wins(1) - if result_str == '1/2-1/2': + if result_str == "1/2-1/2": entries[white].add_draws(1) entries[black].add_draws(1) except: return - entries_ordered = sorted(entries.values(), key=lambda x:0 if x.name == 'master' else -x.elo) + entries_ordered = sorted( + entries.values(), key=lambda x: 0 if x.name == "master" else -x.elo + ) - with open(ordo_file_name_temp, 'w') as ordo_file: - ordo_file.write('\n') - ordo_file.write(' # PLAYER : RATING ERROR POINTS PLAYED (%)\n') + with open(ordo_file_name_temp, "w") as ordo_file: + ordo_file.write("\n") + ordo_file.write( + " # PLAYER : RATING ERROR POINTS PLAYED (%)\n" + ) for i, entry in enumerate(entries_ordered): - if entry.name == 'master': - entry_elo = ' 0.0' - entry_elo_error_95 = '----' + if entry.name == "master": + entry_elo = " 0.0" + entry_elo_error_95 = "----" else: - entry_elo = f'{entry.elo:0.1f}' - entry_elo_error_95 = f'{entry.elo_error_95:0.1f}' - entry_points = f'{entry.points:0.1f}' - entry_performance = f'{entry.performance*100:0.0f}' - ordo_file.write(f' {i+1:2} {entry.name:<26} : {entry_elo:>7} {entry_elo_error_95:>6} {entry_points:>9} {entry.total_games:>7} {entry_performance:>4}\n') - ordo_file.write('\n') + entry_elo = f"{entry.elo:0.1f}" + entry_elo_error_95 = f"{entry.elo_error_95:0.1f}" + entry_points = f"{entry.points:0.1f}" + entry_performance = f"{entry.performance*100:0.0f}" + ordo_file.write( + f" {i+1:2} {entry.name:<26} : {entry_elo:>7} {entry_elo_error_95:>6} {entry_points:>9} {entry.total_games:>7} {entry_performance:>4}\n" + ) + ordo_file.write("\n") if not os.path.exists(ordo_file_name): - os.rename(ordo_file_name_temp, ordo_file_name) + os.rename(ordo_file_name_temp, ordo_file_name) else: - os.replace(ordo_file_name_temp, ordo_file_name) + os.replace(ordo_file_name_temp, ordo_file_name) print_atomic("Finished running ordo.") + def run_ordo(root_dir, ordo_exe, concurrency): - """ run an ordo calcuation on an existing pgn file """ + """run an ordo calcuation on an existing pgn file""" pgn_file_name = os.path.join(root_dir, "out.pgn") ordo_file_name = os.path.join(root_dir, "ordo.out") ordo_file_name_temp = os.path.join(root_dir, "ordo_temp.out") command = [ ordo_exe, - '-q', - '-g', - '-J', - '-p', f'{pgn_file_name}', - '-a', '0.0', - '--anchor=master', - '--draw-auto', - '--white-auto', - '-s', '100', - f'--cpus={concurrency}', - '-o', f'{ordo_file_name_temp}' + "-q", + "-g", + "-J", + "-p", + f"{pgn_file_name}", + "-a", + "0.0", + "--anchor=master", + "--draw-auto", + "--white-auto", + "-s", + "100", + f"--cpus={concurrency}", + "-o", + f"{ordo_file_name_temp}", ] print_atomic("Running ordo ranking ... {}".format(ordo_file_name), flush=True) @@ -306,6 +382,7 @@ def run_ordo(root_dir, ordo_exe, concurrency): print_atomic("Finished running ordo.") + def run_round( root_dir, explore_factor, @@ -316,9 +393,9 @@ def run_round( book_file_name, concurrency, features, - game_params + game_params, ): - """ run a round of games, finding existing nets, analyze an ordo file to pick most suitable ones, run a round, and run ordo """ + """run a round of games, finding existing nets, analyze an ordo file to pick most suitable ones, run a round, and run ordo""" # find and convert checkpoints to .nnue convert_ckpt(root_dir, features) @@ -342,7 +419,9 @@ def run_round( ) count = 0 for net in ordo_scores: - print_atomic(" {} : {} +- {}".format(net, ordo_scores[net][0], ordo_scores[net][1])) + print_atomic( + " {} : {} +- {}".format(net, ordo_scores[net][0], ordo_scores[net][1]) + ) count += 1 if count == 3: break @@ -358,7 +437,9 @@ def run_round( ) best = [] for net in ordo_scores: - print_atomic(" {} : {} +- {}".format(net, ordo_scores[net][0], ordo_scores[net][1])) + print_atomic( + " {} : {} +- {}".format(net, ordo_scores[net][0], ordo_scores[net][1]) + ) best.append(net) if len(best) == 3: break @@ -367,17 +448,24 @@ def run_round( # and run a new ordo ranking for the games played so far run_match_thread = threading.Thread( target=run_match, - args=(best, root_dir, c_chess_exe, concurrency, book_file_name, stockfish_base, stockfish_test, game_params) + args=( + best, + root_dir, + c_chess_exe, + concurrency, + book_file_name, + stockfish_base, + stockfish_test, + game_params, + ), ) if ordo_exe: run_ordo_thread = threading.Thread( - target=run_ordo, - args=(root_dir, ordo_exe, concurrency) + target=run_ordo, args=(root_dir, ordo_exe, concurrency) ) else: run_ordo_thread = threading.Thread( - target=run_approximate_ordo, - args=(root_dir,) + target=run_approximate_ordo, args=(root_dir,) ) run_match_thread.start() @@ -392,15 +480,17 @@ def run_round( main_pgn_file_name = os.path.join(root_dir, "out.pgn") curr_pgn_file_name = os.path.join(root_dir, "out_temp.pgn") if not os.path.exists(main_pgn_file_name): - with open(main_pgn_file_name, 'w'): pass + with open(main_pgn_file_name, "w"): + pass try: - with open(main_pgn_file_name, 'a') as file_to: - with open(curr_pgn_file_name, 'r') as file_from: + with open(main_pgn_file_name, "a") as file_to: + with open(curr_pgn_file_name, "r") as file_from: for line in file_from: file_to.write(line) os.remove(curr_pgn_file_name) except: - print_atomic('Something went wrong when adding new games to the main file.') + print_atomic("Something went wrong when adding new games to the main file.") + def main(): # basic setup @@ -457,37 +547,14 @@ def main(): default="./noob_3moves.epd", help="Path to a suitable book, see https://github.com/official-stockfish/books", ) + parser.add_argument("--time_per_game", type=float, default=4.0) + parser.add_argument("--time_increment_per_move", type=float, default=0.04) parser.add_argument( - "--time_per_game", - type=float, - default=4.0 - ) - parser.add_argument( - "--time_increment_per_move", - type=float, - default=0.04 - ) - parser.add_argument( - "--nodes_per_move", - type=int, - default=None, - help="Overrides time per move." - ) - parser.add_argument( - "--hash", - type=int, - default=8 - ) - parser.add_argument( - "--threads", - type=int, - default=1 - ) - parser.add_argument( - "--games_per_round", - type=int, - default=200 + "--nodes_per_move", type=int, default=None, help="Overrides time per move." ) + parser.add_argument("--hash", type=int, default=8) + parser.add_argument("--threads", type=int, default=1) + parser.add_argument("--games_per_round", type=int, default=200) features.add_argparse_args(parser) args = parser.parse_args() @@ -497,19 +564,19 @@ def main(): stockfish_test = stockfish_base if not shutil.which(stockfish_base): - sys.exit("Stockfish base is not executable !") + sys.exit("Stockfish base is not executable !") if not shutil.which(stockfish_test): - sys.exit("Stockfish test is not executable!") + sys.exit("Stockfish test is not executable!") if args.ordo_exe and not shutil.which(args.ordo_exe): - sys.exit("ordo is not executable!") + sys.exit("ordo is not executable!") if not shutil.which(args.c_chess_exe): - sys.exit("c_chess_cli is not executable!") + sys.exit("c_chess_cli is not executable!") if not os.path.exists(args.book_file_name): - sys.exit("book does not exist!") + sys.exit("book does not exist!") random.seed() @@ -524,7 +591,14 @@ def main(): args.book_file_name, args.concurrency, args.features, - GameParams(args.hash, args.threads, args.games_per_round, args.time_per_game, args.time_increment_per_move, args.nodes_per_move) + GameParams( + args.hash, + args.threads, + args.games_per_round, + args.time_per_game, + args.time_increment_per_move, + args.nodes_per_move, + ), ) diff --git a/scripts/easy_train.py b/scripts/easy_train.py index 8a43cdb7..ca7b6039 100644 --- a/scripts/easy_train.py +++ b/scripts/easy_train.py @@ -24,100 +24,117 @@ LOGGER.addHandler(logging.StreamHandler(stream=sys.stdout)) LOGGER.propagate = False + def validate_python_version(): if sys.version_info >= (3, 7): - LOGGER.info(f'Found python version {sys.version}. OK.') + LOGGER.info(f"Found python version {sys.version}. OK.") return True else: - LOGGER.error(f'Found python version {sys.version} but 3.7 is required. Exiting.') + LOGGER.error( + f"Found python version {sys.version} but 3.7 is required. Exiting." + ) return False # Functions for checking external dependencies. + def run_for_version(name): process = subprocess.Popen( - [name, '--version'], + [name, "--version"], shell=False, bufsize=-1, universal_newlines=True, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT + stderr=subprocess.STDOUT, ) return process.stdout.read() + def validate_cmake(): success = True try: - out = run_for_version('cmake') - parts = out.split('\n')[0].split() + out = run_for_version("cmake") + parts = out.split("\n")[0].split() version_str = parts[-1] - major_version = int(version_str.split('.')[0]) - minor_version = int(version_str.split('.')[1]) + major_version = int(version_str.split(".")[0]) + minor_version = int(version_str.split(".")[1]) success = (major_version, minor_version) >= (3, 4) if success: - LOGGER.info(f'Found cmake executable version {version_str}. OK.') + LOGGER.info(f"Found cmake executable version {version_str}. OK.") else: - LOGGER.error(f'Found cmake executable version {version_str} but at least 3.4 required. Exiting.') + LOGGER.error( + f"Found cmake executable version {version_str} but at least 3.4 required. Exiting." + ) except: success = False - LOGGER.error('No cmake executable found. Exiting.') + LOGGER.error("No cmake executable found. Exiting.") return success + def validate_make(): success = True try: - out = run_for_version('make') - parts = out.split('\n')[0].split() + out = run_for_version("make") + parts = out.split("\n")[0].split() version_str = parts[-1] - major_version = int(version_str.split('.')[0]) + major_version = int(version_str.split(".")[0]) success = major_version >= 3 if success: - LOGGER.info(f'Found make executable version {version_str}. OK.') + LOGGER.info(f"Found make executable version {version_str}. OK.") else: - LOGGER.error(f'Found make executable version {version_str} but at least 3 required. Exiting.') + LOGGER.error( + f"Found make executable version {version_str} but at least 3 required. Exiting." + ) except: success = False - LOGGER.error('No make executable found. Exiting.') + LOGGER.error("No make executable found. Exiting.") return success + def validate_gcc(): success = True try: - out = run_for_version('gcc') - parts = out.split('\n')[0].split() + out = run_for_version("gcc") + parts = out.split("\n")[0].split() for part in parts: try: - version_str = part # sometimes there are trailing strings in the version number - major_version = int(version_str.split('.')[0]) - minor_version = int(version_str.split('.')[1]) + version_str = ( + part # sometimes there are trailing strings in the version number + ) + major_version = int(version_str.split(".")[0]) + minor_version = int(version_str.split(".")[1]) success = (major_version, minor_version) >= (9, 2) except: continue if success: - LOGGER.info(f'Found gcc executable version {version_str}. OK.') + LOGGER.info(f"Found gcc executable version {version_str}. OK.") else: - LOGGER.error(f'Found gcc executable version {version_str} but at least 9.2 required. Exiting.') + LOGGER.error( + f"Found gcc executable version {version_str} but at least 9.2 required. Exiting." + ) except: success = False - LOGGER.error('No gcc executable found. Exiting.') + LOGGER.error("No gcc executable found. Exiting.") return success + def maybe_int(v): try: return int(v) except: return v + class PackageInfo: - ''' + """ Represents an [installed] python package. - ''' + """ def __init__(self, name): self._spec = importlib.util.find_spec(name) @@ -126,7 +143,9 @@ def __init__(self, name): try: if self._spec: self._version_str = importlib.metadata.version(name) - self._version_tup = tuple(maybe_int(v) for v in self._version_str.split('.')) + self._version_tup = tuple( + maybe_int(v) for v in self._version_str.split(".") + ) except: pass @@ -141,65 +160,86 @@ def is_version_at_least(self, desired): def version(self): return self._version_str + # Functions for checking required python packages. + def validate_asciimatics(): - pkg = PackageInfo('asciimatics') + pkg = PackageInfo("asciimatics") if pkg.exists: - LOGGER.info('Found asciimatics package. OK.') + LOGGER.info("Found asciimatics package. OK.") return True else: - LOGGER.error('No asciimatics package found. Run `pip install asciimatics`. Exiting.') + LOGGER.error( + "No asciimatics package found. Run `pip install asciimatics`. Exiting." + ) return False + def validate_pytorch(): - pkg = PackageInfo('torch') + pkg = PackageInfo("torch") if pkg.exists: if pkg.is_version_at_least((1, 7)): - LOGGER.info(f'Found torch version {pkg.version}. OK.') + LOGGER.info(f"Found torch version {pkg.version}. OK.") from torch import cuda + if cuda.is_available() and cuda.device_count() > 0: - LOGGER.info(f'Found torch with CUDA. OK.') + LOGGER.info(f"Found torch with CUDA. OK.") return True else: - LOGGER.error(f'Found torch without CUDA but CUDA support required. Exiting') + LOGGER.error( + f"Found torch without CUDA but CUDA support required. Exiting" + ) return False else: - LOGGER.error(f'Found torch version {pkg.version} but at least 1.8 required. Exiting.') + LOGGER.error( + f"Found torch version {pkg.version} but at least 1.8 required. Exiting." + ) return False else: - LOGGER.error('No torch package found. Install at least torch 1.8 with cuda. See https://pytorch.org/. Exiting.') + LOGGER.error( + "No torch package found. Install at least torch 1.8 with cuda. See https://pytorch.org/. Exiting." + ) return False + def validate_pytorchlightning(): - pkg = PackageInfo('pytorch_lightning') + pkg = PackageInfo("pytorch_lightning") if pkg.exists: - LOGGER.info(f'Found pytorch_lightning version {pkg.version}. OK.') + LOGGER.info(f"Found pytorch_lightning version {pkg.version}. OK.") return True else: - LOGGER.error('No pytorch_lightning found. Run `pip install pytorch-lightning`. Exiting.') + LOGGER.error( + "No pytorch_lightning found. Run `pip install pytorch-lightning`. Exiting." + ) return False + def validate_cupy(): - pkg = PackageInfo('cupy') + pkg = PackageInfo("cupy") if pkg.exists: - LOGGER.info(f'Found cupy version {pkg.version}. OK.') + LOGGER.info(f"Found cupy version {pkg.version}. OK.") return True else: - LOGGER.error('No cupy found. Install cupy matching cuda version used by pytorch. See https://cupy.dev/. Exiting.') + LOGGER.error( + "No cupy found. Install cupy matching cuda version used by pytorch. See https://cupy.dev/. Exiting." + ) return False + def validate_gputil(): - pkg = PackageInfo('GPUtil') + pkg = PackageInfo("GPUtil") if pkg.exists: - LOGGER.info(f'Found GPUtil version {pkg.version}. OK.') + LOGGER.info(f"Found GPUtil version {pkg.version}. OK.") return True else: - LOGGER.error('No GPUtil found. Run `pip install GPUtil`. Exiting.') + LOGGER.error("No GPUtil found. Run `pip install GPUtil`. Exiting.") return False + # Validation of required external and package dependencies. + def validate_imports(): success = True success &= validate_asciimatics() @@ -209,6 +249,7 @@ def validate_imports(): success &= validate_gputil() return success + def validate_environment_requirements(): success = True try: @@ -222,13 +263,26 @@ def validate_environment_requirements(): return False return success + # Exit early if the requires packages have not been found if not validate_environment_requirements(): sys.exit(EXITCODE_MISSING_DEPENDENCIES) # Only now import the rest of the required packages -from asciimatics.widgets import Frame, ListBox, Layout, Divider, Text, Button, \ - TextBox, Widget, VerticalDivider, MultiColumnListBox, Label, PopUpDialog +from asciimatics.widgets import ( + Frame, + ListBox, + Layout, + Divider, + Text, + Button, + TextBox, + Widget, + VerticalDivider, + MultiColumnListBox, + Label, + PopUpDialog, +) from asciimatics.scene import Scene from asciimatics.screen import Screen from asciimatics.exceptions import ResizeScreenError, NextScene, StopApplication @@ -252,12 +306,13 @@ def validate_environment_requirements(): # Specify which versions of ordo and c-chess-cli we want. # We rely on specific well-tested commits because we know exactly what we need. # repo/branch, commit id -ORDO_GIT = ('michiguel/Ordo', '17eec774f2e4b9fdd2b1b38739f55ea221fb851a') -C_CHESS_CLI_GIT = ('lucasart/c-chess-cli', '6d08fee2e95b259c486b21a886f6911b61f676af') -TIMEOUT = 600.0 # on some systems starting pytorch can be really slow +ORDO_GIT = ("michiguel/Ordo", "17eec774f2e4b9fdd2b1b38739f55ea221fb851a") +C_CHESS_CLI_GIT = ("lucasart/c-chess-cli", "6d08fee2e95b259c486b21a886f6911b61f676af") +TIMEOUT = 600.0 # on some systems starting pytorch can be really slow + def terminate_process_on_exit(process): - ''' + """ Create a watchdog process that awaits the termination of this (calling) process and automatically terminates a given process (python's subprocess object) after. @@ -268,46 +323,52 @@ def terminate_process_on_exit(process): TODO: powershell version TODO: linux version - ''' + """ if sys.platform == "win32": try: # We cannot execute from string so we write the script to a file. # Doesn't do anything if the file already exists. - with open('.process_watchdog_helper.bat', 'x') as file: - file.write(""":waitforpid + with open(".process_watchdog_helper.bat", "x") as file: + file.write( + """:waitforpid tasklist /nh /fi "pid eq %1" 2>nul | find "%1" >nul if %ERRORLEVEL%==0 ( timeout /t 5 /nobreak >nul goto :waitforpid ) else ( wmic process where processid="%2" call terminate >nul -)""") +)""" + ) except: pass subprocess.Popen( - ['.process_watchdog_helper.bat', str(os.getpid()), str(process.pid)], + [".process_watchdog_helper.bat", str(os.getpid()), str(process.pid)], stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL + stderr=subprocess.DEVNULL, ) elif sys.platform == "linux": # TODO: this pass + # Exits the process forcefully after a specified amount of seconds with a given error code TUI_SCREEN = None + + def schedule_exit(timeout_seconds, errcode): def f(): time.sleep(timeout_seconds) - LOGGER.info(f'Performing a scheduled exit.') + LOGGER.info(f"Performing a scheduled exit.") if TUI_SCREEN: - if sys.platform == 'win32': + if sys.platform == "win32": TUI_SCREEN.close(restore=True) else: # We cannot call .close directly because it tries to reset signals... # But resetting signals won't work from a non-main thread... import curses + TUI_SCREEN._screen.keypad(0) curses.echo() curses.nocbreak() @@ -319,15 +380,23 @@ def f(): thread.daemon = True thread.start() + if sys.platform == "win32": import ctypes WINAPI_CreateMutex = ctypes.windll.kernel32.CreateMutexA - WINAPI_CreateMutex.argtypes = [ctypes.wintypes.LPCVOID, ctypes.wintypes.BOOL, ctypes.c_char_p] + WINAPI_CreateMutex.argtypes = [ + ctypes.wintypes.LPCVOID, + ctypes.wintypes.BOOL, + ctypes.c_char_p, + ] WINAPI_CreateMutex.restype = ctypes.wintypes.HANDLE WINAPI_WaitForSingleObject = ctypes.windll.kernel32.WaitForSingleObject - WINAPI_WaitForSingleObject.argtypes = [ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD] + WINAPI_WaitForSingleObject.argtypes = [ + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ] WINAPI_WaitForSingleObject.restype = ctypes.wintypes.DWORD WINAPI_ReleaseMutex = ctypes.windll.kernel32.ReleaseMutex @@ -343,9 +412,9 @@ def __init__(self, name): # \ is a reserved character so we have to convert them to / to be recognized as # directory delimiters # encode as utf-8 because LPCSTR is bytes not str - self.name = str(os.path.abspath(name)).replace('\\', '/').encode('utf-8') + self.name = str(os.path.abspath(name)).replace("\\", "/").encode("utf-8") self.acquired = False - self.file = open(self.name, 'a+') + self.file = open(self.name, "a+") self.handle = WINAPI_CreateMutex(None, False, self.name) if not self.handle: raise ctypes.WinError() @@ -390,6 +459,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.release() self.close() + else: import fcntl @@ -397,7 +467,7 @@ class SystemWideMutex: def __init__(self, name): self.name = name self.acquired = False - self.file = open(self.name, 'a+') + self.file = open(self.name, "a+") def acquire(self): fcntl.lockf(self.file, fcntl.LOCK_EX) @@ -424,12 +494,13 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.release() + class DecayingRunningAverage: - ''' + """ Represents an average of a list of values with exponential decay of old values. Every added value has weight of decay**n, where n is the distance from the last element. For the last added element n==0. - ''' + """ def __init__(self, decay=0.995): self._decay = decay @@ -445,21 +516,22 @@ def value(self): try: return self._total / self._count except: - return float('NaN') + return float("NaN") def update(self, value): - ''' + """ Adds a new value at the end of the implicit running average list and updates the counters to reflect the change in the running average. - ''' + """ self._total = self._total * self._decay + value self._count = self._count * self._decay + 1.0 + class SystemResources: - ''' + """ Holds information about the usage of system resources at a time point of creation. This includes GPU, CPU, and memory usage. - ''' + """ def __init__(self): self._gpus = dict() @@ -486,11 +558,12 @@ def ram_usage_mb(self): def ram_max_mb(self): return self._ram_max_mb + class SystemResourcesMonitor(Thread): - ''' + """ Periodically queries system resources. Runs as a daemon so does not need to be cleaned up. - ''' + """ def __init__(self, period_seconds): super(SystemResourcesMonitor, self).__init__() @@ -519,9 +592,9 @@ def run(self): @property def resources(self): - ''' + """ Returns the most recent system resources measurement. - ''' + """ self._mutex.acquire() try: return self._resources @@ -532,25 +605,36 @@ def stop(self): self._running = False self._stop_event.set() + def find_latest_checkpoint(root_dir): - ''' + """ Recursively searches the specified directory for the .ckpt file with the latest creation date. - ''' + """ ckpts = [file for file in Path(root_dir).rglob("*.ckpt")] if not ckpts: return None return str(max(ckpts, key=lambda p: p.stat().st_ctime_ns)) + class OrdoEntry: - ''' + """ Represents a single entry in an ordo file. Expects players to be named after network paths, if the form experiment_path/run_{}/nn-epoch{}.nnue - ''' + """ + + NET_PATTERN = re.compile(r".*?run_(\d+).*?nn-epoch(\d+)\.nnue") - NET_PATTERN = re.compile(r'.*?run_(\d+).*?nn-epoch(\d+)\.nnue') - def __init__(self, line=None, network_path=None, elo=None, elo_error=None, run_id=None, epoch=None): + def __init__( + self, + line=None, + network_path=None, + elo=None, + elo_error=None, + run_id=None, + epoch=None, + ): if line: fields = line.split() self._network_path = fields[1] @@ -586,8 +670,9 @@ def elo(self): def elo_error(self): return self._elo_error + def find_best_checkpoint(root_dir): - ''' + """ Recursively searches the specified directory the best .ckpt file as determined by an ordo output file that must be present under the path os.path.join(root_dir, 'ordo.out'). @@ -596,52 +681,57 @@ def find_best_checkpoint(root_dir): Returns None if the ordo file does not exist or no suitable checkpoint has been found. - ''' + """ ckpts = [str(file) for file in Path(root_dir).rglob("*.ckpt")] nnues = [str(file) for file in Path(root_dir).rglob("*.nnue")] - ordo_file_path = os.path.join(root_dir, 'ordo.out') + ordo_file_path = os.path.join(root_dir, "ordo.out") - with open(ordo_file_path, 'r') as ordo_file: + with open(ordo_file_path, "r") as ordo_file: entries = [] lines = ordo_file.readlines() for line in lines: - if 'nn-epoch' in line: + if "nn-epoch" in line: try: entries.append(OrdoEntry(line=line)) except: pass - entries.sort(key=lambda x:-x.elo+x.elo_error) + entries.sort(key=lambda x: -x.elo + x.elo_error) run_id = entries[0].run_id epoch = entries[0].epoch for ckpt in ckpts: - if f'run_{run_id}' in ckpt and f'epoch={epoch}' in ckpt: + if f"run_{run_id}" in ckpt and f"epoch={epoch}" in ckpt: return ckpt # fallback to .nnue if no checkpoint file for nnue in nnues: - if f'run_{run_id}' in nnue and f'nn-epoch{epoch}' in nnue: + if f"run_{run_id}" in nnue and f"nn-epoch{epoch}" in nnue: return nnue return None + # A global instance of the resource monitor. # There is no need to have more than one. RESOURCE_MONITOR = SystemResourcesMonitor(2) # A regex pattern for a float number. -NUMERIC_CONST_PATTERN = '[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[Ee][+-]?\d+)?' +NUMERIC_CONST_PATTERN = "[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[Ee][+-]?\d+)?" + class TrainingRun(Thread): - ''' + """ Manages a single pytorch training run. Starts it as a subprocess. Provides information about the current state of training. Runs as a separate thread and must be stopped before exiting. - ''' + """ # The regex pattern for extracting information from the pytorch lightning's tqdm process bar output - ITERATION_PATTERN = re.compile(f'Epoch (\\d+).*?(\\d+)/(\\d+).*?({NUMERIC_CONST_PATTERN})it/s, loss=({NUMERIC_CONST_PATTERN})') + ITERATION_PATTERN = re.compile( + f"Epoch (\\d+).*?(\\d+)/(\\d+).*?({NUMERIC_CONST_PATTERN})it/s, loss=({NUMERIC_CONST_PATTERN})" + ) + def __init__( self, gpu_id, @@ -671,7 +761,7 @@ def __init__( resume_training=False, start_lambda=None, end_lambda=None, - additional_args=[] + additional_args=[], ): super(TrainingRun, self).__init__() @@ -726,48 +816,48 @@ def __init__( def _get_stringified_args(self): args = [ - f'--num-workers={self._num_data_loader_threads}', - f'--threads={self._num_pytorch_threads}', - f'--max_epoch={self._num_epochs}', - f'--batch-size={self._batch_size}', - f'--random-fen-skipping={self._random_fen_skipping}', - f'--early-fen-skipping={self._early_fen_skipping}', - f'--gpus={self._gpu_id},', - f'--features={self._features}', - f'--lr={self._lr}', - f'--gamma={self._gamma}', - f'--lambda={self._lambda}', - f'--network-save-period={self._network_save_period}', - f'--save-last-network={self._save_last_network}', - f'--seed={self._seed}', - f'--epoch-size={self._epoch_size}', - f'--validation-size={self._validation_size}', - f'--default_root_dir={self._root_dir}', + f"--num-workers={self._num_data_loader_threads}", + f"--threads={self._num_pytorch_threads}", + f"--max_epoch={self._num_epochs}", + f"--batch-size={self._batch_size}", + f"--random-fen-skipping={self._random_fen_skipping}", + f"--early-fen-skipping={self._early_fen_skipping}", + f"--gpus={self._gpu_id},", + f"--features={self._features}", + f"--lr={self._lr}", + f"--gamma={self._gamma}", + f"--lambda={self._lambda}", + f"--network-save-period={self._network_save_period}", + f"--save-last-network={self._save_last_network}", + f"--seed={self._seed}", + f"--epoch-size={self._epoch_size}", + f"--validation-size={self._validation_size}", + f"--default_root_dir={self._root_dir}", ] if self._smart_fen_skipping: - args.append('--smart-fen-skipping') + args.append("--smart-fen-skipping") else: - args.append('--no-smart-fen-skipping') + args.append("--no-smart-fen-skipping") if not self._wld_fen_skipping: - args.append('--no-wld-fen-skipping') + args.append("--no-wld-fen-skipping") if self._start_lambda: - args.append(f'--start-lambda={self._start_lambda}') + args.append(f"--start-lambda={self._start_lambda}") if self._end_lambda: - args.append(f'--end-lambda={self._end_lambda}') + args.append(f"--end-lambda={self._end_lambda}") resumed = False if self._resume_training: ckpt_path = find_latest_checkpoint(self._root_dir) if ckpt_path: - args.append(f'--resume_from_checkpoint={ckpt_path}') + args.append(f"--resume_from_checkpoint={ckpt_path}") resumed = True if self._start_from_model and not resumed: - args.append(f'--resume-from-model={self._start_from_model}') + args.append(f"--resume-from-model={self._start_from_model}") for arg in self._additional_args: args.append(arg) @@ -776,12 +866,14 @@ def _get_stringified_args(self): args.append(dataset) for dataset in self._validation_datasets: - args.append(f'--validation-data={dataset}') + args.append(f"--validation-data={dataset}") return args def run(self): - if self._resume_training and os.path.exists(os.path.join(self._root_dir, 'training_finished')): + if self._resume_training and os.path.exists( + os.path.join(self._root_dir, "training_finished") + ): self._has_started = True self._has_finished = True self._running = False @@ -789,10 +881,10 @@ def run(self): self._running = True - cmd = [sys.executable, 'train.py'] + self._get_stringified_args() - LOGGER.info(f'Running training with command: {cmd}') + cmd = [sys.executable, "train.py"] + self._get_stringified_args() + LOGGER.info(f"Running training with command: {cmd}") LOGGER.info(f'Also known as: {" ".join(cmd)}') - LOGGER.info(f'Running in working directory: {self._nnue_pytorch_directory}') + LOGGER.info(f"Running in working directory: {self._nnue_pytorch_directory}") self._process = subprocess.Popen( cmd, cwd=self._nnue_pytorch_directory, @@ -834,10 +926,14 @@ def run(self): self._last_step = curr_step continue - #self._momentary_iterations_per_second = float(matches.group(4)) + # self._momentary_iterations_per_second = float(matches.group(4)) if curr_step % 10 == 0: - self._momentary_iterations_per_second = (curr_step-self._last_step)/((curr_time-self._last_time)/1e9) - self._smooth_iterations_per_second.update(self._momentary_iterations_per_second) + self._momentary_iterations_per_second = ( + curr_step - self._last_step + ) / ((curr_time - self._last_time) / 1e9) + self._smooth_iterations_per_second.update( + self._momentary_iterations_per_second + ) self._last_time = curr_time self._last_step = curr_step @@ -855,9 +951,9 @@ def run(self): LOGGER.info(line) pass - if 'CUDA_ERROR_OUT_OF_MEMORY' in line or 'CUDA out of memory' in line: + if "CUDA_ERROR_OUT_OF_MEMORY" in line or "CUDA out of memory" in line: self._process.terminate() - self._error = 'Cuda out of memory error.' + self._error = "Cuda out of memory error." break # Since _num_steps_in_epochs includes validation steps, that we cannot actually catch @@ -865,16 +961,20 @@ def run(self): # we can just estimate whether it finished with a success by using some margin... # NOTE: We still cannot catch when the trainer exits with no work, which for example # happens when resuming from a checkpoint at the end of training. - if self._has_started and self._current_epoch == self._num_epochs - 1 and self._current_step_in_epoch >= self._num_steps_in_epoch * 0.9: + if ( + self._has_started + and self._current_epoch == self._num_epochs - 1 + and self._current_step_in_epoch >= self._num_steps_in_epoch * 0.9 + ): self._has_finished = True if self._running and not self._has_finished: if not self._error: - self._error = 'Unknown error occured.' - LOGGER.warning(f'Training run {self._run_id} exited unexpectedly.') - LOGGER.error(f'Error: {self._error}') + self._error = "Unknown error occured." + LOGGER.warning(f"Training run {self._run_id} exited unexpectedly.") + LOGGER.error(f"Error: {self._error}") else: - LOGGER.info(f'Training run {self._run_id} finished.') + LOGGER.info(f"Training run {self._run_id} finished.") self._has_started = True self._running = False @@ -946,209 +1046,236 @@ def error(self): def batch_size(self): return self._batch_size + def requests_get_content(url, *args, **kwargs): try: result = requests.get(url, *args, **kwargs) result.raise_for_status() return result.content except Exception as e: - raise Exception(f'GET request to {url} failed') + raise Exception(f"GET request to {url} failed") + def get_zipfile_members_strip_common_prefix(zipfile): - ''' + """ Removes a common previx from zipfile entries. So for example will remove the top-level directory. - ''' + """ parts = [] for name in zipfile.namelist(): - if not name.endswith('/'): - parts.append(name.split('/')[:-1]) - offset = len('/'.join(os.path.commonprefix(parts)) + '/') + if not name.endswith("/"): + parts.append(name.split("/")[:-1]) + offset = len("/".join(os.path.commonprefix(parts)) + "/") for zipinfo in zipfile.infolist(): name = zipinfo.filename if len(name) > offset: zipinfo.filename = name[offset:] yield zipinfo + def git_download_branch_or_commit(directory, repo, branch_or_commit): - ''' + """ Github proves an API to download zips of specific commits, so we don't need to use git clone. - ''' - url = f'http://github.com/{repo}/zipball/{branch_or_commit}' + """ + url = f"http://github.com/{repo}/zipball/{branch_or_commit}" zipped_content = requests_get_content(url, timeout=TIMEOUT) - zipped_input = zipfile.ZipFile(io.BytesIO(zipped_content), mode='r') - zipped_input.extractall(directory, get_zipfile_members_strip_common_prefix(zipped_input)) + zipped_input = zipfile.ZipFile(io.BytesIO(zipped_content), mode="r") + zipped_input.extractall( + directory, get_zipfile_members_strip_common_prefix(zipped_input) + ) + # Utility functions for dependency setup and executable location. + def make_ordo_executable_path(directory): - path = os.path.join(directory, 'ordo') + path = os.path.join(directory, "ordo") if sys.platform == "win32": - path += '.exe' + path += ".exe" return path + def is_ordo_setup(directory): try: ordo_path = make_ordo_executable_path(directory) - with subprocess.Popen([ordo_path, '--help'], stdout=subprocess.DEVNULL) as process: + with subprocess.Popen( + [ordo_path, "--help"], stdout=subprocess.DEVNULL + ) as process: if process.wait(timeout=TIMEOUT): return False return True except: return False + def setup_ordo(directory): if is_ordo_setup(directory): - LOGGER.info(f'Ordo already setup in {directory}') + LOGGER.info(f"Ordo already setup in {directory}") return - LOGGER.info(f'Setting up ordo in {directory}.') + LOGGER.info(f"Setting up ordo in {directory}.") git_download_branch_or_commit(directory, *ORDO_GIT) if sys.platform == "win32": # need to append -DMINGW # ugly hack for a dumb makefile - with open(os.path.join(directory, 'Makefile'), 'r') as makefile: + with open(os.path.join(directory, "Makefile"), "r") as makefile: lines = makefile.readlines() for i, line in enumerate(lines): - if line.startswith('CFLAGS'): - lines.insert(i+1, 'CFLAGS += -DMINGW\n') + if line.startswith("CFLAGS"): + lines.insert(i + 1, "CFLAGS += -DMINGW\n") break - with open(os.path.join(directory, 'Makefile'), 'w') as makefile: - makefile.write(''.join(lines)) + with open(os.path.join(directory, "Makefile"), "w") as makefile: + makefile.write("".join(lines)) - with subprocess.Popen(['make'], cwd=directory) as process: + with subprocess.Popen(["make"], cwd=directory) as process: if process.wait(): - raise Exception('Ordo compilation failed.') + raise Exception("Ordo compilation failed.") if not is_ordo_setup(directory): - raise Exception('Ordo does not work.') + raise Exception("Ordo does not work.") + def make_c_chess_cli_executable_path(directory): - path = os.path.join(directory, 'c-chess-cli') + path = os.path.join(directory, "c-chess-cli") if sys.platform == "win32": - path += '.exe' + path += ".exe" return path + def is_c_chess_cli_setup(directory): try: path = make_c_chess_cli_executable_path(directory) - with subprocess.Popen([path, '-version'], stdout=subprocess.DEVNULL) as process: + with subprocess.Popen([path, "-version"], stdout=subprocess.DEVNULL) as process: if process.wait(timeout=TIMEOUT): return False return True except: return False + def setup_c_chess_cli(directory): if is_c_chess_cli_setup(directory): - LOGGER.info(f'c-chess-cli already setup in {directory}') + LOGGER.info(f"c-chess-cli already setup in {directory}") return - LOGGER.info(f'Setting up c-chess-cli in {directory}.') + LOGGER.info(f"Setting up c-chess-cli in {directory}.") git_download_branch_or_commit(directory, *C_CHESS_CLI_GIT) - with open(os.path.join(directory, 'make.py'), 'r') as makefile: + with open(os.path.join(directory, "make.py"), "r") as makefile: lines = makefile.readlines() for i, line in enumerate(lines): - if line.startswith('version = '): - lines[i] = f'version = \'easy_train_custom_{C_CHESS_CLI_GIT[1]}\'\n' + if line.startswith("version = "): + lines[i] = f"version = 'easy_train_custom_{C_CHESS_CLI_GIT[1]}'\n" - with open(os.path.join(directory, 'make.py'), 'w') as makefile: - makefile.write(''.join(lines)) + with open(os.path.join(directory, "make.py"), "w") as makefile: + makefile.write("".join(lines)) - with subprocess.Popen([sys.executable, 'make.py', '-c', 'gcc'], cwd=directory) as process: + with subprocess.Popen( + [sys.executable, "make.py", "-c", "gcc"], cwd=directory + ) as process: if process.wait(): - raise Exception('c-chess-cli compilation failed.') + raise Exception("c-chess-cli compilation failed.") if not is_c_chess_cli_setup(directory): - raise Exception('c-chess-cli does not work') + raise Exception("c-chess-cli does not work") + def make_stockfish_executable_path(directory): - path = os.path.join(directory, 'src/stockfish') + path = os.path.join(directory, "src/stockfish") if sys.platform == "win32": - path += '.exe' + path += ".exe" return path + def is_stockfish_setup(directory): try: path = make_stockfish_executable_path(directory) - with subprocess.Popen([path, 'compiler'], stdout=subprocess.DEVNULL) as process: + with subprocess.Popen([path, "compiler"], stdout=subprocess.DEVNULL) as process: if process.wait(timeout=TIMEOUT): return False return True except: return False + def setup_stockfish(directory, repo, branch_or_commit, arch, threads=1): if is_stockfish_setup(directory): - LOGGER.info(f'Stockfish already setup in {directory}.') + LOGGER.info(f"Stockfish already setup in {directory}.") return - LOGGER.info(f'Setting up stockfish in {directory}.') + LOGGER.info(f"Setting up stockfish in {directory}.") git_download_branch_or_commit(directory, repo, branch_or_commit) - srcdir = os.path.join(directory, 'src') + srcdir = os.path.join(directory, "src") env = os.environ.copy() - if sys.platform == 'win32': - env['MSYSTEM'] = 'MINGW64' + if sys.platform == "win32": + env["MSYSTEM"] = "MINGW64" with subprocess.Popen( - ['make', 'build', f'ARCH={arch}', f'-j{threads}'], - cwd=srcdir, - env=env - ) as process: + ["make", "build", f"ARCH={arch}", f"-j{threads}"], cwd=srcdir, env=env + ) as process: if process.wait(): - raise Exception(f'stockfish {repo}/{branch_or_commit} compilation failed') + raise Exception(f"stockfish {repo}/{branch_or_commit} compilation failed") if not is_stockfish_setup(directory): - raise Exception(f'stockfish {repo}/{branch_or_commit} does not work') + raise Exception(f"stockfish {repo}/{branch_or_commit} does not work") + def is_nnue_pytorch_setup(directory): try: - with subprocess.Popen([sys.executable, 'nnue_dataset.py'], cwd=directory) as process: + with subprocess.Popen( + [sys.executable, "nnue_dataset.py"], cwd=directory + ) as process: if process.wait(timeout=TIMEOUT): return False return True except: return False + def setup_nnue_pytorch(directory, repo, branch_or_commit): if is_nnue_pytorch_setup(directory): - LOGGER.info(f'nnue-pytorch already setup in {directory}') + LOGGER.info(f"nnue-pytorch already setup in {directory}") return - LOGGER.info(f'Setting up nnue-pytorch in {directory}') + LOGGER.info(f"Setting up nnue-pytorch in {directory}") git_download_branch_or_commit(directory, repo, branch_or_commit) command = [] if sys.platform == "linux": - command += ['sh'] + command += ["sh"] # It's a .bat file made for windows but works on linux too. # Just needs to be called with sh. - command += [os.path.join(directory, 'compile_data_loader.bat')] + command += [os.path.join(directory, "compile_data_loader.bat")] with subprocess.Popen(command, cwd=directory) as process: if process.wait(): - raise Exception(f'nnue-pytorch {repo}/{branch_or_commit} data loader compilation failed') + raise Exception( + f"nnue-pytorch {repo}/{branch_or_commit} data loader compilation failed" + ) if not is_nnue_pytorch_setup(directory): - raise Exception(f'Incorrect nnue-pytorch setup or timeout.') + raise Exception(f"Incorrect nnue-pytorch setup or timeout.") + class CChessCliRunningTestEntry: - ''' + """ Represents a single line of output from the run_games.py (which forwards c-chess-cli output) during network testing process. Calculates additional match statistics. - ''' + """ + + LINE_PATTERN = re.compile( + r"Score.*?run_(\d+).*?nn-epoch(\d+)\.nnue:\s*(\d+)\s*-\s*(\d+)\s*-\s*(\d+)\s*" + ) - LINE_PATTERN = re.compile(r'Score.*?run_(\d+).*?nn-epoch(\d+)\.nnue:\s*(\d+)\s*-\s*(\d+)\s*-\s*(\d+)\s*') def __init__(self, line=None): fields = CChessCliRunningTestEntry.LINE_PATTERN.search(line) self._line = line self._run_id = int(fields[1]) self._epoch = int(fields[2]) - self._losses = int(fields[3]) # from base perspective so reversed + self._losses = int(fields[3]) # from base perspective so reversed self._wins = int(fields[4]) self._draws = int(fields[5]) @@ -1200,14 +1327,14 @@ def line(self): class NetworkTesting(Thread): - ''' + """ Manages the network testing process. Encapsulates run_games.py. Provides information about the current set of networks and their results. Provides information about the currently ongoing tests. Provides information about the current ongoing network conversions. Runs as a separate thread and must be stopped before exiting. - ''' + """ def __init__( self, @@ -1215,7 +1342,7 @@ def __init__( root_dir, num_parallel_games=4, explore_factor=1.5, - book_file_path='', + book_file_path="", time_per_game=None, time_increment_per_move=None, nodes_per_move=1000, @@ -1227,7 +1354,7 @@ def __init__( stockfish_test_exe=None, features=None, active=True, - additional_args=[] + additional_args=[], ): super(NetworkTesting, self).__init__() @@ -1256,7 +1383,7 @@ def __init__( self._current_test = None self._current_convert = None self._error = None - self._has_finished = False # currently never finishes + self._has_finished = False # currently never finishes self._has_started = False self._mutex = Lock() @@ -1264,28 +1391,28 @@ def __init__( def _get_stringified_args(self): args = [ self._root_dir, - f'--concurrency={self._num_parallel_games}', - f'--explore_factor={self._explore_factor}', - f'--c_chess_exe={self._c_chess_cli_exe}', - f'--stockfish_base={self._stockfish_base_exe}', - f'--stockfish_test={self._stockfish_test_exe}', - f'--book_file_name={self._book_file_path}', - f'--hash={self._hash}', - f'--games_per_round={self._games_per_round}', - f'--features={self._features}', + f"--concurrency={self._num_parallel_games}", + f"--explore_factor={self._explore_factor}", + f"--c_chess_exe={self._c_chess_cli_exe}", + f"--stockfish_base={self._stockfish_base_exe}", + f"--stockfish_test={self._stockfish_test_exe}", + f"--book_file_name={self._book_file_path}", + f"--hash={self._hash}", + f"--games_per_round={self._games_per_round}", + f"--features={self._features}", ] if self._time_per_game: - args.append(f'--time_per_game={self._time_per_game}') + args.append(f"--time_per_game={self._time_per_game}") if self._time_increment_per_move: - args.append(f'--time_increment_per_move={self._time_increment_per_move}') + args.append(f"--time_increment_per_move={self._time_increment_per_move}") if self._nodes_per_move: - args.append(f'--nodes_per_move={self._nodes_per_move}') + args.append(f"--nodes_per_move={self._nodes_per_move}") if self._ordo_exe: - args.append(f'--ordo_exe={self._ordo_exe}'), + args.append(f"--ordo_exe={self._ordo_exe}"), for arg in self._additional_args: args.append(arg) @@ -1296,35 +1423,35 @@ def get_status_string(self): self._mutex.acquire() try: if not self._active: - return 'Network testing inactive.' + return "Network testing inactive." elif self._has_finished: - return 'Network testing finished.' + return "Network testing finished." elif not self._has_started: - return 'Starting testing process...' + return "Starting testing process..." elif not self._running: - lines = ['Network testing has exited unexpectedly.'] + lines = ["Network testing has exited unexpectedly."] if self._error: - lines.append(f'Error: {self._error}') - return '\n'.join(lines) + lines.append(f"Error: {self._error}") + return "\n".join(lines) elif self._current_convert is not None: lines = [ - f'Converting network...', - f'Run : {self._current_convert[0]}', - f'Epoch: {self._current_convert[1]}' + f"Converting network...", + f"Run : {self._current_convert[0]}", + f"Epoch: {self._current_convert[1]}", ] - return '\n'.join(lines) + return "\n".join(lines) elif self._current_test is not None: perf_pct = int(round(self._current_test.performance * 100)) cpu_usage = RESOURCE_MONITOR.resources.cpu_usage lines = [ - f'CPU load: {cpu_usage * 100:0.1f}%', - f'Testing run {self._current_test.run_id} epoch {self._current_test.epoch}', - f'+{self._current_test.wins}={self._current_test.draws}-{self._current_test.losses} [{perf_pct:0.1f}%] ({self._current_test.total_games}/{self._games_per_round})', - f'{self._current_test.elo:0.1f}±{self._current_test.elo_error_95:0.1f} Elo' + f"CPU load: {cpu_usage * 100:0.1f}%", + f"Testing run {self._current_test.run_id} epoch {self._current_test.epoch}", + f"+{self._current_test.wins}={self._current_test.draws}-{self._current_test.losses} [{perf_pct:0.1f}%] ({self._current_test.total_games}/{self._games_per_round})", + f"{self._current_test.elo:0.1f}±{self._current_test.elo_error_95:0.1f} Elo", ] - return '\n'.join(lines) + return "\n".join(lines) else: - return 'Waiting for networks...' + return "Waiting for networks..." finally: self._mutex.release() @@ -1335,10 +1462,10 @@ def run(self): self._running = True - cmd = [sys.executable, 'run_games.py'] + self._get_stringified_args() - LOGGER.info(f'Running network testing with command: {cmd}') + cmd = [sys.executable, "run_games.py"] + self._get_stringified_args() + LOGGER.info(f"Running network testing with command: {cmd}") LOGGER.info(f'Also known as: {" ".join(cmd)}') - LOGGER.info(f'Running in working directory: {self._nnue_pytorch_directory}') + LOGGER.info(f"Running in working directory: {self._nnue_pytorch_directory}") self._process = subprocess.Popen( cmd, cwd=self._nnue_pytorch_directory, @@ -1358,12 +1485,12 @@ def run(self): self._mutex.acquire() try: - if not line.startswith('Score of'): - LOGGER.info(line) + if not line.startswith("Score of"): + LOGGER.info(line) - if line.startswith('Finished running ordo.'): + if line.startswith("Finished running ordo."): self._update_results_from_ordo_file(self._get_ordo_file_path()) - elif line.startswith('Score of'): + elif line.startswith("Score of"): try: self._current_test = CChessCliRunningTestEntry(line=line) if self._current_test.total_games % 100 == 0: @@ -1371,17 +1498,19 @@ def run(self): self._current_convert = None except: self._current_test = None - elif line.startswith('Converting'): + elif line.startswith("Converting"): fields = OrdoEntry.NET_PATTERN.search(line) try: self._current_convert = (fields[1], fields[2]) self._current_test = None - LOGGER.info(f'Converting network epoch {self._current_convert[0]}, run id {self._current_convert[1]}') + LOGGER.info( + f"Converting network epoch {self._current_convert[0]}, run id {self._current_convert[1]}" + ) except: self._current_convert = None - elif line.startswith('Error running match!'): + elif line.startswith("Error running match!"): self._process.terminate() - self._error = 'Error running matches.' + self._error = "Error running matches." break else: self._current_test = None @@ -1396,12 +1525,12 @@ def run(self): self._process.wait() if self._running and not self._has_finished: - LOGGER.warning('Network testing exited unexpectedly.') + LOGGER.warning("Network testing exited unexpectedly.") if not self._error: - self._error = 'Unknown error occured.' - LOGGER.error(f'Error: {self._error}') + self._error = "Unknown error occured." + LOGGER.error(f"Error: {self._error}") - LOGGER.info('Network testing finished.') + LOGGER.info("Network testing finished.") self._has_started = True self._running = False @@ -1413,13 +1542,13 @@ def stop(self): self._process.wait() def _get_ordo_file_path(self): - return os.path.join(self._root_dir, 'ordo.out') + return os.path.join(self._root_dir, "ordo.out") def _update_results_from_ordo_file(self, ordo_file_path): new_results = [] try: - with open(ordo_file_path, 'r') as ordo_file: + with open(ordo_file_path, "r") as ordo_file: lines = ordo_file.readlines() # Pring the first few lines for the CLI interface. for line in lines[:7]: @@ -1449,27 +1578,30 @@ def is_running(self): def is_active(self): return self._active + def duration_string_from_seconds(seconds): second = int(seconds) % 60 minute = int(seconds) // 60 % 60 hour = int(seconds) // 3600 - return f'{hour}:{minute:02}:{second:02}' + return f"{hour}:{minute:02}:{second:02}" + def duration_string_from_seconds_compact(seconds): second = int(seconds) % 60 minute = int(seconds) // 60 % 60 hour = int(seconds) // 3600 if hour > 0: - return f'~{hour}h' + return f"~{hour}h" elif minute > 0: - return f'~{minute}m' + return f"~{minute}m" else: - return f'~{second}s' + return f"~{second}s" + class TrainerRunsWidget(Widget): - ''' + """ Displays information about the assigned training run. - ''' + """ def __init__(self, runs, name=None): super(TrainerRunsWidget, self).__init__(name) @@ -1490,17 +1622,19 @@ def reset(self): pass def _clear_area(self): - colour, attr, background = self._frame.palette['field'] + colour, attr, background = self._frame.palette["field"] height = self._h width = self._w - self._offset for i in range(height): self._frame.canvas.print_at( - ' ' * width, + " " * width, self._x + self._offset, self._y + i, - colour, attr, background + colour, + attr, + background, ) def _get_gpu_usage(self, gpu_ids): @@ -1510,9 +1644,9 @@ def _get_gpu_usage(self, gpu_ids): if gpu_id in gpus: gpu = gpus[gpu_id] by_gpu_id[gpu_id] = { - 'compute_pct' : int(gpu.load * 100), - 'memory_mb' : int(gpu.memoryUsed), - 'max_memory_mb' : int(gpu.memoryTotal) + "compute_pct": int(gpu.load * 100), + "memory_mb": int(gpu.memoryUsed), + "max_memory_mb": int(gpu.memoryTotal), } return by_gpu_id @@ -1526,14 +1660,14 @@ def _make_run_text(self, run): # TODO: Some output for the logger. # Right now only the progress bar in the training run is printed occasionally. if run.has_finished: - loss = run.current_loss or 'unknown' - return [f' Run {run.run_id} - Completed; Loss: {loss}'] + loss = run.current_loss or "unknown" + return [f" Run {run.run_id} - Completed; Loss: {loss}"] elif not run.has_started: - return f' Run {run.run_id} - Starting...', + return (f" Run {run.run_id} - Starting...",) elif not run.is_running: - lines = [f' Run {run.run_id} - Exited unexpectedly.'] + lines = [f" Run {run.run_id} - Exited unexpectedly."] if run.error: - lines += [f' Error: {run.error}'] + lines += [f" Error: {run.error}"] return lines else: try: @@ -1553,25 +1687,23 @@ def _make_run_text(self, run): eta_str = duration_string_from_seconds_compact(eta_seconds) return [ - f' Run {run.run_id} - {complete_pct:0.2f}% [ETA {eta_str}]', - f' Speed: {speed:0.1f}it/s; {speed_knps:0.0f}kpos/s', - f' Epoch: {epoch}/{max_epoch}; Step: {step_in_epoch}/{max_step}', - f' Loss: {loss}', + f" Run {run.run_id} - {complete_pct:0.2f}% [ETA {eta_str}]", + f" Speed: {speed:0.1f}it/s; {speed_knps:0.0f}kpos/s", + f" Epoch: {epoch}/{max_epoch}; Step: {step_in_epoch}/{max_step}", + f" Loss: {loss}", ] except: - return [ - f' Run {run.run_id} - Waiting for enough data to display...' - ] + return [f" Run {run.run_id} - Waiting for enough data to display..."] def _make_gpu_text(self, gpu_id, gpu_usage): # TODO: Some output for the logger if gpu_id in gpu_usage: - gpu_compute_pct = gpu_usage[gpu_id]['compute_pct'] - gpu_memory_mb = gpu_usage[gpu_id]['memory_mb'] - gpu_max_memory_mb = gpu_usage[gpu_id]['max_memory_mb'] - return f'GPU {gpu_id} - Usage: {gpu_compute_pct}% {gpu_memory_mb}MB/{gpu_max_memory_mb}MB ' + gpu_compute_pct = gpu_usage[gpu_id]["compute_pct"] + gpu_memory_mb = gpu_usage[gpu_id]["memory_mb"] + gpu_max_memory_mb = gpu_usage[gpu_id]["max_memory_mb"] + return f"GPU {gpu_id} - Usage: {gpu_compute_pct}% {gpu_memory_mb}MB/{gpu_max_memory_mb}MB " else: - return f'GPU {gpu_id}' + return f"GPU {gpu_id}" def update(self, frame_no): # TODO: scrolling @@ -1602,30 +1734,36 @@ def update(self, frame_no): if curr_line >= height: break - colour, attr, background = self._frame.palette['label'] + colour, attr, background = self._frame.palette["label"] text = self._make_gpu_text(curr_gpu_id, gpu_usage) if len(text) < width: - text += '-' * (len(text) - width) + text += "-" * (len(text) - width) self._frame.canvas.paint( text, self._x + self._offset, self._y + curr_line, - colour, attr, background + colour, + attr, + background, ) curr_line += 1 prev_gpu_id = curr_gpu_id - colour, attr, background = self._pick_colours('field', i == self._selected_index) + colour, attr, background = self._pick_colours( + "field", i == self._selected_index + ) for line in self._make_run_text(run): if curr_line >= height: break self._frame.canvas.paint( - line[:width-1], + line[: width - 1], self._x + self._offset, self._y + curr_line, - colour, attr, background + colour, + attr, + background, ) curr_line += 1 @@ -1642,13 +1780,17 @@ def process_event(self, event): self._selected_index = max(0, self._selected_index - 1) elif len(self._runs) > 0 and event.key_code == Screen.KEY_DOWN: # Move down one line in text - use value to trigger on_select. - self._selected_index = min(len(self._runs) - 1, self._selected_index + 1) + self._selected_index = min( + len(self._runs) - 1, self._selected_index + 1 + ) elif len(self._runs) > 0 and event.key_code == Screen.KEY_PAGE_UP: # Move up one page. self._selected_index = max(0, self._selected_index - self._h) elif len(self._runs) > 0 and event.key_code == Screen.KEY_PAGE_DOWN: # Move down one page. - self._selected_index = min(len(self._runs) - 1, self._selected_index + self._h) + self._selected_index = min( + len(self._runs) - 1, self._selected_index + self._h + ) else: return event else: @@ -1658,6 +1800,7 @@ def process_event(self, event): # If we got here, we processed the event - swallow it. return None + class MainView(Frame): def __init__(self, screen, training_runs, network_testing): super(MainView, self).__init__( @@ -1676,21 +1819,24 @@ def __init__(self, screen, training_runs, network_testing): layout = Layout([300, 10, 200], fill_frame=True) self.add_layout(layout) - layout.add_widget(TrainerRunsWidget(self._training_runs, 'TrainerRuns'), 0) + layout.add_widget(TrainerRunsWidget(self._training_runs, "TrainerRuns"), 0) layout.add_widget(VerticalDivider(), 1) layout.add_widget(Label("Testing status:", 1), 2) - self._network_testing_status = layout.add_widget(TextBox(4, line_wrap=True, readonly=True, as_string=True), 2) + self._network_testing_status = layout.add_widget( + TextBox(4, line_wrap=True, readonly=True, as_string=True), 2 + ) self._network_testing_status.disabled = True layout.add_widget(Divider(), 2) self._networks_view = layout.add_widget( MultiColumnListBox( Widget.FILL_FRAME, - ['<4', '>4', '<6', '0', '>7', '<6'], + ["<4", ">4", "<6", "0", ">7", "<6"], [], add_scroll_bar=True, - titles=['#', 'Run', 'Epoch', '', 'Elo', 'Err'] + titles=["#", "Run", "Epoch", "", "Elo", "Err"], ), - 2) + 2, + ) layouta = Layout([1]) self.add_layout(layouta) @@ -1710,14 +1856,19 @@ def reset(self): def _update_network_list(self): self._networks_view.options.clear() for i, entry in enumerate(self._network_testing.get_ordered_results()): - self._networks_view.options.append(([ - str(i+1), - str(entry.run_id), - str(entry.epoch), - '', - f'{entry.elo:0.1f}', - f'±{entry.elo_error:0.1f}' - ], i)) + self._networks_view.options.append( + ( + [ + str(i + 1), + str(entry.run_id), + str(entry.epoch), + "", + f"{entry.elo:0.1f}", + f"±{entry.elo_error:0.1f}", + ], + i, + ) + ) def _update_network_testing_status(self): self._network_testing_status.value = self._network_testing.get_status_string() @@ -1735,7 +1886,7 @@ def _quit(self): "Are you sure you want to quit?", ["Yes", "No"], has_shadow=True, - on_close=self._quit_on_yes + on_close=self._quit_on_yes, ) ) @@ -1749,6 +1900,7 @@ def _quit_on_yes(selected): def frame_update_count(self): return 1 + def app(screen, scene, training_runs, network_testing): global TUI_SCREEN @@ -1765,441 +1917,447 @@ def app(screen, scene, training_runs, network_testing): finally: TUI_SCREEN = None + def str2bool(v): - ''' + """ A "type" for argparse - ''' + """ if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") + def flatten_once(lst): return sum(lst, []) + def parse_cli_args(): default_pytorch_threads = 2 default_data_loader_threads = 4 - default_testing_threads = max(1, os.cpu_count() - default_pytorch_threads - default_data_loader_threads) + default_testing_threads = max( + 1, os.cpu_count() - default_pytorch_threads - default_data_loader_threads + ) default_build_threads = max(1, os.cpu_count() // 2) parser = argparse.ArgumentParser( description="Trains the network.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - '--workspace-path', - default='./easy_train_data', + "--workspace-path", + default="./easy_train_data", type=str, - metavar='PATH', - dest='workspace_path', - help='Specifies the directory in which the dependencies, training, and testing will be set up.' + metavar="PATH", + dest="workspace_path", + help="Specifies the directory in which the dependencies, training, and testing will be set up.", ) parser.add_argument( - '--experiment-name', + "--experiment-name", default=None, type=str, - metavar='NAME', - dest='experiment_name', + metavar="NAME", + dest="experiment_name", required=True, - help='A name of the experiment is used to identify it. The experiment\'s directory will have the name experiment_[experiment_name].' + help="A name of the experiment is used to identify it. The experiment's directory will have the name experiment_[experiment_name].", ) parser.add_argument( - '--training-dataset', + "--training-dataset", type=str, - action='append', - nargs='+', - metavar='PATH', - dest='training_datasets', + action="append", + nargs="+", + metavar="PATH", + dest="training_datasets", required=True, - help='Path to a training dataset. Supports .binpack files.' + help="Path to a training dataset. Supports .binpack files.", ) parser.add_argument( - '--validation-dataset', + "--validation-dataset", type=str, - action='append', - nargs='+', - metavar='PATH', - dest='validation_datasets', - help='Path to a validation dataset. Supports .binpack files.' + action="append", + nargs="+", + metavar="PATH", + dest="validation_datasets", + help="Path to a validation dataset. Supports .binpack files.", ) parser.add_argument( - '--lambda', + "--lambda", default=1.0, type=float, - metavar='FLOAT', - dest='lambda_', - help='Interpolation coefficient for training on evaluation/result. lambda=1.0 means train on evaluations. lambda=0.0 means train on game results. Must be in range [0, 1].' + metavar="FLOAT", + dest="lambda_", + help="Interpolation coefficient for training on evaluation/result. lambda=1.0 means train on evaluations. lambda=0.0 means train on game results. Must be in range [0, 1].", ) parser.add_argument( - '--start-lambda', + "--start-lambda", default=None, type=float, - metavar='FLOAT', - dest='start_lambda', - help='Lambda to use at the first epoch. Defaults to --lambda if not specified.' + metavar="FLOAT", + dest="start_lambda", + help="Lambda to use at the first epoch. Defaults to --lambda if not specified.", ) parser.add_argument( - '--end-lambda', + "--end-lambda", default=None, type=float, - metavar='FLOAT', - dest='end_lambda', - help='Lambda to use at the last epoch. Defaults to --lambda if not specified.' + metavar="FLOAT", + dest="end_lambda", + help="Lambda to use at the last epoch. Defaults to --lambda if not specified.", ) parser.add_argument( - '--gamma', + "--gamma", default=0.992, type=float, - metavar='FLOAT', - dest='gamma', - help='Multiplicative factor applied to the learning rate after every epoch. Values lower than 1 will cause the learning rate to decrease exponentially as training progresses.' + metavar="FLOAT", + dest="gamma", + help="Multiplicative factor applied to the learning rate after every epoch. Values lower than 1 will cause the learning rate to decrease exponentially as training progresses.", ) parser.add_argument( - '--lr', + "--lr", default=8.75e-4, type=float, - metavar='FLOAT', - dest='lr', - help='Initial learning rate.' + metavar="FLOAT", + dest="lr", + help="Initial learning rate.", ) parser.add_argument( - '--num-workers', + "--num-workers", default=default_data_loader_threads, type=int, - dest='num_workers', - help='Number of worker threads to use for supplying training data. Increase with large skipping rates, or underloaded GPUs.' + dest="num_workers", + help="Number of worker threads to use for supplying training data. Increase with large skipping rates, or underloaded GPUs.", ) parser.add_argument( - '--batch-size', + "--batch-size", default=16384, type=int, - metavar='INTEGER', - dest='batch_size', - help='Number of positions per batch (1 batch = 1 iteration).' + metavar="INTEGER", + dest="batch_size", + help="Number of positions per batch (1 batch = 1 iteration).", ) parser.add_argument( - '--threads', + "--threads", default=default_pytorch_threads, type=int, - metavar='INTEGER', - dest='threads', - help='Number of threads for pytorch to use. Generally performance does not scale well with the amount of threads.' + metavar="INTEGER", + dest="threads", + help="Number of threads for pytorch to use. Generally performance does not scale well with the amount of threads.", ) parser.add_argument( - '--seed', + "--seed", default=42, type=int, - metavar='INTEGER', - dest='seed', - help='The random number generator seed to use for training. Each run within a single session gets a slightly different (but deterministic) seed based on this master value.' + metavar="INTEGER", + dest="seed", + help="The random number generator seed to use for training. Each run within a single session gets a slightly different (but deterministic) seed based on this master value.", ) parser.add_argument( - '--smart-fen-skipping', + "--smart-fen-skipping", default=True, type=str2bool, - metavar='BOOL', - dest='smart_fen_skipping', - help='Whether to perform smart fen skipping. This attempts to heuristically skip non-quiet positions during training.' + metavar="BOOL", + dest="smart_fen_skipping", + help="Whether to perform smart fen skipping. This attempts to heuristically skip non-quiet positions during training.", ) parser.add_argument( - '--wld-fen-skipping', + "--wld-fen-skipping", default=True, type=str2bool, - metavar='BOOL', - dest='wld_fen_skipping', - help='Whether to perform position skipping during training that increases correlation between evaluations and results.' + metavar="BOOL", + dest="wld_fen_skipping", + help="Whether to perform position skipping during training that increases correlation between evaluations and results.", ) parser.add_argument( - '--random-fen-skipping', + "--random-fen-skipping", default=3, type=int, - metavar='INTEGER', - dest='random_fen_skipping', - help='Skip on average random_fen_skipping positions during training before using one. Increases diversity for data that is not fully shuffled.' + metavar="INTEGER", + dest="random_fen_skipping", + help="Skip on average random_fen_skipping positions during training before using one. Increases diversity for data that is not fully shuffled.", ) parser.add_argument( - '--early-fen-skipping', + "--early-fen-skipping", default=-1, type=int, - metavar='INTEGER', - dest='early_fen_skipping', - help='Skip all fens from training game plies <= the given number.' + metavar="INTEGER", + dest="early_fen_skipping", + help="Skip all fens from training game plies <= the given number.", ) parser.add_argument( - '--start-from-model', + "--start-from-model", default=None, type=str, - metavar='PATH', - dest='start_from_model', - help='Initializes training using the weights from the given .pt, .ckpt, or .nnue model.' + metavar="PATH", + dest="start_from_model", + help="Initializes training using the weights from the given .pt, .ckpt, or .nnue model.", ) parser.add_argument( - '--start-from-experiment', + "--start-from-experiment", default=None, type=str, - metavar='NAME', - dest='start_from_experiment', - help='Initializes training using the best network from a given experiment (by name). Uses the best net from ordo, falls back to last created.' + metavar="NAME", + dest="start_from_experiment", + help="Initializes training using the best network from a given experiment (by name). Uses the best net from ordo, falls back to last created.", ) parser.add_argument( - '--start-from-engine-test-net', + "--start-from-engine-test-net", default=False, type=str2bool, - metavar='BOOL', - dest='start_from_engine_test_net', - help='Initializes training using the weights from the .nnue model associated with --engine-test-branch.' + metavar="BOOL", + dest="start_from_engine_test_net", + help="Initializes training using the weights from the .nnue model associated with --engine-test-branch.", ) parser.add_argument( - '--gpus', + "--gpus", type=str, - metavar='INTEGER[,INTEGER]*', - dest='gpus', - default='0', - help='A single GPU ID or a list of GPU IDs to use for training. Note that a single run still uses a single GPU.' + metavar="INTEGER[,INTEGER]*", + dest="gpus", + default="0", + help="A single GPU ID or a list of GPU IDs to use for training. Note that a single run still uses a single GPU.", ) parser.add_argument( - '--runs-per-gpu', + "--runs-per-gpu", default=1, type=int, - metavar='INTEGER', - dest='runs_per_gpu', - help='Number of runs to do in parallel on each GPU. To increase the load on strong GPUs run more than one run per GPU. Doing multiple runs also means that variance has lower impact on the results.' + metavar="INTEGER", + dest="runs_per_gpu", + help="Number of runs to do in parallel on each GPU. To increase the load on strong GPUs run more than one run per GPU. Doing multiple runs also means that variance has lower impact on the results.", ) parser.add_argument( - '--features', + "--features", default=None, type=str, - metavar='FEATURESET', - help='The feature set to use. If not specified then will be inferred from the cloned nnue-pytorch repo.' + metavar="FEATURESET", + help="The feature set to use. If not specified then will be inferred from the cloned nnue-pytorch repo.", ) parser.add_argument( - '--max_epoch', '--num-epochs', # --max_epoch kept to match pytorch-lightning's name + "--max_epoch", + "--num-epochs", # --max_epoch kept to match pytorch-lightning's name default=400, type=int, - metavar='INTEGER', - dest='max_epoch', - help='Number of epochs to train for.' + metavar="INTEGER", + dest="max_epoch", + help="Number of epochs to train for.", ) parser.add_argument( - '--network-save-period', + "--network-save-period", default=20, type=int, - metavar='INTEGER', - dest='network_save_period', - help='Number of epochs between network snapshots (checkpoints). None to disable. Note that these take a lot of space.' + metavar="INTEGER", + dest="network_save_period", + help="Number of epochs between network snapshots (checkpoints). None to disable. Note that these take a lot of space.", ) parser.add_argument( - '--save-last-network', + "--save-last-network", default=True, type=str2bool, - metavar='BOOL', - dest='save_last_network', - help='Whether to always save the last produced network (checkpoint).' + metavar="BOOL", + dest="save_last_network", + help="Whether to always save the last produced network (checkpoint).", ) parser.add_argument( - '--additional-training-arg', + "--additional-training-arg", type=str, - metavar='STRING', - nargs='*', - dest='additional_training_args', - help='Additional training args passed verbatim.' + metavar="STRING", + nargs="*", + dest="additional_training_args", + help="Additional training args passed verbatim.", ) parser.add_argument( - '--additional-testing-arg', + "--additional-testing-arg", type=str, - metavar='STRING', - nargs='*', - dest='additional_testing_args', - help='Additional network testing args passed verbatim.' + metavar="STRING", + nargs="*", + dest="additional_testing_args", + help="Additional network testing args passed verbatim.", ) parser.add_argument( - '--engine-base-branch', - default='official-stockfish/Stockfish/master', + "--engine-base-branch", + default="official-stockfish/Stockfish/master", type=str, - metavar='BRANCH_OR_COMMIT', - dest='engine_base_branch', - help='Path to the commit/branch to use for the engine baseline. It is recommended to use a specific commit for consistency.' + metavar="BRANCH_OR_COMMIT", + dest="engine_base_branch", + help="Path to the commit/branch to use for the engine baseline. It is recommended to use a specific commit for consistency.", ) parser.add_argument( - '--engine-test-branch', - default='official-stockfish/Stockfish/master', + "--engine-test-branch", + default="official-stockfish/Stockfish/master", type=str, - metavar='BRANCH_OR_COMMIT', - dest='engine_test_branch', - help='Path to the commit/branch to use for the engine being tested. It is recommended to use a specific commit for consistency.' + metavar="BRANCH_OR_COMMIT", + dest="engine_test_branch", + help="Path to the commit/branch to use for the engine being tested. It is recommended to use a specific commit for consistency.", ) parser.add_argument( - '--nnue-pytorch-branch', - default='official-stockfish/nnue-pytorch/master', + "--nnue-pytorch-branch", + default="official-stockfish/nnue-pytorch/master", type=str, - metavar='BRANCH_OR_COMMIT', - dest='nnue_pytorch_branch', - help='Path to the commit/branch to use for the trainer being tested. It is recommended to use a specific commit for consistency.' + metavar="BRANCH_OR_COMMIT", + dest="nnue_pytorch_branch", + help="Path to the commit/branch to use for the trainer being tested. It is recommended to use a specific commit for consistency.", ) parser.add_argument( - '--build-engine-arch', - default='x86-64-modern', + "--build-engine-arch", + default="x86-64-modern", type=str, - metavar='ARCH', - dest='build_engine_arch', - help='ARCH to use for engine compilation, e.g. x86-64-avx2 for recent hardware.' + metavar="ARCH", + dest="build_engine_arch", + help="ARCH to use for engine compilation, e.g. x86-64-avx2 for recent hardware.", ) parser.add_argument( - '--build-threads', + "--build-threads", default=default_build_threads, type=int, - metavar='INTEGER', - dest='build_threads', - help='Number of threads to use for engine compilation.' + metavar="INTEGER", + dest="build_threads", + help="Number of threads to use for engine compilation.", ) parser.add_argument( - '--fail-on-experiment-exists', + "--fail-on-experiment-exists", default=True, type=str2bool, - metavar='BOOL', - dest='fail_on_experiment_exists', - help='By default an experiment must be created in an empty directory. Ignored when --resume-training is True. Care should be taken when the directory already exists as it might create consistency issue when not everything gets resetup.' + metavar="BOOL", + dest="fail_on_experiment_exists", + help="By default an experiment must be created in an empty directory. Ignored when --resume-training is True. Care should be taken when the directory already exists as it might create consistency issue when not everything gets resetup.", ) parser.add_argument( - '--epoch-size', + "--epoch-size", default=100000000, type=int, - metavar='INTEGER', - dest='epoch_size', - help='Number of positions per epoch (training step).' + metavar="INTEGER", + dest="epoch_size", + help="Number of positions per epoch (training step).", ) parser.add_argument( - '--validation-size', + "--validation-size", default=1000000, type=int, - metavar='INTEGER', - dest='validation_size', - help='Number of positions per validation step.' + metavar="INTEGER", + dest="validation_size", + help="Number of positions per validation step.", ) parser.add_argument( - '--tui', + "--tui", default=True, type=str2bool, - metavar='BOOL', - dest='tui', - help='Whether to show a nice terminal user interface.' + metavar="BOOL", + dest="tui", + help="Whether to show a nice terminal user interface.", ) parser.add_argument( - '--do-network-testing', + "--do-network-testing", default=True, type=str2bool, - metavar='BOOL', - dest='do_network_testing', - help='Whether to test networks as they are generated.' + metavar="BOOL", + dest="do_network_testing", + help="Whether to test networks as they are generated.", ) parser.add_argument( - '--do-network-training', + "--do-network-training", default=True, type=str2bool, - metavar='BOOL', - dest='do_network_training', - help='Whether to train networks.' + metavar="BOOL", + dest="do_network_training", + help="Whether to train networks.", ) parser.add_argument( - '--network-testing-threads', + "--network-testing-threads", default=default_testing_threads, type=int, - metavar='INTEGER', - dest='network_testing_threads', - help='Number of threads to use for network testing. By default the available number of threads minus default data loader and pytorch threads. The optimal value might depend on the --threads, --num-workers and other machine load.' + metavar="INTEGER", + dest="network_testing_threads", + help="Number of threads to use for network testing. By default the available number of threads minus default data loader and pytorch threads. The optimal value might depend on the --threads, --num-workers and other machine load.", ) parser.add_argument( - '--network-testing-explore-factor', + "--network-testing-explore-factor", default=1.5, type=float, - metavar='FLOAT', - dest='network_testing_explore_factor', - help='Elo error estimates are multiplied by this amount to determine testing candidates.' + metavar="FLOAT", + dest="network_testing_explore_factor", + help="Elo error estimates are multiplied by this amount to determine testing candidates.", ) parser.add_argument( '--network-testing-book', default='https://github.com/official-stockfish/books/raw/master/UHO_Lichess_4852_v1.epd.zip', type=str, - metavar='PATH_OR_URL', - dest='network_testing_book', - help='Path to a suitable book, or suitable link (URL). See https://github.com/official-stockfish/books.' + metavar="PATH_OR_URL", + dest="network_testing_book", + help="Path to a suitable book, or suitable link (URL). See https://github.com/official-stockfish/books.", ) parser.add_argument( - '--network-testing-time-per-game', + "--network-testing-time-per-game", default=None, type=float, - metavar='FLOAT', - dest='network_testing_time_per_game', - help='Number of seconds per game for each engine.' + metavar="FLOAT", + dest="network_testing_time_per_game", + help="Number of seconds per game for each engine.", ) parser.add_argument( - '--network-testing-time-increment-per-move', + "--network-testing-time-increment-per-move", default=None, type=float, - metavar='FLOAT', - dest='network_testing_time_increment_per_move', - help='Number of seconds added to the clock of an engine per move.' + metavar="FLOAT", + dest="network_testing_time_increment_per_move", + help="Number of seconds added to the clock of an engine per move.", ) parser.add_argument( - '--network-testing-nodes-per-move', + "--network-testing-nodes-per-move", default=None, type=int, - metavar='INTEGER', - dest='network_testing_nodes_per_move', - help='Number of nodes per move to use for testing. Overrides time control. Recommended over time control for better consistency.' + metavar="INTEGER", + dest="network_testing_nodes_per_move", + help="Number of nodes per move to use for testing. Overrides time control. Recommended over time control for better consistency.", ) parser.add_argument( - '--network-testing-hash-mb', + "--network-testing-hash-mb", default=8, type=int, - metavar='INTEGER', - dest='network_testing_hash_mb', - help='Number of MiB of memory to use for hash allocation for each engine being tested.' + metavar="INTEGER", + dest="network_testing_hash_mb", + help="Number of MiB of memory to use for hash allocation for each engine being tested.", ) parser.add_argument( - '--network-testing-games-per-round', + "--network-testing-games-per-round", default=20 * default_testing_threads, type=int, - metavar='INTEGER', - dest='network_testing_games_per_round', - help='Number of games per round to use. Essentially a testing batch size.' + metavar="INTEGER", + dest="network_testing_games_per_round", + help="Number of games per round to use. Essentially a testing batch size.", ) parser.add_argument( - '--resume-training', + "--resume-training", default=True, type=str2bool, - metavar='BOOL', - dest='resume_training', - help='Attempts to resume each run from its latest checkpoint.' + metavar="BOOL", + dest="resume_training", + help="Attempts to resume each run from its latest checkpoint.", ) parser.add_argument( - '--do-approximate-ordo', + "--do-approximate-ordo", default=True, type=str2bool, - metavar='BOOL', - dest='do_approximate_ordo', - help='If true then does not launch ordo and instead does a fast approximate computation. Workaround for ordo memory usage issues.' + metavar="BOOL", + dest="do_approximate_ordo", + help="If true then does not launch ordo and instead does a fast approximate computation. Workaround for ordo memory usage issues.", ) parser.add_argument( - '--auto-exit-timeout', + "--auto-exit-timeout", default=None, type=str, - metavar='DURATION', - dest='auto_exit_timeout', - help='Automatically exit the script after a specified time has passed since its start. Duration format "h:m:s", "m:s", or "s".' + metavar="DURATION", + dest="auto_exit_timeout", + help='Automatically exit the script after a specified time has passed since its start. Duration format "h:m:s", "m:s", or "s".', ) parser.add_argument( - '--auto-exit-timeout-on-training-finished', + "--auto-exit-timeout-on-training-finished", default=None, type=str, - metavar='DURATION', - dest='auto_exit_timeout_on_training_finished', - help='Automatically exit the script after a specified time has passed after training finished. Duration format "h:m:s", "m:s", or "s"' + metavar="DURATION", + dest="auto_exit_timeout_on_training_finished", + help='Automatically exit the script after a specified time has passed after training finished. Duration format "h:m:s", "m:s", or "s"', ) args = parser.parse_args() @@ -2210,65 +2368,91 @@ def parse_cli_args(): args.validation_datasets = [] if len(args.training_datasets) == 0: - raise Exception('No training data specified') + raise Exception("No training data specified") if args.lambda_ < 0.0 or args.lambda_ > 1.0: - raise Exception('lambda must be within [0, 1]') + raise Exception("lambda must be within [0, 1]") args.validation_datasets = args.validation_datasets or args.training_datasets for dataset in args.validation_datasets: if not Path(dataset).is_file(): - raise Exception(f'Invalid validation data set file name: {dataset}') + raise Exception(f"Invalid validation data set file name: {dataset}") for dataset in args.training_datasets: if not Path(dataset).is_file(): - raise Exception(f'Invalid training data set file name: {dataset}') + raise Exception(f"Invalid training data set file name: {dataset}") # these are not required because testing is optional - if args.engine_base_branch and args.engine_base_branch.count('/') != 2: - raise Exception(f'Invalid base engine repo path: {args.engine_base_branch}') + if args.engine_base_branch and args.engine_base_branch.count("/") != 2: + raise Exception(f"Invalid base engine repo path: {args.engine_base_branch}") - if args.engine_test_branch and args.engine_test_branch.count('/') != 2: - raise Exception(f'Invalid test engine repo path: {args.engine_test_branch}') + if args.engine_test_branch and args.engine_test_branch.count("/") != 2: + raise Exception(f"Invalid test engine repo path: {args.engine_test_branch}") # this one is required because it has other important scripts - if not args.nnue_pytorch_branch or args.nnue_pytorch_branch.count('/') != 2: - raise Exception(f'Invalid test trainer repo path: {args.nnue_pytorch_branch}') + if not args.nnue_pytorch_branch or args.nnue_pytorch_branch.count("/") != 2: + raise Exception(f"Invalid test trainer repo path: {args.nnue_pytorch_branch}") - if not args.network_testing_time_per_game and not args.network_testing_nodes_per_move: - args.network_testing_nodes_per_move=25000 - LOGGER.info(f'No time control specified. Using a default {args.network_testing_nodes_per_move} nodes per move') + if ( + not args.network_testing_time_per_game + and not args.network_testing_nodes_per_move + ): + args.network_testing_nodes_per_move = 25000 + LOGGER.info( + f"No time control specified. Using a default {args.network_testing_nodes_per_move} nodes per move" + ) - if [args.start_from_model, args.start_from_engine_test_net, args.start_from_experiment].count(True) > 1: - raise Exception('Only one of --start-from-model, --start-from-engine-test-net, and --start-from-experiment can be specified at a time.') + if [ + args.start_from_model, + args.start_from_engine_test_net, + args.start_from_experiment, + ].count(True) > 1: + raise Exception( + "Only one of --start-from-model, --start-from-engine-test-net, and --start-from-experiment can be specified at a time." + ) if args.start_from_engine_test_net and not args.engine_test_branch: - raise Exception('--start-from-engine-test-net but --engine-test-branch not given') + raise Exception( + "--start-from-engine-test-net but --engine-test-branch not given" + ) - if args.start_from_experiment and not args.start_from_experiment.startswith('experiment_'): - args.start_from_experiment = 'experiment_' + args.start_from_experiment + if args.start_from_experiment and not args.start_from_experiment.startswith( + "experiment_" + ): + args.start_from_experiment = "experiment_" + args.start_from_experiment return args + def log_args(directory, args): os.makedirs(directory, exist_ok=True) - args_dump_file_path = os.path.join(directory, 'args_dump.txt') - with open(args_dump_file_path, 'w') as file: + args_dump_file_path = os.path.join(directory, "args_dump.txt") + with open(args_dump_file_path, "w") as file: file.write(repr(args)) - logs_file_path = os.path.join(directory, 'easy_train.log') + logs_file_path = os.path.join(directory, "easy_train.log") + + LOGGER.addHandler(logging.FileHandler(logs_file_path, encoding="utf-8")) - LOGGER.addHandler(logging.FileHandler(logs_file_path, encoding='utf-8')) def is_url(path): - return path.startswith('http://') or path.startswith('https://') or path.startswith('ftp://') or path.startswith('sftp://') + return ( + path.startswith("http://") + or path.startswith("https://") + or path.startswith("ftp://") + or path.startswith("sftp://") + ) + class TqdmDownloadProgressBar(tqdm): def update_to(self, blocks_transferred=1, block_size=1, total_size=None): if total_size is not None: self.total = total_size - return self.update(blocks_transferred * block_size - self.n) # also sets self.n = b * bsize + return self.update( + blocks_transferred * block_size - self.n + ) # also sets self.n = b * bsize + class TqdmToLogger(io.StringIO): def __init__(self): @@ -2280,12 +2464,13 @@ def write(self, buf): def flush(self): LOGGER.info(self.buf) + def setup_book(directory, args): - ''' + """ If the args.network_testing_book is a URL then it downloads the book and reassigns args.network_testing_book to the actual book path. Otherwise does nothing. - ''' + """ if not is_url(args.network_testing_book): return @@ -2293,15 +2478,17 @@ def setup_book(directory, args): os.makedirs(directory, exist_ok=True) url = args.network_testing_book - temp_filename = urllib.parse.unquote(url.split('/')[-1]) - if temp_filename.endswith('.zip'): + temp_filename = urllib.parse.unquote(url.split("/")[-1]) + if temp_filename.endswith(".zip"): filename = temp_filename[:-4] - elif temp_filename.endswith('.epd'): + elif temp_filename.endswith(".epd"): filename = temp_filename - if not filename.endswith('.epd'): - LOGGER.error('Cannot handle the book. Currently only .epd books are supported. If compressed with .zip the name must be a.epd.zip. No other compression format is supported right now.') - raise Exception('Cannot handle opening book') + if not filename.endswith(".epd"): + LOGGER.error( + "Cannot handle the book. Currently only .epd books are supported. If compressed with .zip the name must be a.epd.zip. No other compression format is supported right now." + ) + raise Exception("Cannot handle opening book") destination_temp_file_path = os.path.abspath(os.path.join(directory, temp_filename)) destination_file_path = os.path.abspath(os.path.join(directory, filename)) @@ -2310,119 +2497,139 @@ def setup_book(directory, args): if not os.path.exists(destination_file_path): if temp_filename != filename and not os.path.exists(destination_temp_file_path): with TqdmDownloadProgressBar( - unit='B', + unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=temp_filename, file=TqdmToLogger(), - mininterval=0.1 # at least 0.1s between update so the logfile doesn't get polluted. + mininterval=0.1, # at least 0.1s between update so the logfile doesn't get polluted. ) as progress_bar: urllib.request.urlretrieve( url, filename=destination_temp_file_path, reporthook=progress_bar.update_to, - data=None + data=None, ) progress_bar.total = progress_bar.n - if temp_filename.endswith('.zip'): - zipped = zipfile.ZipFile(destination_temp_file_path, mode='r') + if temp_filename.endswith(".zip"): + zipped = zipfile.ZipFile(destination_temp_file_path, mode="r") names = zipped.namelist() if len(names) > 1 or names[0] != filename: - LOGGER.error(f'Expected only a book with name {filename} in the archive but did not find it or found more') - raise Exception('Unexpected opening book archive content.') - LOGGER.info(f'Extracting {temp_filename} to {filename}') + LOGGER.error( + f"Expected only a book with name {filename} in the archive but did not find it or found more" + ) + raise Exception("Unexpected opening book archive content.") + LOGGER.info(f"Extracting {temp_filename} to {filename}") zipped.extract(filename, directory) - LOGGER.info('Book setup completed.') + LOGGER.info("Book setup completed.") -def prepare_start_model(directory, model_path, run_id, nnue_pytorch_directory, features): - ''' + +def prepare_start_model( + directory, model_path, run_id, nnue_pytorch_directory, features +): + """ Copies the specified model to the desired directory. Performs conversion to .pt if necessary. - ''' + """ os.makedirs(directory, exist_ok=True) - LOGGER.info(f'Starting from model: {model_path}') + LOGGER.info(f"Starting from model: {model_path}") - destination_filename = 'start_model' + destination_filename = "start_model" if run_id: - destination_filename += 'run_' + str(run_id) - destination_filename += '.pt' + destination_filename += "run_" + str(run_id) + destination_filename += ".pt" destination_model_path = os.path.join(directory, destination_filename) - if model_path.endswith('.pt'): + if model_path.endswith(".pt"): shutil.copyfile(model_path, destination_model_path) - elif model_path.endswith('.nnue') or model_path.endswith('.ckpt'): - if model_path.endswith('.nnue') and features.endswith('^'): + elif model_path.endswith(".nnue") or model_path.endswith(".ckpt"): + if model_path.endswith(".nnue") and features.endswith("^"): features = features[:-1] with subprocess.Popen( [ sys.executable, - 'serialize.py', + "serialize.py", os.path.abspath(model_path), destination_model_path, - f'--features={features}', + f"--features={features}", ], cwd=nnue_pytorch_directory, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT + stderr=subprocess.STDOUT, ) as process: if process.wait(): - raise Exception('Failed to run serialize.py for start model.') + raise Exception("Failed to run serialize.py for start model.") if not os.path.exists(destination_model_path): - raise Exception('Failed to convert start model.') + raise Exception("Failed to convert start model.") return destination_model_path -def prepare_start_model_from_experiment(directory, experiment_path, run_id, nnue_pytorch_directory, features): - ''' + +def prepare_start_model_from_experiment( + directory, experiment_path, run_id, nnue_pytorch_directory, features +): + """ Prepares start model with the best or (if no ordo found) last checkpoint from the given experiment. - ''' - root_dir = os.path.join(experiment_path, 'training') + """ + root_dir = os.path.join(experiment_path, "training") best_model = find_best_checkpoint(root_dir) if best_model is None: best_model = find_latest_checkpoint(root_dir) if best_model is None: - raise Exception('Could not find any viable .ckpt nor .nnue files in the start experiment.') - return prepare_start_model(directory, best_model, run_id, nnue_pytorch_directory, features) + raise Exception( + "Could not find any viable .ckpt nor .nnue files in the start experiment." + ) + return prepare_start_model( + directory, best_model, run_id, nnue_pytorch_directory, features + ) + def get_default_feature_set_from_nnue_pytorch(nnue_pytorch_directory): - ''' + """ features.py in nnue-pytorch defines the default feature set to use. We scrape it for the feature set name. Normally we could import that file and let it add the argument to argparse, but we setup argparse before nnue-pytorch is setup so we have to do it like that. - ''' + """ try: - with open(os.path.join(nnue_pytorch_directory, 'features.py'), 'r') as features_file: + with open( + os.path.join(nnue_pytorch_directory, "features.py"), "r" + ) as features_file: for line in features_file: line = line.strip() - if line.startswith('_default_feature_set_name'): + if line.startswith("_default_feature_set_name"): return line.split()[-1][1:-1] except: - raise Exception('Could not infer the default feature set from the nnue-pytorch installation.') + raise Exception( + "Could not infer the default feature set from the nnue-pytorch installation." + ) + def parse_duration_hms_to_s(duration_str): - ''' + """ Parses a duration of the form [h:][m:]s - ''' - parts = duration_str.split(':') + """ + parts = duration_str.split(":") s = int(parts[-1]) m = 0 if len(parts) < 2 else int(parts[-2]) h = 0 if len(parts) < 3 else int(parts[-3]) return h * 3600 + m * 60 + s + def spawn_training_watcher(training_runs, exit_timeout_after_finished): - ''' + """ Spawns a daemon thread that awaits training end and the schedules the script to exit after the specified amount of seconds. - ''' + """ + def f(): while True: finished = True @@ -2444,27 +2651,30 @@ def f(): thread.daemon = True thread.start() + def main(): - LOGGER.info('Initializing...') + LOGGER.info("Initializing...") args = parse_cli_args() # if we ask to resume don't fail on existing directory if args.resume_training: - args.fail_on_experiment_exists = False + args.fail_on_experiment_exists = False absolute_workspace_path = os.path.abspath(args.workspace_path) os.makedirs(absolute_workspace_path, exist_ok=True) - do_network_testing = args.engine_base_branch and args.engine_test_branch and args.do_network_testing + do_network_testing = ( + args.engine_base_branch and args.engine_test_branch and args.do_network_testing + ) do_network_training = args.do_network_training and args.training_datasets # Global (workspace) setup - with SystemWideMutex(os.path.join(absolute_workspace_path, f'.lock')) as mutex: - ordo_directory = os.path.join(absolute_workspace_path, 'ordo') - c_chess_cli_directory = os.path.join(absolute_workspace_path, 'c-chess-cli') - books_directory = os.path.join(absolute_workspace_path, 'books') + with SystemWideMutex(os.path.join(absolute_workspace_path, f".lock")) as mutex: + ordo_directory = os.path.join(absolute_workspace_path, "ordo") + c_chess_cli_directory = os.path.join(absolute_workspace_path, "c-chess-cli") + books_directory = os.path.join(absolute_workspace_path, "books") if not args.do_approximate_ordo: setup_ordo(ordo_directory) @@ -2476,43 +2686,66 @@ def main(): # Local (experiment) setup - experiment_directory = os.path.join(absolute_workspace_path, f'experiments/experiment_{args.experiment_name}') + experiment_directory = os.path.join( + absolute_workspace_path, f"experiments/experiment_{args.experiment_name}" + ) try: os.makedirs(experiment_directory, exist_ok=False) except FileExistsError as e: if args.fail_on_experiment_exists and os.listdir(experiment_directory): - LOGGER.error(f'Directory {experiment_directory} already exists. An experiment must use a new directory.') - LOGGER.error(f'Alternatively, override this with the option --resume-training=True or --fail-on-experiment-exists=False.') + LOGGER.error( + f"Directory {experiment_directory} already exists. An experiment must use a new directory." + ) + LOGGER.error( + f"Alternatively, override this with the option --resume-training=True or --fail-on-experiment-exists=False." + ) return - stockfish_base_directory = os.path.join(experiment_directory, 'stockfish_base') - stockfish_test_directory = os.path.join(experiment_directory, 'stockfish_test') - nnue_pytorch_directory = os.path.join(experiment_directory, 'nnue-pytorch') - logging_directory = os.path.join(experiment_directory, 'logging') - start_model_directory = os.path.join(experiment_directory, 'start_models') + stockfish_base_directory = os.path.join(experiment_directory, "stockfish_base") + stockfish_test_directory = os.path.join(experiment_directory, "stockfish_test") + nnue_pytorch_directory = os.path.join(experiment_directory, "nnue-pytorch") + logging_directory = os.path.join(experiment_directory, "logging") + start_model_directory = os.path.join(experiment_directory, "start_models") log_args(logging_directory, args) if do_network_testing: - LOGGER.info('Engines provided. Enabling network testing.') - stockfish_base_repo = '/'.join(args.engine_base_branch.split('/')[:2]) - stockfish_test_repo = '/'.join(args.engine_test_branch.split('/')[:2]) - stockfish_base_branch_or_commit = args.engine_base_branch.split('/')[2] - stockfish_test_branch_or_commit = args.engine_test_branch.split('/')[2] - setup_stockfish(stockfish_base_directory, stockfish_base_repo, stockfish_base_branch_or_commit, args.build_engine_arch, args.build_threads) - setup_stockfish(stockfish_test_directory, stockfish_test_repo, stockfish_test_branch_or_commit, args.build_engine_arch, args.build_threads) + LOGGER.info("Engines provided. Enabling network testing.") + stockfish_base_repo = "/".join(args.engine_base_branch.split("/")[:2]) + stockfish_test_repo = "/".join(args.engine_test_branch.split("/")[:2]) + stockfish_base_branch_or_commit = args.engine_base_branch.split("/")[2] + stockfish_test_branch_or_commit = args.engine_test_branch.split("/")[2] + setup_stockfish( + stockfish_base_directory, + stockfish_base_repo, + stockfish_base_branch_or_commit, + args.build_engine_arch, + args.build_threads, + ) + setup_stockfish( + stockfish_test_directory, + stockfish_test_repo, + stockfish_test_branch_or_commit, + args.build_engine_arch, + args.build_threads, + ) else: - LOGGER.info('Not doing network testing. Either engines no provided or explicitely disabled.') + LOGGER.info( + "Not doing network testing. Either engines no provided or explicitely disabled." + ) - nnue_pytorch_repo = '/'.join(args.nnue_pytorch_branch.split('/')[:2]) - nnue_pytorch_branch_or_commit = args.nnue_pytorch_branch.split('/')[2] - setup_nnue_pytorch(nnue_pytorch_directory, nnue_pytorch_repo, nnue_pytorch_branch_or_commit) + nnue_pytorch_repo = "/".join(args.nnue_pytorch_branch.split("/")[:2]) + nnue_pytorch_branch_or_commit = args.nnue_pytorch_branch.split("/")[2] + setup_nnue_pytorch( + nnue_pytorch_directory, nnue_pytorch_repo, nnue_pytorch_branch_or_commit + ) if args.features is None: - args.features = get_default_feature_set_from_nnue_pytorch(nnue_pytorch_directory) - - LOGGER.info('Initialization completed.') + args.features = get_default_feature_set_from_nnue_pytorch( + nnue_pytorch_directory + ) + LOGGER.info("Initialization completed.") # Directory layout: # tmp/experiments/experiment_{name}/training/run_{i} @@ -2525,7 +2758,9 @@ def main(): start_model = None if args.start_from_engine_test_net: - args.start_from_model = str(next(Path(os.path.join(stockfish_test_directory,"src/")).rglob("*.nnue"))) + args.start_from_model = str( + next(Path(os.path.join(stockfish_test_directory, "src/")).rglob("*.nnue")) + ) if args.start_from_model: start_model = prepare_start_model( @@ -2533,62 +2768,74 @@ def main(): model_path=args.start_from_model, run_id=None, nnue_pytorch_directory=nnue_pytorch_directory, - features=args.features + features=args.features, ) elif args.start_from_experiment: start_model = prepare_start_model_from_experiment( directory=start_model_directory, - experiment_path=os.path.join(absolute_workspace_path, 'experiments', args.start_from_experiment), + experiment_path=os.path.join( + absolute_workspace_path, "experiments", args.start_from_experiment + ), run_id=None, nnue_pytorch_directory=nnue_pytorch_directory, - features=args.features + features=args.features, ) training_runs = [] if do_network_training: - gpu_ids = [int(v) for v in args.gpus.split(',') if v] + gpu_ids = [int(v) for v in args.gpus.split(",") if v] for gpu_id in gpu_ids: for j in range(args.runs_per_gpu): - run_id = gpu_id*args.runs_per_gpu+j - - training_runs.append(TrainingRun( - gpu_id=gpu_id, - run_id=run_id, - nnue_pytorch_directory=nnue_pytorch_directory, - training_datasets=args.training_datasets, - validation_datasets=args.validation_datasets, - num_data_loader_threads=args.num_workers, - num_pytorch_threads=args.threads, - num_epochs=args.max_epoch, - batch_size=args.batch_size, - random_fen_skipping=args.random_fen_skipping, - smart_fen_skipping=args.smart_fen_skipping, - wld_fen_skipping=args.wld_fen_skipping, - early_fen_skipping=args.early_fen_skipping, - features=args.features, - lr=args.lr, - gamma=args.gamma, - lambda_=args.lambda_, - start_lambda=args.start_lambda, - end_lambda=args.end_lambda, - network_save_period=args.network_save_period, - save_last_network=args.save_last_network, - seed=args.seed + run_id, - start_from_model=start_model, - root_dir=os.path.join(experiment_directory, 'training', f'run_{run_id}'), - epoch_size=args.epoch_size, - validation_size=args.validation_size, - resume_training=args.resume_training, - additional_args=[arg for arg in args.additional_training_args or []] - )) - LOGGER.info(f'Doing network training on gpus {gpu_ids}. {len(training_runs)} runs in total.') + run_id = gpu_id * args.runs_per_gpu + j + + training_runs.append( + TrainingRun( + gpu_id=gpu_id, + run_id=run_id, + nnue_pytorch_directory=nnue_pytorch_directory, + training_datasets=args.training_datasets, + validation_datasets=args.validation_datasets, + num_data_loader_threads=args.num_workers, + num_pytorch_threads=args.threads, + num_epochs=args.max_epoch, + batch_size=args.batch_size, + random_fen_skipping=args.random_fen_skipping, + smart_fen_skipping=args.smart_fen_skipping, + wld_fen_skipping=args.wld_fen_skipping, + early_fen_skipping=args.early_fen_skipping, + features=args.features, + lr=args.lr, + gamma=args.gamma, + lambda_=args.lambda_, + start_lambda=args.start_lambda, + end_lambda=args.end_lambda, + network_save_period=args.network_save_period, + save_last_network=args.save_last_network, + seed=args.seed + run_id, + start_from_model=start_model, + root_dir=os.path.join( + experiment_directory, "training", f"run_{run_id}" + ), + epoch_size=args.epoch_size, + validation_size=args.validation_size, + resume_training=args.resume_training, + additional_args=[ + arg for arg in args.additional_training_args or [] + ], + ) + ) + LOGGER.info( + f"Doing network training on gpus {gpu_ids}. {len(training_runs)} runs in total." + ) else: - LOGGER.info('Not training networks.') + LOGGER.info("Not training networks.") network_testing = NetworkTesting( nnue_pytorch_directory=nnue_pytorch_directory, - root_dir=os.path.join(experiment_directory, 'training'), - ordo_exe=None if args.do_approximate_ordo else make_ordo_executable_path(ordo_directory), + root_dir=os.path.join(experiment_directory, "training"), + ordo_exe=None + if args.do_approximate_ordo + else make_ordo_executable_path(ordo_directory), c_chess_cli_exe=make_c_chess_cli_executable_path(c_chess_cli_directory), stockfish_base_exe=make_stockfish_executable_path(stockfish_base_directory), stockfish_test_exe=make_stockfish_executable_path(stockfish_test_directory), @@ -2602,7 +2849,7 @@ def main(): hash=args.network_testing_hash_mb, games_per_round=args.network_testing_games_per_round, active=do_network_testing, - additional_args=[arg for arg in args.additional_testing_args or []] + additional_args=[arg for arg in args.additional_testing_args or []], ) for tr in training_runs: @@ -2625,7 +2872,11 @@ def main(): last_scene = None while True: try: - Screen.wrapper(app, catch_interrupt=True, arguments=[last_scene, training_runs, network_testing]) + Screen.wrapper( + app, + catch_interrupt=True, + arguments=[last_scene, training_runs, network_testing], + ) break except ResizeScreenError as e: last_scene = e.scene @@ -2633,19 +2884,19 @@ def main(): while True: try: v = input() - if v == 'quit': + if v == "quit": break else: - print('Type `quit` to stop.') + print("Type `quit` to stop.") except EOFError: # For non-interactive environments time.sleep(1) - LOGGER.info('Stopping training runs.') + LOGGER.info("Stopping training runs.") for tr in training_runs: tr.stop() - LOGGER.info('Stopping network testing.') + LOGGER.info("Stopping network testing.") network_testing.stop() any_training_error = False @@ -2657,5 +2908,6 @@ def main(): if any_training_error: sys.exit(EXITCODE_TRAINING_NOT_FINISHED) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/serialize.py b/serialize.py index d48dc834..f3e21002 100644 --- a/serialize.py +++ b/serialize.py @@ -13,349 +13,438 @@ import numpy as np from numba import njit + def ascii_hist(name, x, bins=6): - N,X = np.histogram(x, bins=bins) - total = 1.0*len(x) - width = 50 - nmax = N.max() + N, X = np.histogram(x, bins=bins) + total = 1.0 * len(x) + width = 50 + nmax = N.max() + + print(name) + for (xi, n) in zip(X, N): + bar = "#" * int(n * 1.0 * width / nmax) + xi = "{0: <8.4g}".format(xi).ljust(10) + print("{0}| {1}".format(xi, bar)) - print(name) - for (xi, n) in zip(X,N): - bar = '#'*int(n*1.0*width/nmax) - xi = '{0: <8.4g}'.format(xi).ljust(10) - print('{0}| {1}'.format(xi,bar)) @njit def encode_leb_128_array(arr): - res = [] - for v in arr: - while True: - byte = v & 0x7f - v = v >> 7 - if (v == 0 and byte & 0x40 == 0) or (v == -1 and byte & 0x40 != 0): - res.append(byte) - break - res.append(byte | 0x80) - return res + res = [] + for v in arr: + while True: + byte = v & 0x7F + v = v >> 7 + if (v == 0 and byte & 0x40 == 0) or (v == -1 and byte & 0x40 != 0): + res.append(byte) + break + res.append(byte | 0x80) + return res + @njit def decode_leb_128_array(arr, n): - ints = np.zeros(n) - k = 0 - for i in range(n): - r = 0 - shift = 0 - while True: - byte = arr[k] - k = k + 1 - r |= (byte & 0x7f) << shift - shift += 7 - if (byte & 0x80) == 0: - ints[i] = r if (byte & 0x40) == 0 else r | ~((1 << shift) - 1) - break - return ints + ints = np.zeros(n) + k = 0 + for i in range(n): + r = 0 + shift = 0 + while True: + byte = arr[k] + k = k + 1 + r |= (byte & 0x7F) << shift + shift += 7 + if (byte & 0x80) == 0: + ints[i] = r if (byte & 0x40) == 0 else r | ~((1 << shift) - 1) + break + return ints + # hardcoded for now VERSION = 0x7AF32F20 DEFAULT_DESCRIPTION = "Network trained with the https://github.com/official-stockfish/nnue-pytorch trainer." -class NNUEWriter(): - """ - All values are stored in little endian. - """ - def __init__(self, model, description=None, ft_compression='none'): - if description is None: - description = DEFAULT_DESCRIPTION - - self.buf = bytearray() - - # NOTE: model._clip_weights() should probably be called here. It's not necessary now - # because it doesn't have more restrictive bounds than these defined by quantization, - # but it might be necessary in the future. - fc_hash = self.fc_hash(model) - self.write_header(model, fc_hash, description) - self.int32(model.feature_set.hash ^ (M.L1*2)) # Feature transformer hash - self.write_feature_transformer(model, ft_compression) - for l1, l2, output in model.layer_stacks.get_coalesced_layer_stacks(): - self.int32(fc_hash) # FC layers hash - self.write_fc_layer(model, l1) - self.write_fc_layer(model, l2) - self.write_fc_layer(model, output, is_output=True) - - @staticmethod - def fc_hash(model): - # InputSlice hash - prev_hash = 0xEC42E90D - prev_hash ^= (M.L1 * 2) - - # Fully connected layers - layers = [model.layer_stacks.l1, model.layer_stacks.l2, model.layer_stacks.output] - for layer in layers: - layer_hash = 0xCC03DAE4 - layer_hash += layer.out_features // model.num_ls_buckets - layer_hash ^= prev_hash >> 1 - layer_hash ^= (prev_hash << 31) & 0xFFFFFFFF - if layer.out_features // model.num_ls_buckets != 1: - # Clipped ReLU hash - layer_hash = (layer_hash + 0x538D24C7) & 0xFFFFFFFF - prev_hash = layer_hash - return layer_hash - - def write_header(self, model, fc_hash, description): - self.int32(VERSION) # version - self.int32(fc_hash ^ model.feature_set.hash ^ (M.L1*2)) # halfkp network hash - encoded_description = description.encode('utf-8') - self.int32(len(encoded_description)) # Network definition - self.buf.extend(encoded_description) - - def write_leb_128_array(self, arr): - buf = encode_leb_128_array(arr) - self.int32(len(buf)) - self.buf.extend(buf) - - def write_tensor(self, arr, compression='none'): - if compression == 'none': - self.buf.extend(arr.tobytes()) - elif compression == 'leb128': - self.buf.extend('COMPRESSED_LEB128'.encode('utf-8')) - self.write_leb_128_array(arr) - else: - raise Exception('Invalid compression method.') - - def write_feature_transformer(self, model, ft_compression): - layer = model.input - - bias = layer.bias.data[:M.L1] - bias = bias.mul(model.quantized_one).round().to(torch.int16) - - all_weight = M.coalesce_ft_weights(model, layer) - weight = all_weight[:, :M.L1] - psqt_weight = all_weight[:, M.L1:] - - weight = weight.mul(model.quantized_one).round().to(torch.int16) - psqt_weight = psqt_weight.mul(model.nnue2score * model.weight_scale_out).round().to(torch.int32) - - ascii_hist('ft bias:', bias.numpy()) - ascii_hist('ft weight:', weight.numpy()) - ascii_hist('ft psqt weight:', psqt_weight.numpy()) - - # Weights stored as [num_features][outputs] - - self.write_tensor(bias.flatten().numpy(), ft_compression) - self.write_tensor(weight.flatten().numpy(), ft_compression) - self.write_tensor(psqt_weight.flatten().numpy(), ft_compression) - - def write_fc_layer(self, model, layer, is_output=False): - # FC layers are stored as int8 weights, and int32 biases - kWeightScaleHidden = model.weight_scale_hidden - kWeightScaleOut = model.nnue2score * model.weight_scale_out / model.quantized_one - kWeightScale = kWeightScaleOut if is_output else kWeightScaleHidden - kBiasScaleOut = model.weight_scale_out * model.nnue2score - kBiasScaleHidden = model.weight_scale_hidden * model.quantized_one - kBiasScale = kBiasScaleOut if is_output else kBiasScaleHidden - kMaxWeight = model.quantized_one / kWeightScale - - bias = layer.bias.data - bias = bias.mul(kBiasScale).round().to(torch.int32) - - weight = layer.weight.data - clipped = torch.count_nonzero(weight.clamp(-kMaxWeight, kMaxWeight) - weight) - total_elements = torch.numel(weight) - clipped_max = torch.max(torch.abs(weight.clamp(-kMaxWeight, kMaxWeight) - weight)) - - weight = weight.clamp(-kMaxWeight, kMaxWeight).mul(kWeightScale).round().to(torch.int8) - - ascii_hist('fc bias:', bias.numpy()) - print("layer has {}/{} clipped weights. Exceeding by {} the maximum {}.".format(clipped, total_elements, clipped_max, kMaxWeight)) - ascii_hist('fc weight:', weight.numpy()) - - # FC inputs are padded to 32 elements by spec. - num_input = weight.shape[1] - if num_input % 32 != 0: - num_input += 32 - (num_input % 32) - new_w = torch.zeros(weight.shape[0], num_input, dtype=torch.int8) - new_w[:, :weight.shape[1]] = weight - weight = new_w - - self.buf.extend(bias.flatten().numpy().tobytes()) - # Weights stored as [outputs][inputs], so we can flatten - self.buf.extend(weight.flatten().numpy().tobytes()) - - def int32(self, v): - self.buf.extend(struct.pack("> 1 + layer_hash ^= (prev_hash << 31) & 0xFFFFFFFF + if layer.out_features // model.num_ls_buckets != 1: + # Clipped ReLU hash + layer_hash = (layer_hash + 0x538D24C7) & 0xFFFFFFFF + prev_hash = layer_hash + return layer_hash + + def write_header(self, model, fc_hash, description): + self.int32(VERSION) # version + self.int32(fc_hash ^ model.feature_set.hash ^ (M.L1 * 2)) # halfkp network hash + encoded_description = description.encode("utf-8") + self.int32(len(encoded_description)) # Network definition + self.buf.extend(encoded_description) + + def write_leb_128_array(self, arr): + buf = encode_leb_128_array(arr) + self.int32(len(buf)) + self.buf.extend(buf) + + def write_tensor(self, arr, compression="none"): + if compression == "none": + self.buf.extend(arr.tobytes()) + elif compression == "leb128": + self.buf.extend("COMPRESSED_LEB128".encode("utf-8")) + self.write_leb_128_array(arr) + else: + raise Exception("Invalid compression method.") + + def write_feature_transformer(self, model, ft_compression): + layer = model.input + + bias = layer.bias.data[: M.L1] + bias = bias.mul(model.quantized_one).round().to(torch.int16) + + all_weight = M.coalesce_ft_weights(model, layer) + weight = all_weight[:, : M.L1] + psqt_weight = all_weight[:, M.L1 :] + + weight = weight.mul(model.quantized_one).round().to(torch.int16) + psqt_weight = ( + psqt_weight.mul(model.nnue2score * model.weight_scale_out) + .round() + .to(torch.int32) + ) + + ascii_hist("ft bias:", bias.numpy()) + ascii_hist("ft weight:", weight.numpy()) + ascii_hist("ft psqt weight:", psqt_weight.numpy()) + + # Weights stored as [num_features][outputs] + + self.write_tensor(bias.flatten().numpy(), ft_compression) + self.write_tensor(weight.flatten().numpy(), ft_compression) + self.write_tensor(psqt_weight.flatten().numpy(), ft_compression) + + def write_fc_layer(self, model, layer, is_output=False): + # FC layers are stored as int8 weights, and int32 biases + kWeightScaleHidden = model.weight_scale_hidden + kWeightScaleOut = ( + model.nnue2score * model.weight_scale_out / model.quantized_one + ) + kWeightScale = kWeightScaleOut if is_output else kWeightScaleHidden + kBiasScaleOut = model.weight_scale_out * model.nnue2score + kBiasScaleHidden = model.weight_scale_hidden * model.quantized_one + kBiasScale = kBiasScaleOut if is_output else kBiasScaleHidden + kMaxWeight = model.quantized_one / kWeightScale + + bias = layer.bias.data + bias = bias.mul(kBiasScale).round().to(torch.int32) + + weight = layer.weight.data + clipped = torch.count_nonzero(weight.clamp(-kMaxWeight, kMaxWeight) - weight) + total_elements = torch.numel(weight) + clipped_max = torch.max( + torch.abs(weight.clamp(-kMaxWeight, kMaxWeight) - weight) + ) + + weight = ( + weight.clamp(-kMaxWeight, kMaxWeight) + .mul(kWeightScale) + .round() + .to(torch.int8) + ) + + ascii_hist("fc bias:", bias.numpy()) + print( + "layer has {}/{} clipped weights. Exceeding by {} the maximum {}.".format( + clipped, total_elements, clipped_max, kMaxWeight + ) + ) + ascii_hist("fc weight:", weight.numpy()) + + # FC inputs are padded to 32 elements by spec. + num_input = weight.shape[1] + if num_input % 32 != 0: + num_input += 32 - (num_input % 32) + new_w = torch.zeros(weight.shape[0], num_input, dtype=torch.int8) + new_w[:, : weight.shape[1]] = weight + weight = new_w + + self.buf.extend(bias.flatten().numpy().tobytes()) + # Weights stored as [outputs][inputs], so we can flatten + self.buf.extend(weight.flatten().numpy().tobytes()) + + def int32(self, v): + self.buf.extend(struct.pack(" 0: - val_datasets = args.validation_datasets - - if (args.start_lambda is not None) != (args.end_lambda is not None): - raise Exception('Either both or none of start_lambda and end_lambda must be specified.') - - feature_set = features.get_feature_set_from_name(args.features) - - start_lambda = args.start_lambda or args.lambda_ - end_lambda = args.end_lambda or args.lambda_ - max_epoch = args.max_epochs or 800 - if args.resume_from_model is None: - nnue = M.NNUE( - feature_set=feature_set, - start_lambda=start_lambda, - max_epoch=max_epoch, - end_lambda=end_lambda, - gamma=args.gamma, - lr=args.lr, - param_index=args.param_index - ) - else: - nnue = torch.load(args.resume_from_model) - nnue.set_feature_set(feature_set) - nnue.start_lambda = start_lambda - nnue.end_lambda = end_lambda - nnue.max_epoch = max_epoch - # we can set the following here just like that because when resuming - # from .pt the optimizer is only created after the training is started - nnue.gamma = args.gamma - nnue.lr = args.lr - nnue.param_index=args.param_index - - print("Feature set: {}".format(feature_set.name)) - print("Num real features: {}".format(feature_set.num_real_features)) - print("Num virtual features: {}".format(feature_set.num_virtual_features)) - print("Num features: {}".format(feature_set.num_features)) - - print("Training with: {}".format(train_datasets)) - print("Validating with: {}".format(val_datasets)) - - pl.seed_everything(args.seed) - print("Seed {}".format(args.seed)) - - batch_size = args.batch_size - if batch_size <= 0: - batch_size = 16384 - print('Using batch size {}'.format(batch_size)) - - print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping)) - print('WLD fen skipping: {}'.format(not args.no_wld_fen_skipping)) - print('Random fen skipping: {}'.format(args.random_fen_skipping)) - print('Skip early plies: {}'.format(args.early_fen_skipping)) - print('Param index: {}'.format(args.param_index)) - - if args.threads > 0: - print('limiting torch to {} threads.'.format(args.threads)) - t_set_num_threads(args.threads) - - logdir = args.default_root_dir if args.default_root_dir else 'logs/' - print('Using log dir {}'.format(logdir), flush=True) - - tb_logger = pl_loggers.TensorBoardLogger(logdir) - checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=args.save_last_network, every_n_epochs=args.network_save_period, save_top_k=-1) - trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=tb_logger) - - main_device = trainer.strategy.root_device if trainer.strategy.root_device.index is None else 'cuda:' + str(trainer.strategy.root_device.index) - - nnue.to(device=main_device) - - print('Using c++ data loader') - train, val = make_data_loaders( - train_datasets, - val_datasets, - feature_set, - args.num_workers, - batch_size, - not args.no_smart_fen_skipping, - args.random_fen_skipping, - not args.no_wld_fen_skipping, - args.early_fen_skipping, - args.param_index, - main_device, - args.epoch_size, - args.validation_size) + parser = argparse.ArgumentParser(description="Trains the network.") + parser.add_argument( + "datasets", + action="append", + nargs="+", + help="Training datasets (.binpack). Interleaved at chunk level if multiple specified. Same data is used for training and validation if not validation data is specified.", + ) + parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument( + "--validation-data", + type=str, + action="append", + nargs="+", + dest="validation_datasets", + help="Validation data to use for validation instead of the training data.", + ) + parser.add_argument( + "--lambda", + default=1.0, + type=float, + dest="lambda_", + help="lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0).", + ) + parser.add_argument( + "--start-lambda", + default=None, + type=float, + dest="start_lambda", + help="lambda to use at first epoch.", + ) + parser.add_argument( + "--end-lambda", + default=None, + type=float, + dest="end_lambda", + help="lambda to use at last epoch.", + ) + parser.add_argument( + "--gamma", + default=0.992, + type=float, + dest="gamma", + help="Multiplicative factor applied to the learning rate after every epoch.", + ) + parser.add_argument( + "--lr", default=8.75e-4, type=float, dest="lr", help="Initial learning rate." + ) + parser.add_argument( + "--num-workers", + default=1, + type=int, + dest="num_workers", + help="Number of worker threads to use for data loading. Currently only works well for binpack.", + ) + parser.add_argument( + "--batch-size", + default=-1, + type=int, + dest="batch_size", + help="Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128.", + ) + parser.add_argument( + "--threads", + default=-1, + type=int, + dest="threads", + help="Number of torch threads to use. Default automatic (cores) .", + ) + parser.add_argument( + "--seed", default=42, type=int, dest="seed", help="torch seed to use." + ) + parser.add_argument( + "--smart-fen-skipping", + action="store_true", + dest="smart_fen_skipping_deprecated", + help="If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored", + ) + parser.add_argument( + "--no-smart-fen-skipping", + action="store_true", + dest="no_smart_fen_skipping", + help="If used then no smart fen skipping will be done. By default smart fen skipping is done.", + ) + parser.add_argument( + "--no-wld-fen-skipping", + action="store_true", + dest="no_wld_fen_skipping", + help="If used then no wld fen skipping will be done. By default wld fen skipping is done.", + ) + parser.add_argument( + "--random-fen-skipping", + default=3, + type=int, + dest="random_fen_skipping", + help="skip fens randomly on average random_fen_skipping before using one.", + ) + parser.add_argument( + "--resume-from-model", + dest="resume_from_model", + help="Initializes training using the weights from the given .pt model", + ) + parser.add_argument( + "--network-save-period", + type=int, + default=20, + dest="network_save_period", + help="Number of epochs between network snapshots. None to disable.", + ) + parser.add_argument( + "--save-last-network", + type=str2bool, + default=True, + dest="save_last_network", + help="Whether to always save the last produced network.", + ) + parser.add_argument( + "--epoch-size", + type=int, + default=100000000, + dest="epoch_size", + help="Number of positions per epoch.", + ) + parser.add_argument( + "--validation-size", + type=int, + default=1000000, + dest="validation_size", + help="Number of positions per validation step.", + ) + parser.add_argument( + "--param-index", + type=int, + default=0, + dest="param_index", + help="Indexing for parameter scans.", + ) + parser.add_argument( + "--early-fen-skipping", + type=int, + default=-1, + dest="early_fen_skipping", + help="Skip n plies from the start.", + ) + features.add_argparse_args(parser) + args = parser.parse_args() + + args.datasets = flatten_once(args.datasets) + if args.validation_datasets: + args.validation_datasets = flatten_once(args.validation_datasets) + else: + args.validation_datasets = [] + + for dataset in args.datasets: + if not os.path.exists(dataset): + raise Exception("{0} does not exist".format(dataset)) + + for val_dataset in args.validation_datasets: + if not os.path.exists(val_dataset): + raise Exception("{0} does not exist".format(val_dataset)) + + train_datasets = args.datasets + val_datasets = train_datasets + if len(args.validation_datasets) > 0: + val_datasets = args.validation_datasets + + if (args.start_lambda is not None) != (args.end_lambda is not None): + raise Exception( + "Either both or none of start_lambda and end_lambda must be specified." + ) + + feature_set = features.get_feature_set_from_name(args.features) + + start_lambda = args.start_lambda or args.lambda_ + end_lambda = args.end_lambda or args.lambda_ + max_epoch = args.max_epochs or 800 + if args.resume_from_model is None: + nnue = M.NNUE( + feature_set=feature_set, + start_lambda=start_lambda, + max_epoch=max_epoch, + end_lambda=end_lambda, + gamma=args.gamma, + lr=args.lr, + param_index=args.param_index, + ) + else: + nnue = torch.load(args.resume_from_model) + nnue.set_feature_set(feature_set) + nnue.start_lambda = start_lambda + nnue.end_lambda = end_lambda + nnue.max_epoch = max_epoch + # we can set the following here just like that because when resuming + # from .pt the optimizer is only created after the training is started + nnue.gamma = args.gamma + nnue.lr = args.lr + nnue.param_index = args.param_index + + print("Feature set: {}".format(feature_set.name)) + print("Num real features: {}".format(feature_set.num_real_features)) + print("Num virtual features: {}".format(feature_set.num_virtual_features)) + print("Num features: {}".format(feature_set.num_features)) + + print("Training with: {}".format(train_datasets)) + print("Validating with: {}".format(val_datasets)) + + pl.seed_everything(args.seed) + print("Seed {}".format(args.seed)) + + batch_size = args.batch_size + if batch_size <= 0: + batch_size = 16384 + print("Using batch size {}".format(batch_size)) + + print("Smart fen skipping: {}".format(not args.no_smart_fen_skipping)) + print("WLD fen skipping: {}".format(not args.no_wld_fen_skipping)) + print("Random fen skipping: {}".format(args.random_fen_skipping)) + print("Skip early plies: {}".format(args.early_fen_skipping)) + print("Param index: {}".format(args.param_index)) + + if args.threads > 0: + print("limiting torch to {} threads.".format(args.threads)) + t_set_num_threads(args.threads) + + logdir = args.default_root_dir if args.default_root_dir else "logs/" + print("Using log dir {}".format(logdir), flush=True) + + tb_logger = pl_loggers.TensorBoardLogger(logdir) + checkpoint_callback = pl.callbacks.ModelCheckpoint( + save_last=args.save_last_network, + every_n_epochs=args.network_save_period, + save_top_k=-1, + ) + trainer = pl.Trainer.from_argparse_args( + args, callbacks=[checkpoint_callback], logger=tb_logger + ) + + main_device = ( + trainer.strategy.root_device + if trainer.strategy.root_device.index is None + else "cuda:" + str(trainer.strategy.root_device.index) + ) + + nnue.to(device=main_device) + + print("Using c++ data loader") + train, val = make_data_loaders( + train_datasets, + val_datasets, + feature_set, + args.num_workers, + batch_size, + not args.no_smart_fen_skipping, + args.random_fen_skipping, + not args.no_wld_fen_skipping, + args.early_fen_skipping, + args.param_index, + main_device, + args.epoch_size, + args.validation_size, + ) + + trainer.fit(nnue, train, val) - trainer.fit(nnue, train, val) + with open(os.path.join(logdir, "training_finished"), "w"): + pass - with open(os.path.join(logdir, 'training_finished'), 'w'): - pass -if __name__ == '__main__': - main() - if sys.platform == "win32": - os.system(f'wmic process where processid="{os.getpid()}" call terminate >nul') +if __name__ == "__main__": + main() + if sys.platform == "win32": + os.system(f'wmic process where processid="{os.getpid()}" call terminate >nul') diff --git a/training_data_loader.cpp b/training_data_loader.cpp index 7dcdc3f5..e678ea94 100644 --- a/training_data_loader.cpp +++ b/training_data_loader.cpp @@ -13,24 +13,23 @@ #include "lib/nnue_training_data_stream.h" #include "lib/rng.h" -#if defined (__x86_64__) -#define EXPORT -#define CDECL +#if defined(__x86_64__) + #define EXPORT + #define CDECL #else -#if defined (_MSC_VER) -#define EXPORT __declspec(dllexport) -#define CDECL __cdecl -#else -#define EXPORT -#define CDECL __attribute__ ((__cdecl__)) -#endif + #if defined(_MSC_VER) + #define EXPORT __declspec(dllexport) + #define CDECL __cdecl + #else + #define EXPORT + #define CDECL __attribute__((__cdecl__)) + #endif #endif using namespace binpack; using namespace chess; -static Square orient(Color color, Square sq) -{ +static Square orient(Color color, Square sq) { if (color == Color::White) { return sq; @@ -44,8 +43,7 @@ static Square orient(Color color, Square sq) } } -static Square orient_flip(Color color, Square sq) -{ +static Square orient_flip(Color color, Square sq) { if (color == Color::White) { return sq; @@ -57,210 +55,216 @@ static Square orient_flip(Color color, Square sq) } struct HalfKP { - static constexpr int NUM_SQ = 64; - static constexpr int NUM_PT = 10; + static constexpr int NUM_SQ = 64; + static constexpr int NUM_PT = 10; static constexpr int NUM_PLANES = (NUM_SQ * NUM_PT + 1); - static constexpr int INPUTS = NUM_PLANES * NUM_SQ; + static constexpr int INPUTS = NUM_PLANES * NUM_SQ; static constexpr int MAX_ACTIVE_FEATURES = 32; - static int feature_index(Color color, Square ksq, Square sq, Piece p) - { + static int feature_index(Color color, Square ksq, Square sq, Piece p) { auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); - return 1 + static_cast(orient(color, sq)) + p_idx * NUM_SQ + static_cast(ksq) * NUM_PLANES; + return 1 + static_cast(orient(color, sq)) + p_idx * NUM_SQ + + static_cast(ksq) * NUM_PLANES; } - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { - auto& pos = e.pos; - auto pieces = pos.piecesBB() & ~(pos.piecesBB(Piece(PieceType::King, Color::White)) | pos.piecesBB(Piece(PieceType::King, Color::Black))); + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { + auto& pos = e.pos; + auto pieces = pos.piecesBB() + & ~(pos.piecesBB(Piece(PieceType::King, Color::White)) + | pos.piecesBB(Piece(PieceType::King, Color::Black))); auto ksq = pos.kingSquare(color); // We order the features so that the resulting sparse // tensor is coalesced. int j = 0; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); - values[j] = 1.0f; + auto p = pos.pieceAt(sq); + values[j] = 1.0f; features[j] = feature_index(color, orient(color, ksq), sq, p); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; struct HalfKPFactorized { // Factorized features - static constexpr int K_INPUTS = HalfKP::NUM_SQ; + static constexpr int K_INPUTS = HalfKP::NUM_SQ; static constexpr int PIECE_INPUTS = HalfKP::NUM_SQ * HalfKP::NUM_PT; - static constexpr int INPUTS = HalfKP::INPUTS + K_INPUTS + PIECE_INPUTS; + static constexpr int INPUTS = HalfKP::INPUTS + K_INPUTS + PIECE_INPUTS; - static constexpr int MAX_K_FEATURES = 1; + static constexpr int MAX_K_FEATURES = 1; static constexpr int MAX_PIECE_FEATURES = 32; - static constexpr int MAX_ACTIVE_FEATURES = HalfKP::MAX_ACTIVE_FEATURES + MAX_K_FEATURES + MAX_PIECE_FEATURES; + static constexpr int MAX_ACTIVE_FEATURES = + HalfKP::MAX_ACTIVE_FEATURES + MAX_K_FEATURES + MAX_PIECE_FEATURES; - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { auto [start_j, offset] = HalfKP::fill_features_sparse(e, features, values, color); - int j = start_j; - auto& pos = e.pos; + int j = start_j; + auto& pos = e.pos; { // king square factor - auto ksq = pos.kingSquare(color); + auto ksq = pos.kingSquare(color); features[j] = offset + static_cast(orient(color, ksq)); - values[j] = static_cast(start_j); + values[j] = static_cast(start_j); ++j; } offset += K_INPUTS; - auto pieces = pos.piecesBB() & ~(pos.piecesBB(Piece(PieceType::King, Color::White)) | pos.piecesBB(Piece(PieceType::King, Color::Black))); + auto pieces = pos.piecesBB() + & ~(pos.piecesBB(Piece(PieceType::King, Color::White)) + | pos.piecesBB(Piece(PieceType::King, Color::Black))); // We order the features so that the resulting sparse // tensor is coalesced. Note that we can just sort // the parts where values are all 1.0f and leave the // halfk feature where it was. - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); - auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); - values[j] = 1.0f; + auto p = pos.pieceAt(sq); + auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); + values[j] = 1.0f; features[j] = offset + (p_idx * HalfKP::NUM_SQ) + static_cast(orient(color, sq)); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; struct HalfKA { - static constexpr int NUM_SQ = 64; - static constexpr int NUM_PT = 12; + static constexpr int NUM_SQ = 64; + static constexpr int NUM_PT = 12; static constexpr int NUM_PLANES = (NUM_SQ * NUM_PT + 1); - static constexpr int INPUTS = NUM_PLANES * NUM_SQ; + static constexpr int INPUTS = NUM_PLANES * NUM_SQ; static constexpr int MAX_ACTIVE_FEATURES = 32; - static int feature_index(Color color, Square ksq, Square sq, Piece p) - { + static int feature_index(Color color, Square ksq, Square sq, Piece p) { auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); - return 1 + static_cast(orient_flip(color, sq)) + p_idx * NUM_SQ + static_cast(ksq) * NUM_PLANES; + return 1 + static_cast(orient_flip(color, sq)) + p_idx * NUM_SQ + + static_cast(ksq) * NUM_PLANES; } - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { - auto& pos = e.pos; - auto pieces = pos.piecesBB(); - auto ksq = pos.kingSquare(color); + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { + auto& pos = e.pos; + auto pieces = pos.piecesBB(); + auto ksq = pos.kingSquare(color); int j = 0; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); - values[j] = 1.0f; + auto p = pos.pieceAt(sq); + values[j] = 1.0f; features[j] = feature_index(color, orient_flip(color, ksq), sq, p); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; struct HalfKAFactorized { // Factorized features static constexpr int PIECE_INPUTS = HalfKA::NUM_SQ * HalfKA::NUM_PT; - static constexpr int INPUTS = HalfKA::INPUTS + PIECE_INPUTS; + static constexpr int INPUTS = HalfKA::INPUTS + PIECE_INPUTS; - static constexpr int MAX_PIECE_FEATURES = 32; + static constexpr int MAX_PIECE_FEATURES = 32; static constexpr int MAX_ACTIVE_FEATURES = HalfKA::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES; - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { const auto [start_j, offset] = HalfKA::fill_features_sparse(e, features, values, color); - auto& pos = e.pos; - auto pieces = pos.piecesBB(); + auto& pos = e.pos; + auto pieces = pos.piecesBB(); int j = start_j; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); + auto p = pos.pieceAt(sq); auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); - values[j] = 1.0f; - features[j] = offset + (p_idx * HalfKA::NUM_SQ) + static_cast(orient_flip(color, sq)); + values[j] = 1.0f; + features[j] = + offset + (p_idx * HalfKA::NUM_SQ) + static_cast(orient_flip(color, sq)); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; struct HalfKAv2 { - static constexpr int NUM_SQ = 64; - static constexpr int NUM_PT = 11; + static constexpr int NUM_SQ = 64; + static constexpr int NUM_PT = 11; static constexpr int NUM_PLANES = NUM_SQ * NUM_PT; - static constexpr int INPUTS = NUM_PLANES * NUM_SQ; + static constexpr int INPUTS = NUM_PLANES * NUM_SQ; static constexpr int MAX_ACTIVE_FEATURES = 32; - static int feature_index(Color color, Square ksq, Square sq, Piece p) - { + static int feature_index(Color color, Square ksq, Square sq, Piece p) { auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); if (p_idx == 11) - --p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ - return static_cast(orient_flip(color, sq)) + p_idx * NUM_SQ + static_cast(ksq) * NUM_PLANES; + --p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ + return static_cast(orient_flip(color, sq)) + p_idx * NUM_SQ + + static_cast(ksq) * NUM_PLANES; } - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { - auto& pos = e.pos; - auto pieces = pos.piecesBB(); - auto ksq = pos.kingSquare(color); + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { + auto& pos = e.pos; + auto pieces = pos.piecesBB(); + auto ksq = pos.kingSquare(color); int j = 0; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); - values[j] = 1.0f; + auto p = pos.pieceAt(sq); + values[j] = 1.0f; features[j] = feature_index(color, orient_flip(color, ksq), sq, p); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; struct HalfKAv2Factorized { // Factorized features - static constexpr int NUM_PT = 12; + static constexpr int NUM_PT = 12; static constexpr int PIECE_INPUTS = HalfKAv2::NUM_SQ * NUM_PT; - static constexpr int INPUTS = HalfKAv2::INPUTS + PIECE_INPUTS; + static constexpr int INPUTS = HalfKAv2::INPUTS + PIECE_INPUTS; - static constexpr int MAX_PIECE_FEATURES = 32; + static constexpr int MAX_PIECE_FEATURES = 32; static constexpr int MAX_ACTIVE_FEATURES = HalfKAv2::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES; - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { const auto [start_j, offset] = HalfKAv2::fill_features_sparse(e, features, values, color); - auto& pos = e.pos; - auto pieces = pos.piecesBB(); + auto& pos = e.pos; + auto pieces = pos.piecesBB(); int j = start_j; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); + auto p = pos.pieceAt(sq); auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); - values[j] = 1.0f; - features[j] = offset + (p_idx * HalfKAv2::NUM_SQ) + static_cast(orient_flip(color, sq)); + values[j] = 1.0f; + features[j] = + offset + (p_idx * HalfKAv2::NUM_SQ) + static_cast(orient_flip(color, sq)); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; // ksq must not be oriented -static Square orient_flip_2(Color color, Square sq, Square ksq) -{ +static Square orient_flip_2(Color color, Square sq, Square ksq) { bool h = ksq.file() < fileE; if (color == Color::Black) sq = sq.flippedVertically(); @@ -270,118 +274,112 @@ static Square orient_flip_2(Color color, Square sq, Square ksq) } struct HalfKAv2_hm { - static constexpr int NUM_SQ = 64; - static constexpr int NUM_PT = 11; + static constexpr int NUM_SQ = 64; + static constexpr int NUM_PT = 11; static constexpr int NUM_PLANES = NUM_SQ * NUM_PT; - static constexpr int INPUTS = NUM_PLANES * NUM_SQ / 2; + static constexpr int INPUTS = NUM_PLANES * NUM_SQ / 2; static constexpr int MAX_ACTIVE_FEATURES = 32; static constexpr int KingBuckets[64] = { - -1, -1, -1, -1, 31, 30, 29, 28, - -1, -1, -1, -1, 27, 26, 25, 24, - -1, -1, -1, -1, 23, 22, 21, 20, - -1, -1, -1, -1, 19, 18, 17, 16, - -1, -1, -1, -1, 15, 14, 13, 12, - -1, -1, -1, -1, 11, 10, 9, 8, - -1, -1, -1, -1, 7, 6, 5, 4, - -1, -1, -1, -1, 3, 2, 1, 0 - }; - - static int feature_index(Color color, Square ksq, Square sq, Piece p) - { + -1, -1, -1, -1, 31, 30, 29, 28, -1, -1, -1, -1, 27, 26, 25, 24, -1, -1, -1, -1, 23, 22, + 21, 20, -1, -1, -1, -1, 19, 18, 17, 16, -1, -1, -1, -1, 15, 14, 13, 12, -1, -1, -1, -1, + 11, 10, 9, 8, -1, -1, -1, -1, 7, 6, 5, 4, -1, -1, -1, -1, 3, 2, 1, 0}; + + static int feature_index(Color color, Square ksq, Square sq, Piece p) { Square o_ksq = orient_flip_2(color, ksq, ksq); - auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); + auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); if (p_idx == 11) - --p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ - return static_cast(orient_flip_2(color, sq, ksq)) + p_idx * NUM_SQ + KingBuckets[static_cast(o_ksq)] * NUM_PLANES; + --p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ + return static_cast(orient_flip_2(color, sq, ksq)) + p_idx * NUM_SQ + + KingBuckets[static_cast(o_ksq)] * NUM_PLANES; } - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { - auto& pos = e.pos; - auto pieces = pos.piecesBB(); - auto ksq = pos.kingSquare(color); + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { + auto& pos = e.pos; + auto pieces = pos.piecesBB(); + auto ksq = pos.kingSquare(color); int j = 0; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); - values[j] = 1.0f; + auto p = pos.pieceAt(sq); + values[j] = 1.0f; features[j] = feature_index(color, ksq, sq, p); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; struct HalfKAv2_hmFactorized { // Factorized features - static constexpr int NUM_PT = 12; + static constexpr int NUM_PT = 12; static constexpr int PIECE_INPUTS = HalfKAv2_hm::NUM_SQ * NUM_PT; - static constexpr int INPUTS = HalfKAv2_hm::INPUTS + PIECE_INPUTS; + static constexpr int INPUTS = HalfKAv2_hm::INPUTS + PIECE_INPUTS; static constexpr int MAX_PIECE_FEATURES = 32; - static constexpr int MAX_ACTIVE_FEATURES = HalfKAv2_hm::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES; + static constexpr int MAX_ACTIVE_FEATURES = + HalfKAv2_hm::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES; - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { - const auto [start_j, offset] = HalfKAv2_hm::fill_features_sparse(e, features, values, color); - auto& pos = e.pos; - auto pieces = pos.piecesBB(); - auto ksq = pos.kingSquare(color); + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { + const auto [start_j, offset] = + HalfKAv2_hm::fill_features_sparse(e, features, values, color); + auto& pos = e.pos; + auto pieces = pos.piecesBB(); + auto ksq = pos.kingSquare(color); int j = start_j; - for(Square sq : pieces) + for (Square sq : pieces) { - auto p = pos.pieceAt(sq); - auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); - values[j] = 1.0f; - features[j] = offset + (p_idx * HalfKAv2_hm::NUM_SQ) + static_cast(orient_flip_2(color, sq, ksq)); + auto p = pos.pieceAt(sq); + auto p_idx = static_cast(p.type()) * 2 + (p.color() != color); + values[j] = 1.0f; + features[j] = offset + (p_idx * HalfKAv2_hm::NUM_SQ) + + static_cast(orient_flip_2(color, sq, ksq)); ++j; } - return { j, INPUTS }; + return {j, INPUTS}; } }; -template -struct FeatureSet -{ +template +struct FeatureSet { static_assert(sizeof...(Ts) == 0, "Currently only one feature subset supported."); - static constexpr int INPUTS = T::INPUTS; + static constexpr int INPUTS = T::INPUTS; static constexpr int MAX_ACTIVE_FEATURES = T::MAX_ACTIVE_FEATURES; - static std::pair fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) - { + static std::pair + fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) { return T::fill_features_sparse(e, features, values, color); } }; -struct SparseBatch -{ +struct SparseBatch { static constexpr bool IS_BATCH = true; - template - SparseBatch(FeatureSet, const std::vector& entries) - { - num_inputs = FeatureSet::INPUTS; - size = entries.size(); - is_white = new float[size]; - outcome = new float[size]; - score = new float[size]; - white = new int[size * FeatureSet::MAX_ACTIVE_FEATURES]; - black = new int[size * FeatureSet::MAX_ACTIVE_FEATURES]; - white_values = new float[size * FeatureSet::MAX_ACTIVE_FEATURES]; - black_values = new float[size * FeatureSet::MAX_ACTIVE_FEATURES]; - psqt_indices = new int[size]; + template + SparseBatch(FeatureSet, const std::vector& entries) { + num_inputs = FeatureSet::INPUTS; + size = entries.size(); + is_white = new float[size]; + outcome = new float[size]; + score = new float[size]; + white = new int[size * FeatureSet::MAX_ACTIVE_FEATURES]; + black = new int[size * FeatureSet::MAX_ACTIVE_FEATURES]; + white_values = new float[size * FeatureSet::MAX_ACTIVE_FEATURES]; + black_values = new float[size * FeatureSet::MAX_ACTIVE_FEATURES]; + psqt_indices = new int[size]; layer_stack_indices = new int[size]; num_active_white_features = 0; num_active_black_features = 0; - max_active_features = FeatureSet::MAX_ACTIVE_FEATURES; + max_active_features = FeatureSet::MAX_ACTIVE_FEATURES; for (std::size_t i = 0; i < size * FeatureSet::MAX_ACTIVE_FEATURES; ++i) white[i] = -1; @@ -392,7 +390,7 @@ struct SparseBatch for (std::size_t i = 0; i < size * FeatureSet::MAX_ACTIVE_FEATURES; ++i) black_values[i] = 0.0f; - for(int i = 0; i < entries.size(); ++i) + for (int i = 0; i < entries.size(); ++i) { fill_entry(FeatureSet{}, i, entries[i]); } @@ -404,18 +402,17 @@ struct SparseBatch float* is_white; float* outcome; float* score; - int num_active_white_features; - int num_active_black_features; - int max_active_features; - int* white; - int* black; + int num_active_white_features; + int num_active_black_features; + int max_active_features; + int* white; + int* black; float* white_values; float* black_values; - int* psqt_indices; - int* layer_stack_indices; + int* psqt_indices; + int* layer_stack_indices; - ~SparseBatch() - { + ~SparseBatch() { delete[] is_white; delete[] outcome; delete[] score; @@ -427,106 +424,98 @@ struct SparseBatch delete[] layer_stack_indices; } -private: - - template - void fill_entry(FeatureSet, int i, const TrainingDataEntry& e) - { - is_white[i] = static_cast(e.pos.sideToMove() == Color::White); - outcome[i] = (e.result + 1.0f) / 2.0f; - score[i] = e.score; - psqt_indices[i] = (e.pos.piecesBB().count() - 1) / 4; + private: + template + void fill_entry(FeatureSet, int i, const TrainingDataEntry& e) { + is_white[i] = static_cast(e.pos.sideToMove() == Color::White); + outcome[i] = (e.result + 1.0f) / 2.0f; + score[i] = e.score; + psqt_indices[i] = (e.pos.piecesBB().count() - 1) / 4; layer_stack_indices[i] = psqt_indices[i]; fill_features(FeatureSet{}, i, e); } - template - void fill_features(FeatureSet, int i, const TrainingDataEntry& e) - { + template + void fill_features(FeatureSet, int i, const TrainingDataEntry& e) { const int offset = i * FeatureSet::MAX_ACTIVE_FEATURES; - num_active_white_features += - FeatureSet::fill_features_sparse(e, white + offset, white_values + offset, Color::White) - .first; - num_active_black_features += - FeatureSet::fill_features_sparse(e, black + offset, black_values + offset, Color::Black) - .first; + num_active_white_features += FeatureSet::fill_features_sparse( + e, white + offset, white_values + offset, Color::White) + .first; + num_active_black_features += FeatureSet::fill_features_sparse( + e, black + offset, black_values + offset, Color::Black) + .first; } }; -struct AnyStream -{ +struct AnyStream { virtual ~AnyStream() = default; }; -template -struct Stream : AnyStream -{ +template +struct Stream: AnyStream { using StorageType = StorageT; - Stream(int concurrency, const std::vector& filenames, bool cyclic, std::function skipPredicate) : - m_stream(training_data::open_sfen_input_file_parallel(concurrency, filenames, cyclic, skipPredicate)) - { - } + Stream(int concurrency, + const std::vector& filenames, + bool cyclic, + std::function skipPredicate) : + m_stream(training_data::open_sfen_input_file_parallel( + concurrency, filenames, cyclic, skipPredicate)) {} virtual StorageT* next() = 0; -protected: + protected: std::unique_ptr m_stream; }; -template -struct AsyncStream : Stream -{ +template +struct AsyncStream: Stream { using BaseType = Stream; - AsyncStream(int concurrency, const std::vector& filenames, bool cyclic, std::function skipPredicate) : - BaseType(1, filenames, cyclic, skipPredicate) - { - } + AsyncStream(int concurrency, + const std::vector& filenames, + bool cyclic, + std::function skipPredicate) : + BaseType(1, filenames, cyclic, skipPredicate) {} - ~AsyncStream() - { + ~AsyncStream() { if (m_next.valid()) { delete m_next.get(); } } -protected: + protected: std::future m_next; }; -template -struct FeaturedBatchStream : Stream -{ +template +struct FeaturedBatchStream: Stream { static_assert(StorageT::IS_BATCH); using FeatureSet = FeatureSetT; - using BaseType = Stream; + using BaseType = Stream; static constexpr int num_feature_threads_per_reading_thread = 2; - FeaturedBatchStream(int concurrency, const std::vector& filenames, int batch_size, bool cyclic, std::function skipPredicate) : - BaseType( - std::max( - 1, - concurrency / num_feature_threads_per_reading_thread - ), - filenames, - cyclic, - skipPredicate - ), + FeaturedBatchStream(int concurrency, + const std::vector& filenames, + int batch_size, + bool cyclic, + std::function skipPredicate) : + BaseType(std::max(1, concurrency / num_feature_threads_per_reading_thread), + filenames, + cyclic, + skipPredicate), m_concurrency(concurrency), - m_batch_size(batch_size) - { + m_batch_size(batch_size) { m_stop_flag.store(false); - auto worker = [this]() - { + auto worker = [this]() { std::vector entries; entries.reserve(m_batch_size); - while(!m_stop_flag.load()) + while (!m_stop_flag.load()) { entries.clear(); @@ -543,23 +532,22 @@ struct FeaturedBatchStream : Stream { std::unique_lock lock(m_batch_mutex); - m_batches_not_full.wait(lock, [this]() { return m_batches.size() < m_concurrency + 1 || m_stop_flag.load(); }); + m_batches_not_full.wait(lock, [this]() { + return m_batches.size() < m_concurrency + 1 || m_stop_flag.load(); + }); m_batches.emplace_back(batch); lock.unlock(); m_batches_any.notify_one(); } - } m_num_workers.fetch_sub(1); m_batches_any.notify_one(); }; const int num_feature_threads = std::max( - 1, - concurrency - std::max(1, concurrency / num_feature_threads_per_reading_thread) - ); + 1, concurrency - std::max(1, concurrency / num_feature_threads_per_reading_thread)); for (int i = 0; i < num_feature_threads; ++i) { @@ -573,10 +561,10 @@ struct FeaturedBatchStream : Stream } } - StorageT* next() override - { + StorageT* next() override { std::unique_lock lock(m_batch_mutex); - m_batches_any.wait(lock, [this]() { return !m_batches.empty() || m_num_workers.load() == 0; }); + m_batches_any.wait(lock, + [this]() { return !m_batches.empty() || m_num_workers.load() == 0; }); if (!m_batches.empty()) { @@ -591,8 +579,7 @@ struct FeaturedBatchStream : Stream return nullptr; } - ~FeaturedBatchStream() - { + ~FeaturedBatchStream() { m_stop_flag.store(true); m_batches_not_full.notify_all(); @@ -610,108 +597,91 @@ struct FeaturedBatchStream : Stream } } -private: - int m_batch_size; - int m_concurrency; - std::deque m_batches; - std::mutex m_batch_mutex; - std::mutex m_stream_mutex; + private: + int m_batch_size; + int m_concurrency; + std::deque m_batches; + std::mutex m_batch_mutex; + std::mutex m_stream_mutex; std::condition_variable m_batches_not_full; std::condition_variable m_batches_any; - std::atomic_bool m_stop_flag; - std::atomic_int m_num_workers; + std::atomic_bool m_stop_flag; + std::atomic_int m_num_workers; std::vector m_workers; }; // Very simple fixed size string wrapper with a stable ABI to pass to python. -struct Fen -{ +struct Fen { Fen() : - m_fen(nullptr) - { - } + m_fen(nullptr) {} Fen(const std::string& fen) : m_size(fen.size()), - m_fen(new char[fen.size() + 1]) - { + m_fen(new char[fen.size() + 1]) { std::memcpy(m_fen, fen.c_str(), fen.size() + 1); } - Fen& operator=(const std::string& fen) - { + Fen& operator=(const std::string& fen) { if (m_fen != nullptr) { delete m_fen; } m_size = fen.size(); - m_fen = new char[fen.size() + 1]; + m_fen = new char[fen.size() + 1]; std::memcpy(m_fen, fen.c_str(), fen.size() + 1); return *this; } - ~Fen() - { - delete[] m_fen; - } + ~Fen() { delete[] m_fen; } -private: - int m_size; + private: + int m_size; char* m_fen; }; -struct FenBatch -{ +struct FenBatch { FenBatch(const std::vector& entries) : m_size(entries.size()), - m_fens(new Fen[entries.size()]) - { + m_fens(new Fen[entries.size()]) { for (int i = 0; i < m_size; ++i) { m_fens[i] = entries[i].pos.fen(); } } - ~FenBatch() - { - delete[] m_fens; - } + ~FenBatch() { delete[] m_fens; } -private: - int m_size; + private: + int m_size; Fen* m_fens; }; -struct FenBatchStream : Stream -{ +struct FenBatchStream: Stream { static constexpr int num_feature_threads_per_reading_thread = 2; using BaseType = Stream; - FenBatchStream(int concurrency, const std::vector& filenames, int batch_size, bool cyclic, std::function skipPredicate) : - BaseType( - std::max( - 1, - concurrency / num_feature_threads_per_reading_thread - ), - filenames, - cyclic, - skipPredicate - ), + FenBatchStream(int concurrency, + const std::vector& filenames, + int batch_size, + bool cyclic, + std::function skipPredicate) : + BaseType(std::max(1, concurrency / num_feature_threads_per_reading_thread), + filenames, + cyclic, + skipPredicate), m_concurrency(concurrency), - m_batch_size(batch_size) - { + m_batch_size(batch_size) { m_stop_flag.store(false); - auto worker = [this]() - { + auto worker = [this]() { std::vector entries; entries.reserve(m_batch_size); - while(!m_stop_flag.load()) + while (!m_stop_flag.load()) { entries.clear(); @@ -728,23 +698,22 @@ struct FenBatchStream : Stream { std::unique_lock lock(m_batch_mutex); - m_batches_not_full.wait(lock, [this]() { return m_batches.size() < m_concurrency + 1 || m_stop_flag.load(); }); + m_batches_not_full.wait(lock, [this]() { + return m_batches.size() < m_concurrency + 1 || m_stop_flag.load(); + }); m_batches.emplace_back(batch); lock.unlock(); m_batches_any.notify_one(); } - } m_num_workers.fetch_sub(1); m_batches_any.notify_one(); }; const int num_feature_threads = std::max( - 1, - concurrency - std::max(1, concurrency / num_feature_threads_per_reading_thread) - ); + 1, concurrency - std::max(1, concurrency / num_feature_threads_per_reading_thread)); for (int i = 0; i < num_feature_threads; ++i) { @@ -758,10 +727,10 @@ struct FenBatchStream : Stream } } - FenBatch* next() - { + FenBatch* next() { std::unique_lock lock(m_batch_mutex); - m_batches_any.wait(lock, [this]() { return !m_batches.empty() || m_num_workers.load() == 0; }); + m_batches_any.wait(lock, + [this]() { return !m_batches.empty() || m_num_workers.load() == 0; }); if (!m_batches.empty()) { @@ -776,8 +745,7 @@ struct FenBatchStream : Stream return nullptr; } - ~FenBatchStream() - { + ~FenBatchStream() { m_stop_flag.store(true); m_batches_not_full.notify_all(); @@ -795,32 +763,29 @@ struct FenBatchStream : Stream } } -private: - int m_batch_size; - int m_concurrency; - std::deque m_batches; - std::mutex m_batch_mutex; - std::mutex m_stream_mutex; + private: + int m_batch_size; + int m_concurrency; + std::deque m_batches; + std::mutex m_batch_mutex; + std::mutex m_stream_mutex; std::condition_variable m_batches_not_full; std::condition_variable m_batches_any; - std::atomic_bool m_stop_flag; - std::atomic_int m_num_workers; + std::atomic_bool m_stop_flag; + std::atomic_int m_num_workers; std::vector m_workers; }; -std::function make_skip_predicate(bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index) -{ +std::function make_skip_predicate(bool filtered, + int random_fen_skipping, + bool wld_filtered, + int early_fen_skipping, + int param_index) { if (filtered || random_fen_skipping || wld_filtered || early_fen_skipping) { - return [ - random_fen_skipping, - prob = double(random_fen_skipping) / (random_fen_skipping + 1), - filtered, - wld_filtered, - early_fen_skipping - ](const TrainingDataEntry& e){ - + return [random_fen_skipping, prob = double(random_fen_skipping) / (random_fen_skipping + 1), + filtered, wld_filtered, early_fen_skipping](const TrainingDataEntry& e) { // VALUE_NONE from Stockfish. // We need to allow a way to skip predetermined positions without // having to remove them from the dataset, as otherwise the we lose some @@ -828,14 +793,13 @@ std::function make_skip_predicate(bool filtered, static constexpr int VALUE_NONE = 32002; static constexpr double desired_piece_count_weights[33] = { - 1.000000, - 1.121094, 1.234375, 1.339844, 1.437500, 1.527344, 1.609375, 1.683594, 1.750000, - 1.808594, 1.859375, 1.902344, 1.937500, 1.964844, 1.984375, 1.996094, 2.000000, - 1.996094, 1.984375, 1.964844, 1.937500, 1.902344, 1.859375, 1.808594, 1.750000, - 1.683594, 1.609375, 1.527344, 1.437500, 1.339844, 1.234375, 1.121094, 1.000000 - }; + 1.000000, 1.121094, 1.234375, 1.339844, 1.437500, 1.527344, 1.609375, + 1.683594, 1.750000, 1.808594, 1.859375, 1.902344, 1.937500, 1.964844, + 1.984375, 1.996094, 2.000000, 1.996094, 1.984375, 1.964844, 1.937500, + 1.902344, 1.859375, 1.808594, 1.750000, 1.683594, 1.609375, 1.527344, + 1.437500, 1.339844, 1.234375, 1.121094, 1.000000}; - static constexpr double desired_piece_count_weights_total = [](){ + static constexpr double desired_piece_count_weights_total = []() { double tot = 0; for (auto w : desired_piece_count_weights) tot += w; @@ -845,10 +809,10 @@ std::function make_skip_predicate(bool filtered, static thread_local std::mt19937 gen(std::random_device{}()); // keep stats on passing pieces - static thread_local double alpha = 1; - static thread_local double piece_count_history_all[33] = {0}; - static thread_local double piece_count_history_passed[33] = {0}; - static thread_local double piece_count_history_all_total = 0; + static thread_local double alpha = 1; + static thread_local double piece_count_history_all[33] = {0}; + static thread_local double piece_count_history_passed[33] = {0}; + static thread_local double piece_count_history_all_total = 0; static thread_local double piece_count_history_passed_total = 0; // max skipping rate @@ -856,19 +820,17 @@ std::function make_skip_predicate(bool filtered, auto do_wld_skip = [&]() { std::bernoulli_distribution distrib(1.0 - e.score_result_prob()); - auto& prng = rng::get_thread_local_rng(); + auto& prng = rng::get_thread_local_rng(); return distrib(prng); }; auto do_skip = [&]() { std::bernoulli_distribution distrib(prob); - auto& prng = rng::get_thread_local_rng(); + auto& prng = rng::get_thread_local_rng(); return distrib(prng); }; - auto do_filter = [&]() { - return (e.isCapturingMove() || e.isInCheck()); - }; + auto do_filter = [&]() { return (e.isCapturingMove() || e.isInCheck()); }; // Allow for predermined filtering without the need to remove positions from the dataset. if (e.score == VALUE_NONE) @@ -887,8 +849,10 @@ std::function make_skip_predicate(bool filtered, return true; constexpr bool do_debug_print = false; - if (do_debug_print) { - if (uint64_t(piece_count_history_all_total) % 10000 == 0) { + if (do_debug_print) + { + if (uint64_t(piece_count_history_all_total) % 10000 == 0) + { std::cout << "Total : " << piece_count_history_all_total << '\n'; std::cout << "Passed: " << piece_count_history_passed_total << '\n'; for (int i = 0; i < 33; ++i) @@ -901,14 +865,16 @@ std::function make_skip_predicate(bool filtered, piece_count_history_all_total += 1; // update alpha, which scales the filtering probability, to a maximum rate. - if (uint64_t(piece_count_history_all_total) % 10000 == 0) { + if (uint64_t(piece_count_history_all_total) % 10000 == 0) + { double pass = piece_count_history_all_total * desired_piece_count_weights_total; for (int i = 0; i < 33; ++i) { if (desired_piece_count_weights[pc] > 0) { - double tmp = piece_count_history_all_total * desired_piece_count_weights[pc] / - (desired_piece_count_weights_total * piece_count_history_all[pc]); + double tmp = + piece_count_history_all_total * desired_piece_count_weights[pc] + / (desired_piece_count_weights_total * piece_count_history_all[pc]); if (tmp < pass) pass = tmp; } @@ -916,11 +882,11 @@ std::function make_skip_predicate(bool filtered, alpha = 1.0 / (pass * max_skipping_rate); } - double tmp = alpha * piece_count_history_all_total * desired_piece_count_weights[pc] / - (desired_piece_count_weights_total * piece_count_history_all[pc]); + double tmp = alpha * piece_count_history_all_total * desired_piece_count_weights[pc] + / (desired_piece_count_weights_total * piece_count_history_all[pc]); tmp = std::min(1.0, tmp); std::bernoulli_distribution distrib(1.0 - tmp); - auto& prng = rng::get_thread_local_rng(); + auto& prng = rng::get_thread_local_rng(); if (distrib(prng)) return true; @@ -936,147 +902,153 @@ std::function make_skip_predicate(bool filtered, extern "C" { - EXPORT SparseBatch* get_sparse_batch_from_fens( - const char* feature_set_c, - int num_fens, - const char* const* fens, - int* scores, - int* plies, - int* results - ) +EXPORT SparseBatch* get_sparse_batch_from_fens(const char* feature_set_c, + int num_fens, + const char* const* fens, + int* scores, + int* plies, + int* results) { + std::vector entries; + entries.reserve(num_fens); + for (int i = 0; i < num_fens; ++i) { - std::vector entries; - entries.reserve(num_fens); - for (int i = 0; i < num_fens; ++i) - { - auto& e = entries.emplace_back(); - e.pos = Position::fromFen(fens[i]); - movegen::forEachLegalMove(e.pos, [&](Move m){e.move = m;}); - e.score = scores[i]; - e.ply = plies[i]; - e.result = results[i]; - } - - std::string_view feature_set(feature_set_c); - if (feature_set == "HalfKP") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKP^") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKA") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKA^") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKAv2") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKAv2^") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKAv2_hm") - { - return new SparseBatch(FeatureSet{}, entries); - } - else if (feature_set == "HalfKAv2_hm^") - { - return new SparseBatch(FeatureSet{}, entries); - } - fprintf(stderr, "Unknown feature_set %s\n", feature_set_c); - return nullptr; + auto& e = entries.emplace_back(); + e.pos = Position::fromFen(fens[i]); + movegen::forEachLegalMove(e.pos, [&](Move m) { e.move = m; }); + e.score = scores[i]; + e.ply = plies[i]; + e.result = results[i]; } - // changing the signature needs matching changes in nnue_dataset.py - EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index) + std::string_view feature_set(feature_set_c); + if (feature_set == "HalfKP") { - auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index); - auto filenames_vec = std::vector(filenames, filenames + num_files); - - return new FenBatchStream(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + return new SparseBatch(FeatureSet{}, entries); } - - EXPORT void CDECL destroy_fen_batch_stream(FenBatchStream* stream) + else if (feature_set == "HalfKP^") { - delete stream; + return new SparseBatch(FeatureSet{}, entries); } - - // changing the signature needs matching changes in nnue_dataset.py - EXPORT Stream* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, - bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index) + else if (feature_set == "HalfKA") { - auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index); - auto filenames_vec = std::vector(filenames, filenames + num_files); - - std::string_view feature_set(feature_set_c); - if (feature_set == "HalfKP") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKP^") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKA") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKA^") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKAv2") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKAv2^") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKAv2_hm") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - else if (feature_set == "HalfKAv2_hm^") - { - return new FeaturedBatchStream, SparseBatch>(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); - } - fprintf(stderr, "Unknown feature_set %s\n", feature_set_c); - return nullptr; + return new SparseBatch(FeatureSet{}, entries); } - - EXPORT void CDECL destroy_sparse_batch_stream(Stream* stream) + else if (feature_set == "HalfKA^") { - delete stream; + return new SparseBatch(FeatureSet{}, entries); } - - EXPORT SparseBatch* CDECL fetch_next_sparse_batch(Stream* stream) + else if (feature_set == "HalfKAv2") { - return stream->next(); + return new SparseBatch(FeatureSet{}, entries); } - - EXPORT FenBatch* CDECL fetch_next_fen_batch(Stream* stream) + else if (feature_set == "HalfKAv2^") { - return stream->next(); + return new SparseBatch(FeatureSet{}, entries); } - - EXPORT void CDECL destroy_sparse_batch(SparseBatch* e) + else if (feature_set == "HalfKAv2_hm") { - delete e; + return new SparseBatch(FeatureSet{}, entries); } + else if (feature_set == "HalfKAv2_hm^") + { + return new SparseBatch(FeatureSet{}, entries); + } + fprintf(stderr, "Unknown feature_set %s\n", feature_set_c); + return nullptr; +} - EXPORT void CDECL destroy_fen_batch(FenBatch* e) +// changing the signature needs matching changes in nnue_dataset.py +EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, + int num_files, + const char* const* filenames, + int batch_size, + bool cyclic, + bool filtered, + int random_fen_skipping, + bool wld_filtered, + int early_fen_skipping, + int param_index) { + auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered, + early_fen_skipping, param_index); + auto filenames_vec = std::vector(filenames, filenames + num_files); + + return new FenBatchStream(concurrency, filenames_vec, batch_size, cyclic, skipPredicate); +} + +EXPORT void CDECL destroy_fen_batch_stream(FenBatchStream* stream) { delete stream; } + +// changing the signature needs matching changes in nnue_dataset.py +EXPORT Stream* CDECL create_sparse_batch_stream(const char* feature_set_c, + int concurrency, + int num_files, + const char* const* filenames, + int batch_size, + bool cyclic, + bool filtered, + int random_fen_skipping, + bool wld_filtered, + int early_fen_skipping, + int param_index) { + auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered, + early_fen_skipping, param_index); + auto filenames_vec = std::vector(filenames, filenames + num_files); + + std::string_view feature_set(feature_set_c); + if (feature_set == "HalfKP") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + else if (feature_set == "HalfKP^") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + else if (feature_set == "HalfKA") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + else if (feature_set == "HalfKA^") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + else if (feature_set == "HalfKAv2") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + else if (feature_set == "HalfKAv2^") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + else if (feature_set == "HalfKAv2_hm") { - delete e; + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); } + else if (feature_set == "HalfKAv2_hm^") + { + return new FeaturedBatchStream, SparseBatch>( + concurrency, filenames_vec, batch_size, cyclic, skipPredicate); + } + fprintf(stderr, "Unknown feature_set %s\n", feature_set_c); + return nullptr; +} + +EXPORT void CDECL destroy_sparse_batch_stream(Stream* stream) { delete stream; } + +EXPORT SparseBatch* CDECL fetch_next_sparse_batch(Stream* stream) { + return stream->next(); +} + +EXPORT FenBatch* CDECL fetch_next_fen_batch(Stream* stream) { return stream->next(); } + +EXPORT void CDECL destroy_sparse_batch(SparseBatch* e) { delete e; } +EXPORT void CDECL destroy_fen_batch(FenBatch* e) { delete e; } } /* benches */ /* diff --git a/visualize.py b/visualize.py index 41f31d6b..63cf8091 100644 --- a/visualize.py +++ b/visualize.py @@ -10,23 +10,31 @@ from serialize import NNUEReader -class NNUEVisualizer(): +class NNUEVisualizer: def __init__(self, model, ref_model, args): self.model = model self.ref_model = ref_model self.args = args import matplotlib as mpl + self.dpi = 100 mpl.rcParams["figure.figsize"] = ( - self.args.default_width//self.dpi, self.args.default_height//self.dpi) + self.args.default_width // self.dpi, + self.args.default_height // self.dpi, + ) mpl.rcParams["figure.dpi"] = self.dpi def _process_fig(self, name, fig=None): if self.args.save_dir: from os.path import join + destname = join( - self.args.save_dir, "{}{}.jpg".format("" if self.args.label is None else self.args.label + "_", name)) + self.args.save_dir, + "{}{}.jpg".format( + "" if self.args.label is None else self.args.label + "_", name + ), + ) print("Saving {}".format(destname)) if fig is not None: fig.savefig(destname) @@ -36,13 +44,12 @@ def _process_fig(self, name, fig=None): def plot_input_weights(self): # Coalesce weights and transform them to Numpy domain. weights = M.coalesce_ft_weights(self.model, self.model.input) - weights = weights[:, :M.L1] + weights = weights[:, : M.L1] weights = weights.flatten().numpy() if self.args.ref_model: - ref_weights = M.coalesce_ft_weights( - self.ref_model, self.ref_model.input) - ref_weights = ref_weights[:, :M.L1] + ref_weights = M.coalesce_ft_weights(self.ref_model, self.ref_model.input) + ref_weights = ref_weights[:, : M.L1] ref_weights = ref_weights.flatten().numpy() weights -= ref_weights @@ -56,10 +63,10 @@ def plot_input_weights(self): # Find a factor of hd such that the aspect ratio # is as close to the preferred ratio as possible. factor, smallest_diff = 0, hd - for n in range(1, hd+1): + for n in range(1, hd + 1): if hd % n == 0: - ratio = hd / (n*n) - diff = abs(preferred_ratio-ratio) + ratio = hd / (n * n) + diff = abs(preferred_ratio - ratio) if diff < smallest_diff: factor = n smallest_diff = diff @@ -72,43 +79,124 @@ def plot_input_weights(self): for i in range(hd): neuron_weights_norm[i] = np.sum(np.abs(weights[i::hd])) - self.sorted_input_neurons = np.flip( - np.argsort(neuron_weights_norm)) + self.sorted_input_neurons = np.flip(np.argsort(neuron_weights_norm)) else: self.sorted_input_neurons = np.arange(hd, dtype=int) KingBuckets = [ - -1, -1, -1, -1, 31, 30, 29, 28, - -1, -1, -1, -1, 27, 26, 25, 24, - -1, -1, -1, -1, 23, 22, 21, 20, - -1, -1, -1, -1, 19, 18, 17, 16, - -1, -1, -1, -1, 15, 14, 13, 12, - -1, -1, -1, -1, 11, 10, 9, 8, - -1, -1, -1, -1, 7, 6, 5, 4, - -1, -1, -1, -1, 3, 2, 1, 0 + -1, + -1, + -1, + -1, + 31, + 30, + 29, + 28, + -1, + -1, + -1, + -1, + 27, + 26, + 25, + 24, + -1, + -1, + -1, + -1, + 23, + 22, + 21, + 20, + -1, + -1, + -1, + -1, + 19, + 18, + 17, + 16, + -1, + -1, + -1, + -1, + 15, + 14, + 13, + 12, + -1, + -1, + -1, + -1, + 11, + 10, + 9, + 8, + -1, + -1, + -1, + -1, + 7, + 6, + 5, + 4, + -1, + -1, + -1, + -1, + 3, + 2, + 1, + 0, ] BucketToSquare = [ - 0, 1, 2, 3, - 8, 9, 10, 11, - 16, 17, 18, 19, - 24, 25, 26, 27, - 32, 33, 34, 35, - 40, 41, 42, 43, - 48, 49, 50, 51, - 56, 57, 58, 59 + 0, + 1, + 2, + 3, + 8, + 9, + 10, + 11, + 16, + 17, + 18, + 19, + 24, + 25, + 26, + 27, + 32, + 33, + 34, + 35, + 40, + 41, + 42, + 43, + 48, + 49, + 50, + 51, + 56, + 57, + 58, + 59, ] # Derived/fixed constants. - numy = hd//numx + numy = hd // numx widthx = 128 widthy = 368 totalx = numx * widthx totaly = numy * widthy - totaldim = totalx*totaly + totaldim = totalx * totaly if not self.args.no_input_weights: - default_order = self.args.input_weights_order == "piece-centric-flipped-king" + default_order = ( + self.args.input_weights_order == "piece-centric-flipped-king" + ) # Calculate masks for first input neuron. img_mask = [] @@ -123,7 +211,7 @@ def plot_input_weights(self): rank = (pi % 64) // 8 ki = BucketToSquare[ki] - if ((rank == 0 or rank == 7) and (piece == 0 or piece == 1)): + if (rank == 0 or rank == 7) and (piece == 0 or piece == 1): # Ignore unused weights for pawns on first/last rank. continue @@ -134,18 +222,32 @@ def plot_input_weights(self): # Piece centric, but with flipped king position. # Same order as used by https://github.com/hxim/Stockfish-Evaluation-Guide. # See also https://github.com/official-stockfish/nnue-pytorch/issues/42#issuecomment-753604393. - inpos = [[(7- kipos[0]) + pipos[0] *8, kipos[1]+(7-pipos[1])*8], - [(7-(kipos[0]^7))+(pipos[0]^7)*8, kipos[1]+(7-pipos[1])*8]] - d = - 8 if piece < 2 else 48 + (piece // 2 - 1) * 64 + inpos = [ + [(7 - kipos[0]) + pipos[0] * 8, kipos[1] + (7 - pipos[1]) * 8], + [ + (7 - (kipos[0] ^ 7)) + (pipos[0] ^ 7) * 8, + kipos[1] + (7 - pipos[1]) * 8, + ], + ] + d = -8 if piece < 2 else 48 + (piece // 2 - 1) * 64 else: # King centric. - inpos = [[8* kipos[0] + pipos[0], 8*(7-kipos[1])+(7-pipos[1])], - [8*(kipos[0]^7)+(pipos[0]^7), 8*(7-kipos[1])+(7-pipos[1])]] - d = -2*(7-kipos[1]) - 1 if piece < 2 else 48 + (piece // 2 - 1) * 64 + inpos = [ + [8 * kipos[0] + pipos[0], 8 * (7 - kipos[1]) + (7 - pipos[1])], + [ + 8 * (kipos[0] ^ 7) + (pipos[0] ^ 7), + 8 * (7 - kipos[1]) + (7 - pipos[1]), + ], + ] + d = ( + -2 * (7 - kipos[1]) - 1 + if piece < 2 + else 48 + (piece // 2 - 1) * 64 + ) jhd = j % hd for k in range(2): - x = inpos[k][0] + widthx * (jhd % numx) + (piece % 2)*64 + x = inpos[k][0] + widthx * (jhd % numx) + (piece % 2) * 64 y = inpos[k][1] + d + widthy * (jhd // numx) ii = x + y * totalx @@ -160,8 +262,9 @@ def plot_input_weights(self): for k in range(hd): offset_x = k % numx offset_y = k // numx - img[img_mask + offset_x*widthx + totalx*widthy * - offset_y] = weights[weights_mask + self.sorted_input_neurons[k]] + img[ + img_mask + offset_x * widthx + totalx * widthy * offset_y + ] = weights[weights_mask + self.sorted_input_neurons[k]] if self.args.input_weights_auto_scale: vmin = None @@ -184,34 +287,42 @@ def plot_input_weights(self): if self.args.input_weights_auto_scale or self.args.input_weights_vmin < 0: title_template = "input weights [{LABEL}" + extra_info + "]" hist_title_template = "input weights histogram [{LABEL}]" - cmap = 'coolwarm' + cmap = "coolwarm" else: img = np.abs(img) - title_template = "abs(input weights) [{LABEL}" + \ - extra_info + "]" + title_template = "abs(input weights) [{LABEL}" + extra_info + "]" hist_title_template = "abs(input weights) histogram [{LABEL}]" - cmap = 'viridis' + cmap = "viridis" # Input weights. scalex = (numx / numy) / preferred_ratio - plt.figure(figsize=((scalex*self.args.default_width) // - self.dpi, self.args.default_height//self.dpi)) - plt.matshow(img.reshape((totaldim//totalx, totalx)), - fignum=0, vmin=vmin, vmax=vmax, cmap=cmap) + plt.figure( + figsize=( + (scalex * self.args.default_width) // self.dpi, + self.args.default_height // self.dpi, + ) + ) + plt.matshow( + img.reshape((totaldim // totalx, totalx)), + fignum=0, + vmin=vmin, + vmax=vmax, + cmap=cmap, + ) plt.colorbar(fraction=0.046, pad=0.04) - line_options = {'color': 'black', 'linewidth': 0.5} + line_options = {"color": "black", "linewidth": 0.5} for i in range(1, numx): - plt.axvline(x=widthx*i-0.5, **line_options) + plt.axvline(x=widthx * i - 0.5, **line_options) for j in range(1, numy): - plt.axhline(y=widthy*j-0.5, **line_options) + plt.axhline(y=widthy * j - 0.5, **line_options) plt.xlim([0, totalx]) plt.ylim([totaly, 0]) - plt.xticks(ticks=widthx*np.arange(1, numx) - 0.5) - plt.yticks(ticks=widthy*np.arange(1, numy) - 0.5) - plt.axis('off') + plt.xticks(ticks=widthx * np.arange(1, numx) - 0.5) + plt.yticks(ticks=widthy * np.arange(1, numy) - 0.5) + plt.axis("off") plt.title(title_template.format(LABEL=self.args.label)) plt.tight_layout() @@ -220,38 +331,46 @@ def format_coord(x, y): x_ = x % widthx y_ = y % widthy - piece_type = (y_+16)//64 + piece_type = (y_ + 16) // 64 piece_name = "{} {}".format( - "white" if x_ // (widthx//2) == 0 else "black", chess.piece_name(piece_type+1)) + "white" if x_ // (widthx // 2) == 0 else "black", + chess.piece_name(piece_type + 1), + ) - x_ = x_ % (widthx//2) - y_ = (y_+16) % 64 if y_ >= 48 else y_+8 + x_ = x_ % (widthx // 2) + y_ = (y_ + 16) % 64 if y_ >= 48 else y_ + 8 if default_order: # Piece centric, flipped king. - piece_square_name = chess.square_name(x_//8 + 8*(7-y_//8)) - king_square_name = chess.square_name( - 7-(x_ % 8) + 8*(y_ % 8)) + piece_square_name = chess.square_name(x_ // 8 + 8 * (7 - y_ // 8)) + king_square_name = chess.square_name(7 - (x_ % 8) + 8 * (y_ % 8)) else: # King centric. if piece_type == 0: piece_square_name = chess.square_name( - x_ % 8 + 8*(6-((y_-8) % 6))) + x_ % 8 + 8 * (6 - ((y_ - 8) % 6)) + ) king_square_name = chess.square_name( - x_//8 + 8*(7-(y_-8)//6)) + x_ // 8 + 8 * (7 - (y_ - 8) // 6) + ) else: piece_square_name = chess.square_name( - x_ % 8 + 8*(7-(y_ % 8))) + x_ % 8 + 8 * (7 - (y_ % 8)) + ) king_square_name = chess.square_name( - x_//8 + 8*(7-y_//8)) + x_ // 8 + 8 * (7 - y_ // 8) + ) neuron_id = int(numx * (y // widthy) + x // widthx) if self.args.sort_input_neurons: neuron_label = "sorted neuron {} (original {})".format( - neuron_id, self.sorted_input_neurons[neuron_id]) + neuron_id, self.sorted_input_neurons[neuron_id] + ) else: neuron_label = "neuron {}".format(neuron_id) - return "{}, {} on {}, white king on {}".format(neuron_label, piece_name, piece_square_name, king_square_name) + return "{}, {} on {}, white king on {}".format( + neuron_label, piece_name, piece_square_name, king_square_name + ) ax = plt.gca() ax.format_coord = format_coord @@ -260,8 +379,17 @@ def format_coord(x, y): if not self.args.no_hist: # Input weights histogram. plt.figure() - plt.hist(img, log=True, bins=( - np.arange(int(np.min(img)*127)-1, int(np.max(img)*127)+3)-0.5)/127) + plt.hist( + img, + log=True, + bins=( + np.arange( + int(np.min(img) * 127) - 1, int(np.max(img) * 127) + 3 + ) + - 0.5 + ) + / 127, + ) plt.title(hist_title_template.format(LABEL=self.args.label)) plt.tight_layout() self._process_fig("input-weights-histogram") @@ -279,7 +407,9 @@ def plot_fc_weights(self): fig.suptitle(title_template.format(LABEL=self.args.label)) if self.args.ref_model: - ref_layers = list(self.ref_model.layer_stacks.get_coalesced_layer_stacks()) + ref_layers = list( + self.ref_model.layer_stacks.get_coalesced_layer_stacks() + ) def get_l1_weights(bucket_id, l1): l1_weights_ = l1.weight.data.numpy() @@ -287,14 +417,15 @@ def get_l1_weights(bucket_id, l1): if self.args.ref_model: l1_weights_ -= ref_layers[bucket_id][0].weight.data.numpy() - N = l1_weights_.size // (2*self.M) + N = l1_weights_.size // (2 * self.M) - l1_weights = np.zeros((2*N, self.M)) + l1_weights = np.zeros((2 * N, self.M)) for i in range(N): - l1_weights[2*i] = l1_weights_[i][self.sorted_input_neurons] - l1_weights[2*i+1] = l1_weights_[i][self.M + - self.sorted_input_neurons] + l1_weights[2 * i] = l1_weights_[i][self.sorted_input_neurons] + l1_weights[2 * i + 1] = l1_weights_[i][ + self.M + self.sorted_input_neurons + ] return l1_weights, N def get_l2_weights(bucket_id, l2): @@ -314,14 +445,16 @@ def get_l2_weights(bucket_id, l2): if self.args.fc_weights_auto_scale or self.args.fc_weights_vmin < 0: plot_abs = False - cmap = 'coolwarm' + cmap = "coolwarm" else: plot_abs = True - cmap = 'viridis' + cmap = "viridis" - line_options = {'color': 'gray', 'linewidth': 0.5} + line_options = {"color": "gray", "linewidth": 0.5} - for bucket_id, (l1, l2, output) in enumerate(self.model.layer_stacks.get_coalesced_layer_stacks()): + for bucket_id, (l1, l2, output) in enumerate( + self.model.layer_stacks.get_coalesced_layer_stacks() + ): l1_weights, N = get_l1_weights(bucket_id, l1) l2_weights = get_l2_weights(bucket_id, l2) output_weights = output.weight.data.numpy() @@ -330,23 +463,34 @@ def get_l2_weights(bucket_id, l2): output_weights -= ref_layers[bucket_id][2].weight.data.numpy() ax = axs[0, bucket_id] - im = ax.matshow(np.abs(l1_weights) if plot_abs else l1_weights, - vmin=vmin, vmax=vmax, cmap=cmap) + im = ax.matshow( + np.abs(l1_weights) if plot_abs else l1_weights, + vmin=vmin, + vmax=vmax, + cmap=cmap, + ) for j in range(1, N): - ax.axhline(y=2*j-0.5, **line_options) + ax.axhline(y=2 * j - 0.5, **line_options) ax = axs[1, bucket_id] - im = ax.matshow(np.abs(l2_weights) if plot_abs else l2_weights, - vmin=None if vmin == float("-inf") else vmin, - vmax=vmax, cmap=cmap) + im = ax.matshow( + np.abs(l2_weights) if plot_abs else l2_weights, + vmin=None if vmin == float("-inf") else vmin, + vmax=vmax, + cmap=cmap, + ) ax = axs[2, bucket_id] - im = ax.matshow(np.abs(output_weights) if plot_abs else output_weights, - vmin=vmin, vmax=vmax, cmap=cmap) - - row_names = ['bucket {}'.format(i) for i in range(num_buckets)] - col_names = ['l1', 'l2', 'output'] + im = ax.matshow( + np.abs(output_weights) if plot_abs else output_weights, + vmin=vmin, + vmax=vmax, + cmap=cmap, + ) + + row_names = ["bucket {}".format(i) for i in range(num_buckets)] + col_names = ["l1", "l2", "output"] for i in range(3): for j in range(num_buckets): ax = axs[i, j] @@ -354,11 +498,13 @@ def get_l2_weights(bucket_id, l2): ax.set_yticks([]) if i == 0 and row_names[j]: ax.set_xlabel(row_names[j]) - ax.xaxis.set_label_position('top') + ax.xaxis.set_label_position("top") if j == 0 and col_names[i]: ax.set_ylabel(col_names[i]) - fig.colorbar(im, fraction=0.046, pad=0.04, ax=axs[i, :].ravel().tolist()) + fig.colorbar( + im, fraction=0.046, pad=0.04, ax=axs[i, :].ravel().tolist() + ) self._process_fig("fc-weights", fig) @@ -366,29 +512,55 @@ def get_l2_weights(bucket_id, l2): fig, axs = plt.subplots(num_buckets, 1, sharex=True, dpi=self.dpi) title_template = "L1 weights histogram [{LABEL}]" fig.suptitle(title_template.format(LABEL=self.args.label)) - for bucket_id, (l1, l2, output) in enumerate(self.model.layer_stacks.get_coalesced_layer_stacks()): + for bucket_id, (l1, l2, output) in enumerate( + self.model.layer_stacks.get_coalesced_layer_stacks() + ): # L1 weights histogram. ax = axs[bucket_id] l1_weights, N = get_l1_weights(bucket_id, l1) - ax.hist(l1_weights.flatten(), log=True, bins=( - np.arange(int(np.min(l1_weights)*64)-1, int(np.max(l1_weights)*64)+3)-0.5)/64) + ax.hist( + l1_weights.flatten(), + log=True, + bins=( + np.arange( + int(np.min(l1_weights) * 64) - 1, + int(np.max(l1_weights) * 64) + 3, + ) + - 0.5 + ) + / 64, + ) self._process_fig("l1-weights-histogram", fig) fig, axs = plt.subplots(num_buckets, 1, sharex=True, dpi=self.dpi) title_template = "L2 weights histogram [{LABEL}]" fig.suptitle(title_template.format(LABEL=self.args.label)) - for bucket_id, (l1, l2, output) in enumerate(self.model.layer_stacks.get_coalesced_layer_stacks()): + for bucket_id, (l1, l2, output) in enumerate( + self.model.layer_stacks.get_coalesced_layer_stacks() + ): # L2 weights histogram. ax = axs[bucket_id] l2_weights = get_l2_weights(bucket_id, l2) - ax.hist(l2_weights.flatten(), log=True, bins=( - np.arange(int(np.min(l2_weights)*64)-1, int(np.max(l2_weights)*64)+3)-0.5)/64) + ax.hist( + l2_weights.flatten(), + log=True, + bins=( + np.arange( + int(np.min(l2_weights) * 64) - 1, + int(np.max(l2_weights) * 64) + 3, + ) + - 0.5 + ) + / 64, + ) self._process_fig("l2-weights-histogram", fig) def plot_fc_biases(self): if not self.args.no_biases: if self.args.ref_model: - ref_layers = list(self.ref_model.layer_stacks.get_coalesced_layer_stacks()) + ref_layers = list( + self.ref_model.layer_stacks.get_coalesced_layer_stacks() + ) num_buckets = self.model.feature_set.num_ls_buckets fig, axs = plt.subplots(3, num_buckets, dpi=self.dpi) @@ -407,12 +579,14 @@ def plot_fc_biases(self): if self.args.fc_weights_auto_scale or self.args.fc_weights_vmin < 0: plot_abs = False - cmap = 'coolwarm' + cmap = "coolwarm" else: plot_abs = True - cmap = 'viridis' + cmap = "viridis" - for bucket_id, (l1, l2, output) in enumerate(self.model.layer_stacks.get_coalesced_layer_stacks()): + for bucket_id, (l1, l2, output) in enumerate( + self.model.layer_stacks.get_coalesced_layer_stacks() + ): l1_biases = l1.bias.data.numpy() l2_biases = l2.bias.data.numpy() output_bias = output.bias.data.numpy() @@ -423,19 +597,22 @@ def plot_fc_biases(self): output_bias -= ref_layers[bucket_id][2].bias.data.numpy() ax = axs[0, bucket_id] - im = ax.matshow(np.expand_dims(l1_biases, axis=0), - vmin=vmin, vmax=vmax, cmap=cmap) + im = ax.matshow( + np.expand_dims(l1_biases, axis=0), vmin=vmin, vmax=vmax, cmap=cmap + ) ax = axs[1, bucket_id] - im = ax.matshow(np.expand_dims(l2_biases, axis=0), - vmin=vmin, vmax=vmax, cmap=cmap) + im = ax.matshow( + np.expand_dims(l2_biases, axis=0), vmin=vmin, vmax=vmax, cmap=cmap + ) ax = axs[2, bucket_id] - im = ax.matshow(np.expand_dims(output_bias, axis=0), - vmin=vmin, vmax=vmax, cmap=cmap) + im = ax.matshow( + np.expand_dims(output_bias, axis=0), vmin=vmin, vmax=vmax, cmap=cmap + ) - row_names = ['bucket {}'.format(i) for i in range(num_buckets)] - col_names = ['l1', 'l2', 'output'] + row_names = ["bucket {}".format(i) for i in range(num_buckets)] + col_names = ["l1", "l2", "output"] for i in range(3): for j in range(num_buckets): ax = axs[i, j] @@ -443,11 +620,13 @@ def plot_fc_biases(self): ax.set_yticks([]) if i == 0 and row_names[j]: ax.set_xlabel(row_names[j]) - ax.xaxis.set_label_position('top') + ax.xaxis.set_label_position("top") if j == 0 and col_names[i]: ax.set_ylabel(col_names[i]) - fig.colorbar(im, fraction=0.046, pad=0.04, ax=axs[i, :].ravel().tolist()) + fig.colorbar( + im, fraction=0.046, pad=0.04, ax=axs[i, :].ravel().tolist() + ) self._process_fig("biases", fig) @@ -457,11 +636,10 @@ def load_model(filename, feature_set): if filename.endswith(".pt"): model = torch.load(filename) else: - model = M.NNUE.load_from_checkpoint( - filename, feature_set=feature_set) + model = M.NNUE.load_from_checkpoint(filename, feature_set=feature_set) model.eval() elif filename.endswith(".nnue"): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: reader = NNUEReader(f, feature_set) model = reader.model else: @@ -472,74 +650,116 @@ def load_model(filename, feature_set): def main(): parser = argparse.ArgumentParser( - description="Visualizes networks in ckpt, pt and nnue format.") + description="Visualizes networks in ckpt, pt and nnue format." + ) + parser.add_argument("model", help="Source model (can be .ckpt, .pt or .nnue)") parser.add_argument( - "model", help="Source model (can be .ckpt, .pt or .nnue)") + "--ref-model", + type=str, + required=False, + help="Visualize the difference between the given reference model (can be .ckpt, .pt or .nnue).", + ) parser.add_argument( - "--ref-model", type=str, required=False, - help="Visualize the difference between the given reference model (can be .ckpt, .pt or .nnue).") + "--ref-features", + type=str, + required=False, + help="The reference feature set to use (default = same as source model).", + ) parser.add_argument( - "--ref-features", type=str, required=False, - help="The reference feature set to use (default = same as source model).") + "--input-weights-vmin", + default=-1, + type=float, + help="Minimum of color map range for input weights (absolute values are plotted if this is positive or zero).", + ) parser.add_argument( - "--input-weights-vmin", default=-1, type=float, - help="Minimum of color map range for input weights (absolute values are plotted if this is positive or zero).") + "--input-weights-vmax", + default=1, + type=float, + help="Maximum of color map range for input weights.", + ) parser.add_argument( - "--input-weights-vmax", default=1, type=float, - help="Maximum of color map range for input weights.") + "--input-weights-auto-scale", + action="store_true", + help="Use auto-scale for the color map range for input weights. This ignores input-weights-vmin and input-weights-vmax.", + ) parser.add_argument( - "--input-weights-auto-scale", action="store_true", - help="Use auto-scale for the color map range for input weights. This ignores input-weights-vmin and input-weights-vmax.") + "--input-weights-order", + type=str, + choices=["piece-centric-flipped-king", "king-centric"], + default="piece-centric-flipped-king", + help="Order of the input weights for each input neuron.", + ) parser.add_argument( - "--input-weights-order", type=str, choices=["piece-centric-flipped-king", "king-centric"], default="piece-centric-flipped-king", - help="Order of the input weights for each input neuron.") + "--sort-input-neurons", + action="store_true", + help="Sort the neurons of the input layer by the L1-norm (sum of absolute values) of their weights.", + ) parser.add_argument( - "--sort-input-neurons", action="store_true", - help="Sort the neurons of the input layer by the L1-norm (sum of absolute values) of their weights.") + "--fc-weights-vmin", + default=-2, + type=float, + help="Minimum of color map range for fully-connected layer weights (absolute values are plotted if this is positive or zero).", + ) parser.add_argument( - "--fc-weights-vmin", default=-2, type=float, - help="Minimum of color map range for fully-connected layer weights (absolute values are plotted if this is positive or zero).") + "--fc-weights-vmax", + default=2, + type=float, + help="Maximum of color map range for fully-connected layer weights.", + ) parser.add_argument( - "--fc-weights-vmax", default=2, type=float, - help="Maximum of color map range for fully-connected layer weights.") + "--fc-weights-auto-scale", + action="store_true", + help="Use auto-scale for the color map range for fully-connected layer weights. This ignores fc-weights-vmin and fc-weights-vmax.", + ) parser.add_argument( - "--fc-weights-auto-scale", action="store_true", - help="Use auto-scale for the color map range for fully-connected layer weights. This ignores fc-weights-vmin and fc-weights-vmax.") + "--no-hist", action="store_true", help="Don't generate any histograms." + ) parser.add_argument( - "--no-hist", action="store_true", - help="Don't generate any histograms.") + "--no-biases", action="store_true", help="Don't generate plots for biases." + ) parser.add_argument( - "--no-biases", action="store_true", - help="Don't generate plots for biases.") + "--no-input-weights", + action="store_true", + help="Don't generate plots or histograms for input weights.", + ) parser.add_argument( - "--no-input-weights", action="store_true", - help="Don't generate plots or histograms for input weights.") + "--no-fc-weights", + action="store_true", + help="Don't generate plots or histograms for fully-connected layer weights.", + ) parser.add_argument( - "--no-fc-weights", action="store_true", - help="Don't generate plots or histograms for fully-connected layer weights.") + "--default-width", + default=1600, + type=int, + help="Default width of all plots (in pixels).", + ) parser.add_argument( - "--default-width", default=1600, type=int, - help="Default width of all plots (in pixels).") + "--default-height", + default=900, + type=int, + help="Default height of all plots (in pixels).", + ) parser.add_argument( - "--default-height", default=900, type=int, - help="Default height of all plots (in pixels).") + "--save-dir", type=str, required=False, help="Save the plots in this directory." + ) parser.add_argument( - "--save-dir", type=str, required=False, - help="Save the plots in this directory.") + "--dont-show", action="store_true", help="Don't show the plots." + ) parser.add_argument( - "--dont-show", action="store_true", - help="Don't show the plots.") - parser.add_argument( - "--label", type=str, required=False, - help="Override the label used in plot titles and as prefix of saved files.") + "--label", + type=str, + required=False, + help="Override the label used in plot titles and as prefix of saved files.", + ) features.add_argparse_args(parser) args = parser.parse_args() - supported_features = ('HalfKAv2_hm', 'HalfKAv2_hm^') + supported_features = ("HalfKAv2_hm", "HalfKAv2_hm^") assert args.features in supported_features feature_set = features.get_feature_set_from_name(args.features) from os.path import basename + label = basename(args.model) model = load_model(args.model, feature_set) @@ -547,17 +767,20 @@ def main(): if args.ref_model: if args.ref_features: assert args.ref_features in supported_features - ref_feature_set = features.get_feature_set_from_name( - args.ref_features) + ref_feature_set = features.get_feature_set_from_name(args.ref_features) else: ref_feature_set = feature_set ref_model = load_model(args.ref_model, ref_feature_set) - print("Visualizing difference between {} and {}".format( - args.model, args.ref_model)) + print( + "Visualizing difference between {} and {}".format( + args.model, args.ref_model + ) + ) from os.path import basename + label = "diff " + label + "-" + basename(args.ref_model) else: ref_model = None @@ -576,5 +799,5 @@ def main(): plt.show() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/visualize_multi_hist.py b/visualize_multi_hist.py index fd8a799b..ddb164cb 100644 --- a/visualize_multi_hist.py +++ b/visualize_multi_hist.py @@ -9,16 +9,16 @@ from serialize import NNUEReader + def load_model(filename, feature_set): if filename.endswith(".pt") or filename.endswith(".ckpt"): if filename.endswith(".pt"): model = torch.load(filename) else: - model = M.NNUE.load_from_checkpoint( - filename, feature_set=feature_set) + model = M.NNUE.load_from_checkpoint(filename, feature_set=feature_set) model.eval() elif filename.endswith(".nnue"): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: reader = NNUEReader(f, feature_set) model = reader.model else: @@ -26,67 +26,113 @@ def load_model(filename, feature_set): return model + def get_bins(inputs_columns, num_bins): - a = float('+inf') - b = float('-inf') + a = float("+inf") + b = float("-inf") for inputs in inputs_columns: for inp in inputs: a = min(a, float(np.min(inp))) b = max(b, float(np.max(inp))) a -= 0.001 b += 0.001 - return [a + (b-a) / num_bins * i for i in range(num_bins+1)] + return [a + (b - a) / num_bins * i for i in range(num_bins + 1)] + -def plot_hists(tensors_columns, row_names, col_names, w=8.0, h=3.0, title=None, num_bins=256, filename='a.png'): - fig, axs = plt.subplots(len(tensors_columns[0]), len(tensors_columns), sharex=True, sharey=True, squeeze=False, figsize=(w * len(tensors_columns), h * len(tensors_columns[0])), dpi=100) +def plot_hists( + tensors_columns, + row_names, + col_names, + w=8.0, + h=3.0, + title=None, + num_bins=256, + filename="a.png", +): + fig, axs = plt.subplots( + len(tensors_columns[0]), + len(tensors_columns), + sharex=True, + sharey=True, + squeeze=False, + figsize=(w * len(tensors_columns), h * len(tensors_columns[0])), + dpi=100, + ) if title: fig.suptitle(title) bins = get_bins(tensors_columns, num_bins) for i, tensors in enumerate(tensors_columns): - print('Processing column {}/{}.'.format(i+1, len(tensors_columns))) + print("Processing column {}/{}.".format(i + 1, len(tensors_columns))) for j, tensor in enumerate(tensors): ax = axs[j, i] - print(' Processing tensor {}/{}.'.format(j+1, len(tensors))) + print(" Processing tensor {}/{}.".format(j + 1, len(tensors))) ax.hist(tensor, log=True, bins=bins) if i == 0 and row_names[j]: ax.set_ylabel(row_names[j]) if j == 0 and col_names[i]: ax.set_xlabel(col_names[i]) - ax.xaxis.set_label_position('top') + ax.xaxis.set_label_position("top") fig.savefig(filename) + def main(): parser = argparse.ArgumentParser( - description="Visualizes networks in ckpt, pt and nnue format.") + description="Visualizes networks in ckpt, pt and nnue format." + ) parser.add_argument( - "models", nargs='+', help="Source model (can be .ckpt, .pt or .nnue)") + "models", nargs="+", help="Source model (can be .ckpt, .pt or .nnue)" + ) parser.add_argument( - "--dont-show", action="store_true", - help="Don't show the plots.") + "--dont-show", action="store_true", help="Don't show the plots." + ) features.add_argparse_args(parser) args = parser.parse_args() - supported_features = ('HalfKAv2', 'HalfKAv2^', 'HalfKAv2_hm', 'HalfKAv2_hm^') + supported_features = ("HalfKAv2", "HalfKAv2^", "HalfKAv2_hm", "HalfKAv2_hm^") assert args.features in supported_features feature_set = features.get_feature_set_from_name(args.features) from os.path import basename + labels = [] for m in args.models: label = basename(m) - if label.startswith('nn-'): + if label.startswith("nn-"): label = label[3:] - if label.endswith('.nnue'): + if label.endswith(".nnue"): label = label[:-5] - labels.append('\n'.join(label.split('-'))) + labels.append("\n".join(label.split("-"))) models = [load_model(m, feature_set) for m in args.models] coalesced_ins = [M.coalesce_ft_weights(model, model.input) for model in models] - input_weights = [coalesced_in[:, :M.L1].flatten().numpy() for coalesced_in in coalesced_ins] - input_weights_psqt = [(coalesced_in[:, M.L1:] * 600).flatten().numpy() for coalesced_in in coalesced_ins] - plot_hists([input_weights], labels, [None], w=10.0, h=3.0, num_bins=8*128, title='Distribution of feature transformer weights among different nets', filename='input_weights_hist.png') - plot_hists([input_weights_psqt], labels, [None], w=10.0, h=3.0, num_bins=8*128, title='Distribution of feature transformer PSQT weights among different nets (in stockfish internal units)', filename='input_weights_psqt_hist.png') + input_weights = [ + coalesced_in[:, : M.L1].flatten().numpy() for coalesced_in in coalesced_ins + ] + input_weights_psqt = [ + (coalesced_in[:, M.L1 :] * 600).flatten().numpy() + for coalesced_in in coalesced_ins + ] + plot_hists( + [input_weights], + labels, + [None], + w=10.0, + h=3.0, + num_bins=8 * 128, + title="Distribution of feature transformer weights among different nets", + filename="input_weights_hist.png", + ) + plot_hists( + [input_weights_psqt], + labels, + [None], + w=10.0, + h=3.0, + num_bins=8 * 128, + title="Distribution of feature transformer PSQT weights among different nets (in stockfish internal units)", + filename="input_weights_psqt_hist.png", + ) layer_stacks = [model.layer_stacks for model in models] layers_l1 = [[] for i in range(layer_stacks[0].count)] @@ -98,13 +144,41 @@ def main(): layers_l1[i].append(l1.weight.flatten().numpy()) layers_l2[i].append(l2.weight.flatten().numpy()) layers_l3[i].append(l3.weight.flatten().numpy()) - col_names = ['Subnet {}'.format(i) for i in range(layer_stacks[0].count)] - plot_hists(layers_l1, labels, col_names, w=2.0, h=2.0, num_bins=128, title='Distribution of l1 weights among different nets and buckets', filename='l1_weights_hist.png') - plot_hists(layers_l2, labels, col_names, w=2.0, h=2.0, num_bins=32, title='Distribution of l2 weights among different nets and buckets', filename='l2_weights_hist.png') - plot_hists(layers_l3, labels, col_names, w=2.0, h=2.0, num_bins=16, title='Distribution of output weights among different nets and buckets', filename='output_weights_hist.png') + col_names = ["Subnet {}".format(i) for i in range(layer_stacks[0].count)] + plot_hists( + layers_l1, + labels, + col_names, + w=2.0, + h=2.0, + num_bins=128, + title="Distribution of l1 weights among different nets and buckets", + filename="l1_weights_hist.png", + ) + plot_hists( + layers_l2, + labels, + col_names, + w=2.0, + h=2.0, + num_bins=32, + title="Distribution of l2 weights among different nets and buckets", + filename="l2_weights_hist.png", + ) + plot_hists( + layers_l3, + labels, + col_names, + w=2.0, + h=2.0, + num_bins=16, + title="Distribution of output weights among different nets and buckets", + filename="output_weights_hist.png", + ) if not args.dont_show: plt.show() -if __name__ == '__main__': + +if __name__ == "__main__": main() From 63cda7ca9d5d71da680bff1fc188b44c48679ebf Mon Sep 17 00:00:00 2001 From: Disservin Date: Fri, 15 Mar 2024 17:25:03 +0100 Subject: [PATCH 3/3] Add badges --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d0eb67c7..bf40d5f9 100644 --- a/README.md +++ b/README.md @@ -85,4 +85,7 @@ This script runs in a loop, and will monitor the directory for new checkpoints. * syzygy - http://www.talkchess.com/forum3/viewtopic.php?f=7&t=75506 * https://github.com/DanielUranga/TensorFlowNNUE * https://hxim.github.io/Stockfish-Evaluation-Guide/ -* dkappe - Suggesting ranger (https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer) \ No newline at end of file +* dkappe - Suggesting ranger (https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer) + +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Code style: clang-format](https://img.shields.io/badge/code%20style-clang%20format-000000.svg)](https://github.com/llvm/llvm-project) \ No newline at end of file