diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d64ccb..a9f7aa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ * Tune all evaluation parameters with https://github.com/GediminasMasaitis/texel-tuner (53.17 +- 16.76) * Evaluate piece mobility (41.36 +- 13.82) +* Add a texel tuner in-repo and tune, resolving an issue where mobility scores were not computed correctly (28.53 +- 11.44) ### Misc diff --git a/Cargo.lock b/Cargo.lock index 9608937..24f8203 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "byteorder" version = "1.5.0" @@ -134,6 +140,56 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width 0.1.14", + "windows-sys 0.52.0", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "engine" version = "4.1.0" @@ -142,9 +198,11 @@ dependencies = [ "cc", "clap", "colored", + "indicatif", "nom", "paste", "rand", + "rayon", ] [[package]] @@ -164,12 +222,35 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "indicatif" +version = "0.17.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width 0.2.0", + "web-time", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "js-sys" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -182,6 +263,12 @@ version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + [[package]] name = "memchr" version = "2.7.4" @@ -204,12 +291,30 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + [[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "portable-atomic" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -267,6 +372,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "shlex" version = "1.3.0" @@ -296,6 +421,18 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "utf8parse" version = "0.2.2" @@ -308,6 +445,71 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -317,6 +519,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.59.0" diff --git a/Cargo.toml b/Cargo.toml index 7003973..79c622f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" rust-version = "1.83" [features] -default = ["dep:clap"] +default = ["dep:clap", "dep:rayon", "dep:indicatif"] release = [] [build-dependencies] @@ -15,8 +15,10 @@ cc = "1.2.2" arrayvec = "0.7.6" clap = { version = "4.5.21", features = ["derive"], optional = true } colored = "2.1.0" +indicatif = { version = "0.17.9", optional = true } nom = "7.1.1" rand = "0.8.5" +rayon = { version = "1.8.1", optional = true } [dev-dependencies] paste = "1.0.15" diff --git a/src/chess/bitboard.rs b/src/chess/bitboard.rs index f9fa4ad..6d20498 100644 --- a/src/chess/bitboard.rs +++ b/src/chess/bitboard.rs @@ -162,6 +162,11 @@ impl Bitboard { // If we go west and land on H, we wrapped around. Self(self.0 << 7) & Self::NOT_H_FILE } + + #[inline(always)] + pub fn flip_vertically(self) -> Self { + Self(u64::swap_bytes(self.0)) + } } pub struct SquareIterator(Bitboard); diff --git a/src/chess/board.rs b/src/chess/board.rs index 47db450..d44bfb9 100644 --- a/src/chess/board.rs +++ b/src/chess/board.rs @@ -127,6 +127,35 @@ impl Board { let enemy_attackers = movegen::generate_attackers_of(self, player, king); enemy_attackers.any() } + + pub fn flip_vertically(&self) -> Self { + let [white_colors, black_colors] = self.colors.inner(); + let [pawns, knights, bishops, rooks, queens, king] = self.pieces; + + let squares = self.squares; + let mut flipped_squares: [Option; Square::N] = [None; Square::N]; + for rank in 0..8 { + for file in 0..8 { + flipped_squares[(8 - rank - 1) * 8 + file] = squares[rank * 8 + file]; + } + } + + Self { + colors: ByPlayer::new( + white_colors.flip_vertically(), + black_colors.flip_vertically(), + ), + pieces: [ + pawns.flip_vertically(), + knights.flip_vertically(), + bishops.flip_vertically(), + rooks.flip_vertically(), + queens.flip_vertically(), + king.flip_vertically(), + ], + squares: flipped_squares, + } + } } impl std::fmt::Debug for Board { @@ -236,3 +265,37 @@ impl TryFrom<[Option; Square::N]> for Board { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::chess::game::Game; + + #[test] + fn test_flip_vertically() { + let game = + Game::from_fen("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq -") + .unwrap(); + + let our_flipped_board = game.board.flip_vertically(); + + let flipped_kiwipete = + Game::from_fen("R3K2R/PPPBBPPP/2N2Q1p/1p2P3/3PN3/bn2pnp1/p1ppqpb1/r3k2r w - - 0 1") + .unwrap() + .board; + + assert_eq!( + our_flipped_board.colors.for_player(Player::White), + flipped_kiwipete.colors.for_player(Player::White) + ); + assert_eq!( + our_flipped_board.colors.for_player(Player::Black), + flipped_kiwipete.colors.for_player(Player::Black) + ); + assert_eq!(our_flipped_board.pieces, flipped_kiwipete.pieces); + + for (i, &p) in our_flipped_board.squares.iter().enumerate() { + assert_eq!(p, flipped_kiwipete.squares[i]); + } + } +} diff --git a/src/engine/eval/params.rs b/src/engine/eval/params.rs index 5b7b573..5e9ac50 100644 --- a/src/engine/eval/params.rs +++ b/src/engine/eval/params.rs @@ -1,163 +1,165 @@ #![cfg_attr(any(), rustfmt::skip)] -use crate::chess::piece::PieceKind; use crate::chess::square::{File, Rank}; -use crate::engine::eval::phased_eval::s; use crate::engine::eval::PhasedEval; +pub const fn s(mg: i16, eg: i16) -> PhasedEval { + PhasedEval::new(mg, eg) +} + pub type PieceSquareTableDefinition = [[PhasedEval; File::N]; Rank::N]; -pub const PIECE_VALUES: [PhasedEval; PieceKind::N] = [ - s( 114, 224), +pub const PIECE_VALUES: [PhasedEval; 6] = [ + s( 113, 225), s( 269, 306), - s( 274, 329), - s( 366, 592), - s( 755, 1126), - s( 0, 0) + s( 285, 339), + s( 376, 591), + s( 771, 1125), + s( 0, 0), ]; pub const PAWNS: PieceSquareTableDefinition = [ [s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0)], - [s( 80, 183), s( 110, 171), s( 79, 173), s( 118, 104), s( 97, 97), s( 70, 113), s( -25, 178), s( -60, 197)], - [s( -31, 98), s( -10, 109), s( 35, 61), s( 43, 31), s( 49, 19), s( 77, -2), s( 47, 64), s( -11, 61)], - [s( -52, -2), s( -16, -17), s( -11, -44), s( -10, -57), s( 20, -69), s( 10, -65), s( 14, -38), s( -20, -37)], - [s( -66, -37), s( -26, -40), s( -28, -65), s( -5, -69), s( -4, -72), s( -14, -70), s( -5, -54), s( -36, -64)], - [s( -67, -45), s( -33, -43), s( -33, -67), s( -30, -49), s( -8, -59), s( -25, -64), s( 17, -57), s( -26, -70)], - [s( -67, -39), s( -32, -37), s( -38, -57), s( -50, -46), s( -22, -41), s( -1, -57), s( 30, -58), s( -36, -69)], + [s( 56, 187), s( 98, 175), s( 64, 177), s( 99, 107), s( 80, 101), s( 58, 120), s( -34, 182), s( -72, 200)], + [s( -36, 98), s( -13, 111), s( 26, 65), s( 31, 35), s( 36, 21), s( 67, 2), s( 47, 66), s( -14, 61)], + [s( -50, -3), s( -15, -17), s( -15, -43), s( -9, -56), s( 17, -69), s( 4, -64), s( 14, -38), s( -19, -38)], + [s( -61, -38), s( -23, -41), s( -21, -67), s( -1, -70), s( 4, -74), s( -5, -73), s( -3, -55), s( -33, -65)], + [s( -57, -48), s( -28, -44), s( -24, -69), s( -16, -54), s( 3, -63), s( -10, -68), s( 20, -58), s( -19, -72)], + [s( -53, -43), s( -23, -39), s( -32, -59), s( -24, -43), s( -10, -41), s( 10, -60), s( 35, -59), s( -30, -71)], [s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0), s( 0, 0)], ]; pub const KNIGHTS: PieceSquareTableDefinition = [ - [s( -119, -62), s( -121, 8), s( -63, -18), s( -16, -29), s( 31, -25), s( -53, -58), s( -94, 16), s( -43, -96)], - [s( 20, -2), s( 12, -17), s( 13, -5), s( 38, -7), s( 14, -18), s( 101, -39), s( 9, -23), s( 74, -24)], - [s( 8, -25), s( 20, -2), s( 60, 39), s( 75, 41), s( 126, 18), s( 129, 10), s( 55, -18), s( 46, -40)], - [s( 3, -7), s( -12, 21), s( 36, 55), s( 64, 58), s( 39, 60), s( 74, 51), s( 4, 20), s( 51, -19)], - [s( -15, -5), s( -27, 7), s( 6, 59), s( 9, 59), s( 22, 63), s( 13, 48), s( -3, 10), s( 1, -18)], - [s( -40, -27), s( -43, 0), s( -13, 32), s( -8, 50), s( 9, 48), s( -6, 27), s( -14, -7), s( -17, -25)], - [s( -29, 5), s( -41, -19), s( -53, -4), s( -36, 1), s( -35, -1), s( -32, -7), s( -16, -32), s( 9, 19)], - [s( -44, 6), s( -13, -21), s( -61, -27), s( -41, -22), s( -34, -20), s( -16, -37), s( -7, -11), s( -4, -13)], + [s( -113, -61), s( -140, 7), s( -79, -14), s( -31, -25), s( 18, -21), s( -60, -56), s( -104, 13), s( -45, -93)], + [s( 3, -4), s( 4, -14), s( 3, -3), s( 25, -5), s( 2, -14), s( 87, -35), s( 6, -22), s( 61, -27)], + [s( 1, -22), s( 14, -1), s( 55, 42), s( 65, 44), s( 117, 22), s( 122, 14), s( 50, -17), s( 45, -39)], + [s( 3, -7), s( -14, 21), s( 33, 57), s( 65, 58), s( 37, 60), s( 70, 55), s( 3, 19), s( 52, -19)], + [s( -15, -5), s( -28, 7), s( 13, 58), s( 18, 59), s( 32, 63), s( 22, 48), s( -1, 8), s( 1, -18)], + [s( -39, -27), s( -43, -1), s( -2, 30), s( 10, 49), s( 28, 45), s( 8, 23), s( -12, -8), s( -15, -26)], + [s( -26, 0), s( -35, -19), s( -47, -4), s( -20, -3), s( -22, -2), s( -21, -10), s( -10, -32), s( 10, 13)], + [s( -44, 7), s( 2, -11), s( -42, -28), s( -21, -23), s( -17, -19), s( -8, -35), s( 4, -2), s( 1, -13)], ]; pub const BISHOPS: PieceSquareTableDefinition = [ - [s( -26, -6), s( -53, 6), s( -44, 2), s( -103, 19), s( -78, 11), s( -60, -2), s( -18, -6), s( -67, -12)], - [s( -7, -27), s( 19, -5), s( 4, -2), s( -17, 3), s( 24, -13), s( 20, -9), s( 15, 1), s( 5, -26)], - [s( 4, 11), s( 28, 1), s( 25, 12), s( 52, -6), s( 34, 3), s( 84, 4), s( 54, -1), s( 45, 2)], - [s( -7, 5), s( 3, 20), s( 27, 10), s( 40, 35), s( 37, 20), s( 31, 17), s( 4, 14), s( -4, 6)], - [s( -17, -1), s( -7, 17), s( -7, 26), s( 23, 22), s( 17, 25), s( -4, 18), s( -6, 13), s( -3, -15)], - [s( 1, -2), s( 0, 9), s( -2, 14), s( -5, 15), s( -2, 20), s( -2, 16), s( 5, -4), s( 19, -15)], - [s( 3, -5), s( 4, -10), s( 9, -17), s( -19, 3), s( -6, 4), s( 9, -9), s( 26, -2), s( 8, -33)], - [s( -21, -30), s( 0, -7), s( -17, -35), s( -30, -7), s( -23, -10), s( -23, -10), s( 7, -27), s( -7, -49)], + [s( -22, 10), s( -66, 18), s( -58, 13), s( -116, 29), s( -85, 20), s( -66, 5), s( -34, 3), s( -57, 3)], + [s( -17, -10), s( -2, -5), s( -13, -4), s( -35, 0), s( 5, -15), s( 0, -12), s( 0, -1), s( -2, -10)], + [s( -3, 21), s( 18, 1), s( 11, 0), s( 32, -16), s( 16, -8), s( 70, -7), s( 47, -4), s( 43, 12)], + [s( -7, 12), s( -3, 10), s( 14, -2), s( 29, 15), s( 23, 1), s( 17, 4), s( 2, 3), s( -6, 12)], + [s( -7, 4), s( -12, 8), s( -12, 9), s( 19, 3), s( 13, 5), s( -9, 1), s( -14, 5), s( 10, -8)], + [s( 6, 7), s( 3, 4), s( 3, 1), s( -1, 1), s( 3, 5), s( 5, 1), s( 8, -6), s( 25, -5)], + [s( 21, 9), s( 15, -9), s( 17, -21), s( -6, -1), s( 4, -4), s( 21, -13), s( 37, 1), s( 20, -13)], + [s( 8, -3), s( 28, 5), s( 15, -6), s( -5, 1), s( 7, 1), s( 1, 11), s( 29, -15), s( 24, -20)], ]; pub const ROOKS: PieceSquareTableDefinition = [ - [s( 28, 12), s( 7, 23), s( 14, 34), s( 21, 27), s( 46, 16), s( 74, 5), s( 58, 6), s( 88, -2)], - [s( 11, 13), s( 7, 28), s( 32, 33), s( 59, 20), s( 40, 21), s( 78, 1), s( 71, -4), s( 114, -22)], - [s( -15, 14), s( 15, 17), s( 14, 18), s( 13, 16), s( 53, -3), s( 63, -11), s( 118, -22), s( 91, -29)], - [s( -31, 18), s( -14, 15), s( -13, 26), s( -9, 22), s( 0, 1), s( 11, -7), s( 24, -11), s( 28, -18)], - [s( -57, 8), s( -58, 14), s( -46, 15), s( -33, 14), s( -31, 7), s( -45, 4), s( -9, -13), s( -19, -19)], - [s( -64, 1), s( -57, -1), s( -49, -2), s( -49, 4), s( -39, -3), s( -39, -15), s( 11, -42), s( -16, -40)], - [s( -66, -8), s( -54, -4), s( -36, -3), s( -40, -1), s( -32, -13), s( -26, -21), s( -1, -33), s( -39, -24)], - [s( -37, -12), s( -41, -3), s( -30, 4), s( -23, -1), s( -15, -11), s( -22, -12), s( -4, -19), s( -29, -28)], + [s( 31, 11), s( 3, 24), s( 2, 37), s( 4, 30), s( 28, 20), s( 57, 9), s( 56, 7), s( 85, 0)], + [s( 1, 16), s( -5, 33), s( 18, 38), s( 43, 24), s( 27, 25), s( 64, 7), s( 61, -1), s( 104, -20)], + [s( -25, 16), s( 5, 18), s( -4, 21), s( -1, 18), s( 39, 0), s( 47, -8), s( 114, -21), s( 85, -28)], + [s( -34, 19), s( -21, 15), s( -25, 26), s( -18, 20), s( -12, 1), s( 2, -7), s( 19, -11), s( 24, -17)], + [s( -51, 8), s( -54, 12), s( -42, 13), s( -31, 10), s( -28, 5), s( -38, 1), s( -9, -13), s( -20, -17)], + [s( -53, 2), s( -51, 0), s( -39, -3), s( -33, 0), s( -23, -7), s( -21, -19), s( 17, -42), s( -9, -38)], + [s( -52, -6), s( -43, -4), s( -22, -4), s( -21, -2), s( -14, -13), s( -9, -22), s( 11, -32), s( -26, -23)], + [s( -28, -5), s( -25, -4), s( -12, 3), s( -2, -4), s( 5, -13), s( -5, -9), s( 8, -21), s( -21, -22)], ]; pub const QUEENS: PieceSquareTableDefinition = [ - [s( -36, -5), s( -39, 19), s( 2, 41), s( 48, 21), s( 46, 19), s( 56, 11), s( 95, -54), s( 20, -6)], - [s( 4, -42), s( -33, 13), s( -24, 58), s( -34, 81), s( -28, 108), s( 25, 51), s( -2, 32), s( 63, 6)], - [s( 2, -28), s( -7, -5), s( -9, 47), s( 10, 50), s( 19, 70), s( 78, 42), s( 83, -6), s( 82, -16)], - [s( -22, -10), s( -18, 12), s( -14, 31), s( -17, 59), s( -16, 80), s( 7, 60), s( 7, 48), s( 17, 26)], - [s( -17, -19), s( -25, 19), s( -27, 25), s( -14, 47), s( -17, 44), s( -18, 37), s( -1, 18), s( 6, 7)], - [s( -22, -32), s( -14, -14), s( -23, 12), s( -24, 3), s( -19, 10), s( -10, 4), s( 9, -22), s( 2, -32)], - [s( -20, -40), s( -17, -37), s( -4, -45), s( -5, -33), s( -7, -28), s( 6, -65), s( 16, -101), s( 34, -133)], - [s( -20, -50), s( -36, -42), s( -27, -40), s( -8, -49), s( -19, -47), s( -35, -45), s( -3, -83), s( -6, -85)], + [s( -13, -11), s( -45, 26), s( -18, 50), s( 24, 30), s( 27, 30), s( 36, 28), s( 96, -51), s( 34, -3)], + [s( -4, -23), s( -42, 17), s( -39, 62), s( -51, 85), s( -43, 110), s( 9, 56), s( -4, 33), s( 67, 21)], + [s( -3, -16), s( -15, -5), s( -23, 40), s( -10, 50), s( 4, 65), s( 63, 40), s( 79, -3), s( 84, 3)], + [s( -25, -5), s( -25, 6), s( -28, 21), s( -28, 41), s( -29, 63), s( -7, 53), s( 5, 46), s( 16, 37)], + [s( -19, -13), s( -31, 15), s( -30, 10), s( -20, 30), s( -20, 25), s( -20, 24), s( -7, 20), s( 7, 18)], + [s( -19, -26), s( -14, -16), s( -17, -3), s( -17, -12), s( -12, -5), s( -6, -4), s( 9, -19), s( 8, -21)], + [s( -4, -37), s( -10, -37), s( 2, -47), s( 8, -43), s( 3, -34), s( 17, -69), s( 24, -98), s( 52, -122)], + [s( -11, -38), s( -8, -46), s( 2, -50), s( 12, -30), s( 9, -54), s( -14, -45), s( 21, -78), s( 15, -79)], ]; pub const KING: PieceSquareTableDefinition = [ - [s( 23, -111), s( -7, -41), s( 44, -28), s( -156, 42), s( -76, 13), s( -5, 16), s( 59, 6), s( 151, -134)], - [s( -139, 19), s( -79, 59), s( -137, 77), s( 12, 50), s( -65, 80), s( -58, 96), s( -4, 82), s( -32, 38)], - [s( -170, 40), s( -20, 66), s( -115, 93), s( -141, 108), s( -88, 107), s( 19, 96), s( -9, 94), s( -60, 53)], - [s( -124, 25), s( -141, 74), s( -160, 99), s( -222, 117), s( -205, 116), s( -154, 108), s( -153, 96), s( -188, 60)], - [s( -114, 8), s( -127, 52), s( -170, 86), s( -209, 107), s( -208, 106), s( -156, 87), s( -160, 71), s( -199, 48)], - [s( -55, -7), s( -32, 27), s( -111, 57), s( -132, 74), s( -123, 74), s( -120, 61), s( -54, 34), s( -79, 17)], - [s( 65, -37), s( 10, 2), s( -9, 20), s( -58, 35), s( -60, 40), s( -35, 27), s( 32, 0), s( 43, -24)], - [s( 57, -87), s( 90, -59), s( 58, -32), s( -82, -6), s( 8, -40), s( -47, -8), s( 66, -46), s( 67, -87)], + [s( 33, -113), s( 16, -44), s( 68, -32), s( -137, 38), s( -68, 12), s( 2, 14), s( 60, 3), s( 158, -138)], + [s( -138, 18), s( -65, 57), s( -120, 74), s( 27, 48), s( -48, 78), s( -46, 95), s( -4, 81), s( -40, 37)], + [s( -162, 39), s( -8, 65), s( -102, 92), s( -127, 107), s( -76, 106), s( 25, 96), s( -8, 93), s( -60, 52)], + [s( -120, 25), s( -134, 73), s( -156, 99), s( -215, 116), s( -204, 116), s( -152, 108), s( -152, 96), s( -191, 60)], + [s( -114, 8), s( -125, 52), s( -166, 86), s( -208, 106), s( -209, 106), s( -155, 86), s( -161, 70), s( -198, 47)], + [s( -58, -6), s( -36, 28), s( -113, 57), s( -132, 74), s( -123, 73), s( -120, 61), s( -59, 34), s( -83, 17)], + [s( 62, -34), s( 5, 3), s( -15, 21), s( -60, 36), s( -64, 41), s( -39, 27), s( 26, 1), s( 37, -24)], + [s( 46, -83), s( 90, -58), s( 58, -30), s( -71, -5), s( 14, -34), s( -38, -9), s( 63, -47), s( 58, -88)], ]; pub const KNIGHT_MOBILITY: [PhasedEval; 9] = [ s( 0, 0), s( 0, 0), - s( 38, 138), - s( 82, 149), - s( 113, 194), + s( 38, 139), + s( 89, 154), + s( 115, 194), s( 0, 0), - s( 148, 194), + s( 151, 195), s( 0, 0), - s( 136, 175) + s( 136, 175), ]; pub const BISHOP_MOBILITY: [PhasedEval; 14] = [ s( 0, 0), - s( 123, 178), - s( 116, 158), - s( 119, 166), - s( 125, 171), - s( 127, 168), - s( 135, 179), - s( 141, 181), - s( 147, 187), - s( 148, 187), - s( 153, 191), - s( 158, 181), - s( 158, 176), - s( 194, 157) + s( 56, 57), + s( 74, 70), + s( 84, 114), + s( 101, 127), + s( 109, 139), + s( 126, 162), + s( 138, 169), + s( 147, 185), + s( 149, 190), + s( 156, 199), + s( 160, 193), + s( 161, 193), + s( 198, 179), ]; pub const ROOK_MOBILITY: [PhasedEval; 15] = [ s( 0, 0), s( 0, 0), - s( 154, 314), - s( 164, 308), - s( 168, 310), - s( 172, 307), - s( 172, 309), - s( 174, 310), - s( 179, 310), - s( 186, 310), - s( 194, 313), - s( 200, 316), - s( 207, 319), - s( 222, 322), - s( 227, 321) + s( 127, 255), + s( 139, 271), + s( 149, 282), + s( 156, 287), + s( 159, 292), + s( 164, 300), + s( 172, 302), + s( 182, 306), + s( 192, 313), + s( 201, 315), + s( 207, 321), + s( 219, 325), + s( 225, 326), ]; pub const QUEEN_MOBILITY: [PhasedEval; 28] = [ s( 0, 0), s( 0, 0), s( 0, 0), - s( 376, 516), - s( 345, 541), - s( 378, 537), - s( 375, 540), - s( 375, 545), - s( 384, 545), - s( 384, 551), - s( 386, 553), - s( 387, 554), - s( 388, 561), - s( 387, 565), - s( 390, 567), - s( 391, 568), - s( 388, 578), - s( 389, 583), - s( 391, 582), - s( 390, 589), - s( 391, 591), - s( 396, 584), - s( 404, 580), - s( 419, 569), - s( 419, 569), - s( 445, 564), - s( 433, 572), - s( 585, 513) + s( 374, 117), + s( 326, 240), + s( 361, 325), + s( 361, 366), + s( 361, 449), + s( 368, 469), + s( 369, 494), + s( 375, 502), + s( 379, 518), + s( 382, 533), + s( 386, 538), + s( 391, 545), + s( 392, 557), + s( 393, 567), + s( 396, 577), + s( 396, 586), + s( 398, 593), + s( 400, 606), + s( 406, 605), + s( 407, 609), + s( 420, 605), + s( 422, 604), + s( 440, 603), + s( 504, 566), + s( 569, 547), ]; -pub const BISHOP_PAIR_BONUS: PhasedEval = s( 29, 92); +pub const BISHOP_PAIR_BONUS: PhasedEval = s( 27, 90); diff --git a/src/engine/eval/phased_eval.rs b/src/engine/eval/phased_eval.rs index b9adaae..5108d50 100644 --- a/src/engine/eval/phased_eval.rs +++ b/src/engine/eval/phased_eval.rs @@ -5,12 +5,8 @@ use crate::engine::eval::WhiteEval; const PHASE_COUNT_MAX: i64 = 24; -pub const fn s(mg: i16, eg: i16) -> PhasedEval { - PhasedEval::new(mg, eg) -} - /// A midgame and endgame evaluation -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub struct PhasedEval(i32); impl PhasedEval { @@ -47,6 +43,12 @@ impl PhasedEval { } } +impl std::fmt::Debug for PhasedEval { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "PhasedEval({}, {})", self.midgame().0, self.endgame().0) + } +} + impl std::ops::Add for PhasedEval { type Output = Self; diff --git a/src/utils/cli/mod.rs b/src/utils/cli/mod.rs index 2d44611..b80f707 100644 --- a/src/utils/cli/mod.rs +++ b/src/utils/cli/mod.rs @@ -1,6 +1,8 @@ use crate::engine::uci; use crate::engine::uci::UciInputMode; +use crate::utils; use clap::{Parser, Subcommand}; +use std::path::{Path, PathBuf}; use std::process::ExitCode; #[derive(Parser)] @@ -12,6 +14,13 @@ struct Cli { #[derive(Subcommand)] enum Command { Uci, + + Tune { + file: PathBuf, + + #[clap(default_value_t = 5000)] + epochs: usize, + }, } pub fn uci_command() -> ExitCode { @@ -26,12 +35,18 @@ pub fn uci_command() -> ExitCode { } } +pub fn tune_command(file: &Path, epochs: usize) -> ExitCode { + utils::tuner::tune(file, epochs); + ExitCode::SUCCESS +} + pub fn run() -> ExitCode { let cli = Cli::parse(); match cli.command { Some(c) => match c { Command::Uci => uci_command(), + Command::Tune { file, epochs } => tune_command(&file, epochs), }, _ => uci_command(), } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 4f77372..2b02014 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1 +1,2 @@ pub mod cli; +pub mod tuner; diff --git a/src/utils/tuner/mod.rs b/src/utils/tuner/mod.rs new file mode 100644 index 0000000..ee74bb2 --- /dev/null +++ b/src/utils/tuner/mod.rs @@ -0,0 +1,229 @@ +// A huge thanks to GediminasMasaitis and Andrew Grant. +// This code borrows heavily from https://github.com/GediminasMasaitis/texel-tuner +// which is in turn based on https://github.com/AndyGrant/Ethereal/blob/master/Tuning.pdf + +use crate::chess::game::Game; +use crate::utils::tuner::parameters::Parameters; +use crate::utils::tuner::trace::Trace; +use crate::utils::tuner::tuner_eval::TunerEval; +use indicatif::{ProgressBar, ProgressStyle}; +use rayon::prelude::*; +use std::path::Path; + +mod parameters; +mod trace; +mod tuner_eval; + +enum Outcome { + Win, + Draw, + Loss, +} + +impl Outcome { + fn numeric_outcome(&self) -> f32 { + match self { + Self::Win => 1.0, + Self::Draw => 0.5, + Self::Loss => 0.0, + } + } +} + +#[derive(Clone)] +struct NonZeroCoefficient { + idx: usize, + value: f32, +} + +impl NonZeroCoefficient { + pub fn new(idx: usize, value: f32) -> Self { + Self { idx, value } + } +} + +struct Entry { + outcome: Outcome, + coefficients: Vec, + + midgame_percentage: f32, + endgame_percentage: f32, +} + +fn start_progress_bar(size: usize, label: &str) -> ProgressBar { + let p = ProgressBar::new(size as u64); + p.set_prefix(label.to_owned()); + p.set_style( + ProgressStyle::with_template( + "{prefix} [{wide_bar:.cyan/blue}] {pos}/{len} ({elapsed} {per_sec} ETA {eta})", + ) + .unwrap() + .progress_chars("#>-"), + ); + p +} + +fn load_entries_from_file(path: &Path) -> Vec { + let file_contents = std::fs::read_to_string(path).expect("Unable to read file"); + let lines = file_contents.lines().collect::>(); + + let number_of_positions = lines.len(); + + let parsing_progress = start_progress_bar(number_of_positions, "Loading positions"); + let mut parse_results: Vec<(Game, Outcome)> = Vec::new(); + + for (i, line) in lines.iter().enumerate() { + let (fen_str, outcome_str) = line.split_once('[').expect("Unexpected file format"); + let fen_str = fen_str.trim(); + let outcome_str = outcome_str.trim().replace(']', ""); + + let game = Game::from_fen(fen_str).expect("Unexpected fen"); + + let outcome = match outcome_str.as_str() { + "1.0" => Outcome::Win, + "0.5" => Outcome::Draw, + "0.0" => Outcome::Loss, + _ => panic!("Unexpected outcome format"), + }; + + parse_results.push((game, outcome)); + + if i % 1000 == 0 { + parsing_progress.set_position(i as u64); + } + } + + parsing_progress.finish(); + + let coefficients_progress = start_progress_bar(number_of_positions, "Calculating coefficients"); + let mut entries: Vec = Vec::new(); + + for (i, (game, outcome)) in parse_results.into_iter().enumerate() { + let coefficients = Trace::for_game(&game).non_zero_coefficients(); + + let midgame_percentage = + f32::from(game.incremental_eval.phase_value) / f32::from(tuner_eval::PHASE_COUNT_MAX); + let endgame_percentage = 1.0 - midgame_percentage; + + entries.push(Entry { + outcome, + coefficients, + + midgame_percentage, + endgame_percentage, + }); + + if i % 1000 == 0 { + coefficients_progress.set_position(i as u64); + } + } + + coefficients_progress.finish(); + + entries +} + +fn evaluate(entry: &Entry, parameters: &[TunerEval]) -> f32 { + let mut s = TunerEval::ZERO; + + for coefficient in &entry.coefficients { + s += parameters[coefficient.idx] * coefficient.value; + } + + s.midgame().mul_add( + entry.midgame_percentage, + s.endgame() * entry.endgame_percentage, + ) +} + +fn sigmoid(x: f32) -> f32 { + 1.0 / (1.0 + f32::exp(-x)) +} + +fn calculate_gradient( + entries: &[Entry], + parameters: &[TunerEval; Trace::SIZE], + k: f32, +) -> [TunerEval; Trace::SIZE] { + // Break the entries into chunks, aiming for as many chunks as we have CPUs. + // + // I previously tried using Rayon's .fold() but this results in a huge number of array copies + // due to Rayon's default binary-tree-of-tasks strategy. Each leaf gets its own array, which means + // we have to allocate a huge number of trace-sized arrays. + // With this approach, we allocate only a single array per chunk, and sum them at the end. + let entry_chunks = entries.chunks(entries.len() / 10).collect::>(); + + entry_chunks + .par_iter() + .map(|&entries| { + let mut gradient = [TunerEval::ZERO; Trace::SIZE]; + + for entry in entries { + let eval = evaluate(entry, parameters); + let sigmoid = sigmoid(k * eval / 400.0); + let result = + (entry.outcome.numeric_outcome() - sigmoid) * sigmoid * (1.0 - sigmoid); + + for coefficient in &entry.coefficients { + gradient[coefficient.idx] += TunerEval::new( + entry.midgame_percentage * coefficient.value, + entry.endgame_percentage * coefficient.value, + ) * result; + } + } + + gradient + }) + .reduce( + || [TunerEval::ZERO; Trace::SIZE], + |mut gradient: [TunerEval; Trace::SIZE], thread_gradient: [TunerEval; Trace::SIZE]| { + for (idx, score) in thread_gradient.iter().enumerate() { + gradient[idx] += *score; + } + + gradient + }, + ) +} + +#[expect(clippy::cast_precision_loss, reason = "Known imprecise calculations")] +pub fn tune(path: &Path, epochs: usize) { + rayon::ThreadPoolBuilder::new() + .stack_size(5_000_000) + .build_global() + .unwrap(); + + let entries = load_entries_from_file(path); + + // TODO: Using the same k as was determined by texel-tuner until we compute it here. + let k = 2.5; + + let learning_rate = 1.0; + let beta1 = 0.9; + let beta2 = 0.999; + + let mut parameters: [TunerEval; Trace::SIZE] = [TunerEval::ZERO; Trace::SIZE]; + let mut momentum: [TunerEval; Trace::SIZE] = [TunerEval::ZERO; Trace::SIZE]; + let mut velocities: [TunerEval; Trace::SIZE] = [TunerEval::ZERO; Trace::SIZE]; + + let epoch_progress = start_progress_bar(epochs, "Running epochs"); + + for epoch in 0..epochs { + let gradient = calculate_gradient(&entries, ¶meters, k); + + for param in 0..Trace::SIZE { + let grad = TunerEval::v(-k) / TunerEval::v(400.0) * gradient[param] + / TunerEval::v(entries.len() as f32); + momentum[param] = momentum[param] * beta1 + grad * (1.0 - beta1); + velocities[param] = velocities[param] * beta2 + (grad * grad) * (1.0 - beta2); + + parameters[param] -= + momentum[param] * learning_rate / (TunerEval::v(1e-8) + velocities[param].sqrt()); + } + + epoch_progress.set_position((epoch + 1) as u64); + } + + let parameters = Parameters::from_array(¶meters); + println!("{}", ¶meters); +} diff --git a/src/utils/tuner/parameters.rs b/src/utils/tuner/parameters.rs new file mode 100644 index 0000000..c805d7d --- /dev/null +++ b/src/utils/tuner/parameters.rs @@ -0,0 +1,233 @@ +use crate::chess::bitboard::{bitboards, Bitboard}; +use crate::chess::piece::PieceKind; +use crate::chess::square::Square; +use crate::engine::eval::PhasedEval; +use crate::utils::tuner::trace::Trace; +use crate::utils::tuner::tuner_eval::TunerEval; +use std::fmt::Formatter; + +pub struct ParametersBuilder { + parameters: [PhasedEval; Trace::SIZE], + idx: usize, +} + +impl ParametersBuilder { + pub fn new(parameters: &[PhasedEval; Trace::SIZE]) -> Self { + Self { + parameters: *parameters, + idx: 0, + } + } + + pub fn copy_to(mut self, ps: &mut [PhasedEval]) -> Self { + ps.copy_from_slice(&self.parameters[self.idx..self.idx + ps.len()]); + self.idx += ps.len(); + self + } +} + +#[derive(Clone)] +pub struct Parameters { + material: [PhasedEval; PieceKind::N], + + pawn_pst: [PhasedEval; Square::N], + knight_pst: [PhasedEval; Square::N], + bishop_pst: [PhasedEval; Square::N], + rook_pst: [PhasedEval; Square::N], + queen_pst: [PhasedEval; Square::N], + king_pst: [PhasedEval; Square::N], + + knight_mobility: [PhasedEval; 9], + bishop_mobility: [PhasedEval; 14], + rook_mobility: [PhasedEval; 15], + queen_mobility: [PhasedEval; 28], + + bishop_pair: [PhasedEval; 1], +} + +impl Parameters { + pub fn new() -> Self { + Self { + material: [PhasedEval::ZERO; PieceKind::N], + + pawn_pst: [PhasedEval::ZERO; Square::N], + knight_pst: [PhasedEval::ZERO; Square::N], + bishop_pst: [PhasedEval::ZERO; Square::N], + rook_pst: [PhasedEval::ZERO; Square::N], + queen_pst: [PhasedEval::ZERO; Square::N], + king_pst: [PhasedEval::ZERO; Square::N], + + knight_mobility: [PhasedEval::ZERO; 9], + bishop_mobility: [PhasedEval::ZERO; 14], + rook_mobility: [PhasedEval::ZERO; 15], + queen_mobility: [PhasedEval::ZERO; 28], + + bishop_pair: [PhasedEval::ZERO; 1], + } + } + + pub fn from_array(arr: &[TunerEval; Trace::SIZE]) -> Self { + let mut evals = [PhasedEval::ZERO; Trace::SIZE]; + + for (i, param) in arr.iter().enumerate() { + evals[i] = param.to_phased_eval(); + } + + let mut parameter_components = Self::new(); + + ParametersBuilder::new(&evals) + .copy_to(&mut parameter_components.material) + .copy_to(&mut parameter_components.pawn_pst) + .copy_to(&mut parameter_components.knight_pst) + .copy_to(&mut parameter_components.bishop_pst) + .copy_to(&mut parameter_components.rook_pst) + .copy_to(&mut parameter_components.queen_pst) + .copy_to(&mut parameter_components.king_pst) + .copy_to(&mut parameter_components.knight_mobility) + .copy_to(&mut parameter_components.bishop_mobility) + .copy_to(&mut parameter_components.rook_mobility) + .copy_to(&mut parameter_components.queen_mobility) + .copy_to(&mut parameter_components.bishop_pair); + + parameter_components.rebalance(); + + parameter_components + } + + fn rebalance_pst( + pst: &mut [PhasedEval; Square::N], + material: &mut [PhasedEval; PieceKind::N], + piece: PieceKind, + ignore_mask: Bitboard, + ) { + let mut midgame_sum = 0; + let mut endgame_sum = 0; + + // We won't include some squares in the calculation - e.g. ranks where pawns can never be + let squares = Bitboard::FULL & !ignore_mask; + + // First, calculate the average across all non-zero squares in the PST + // Do our computations in 32 bits to avoid overflowing i16 for large PST values. + for sq in squares { + let v = pst[sq.array_idx()]; + midgame_sum += i32::from(v.midgame().0); + endgame_sum += i32::from(v.endgame().0); + } + + let mg_average = midgame_sum / i32::from(squares.count()); + let eg_average = endgame_sum / i32::from(squares.count()); + + let average = PhasedEval::new( + i16::try_from(mg_average).unwrap(), + i16::try_from(eg_average).unwrap(), + ); + material[piece.array_idx()] += average; + + for sq in squares { + pst[sq.array_idx()] -= average; + } + } + + pub fn rebalance(&mut self) { + Self::rebalance_pst( + &mut self.pawn_pst, + &mut self.material, + PieceKind::Pawn, + bitboards::RANK_1 | bitboards::RANK_8, + ); + Self::rebalance_pst( + &mut self.knight_pst, + &mut self.material, + PieceKind::Knight, + Bitboard::EMPTY, + ); + Self::rebalance_pst( + &mut self.bishop_pst, + &mut self.material, + PieceKind::Bishop, + Bitboard::EMPTY, + ); + Self::rebalance_pst( + &mut self.rook_pst, + &mut self.material, + PieceKind::Rook, + Bitboard::EMPTY, + ); + Self::rebalance_pst( + &mut self.queen_pst, + &mut self.material, + PieceKind::Queen, + Bitboard::EMPTY, + ); + } +} + +fn print_param(f: &mut Formatter<'_>, p: PhasedEval) -> std::fmt::Result { + let (mg, eg) = (p.midgame().0, p.endgame().0); + write!(f, "s({mg: >5}, {eg: >5})") +} + +fn print_array(f: &mut Formatter<'_>, ps: &[PhasedEval], name: &str) -> std::fmt::Result { + let size = ps.len(); + writeln!(f, "pub const {name}: [PhasedEval; {size}] = [")?; + + for param in ps { + write!(f, " ")?; + print_param(f, *param)?; + writeln!(f, ",")?; + } + + writeln!(f, "];\n")?; + + Ok(()) +} + +fn print_pst(f: &mut Formatter<'_>, pst: &[PhasedEval; Square::N], name: &str) -> std::fmt::Result { + writeln!(f, "pub const {name}: PieceSquareTableDefinition = [")?; + + for rank in (0..8).rev() { + write!(f, " [")?; + + for file in 0..8 { + let idx = Square::from_idxs(file, rank).array_idx(); + print_param(f, pst[idx])?; + + if file != 7 { + write!(f, ", ")?; + } + } + + writeln!(f, "],")?; + } + + writeln!(f, "];\n")?; + + Ok(()) +} + +fn print_single(f: &mut Formatter<'_>, p: [PhasedEval; 1], name: &str) -> std::fmt::Result { + write!(f, "pub const {name}: PhasedEval = ")?; + print_param(f, p[0])?; + writeln!(f, ";\n")?; + + Ok(()) +} + +impl std::fmt::Display for Parameters { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + print_array(f, &self.material, "PIECE_VALUES")?; + print_pst(f, &self.pawn_pst, "PAWNS")?; + print_pst(f, &self.knight_pst, "KNIGHTS")?; + print_pst(f, &self.bishop_pst, "BISHOPS")?; + print_pst(f, &self.rook_pst, "ROOKS")?; + print_pst(f, &self.queen_pst, "QUEENS")?; + print_pst(f, &self.king_pst, "KING")?; + print_array(f, &self.knight_mobility, "KNIGHT_MOBILITY")?; + print_array(f, &self.bishop_mobility, "BISHOP_MOBILITY")?; + print_array(f, &self.rook_mobility, "ROOK_MOBILITY")?; + print_array(f, &self.queen_mobility, "QUEEN_MOBILITY")?; + print_single(f, self.bishop_pair, "BISHOP_PAIR_BONUS")?; + + Ok(()) + } +} diff --git a/src/utils/tuner/trace.rs b/src/utils/tuner/trace.rs new file mode 100644 index 0000000..f19cd3c --- /dev/null +++ b/src/utils/tuner/trace.rs @@ -0,0 +1,175 @@ +use crate::chess::board::Board; +use crate::chess::game::Game; +use crate::chess::movegen::tables; +use crate::chess::piece::PieceKind; +use crate::chess::player::Player; +use crate::chess::square::Square; +use crate::utils::tuner::NonZeroCoefficient; + +#[derive(Default, Copy, Clone)] +pub struct TraceComponent(i32); + +impl TraceComponent { + pub fn incr(&mut self, player: Player) { + self.add(player, 1); + } + + pub fn add(&mut self, player: Player, n: i32) { + let multiplier = if player == Player::White { 1 } else { -1 }; + + self.0 += n * multiplier; + } +} + +pub struct Trace { + pub material: [TraceComponent; PieceKind::N], + + pub pawn_pst: [TraceComponent; Square::N], + pub knight_pst: [TraceComponent; Square::N], + pub bishop_pst: [TraceComponent; Square::N], + pub rook_pst: [TraceComponent; Square::N], + pub queen_pst: [TraceComponent; Square::N], + pub king_pst: [TraceComponent; Square::N], + + knight_mobility: [TraceComponent; 9], + bishop_mobility: [TraceComponent; 14], + rook_mobility: [TraceComponent; 15], + queen_mobility: [TraceComponent; 28], + + pub bishop_pair: TraceComponent, +} + +impl Trace { + pub const SIZE: usize = size_of::() / size_of::(); + + fn trace_for_player(trace: &mut Self, board: &Board, player: Player) { + let mut number_of_bishops = 0; + + let occupancy = board.occupancy(); + + for sq in board.pawns(player) { + trace.material[PieceKind::Pawn.array_idx()].incr(player); + trace.pawn_pst[sq.array_idx()].incr(player); + } + + for sq in board.knights(player) { + trace.material[PieceKind::Knight.array_idx()].incr(player); + trace.knight_pst[sq.array_idx()].incr(player); + + let mobility = tables::knight_attacks(sq).count(); + trace.knight_mobility[mobility as usize].incr(player); + } + + for sq in board.bishops(player) { + trace.material[PieceKind::Bishop.array_idx()].incr(player); + trace.bishop_pst[sq.array_idx()].incr(player); + + let mobility = tables::bishop_attacks(sq, occupancy).count(); + trace.bishop_mobility[mobility as usize].incr(player); + + number_of_bishops += 1; + } + + if number_of_bishops > 1 { + trace.bishop_pair.incr(player); + } + + for sq in board.rooks(player) { + trace.material[PieceKind::Rook.array_idx()].incr(player); + trace.rook_pst[sq.array_idx()].incr(player); + + let mobility = tables::rook_attacks(sq, occupancy).count(); + trace.rook_mobility[mobility as usize].incr(player); + } + + for sq in board.queens(player) { + trace.material[PieceKind::Queen.array_idx()].incr(player); + trace.queen_pst[sq.array_idx()].incr(player); + + let mobility = (tables::rook_attacks(sq, occupancy) + | tables::bishop_attacks(sq, occupancy)) + .count(); + trace.queen_mobility[mobility as usize].incr(player); + } + + for sq in board.king(player) { + trace.material[PieceKind::King.array_idx()].incr(player); + trace.king_pst[sq.array_idx()].incr(player); + } + } + + pub fn for_game(game: &Game) -> Self { + let mut trace = Self { + material: [TraceComponent::default(); PieceKind::N], + pawn_pst: [TraceComponent::default(); Square::N], + knight_pst: [TraceComponent::default(); Square::N], + bishop_pst: [TraceComponent::default(); Square::N], + rook_pst: [TraceComponent::default(); Square::N], + queen_pst: [TraceComponent::default(); Square::N], + king_pst: [TraceComponent::default(); Square::N], + + knight_mobility: [TraceComponent::default(); 9], + bishop_mobility: [TraceComponent::default(); 14], + rook_mobility: [TraceComponent::default(); 15], + queen_mobility: [TraceComponent::default(); 28], + + bishop_pair: TraceComponent::default(), + }; + + Self::trace_for_player(&mut trace, &game.board, Player::White); + Self::trace_for_player(&mut trace, &game.board.flip_vertically(), Player::Black); + + trace + } + + pub fn non_zero_coefficients(&self) -> Vec { + CoefficientBuilder::new() + .add(&self.material) + .add(&self.pawn_pst) + .add(&self.knight_pst) + .add(&self.bishop_pst) + .add(&self.rook_pst) + .add(&self.queen_pst) + .add(&self.king_pst) + .add(&self.knight_mobility) + .add(&self.bishop_mobility) + .add(&self.rook_mobility) + .add(&self.queen_mobility) + .add(&[self.bishop_pair]) + .get() + } +} + +struct CoefficientBuilder { + value: Vec, + idx: usize, +} + +impl CoefficientBuilder { + pub fn new() -> Self { + Self { + value: Vec::new(), + idx: 0, + } + } + + #[expect(clippy::cast_precision_loss, reason = "known cast from i32 to f32")] + pub fn add(&mut self, s: &[TraceComponent]) -> &mut Self { + for (i, component) in s.iter().enumerate() { + let coefficient = component.0; + + if coefficient != 0 { + self.value + .push(NonZeroCoefficient::new(self.idx + i, coefficient as f32)); + } + } + + self.idx += s.len(); + self + } + + pub fn get(&self) -> Vec { + assert_eq!(Trace::SIZE, self.idx); + self.value.clone() + } +} diff --git a/src/utils/tuner/tuner_eval.rs b/src/utils/tuner/tuner_eval.rs new file mode 100644 index 0000000..b2fb98b --- /dev/null +++ b/src/utils/tuner/tuner_eval.rs @@ -0,0 +1,105 @@ +use crate::engine::eval::PhasedEval; + +// A non-packed version of PhasedEval +#[derive(Debug, PartialEq, PartialOrd, Clone, Copy)] +pub struct TunerEval(f32, f32); + +pub const PHASE_COUNT_MAX: i16 = 24; + +impl TunerEval { + pub const ZERO: Self = Self(0.0, 0.0); + + pub const fn new(midgame: f32, endgame: f32) -> Self { + Self(midgame, endgame) + } + + pub const fn v(val: f32) -> Self { + Self(val, val) + } + + pub fn midgame(self) -> f32 { + self.0 + } + + pub fn endgame(self) -> f32 { + self.1 + } + + pub fn sqrt(self) -> Self { + Self(self.0.sqrt(), self.1.sqrt()) + } + + #[expect( + clippy::cast_possible_truncation, + reason = "Intentionally truncating down to integers" + )] + pub fn to_phased_eval(self) -> PhasedEval { + PhasedEval::new(self.0.round() as i16, self.1.round() as i16) + } +} + +impl std::ops::Add for TunerEval { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0, self.1 + rhs.1) + } +} + +impl std::ops::AddAssign for TunerEval { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + self.1 += rhs.1; + } +} + +impl std::ops::Sub for TunerEval { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 - rhs.0, self.1 - rhs.1) + } +} + +impl std::ops::SubAssign for TunerEval { + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + self.1 -= rhs.1; + } +} + +impl std::ops::Mul for TunerEval { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self(self.0 * rhs.0, self.1 * rhs.1) + } +} + +impl std::ops::Mul for TunerEval { + type Output = Self; + + fn mul(self, rhs: f32) -> Self::Output { + Self(self.0 * rhs, self.1 * rhs) + } +} + +impl std::ops::Div for TunerEval { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + Self(self.0 / rhs.0, self.1 / rhs.1) + } +} + +impl std::iter::Sum for TunerEval { + fn sum>(iter: I) -> Self { + let mut result = Self::ZERO; + + for i in iter { + result += i; + } + + result + } +}