Skip to content

Commit

Permalink
+ prev coommit
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjia184 committed Oct 30, 2022
1 parent a1a6fe6 commit 6a34bb6
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 59 deletions.
3 changes: 2 additions & 1 deletion game/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/target
.cargo
.vscode
/data
/data
/Microsoft.WebView2.FixedVersionRuntime.107.0.1418.24.x64
34 changes: 20 additions & 14 deletions game/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde_json = "1.0"
serde-pickle = "1.1.1"
serde = { version = "1.0.144", features = ["derive"] }
tauri = { version = "1.1.1", features = ["api-all"] }
pyo3 = { version = "0.17.1", features = [] }
tempfile = "3.3.0"
num_cpus = "1.13.1"
crossbeam = "0.8.2"
flume = "0.10.14"
Expand All @@ -31,6 +31,8 @@ rand = "0.8.5"
rand_distr = "0.4.3"
probability = "0.18.0"

pyo3 = { version = "0.17.1", features = [], optional = true }

[dependencies.uuid]
version = "1.1.2"
features = [
Expand All @@ -41,16 +43,9 @@ features = [


[build-dependencies]
tauri-build = { version = "1.1.1" }
tauri-build = { version = "1.1.1", features = [] }
copy_to_output = "2.0.0"

[features]
# by default Tauri runs in production mode
# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL
default = ["custom-protocol"]
# this feature is used used for production builds where `devPath` points to the filesystem
# DO NOT remove this
custom-protocol = ["tauri/custom-protocol"]


# BLAS source choices:
Expand All @@ -70,8 +65,19 @@ blas-src = { version = "0.8.0", default-features = false, features = [
ndarray = { version = "0.15.6" }


#[profile.release]
#strip = true # Automatically strip symbols from the binary.
#opt-level = "z" # Optimize for size.
#lto = true
#codegen-units = 1
[profile.release]
strip = true # Automatically strip symbols from the binary.
opt-level = "z" # Optimize for size.
lto = true
codegen-units = 1



[features]
# by default Tauri runs in production mode
# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL
default = ["custom-protocol"]
# this feature is used used for production builds where `devPath` points to the filesystem
# DO NOT remove this
custom-protocol = ["tauri/custom-protocol"]
train = [ "dep:pyo3" ]
3 changes: 3 additions & 0 deletions game/build_windows.bat
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
REM cargo install tauri-cli

CALL "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
SET LIBCLANG_PATH=C:\Program Files\Microsoft^ Visual^ Studio\2022\Enterprise\VC\Tools\Llvm\x64\bin

Expand All @@ -7,3 +9,4 @@ SET TFLITEC_PREBUILT_PATH_X86_64_PC_WINDOWS_GNU=%CD%\tensorflowlite_c.dll
SET RUSTFLAGS=-C target-feature=+crt-static
cargo build --release

cargo tauri build
2 changes: 1 addition & 1 deletion game/src/game.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl From<u8> for Color {
}

impl From<f32> for Color {
fn from(v: f32) -> Self {
fn from(_v: f32) -> Self {
unimplemented!()
}
}
Expand Down
61 changes: 38 additions & 23 deletions game/src/human.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ use crate::model::TfLiteModel;
use crossbeam::atomic::AtomicCell;

use std::sync::atomic::Ordering;

use std::include_bytes;
use rand::seq::SliceRandom;
use std::cell::RefCell;
use std::sync::{atomic::AtomicBool, Arc};
use tokio::sync::oneshot::{self, Receiver};
use tokio::task::JoinHandle;
use tempfile::{tempdir, TempDir};
use std::fs::File;
use std::io::prelude::*;

lazy_static! {
static ref SINGLETON_CHANNEL: AtomicCell<Receiver<HumanVsMachineMatch>> = {
Expand All @@ -37,6 +40,19 @@ lazy_static! {
}
AtomicCell::new(rx)
};

static ref MODEL_FILE : (String, TempDir) = {
let mut dir = tempdir().expect("Unable to create template file");
let filename = dir.path().join("model.tflite").display().to_string();
println!("{}", &filename);

let mut file = File::create(&filename).expect("Unable to create temp file");

let bytes = include_bytes!("../best.tflite");
file.write_all(bytes).expect("Unable to save model to file");
file.flush().expect("Unable to flush model to file");
( filename, dir )
};
}

pub async fn access<Fut, F>(mut f: F) -> MatchState
Expand Down Expand Up @@ -96,7 +112,7 @@ impl BoardInfo {
impl HumanVsMachineMatch {
fn new(human_play_black: bool) -> Self {
let ai_player = AiPlayer::new();
let max_threads = 2; //(num_cpus::get() - 0).max(1);
let max_threads = 1; //(num_cpus::get() - 0).max(1);
Self {
ai_player: Arc::new(ai_player),
board: RenjuBoard::default(),
Expand Down Expand Up @@ -268,27 +284,26 @@ impl AiPlayer {
choices: &Vec<(usize, usize)>,
) {
self.tree
.rollout(
breadth_first,
board,
choices,
|state_tensor: StateTensor| {
MODEL.with(|ref_cell| {
let mut model = ref_cell.borrow_mut();
if model.is_none() {
*model = Some(
TfLiteModel::load("best.tflite")
.expect("Unable to load saved model"),
)
}
model
.as_ref()
.unwrap()
.predict_one(state_tensor)
.expect("Unable to predict_one")
})
},
)
.rollout(breadth_first, board, choices, |state_tensor: StateTensor| {
MODEL.with(|ref_cell| {
let mut model = ref_cell.borrow_mut();
if model.is_none() {
#[cfg(feature="train")]
let filepath = "best.tflite";

#[cfg(not(feature="train"))]
let filepath = &MODEL_FILE.0;
*model = Some(
TfLiteModel::load(filepath).expect("Unable to load model"),
)
}
model
.as_ref()
.unwrap()
.predict_one(state_tensor)
.expect("Unable to predict_one")
})
})
.await
.expect("rollout failed")
}
Expand Down
1 change: 1 addition & 0 deletions game/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ pub mod model;

pub use game::{RenjuBoard, SquareMatrix, SquaredMatrixExtension, StateTensor, TerminalState};
pub use mcts::{MonteCarloTree, TreeNode};
#[cfg(feature="train")]
pub use model::PolicyValueModel;
17 changes: 6 additions & 11 deletions game/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ mod game;
mod human;
mod mcts;
mod model;
#[cfg(feature="train")]
mod selfplay;
use human::MatchState;
#[cfg(feature="train")]
use selfplay::Trainer;

static ABOUT_TEXT: &str = "Renju game ";
static ABOUT_TEXT: &str = "Renju Game";

static SELF_PLAY_MATCH_HELP_TEXT: &str = "
Produce matches by self-play
Expand Down Expand Up @@ -65,28 +67,20 @@ enum Verb {
export_dir: String,
},

#[cfg(feature="train")]
/// Self play and train
#[clap(after_help=TRAIN_HELP_TEXT)]
Train {},

/// contest between two model
#[clap()]
Contest {
/// old model name
#[clap(required = true)]
old_model: String,

/// new model name
#[clap(required = true)]
new_model: String,
},
}

#[tokio::main(flavor = "multi_thread")]
async fn main() {
let args = Arguments::parse();

match args.verb {
#[cfg(feature="train")]
Some(Verb::SelfPlay {
model_file,
export_dir,
Expand All @@ -98,6 +92,7 @@ async fn main() {
.await;
}

#[cfg(feature="train")]
Some(Verb::Train {}) => {
let mut trainer = Trainer::new();
trainer.run().await;
Expand Down
9 changes: 7 additions & 2 deletions game/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,26 @@
*/
use bytes::Bytes;
use num_cpus;
#[cfg(feature="train")]
use pyo3::prelude::*;
#[cfg(feature="train")]
use pyo3::types::{PyBytes, PyList, PyTuple};
#[cfg(feature="train")]
use std::fs::OpenOptions;
#[cfg(feature="train")]
use std::io::prelude::*;

use tflitec::interpreter::{Interpreter, Options};
use tflitec::tensor;

use crate::game::*;

#[cfg(feature="train")]
pub struct PolicyValueModel {
module: Py<PyModule>,
}

//
#[cfg(feature="train")]
impl PolicyValueModel {
pub fn new(filename: &str) -> PyResult<Self> {
let mut source_code = String::new();
Expand Down Expand Up @@ -185,7 +190,7 @@ impl TfLiteModel {
pub fn load(tflite_model_path: &str) -> Result<Self, tflitec::Error> {
// Create interpreter options
let mut options = Options::default();
options.thread_count = num_cpus::get() as i32 / 2;
options.thread_count = (num_cpus::get() as i32 - 2).max(1);
options.is_xnnpack_enabled = true;
println!("is_xnnpack_enabled={}", options.is_xnnpack_enabled);

Expand Down
18 changes: 11 additions & 7 deletions game/tauri.conf.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
{
"build": {
"beforeBuildCommand": "npm run build --prefix=ui",
"beforeBuildCommand": "npm run build",
"beforeDevCommand": "npm run build --prefix=ui",
"devPath": "http://localhost:8080",
"distDir": "ui/dist"
},
"package": {
"productName": "renju_board",
"productName": "renju",
"version": "0.1.0"
},
"tauri": {
Expand All @@ -15,7 +15,7 @@
},
"bundle": {
"active": true,
"category": "DeveloperTool",
"category": "games",
"copyright": "",
"deb": {
"depends": []
Expand All @@ -28,8 +28,8 @@
"icons/icon.icns",
"icons/icon.ico"
],
"identifier": "com.tauri.dev",
"longDescription": "",
"identifier": "com.github.wangjia184.renju",
"longDescription": "Renju Game https://github.com/wangjia184/renju",
"macOS": {
"entitlements": null,
"exceptionDomain": "",
Expand All @@ -38,12 +38,16 @@
"signingIdentity": null
},
"resources": [],
"shortDescription": "",
"shortDescription": "https://github.com/wangjia184/renju",
"targets": "all",
"windows": {
"certificateThumbprint": null,
"digestAlgorithm": "sha256",
"timestampUrl": ""
"timestampUrl": "",
"webviewInstallMode": {
"type": "fixedRuntime",
"path": "./Microsoft.WebView2.FixedVersionRuntime.107.0.1418.24.x64/"
}
}
},
"security": {
Expand Down

0 comments on commit 6a34bb6

Please sign in to comment.