diff --git a/game/.gitignore b/game/.gitignore index 7f02987..b90dde5 100644 --- a/game/.gitignore +++ b/game/.gitignore @@ -1,4 +1,5 @@ /target .cargo .vscode -/data \ No newline at end of file +/data +/Microsoft.WebView2.FixedVersionRuntime.107.0.1418.24.x64 \ No newline at end of file diff --git a/game/Cargo.toml b/game/Cargo.toml index 761a2f4..3643d05 100644 --- a/game/Cargo.toml +++ b/game/Cargo.toml @@ -21,7 +21,7 @@ serde_json = "1.0" serde-pickle = "1.1.1" serde = { version = "1.0.144", features = ["derive"] } tauri = { version = "1.1.1", features = ["api-all"] } -pyo3 = { version = "0.17.1", features = [] } +tempfile = "3.3.0" num_cpus = "1.13.1" crossbeam = "0.8.2" flume = "0.10.14" @@ -31,6 +31,8 @@ rand = "0.8.5" rand_distr = "0.4.3" probability = "0.18.0" +pyo3 = { version = "0.17.1", features = [], optional = true } + [dependencies.uuid] version = "1.1.2" features = [ @@ -41,16 +43,9 @@ features = [ [build-dependencies] -tauri-build = { version = "1.1.1" } +tauri-build = { version = "1.1.1", features = [] } copy_to_output = "2.0.0" -[features] -# by default Tauri runs in production mode -# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL -default = ["custom-protocol"] -# this feature is used used for production builds where `devPath` points to the filesystem -# DO NOT remove this -custom-protocol = ["tauri/custom-protocol"] # BLAS source choices: @@ -70,8 +65,19 @@ blas-src = { version = "0.8.0", default-features = false, features = [ ndarray = { version = "0.15.6" } -#[profile.release] -#strip = true # Automatically strip symbols from the binary. -#opt-level = "z" # Optimize for size. -#lto = true -#codegen-units = 1 +[profile.release] +strip = true # Automatically strip symbols from the binary. +opt-level = "z" # Optimize for size. +lto = true +codegen-units = 1 + + + +[features] +# by default Tauri runs in production mode +# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL +default = ["custom-protocol"] +# this feature is used used for production builds where `devPath` points to the filesystem +# DO NOT remove this +custom-protocol = ["tauri/custom-protocol"] +train = [ "dep:pyo3" ] diff --git a/game/build_windows.bat b/game/build_windows.bat index 27d8221..20009e9 100644 --- a/game/build_windows.bat +++ b/game/build_windows.bat @@ -1,3 +1,5 @@ +REM cargo install tauri-cli + CALL "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" SET LIBCLANG_PATH=C:\Program Files\Microsoft^ Visual^ Studio\2022\Enterprise\VC\Tools\Llvm\x64\bin @@ -7,3 +9,4 @@ SET TFLITEC_PREBUILT_PATH_X86_64_PC_WINDOWS_GNU=%CD%\tensorflowlite_c.dll SET RUSTFLAGS=-C target-feature=+crt-static cargo build --release +cargo tauri build \ No newline at end of file diff --git a/game/src/game.rs b/game/src/game.rs index a001cc7..48ce5ad 100644 --- a/game/src/game.rs +++ b/game/src/game.rs @@ -48,7 +48,7 @@ impl From for Color { } impl From for Color { - fn from(v: f32) -> Self { + fn from(_v: f32) -> Self { unimplemented!() } } diff --git a/game/src/human.rs b/game/src/human.rs index 23f0a84..4db51b6 100644 --- a/game/src/human.rs +++ b/game/src/human.rs @@ -21,12 +21,15 @@ use crate::model::TfLiteModel; use crossbeam::atomic::AtomicCell; use std::sync::atomic::Ordering; - +use std::include_bytes; use rand::seq::SliceRandom; use std::cell::RefCell; use std::sync::{atomic::AtomicBool, Arc}; use tokio::sync::oneshot::{self, Receiver}; use tokio::task::JoinHandle; +use tempfile::{tempdir, TempDir}; +use std::fs::File; +use std::io::prelude::*; lazy_static! { static ref SINGLETON_CHANNEL: AtomicCell> = { @@ -37,6 +40,19 @@ lazy_static! { } AtomicCell::new(rx) }; + + static ref MODEL_FILE : (String, TempDir) = { + let mut dir = tempdir().expect("Unable to create template file"); + let filename = dir.path().join("model.tflite").display().to_string(); + println!("{}", &filename); + + let mut file = File::create(&filename).expect("Unable to create temp file"); + + let bytes = include_bytes!("../best.tflite"); + file.write_all(bytes).expect("Unable to save model to file"); + file.flush().expect("Unable to flush model to file"); + ( filename, dir ) + }; } pub async fn access(mut f: F) -> MatchState @@ -96,7 +112,7 @@ impl BoardInfo { impl HumanVsMachineMatch { fn new(human_play_black: bool) -> Self { let ai_player = AiPlayer::new(); - let max_threads = 2; //(num_cpus::get() - 0).max(1); + let max_threads = 1; //(num_cpus::get() - 0).max(1); Self { ai_player: Arc::new(ai_player), board: RenjuBoard::default(), @@ -268,27 +284,26 @@ impl AiPlayer { choices: &Vec<(usize, usize)>, ) { self.tree - .rollout( - breadth_first, - board, - choices, - |state_tensor: StateTensor| { - MODEL.with(|ref_cell| { - let mut model = ref_cell.borrow_mut(); - if model.is_none() { - *model = Some( - TfLiteModel::load("best.tflite") - .expect("Unable to load saved model"), - ) - } - model - .as_ref() - .unwrap() - .predict_one(state_tensor) - .expect("Unable to predict_one") - }) - }, - ) + .rollout(breadth_first, board, choices, |state_tensor: StateTensor| { + MODEL.with(|ref_cell| { + let mut model = ref_cell.borrow_mut(); + if model.is_none() { + #[cfg(feature="train")] + let filepath = "best.tflite"; + + #[cfg(not(feature="train"))] + let filepath = &MODEL_FILE.0; + *model = Some( + TfLiteModel::load(filepath).expect("Unable to load model"), + ) + } + model + .as_ref() + .unwrap() + .predict_one(state_tensor) + .expect("Unable to predict_one") + }) + }) .await .expect("rollout failed") } diff --git a/game/src/lib.rs b/game/src/lib.rs index 125d3db..677e908 100644 --- a/game/src/lib.rs +++ b/game/src/lib.rs @@ -20,4 +20,5 @@ pub mod model; pub use game::{RenjuBoard, SquareMatrix, SquaredMatrixExtension, StateTensor, TerminalState}; pub use mcts::{MonteCarloTree, TreeNode}; +#[cfg(feature="train")] pub use model::PolicyValueModel; diff --git a/game/src/main.rs b/game/src/main.rs index 152994c..589640a 100644 --- a/game/src/main.rs +++ b/game/src/main.rs @@ -31,11 +31,13 @@ mod game; mod human; mod mcts; mod model; +#[cfg(feature="train")] mod selfplay; use human::MatchState; +#[cfg(feature="train")] use selfplay::Trainer; -static ABOUT_TEXT: &str = "Renju game "; +static ABOUT_TEXT: &str = "Renju Game"; static SELF_PLAY_MATCH_HELP_TEXT: &str = " Produce matches by self-play @@ -65,21 +67,12 @@ enum Verb { export_dir: String, }, + #[cfg(feature="train")] /// Self play and train #[clap(after_help=TRAIN_HELP_TEXT)] Train {}, - /// contest between two model - #[clap()] - Contest { - /// old model name - #[clap(required = true)] - old_model: String, - /// new model name - #[clap(required = true)] - new_model: String, - }, } #[tokio::main(flavor = "multi_thread")] @@ -87,6 +80,7 @@ async fn main() { let args = Arguments::parse(); match args.verb { + #[cfg(feature="train")] Some(Verb::SelfPlay { model_file, export_dir, @@ -98,6 +92,7 @@ async fn main() { .await; } + #[cfg(feature="train")] Some(Verb::Train {}) => { let mut trainer = Trainer::new(); trainer.run().await; diff --git a/game/src/model.rs b/game/src/model.rs index a861a61..0399a03 100644 --- a/game/src/model.rs +++ b/game/src/model.rs @@ -15,9 +15,13 @@ */ use bytes::Bytes; use num_cpus; +#[cfg(feature="train")] use pyo3::prelude::*; +#[cfg(feature="train")] use pyo3::types::{PyBytes, PyList, PyTuple}; +#[cfg(feature="train")] use std::fs::OpenOptions; +#[cfg(feature="train")] use std::io::prelude::*; use tflitec::interpreter::{Interpreter, Options}; @@ -25,11 +29,12 @@ use tflitec::tensor; use crate::game::*; +#[cfg(feature="train")] pub struct PolicyValueModel { module: Py, } -// +#[cfg(feature="train")] impl PolicyValueModel { pub fn new(filename: &str) -> PyResult { let mut source_code = String::new(); @@ -185,7 +190,7 @@ impl TfLiteModel { pub fn load(tflite_model_path: &str) -> Result { // Create interpreter options let mut options = Options::default(); - options.thread_count = num_cpus::get() as i32 / 2; + options.thread_count = (num_cpus::get() as i32 - 2).max(1); options.is_xnnpack_enabled = true; println!("is_xnnpack_enabled={}", options.is_xnnpack_enabled); diff --git a/game/tauri.conf.json b/game/tauri.conf.json index cc17c79..da620f8 100644 --- a/game/tauri.conf.json +++ b/game/tauri.conf.json @@ -1,12 +1,12 @@ { "build": { - "beforeBuildCommand": "npm run build --prefix=ui", + "beforeBuildCommand": "npm run build", "beforeDevCommand": "npm run build --prefix=ui", "devPath": "http://localhost:8080", "distDir": "ui/dist" }, "package": { - "productName": "renju_board", + "productName": "renju", "version": "0.1.0" }, "tauri": { @@ -15,7 +15,7 @@ }, "bundle": { "active": true, - "category": "DeveloperTool", + "category": "games", "copyright": "", "deb": { "depends": [] @@ -28,8 +28,8 @@ "icons/icon.icns", "icons/icon.ico" ], - "identifier": "com.tauri.dev", - "longDescription": "", + "identifier": "com.github.wangjia184.renju", + "longDescription": "Renju Game https://github.com/wangjia184/renju", "macOS": { "entitlements": null, "exceptionDomain": "", @@ -38,12 +38,16 @@ "signingIdentity": null }, "resources": [], - "shortDescription": "", + "shortDescription": "https://github.com/wangjia184/renju", "targets": "all", "windows": { "certificateThumbprint": null, "digestAlgorithm": "sha256", - "timestampUrl": "" + "timestampUrl": "", + "webviewInstallMode": { + "type": "fixedRuntime", + "path": "./Microsoft.WebView2.FixedVersionRuntime.107.0.1418.24.x64/" + } } }, "security": {