Skip to content

Commit

Permalink
Merge pull request #11 from chris-chris/feature-4-kimchi
Browse files Browse the repository at this point in the history
Feature 4 kimchi
  • Loading branch information
chris-chris authored Dec 3, 2024
2 parents 403ebaf + be71efc commit 620c5b7
Show file tree
Hide file tree
Showing 23 changed files with 3,582 additions and 709 deletions.
Binary file added .DS_Store
Binary file not shown.
1,114 changes: 780 additions & 334 deletions Cargo.lock

Large diffs are not rendered by default.

37 changes: 34 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
[package]
name = "onnx-parser"
name = "mina-zkml"
version = "0.1.0"
edition = "2021"

[lib]
name = "kimchi"
name = "mina_zkml"
crate-type = ["cdylib", "rlib"]

[[example]]
name = "perceptron"
path = "examples/perceptron.rs"

[[example]]
name = "mnist_inference"
path = "examples/mnist_inference.rs"

[[example]]
name = "zk_inference"
path = "examples/zk_inference.rs"

[dependencies]
anyhow = "1.0.90"
bincode = "1.3"
image = "0.25.4"
image = "0.24.7"
instant = "0.1.13"
log = "0.4.22"
ndarray = "0.15.4"
Expand All @@ -19,3 +31,22 @@ serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0.64"
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false }
kimchi = { git = "https://github.com/o1-labs/proof-systems", package = "kimchi" }
ark-ff = "0.4.0"
ark-poly = "0.4.0"
ark-ec = "0.4.0"
mina-curves = { git = "https://github.com/o1-labs/proof-systems" }
chrono = "0.4.38"
rand = "0.8.5"
groupmap = { git = "https://github.com/o1-labs/proof-systems" }
poly-commitment = { git = "https://github.com/o1-labs/proof-systems" }
mina-poseidon = { git = "https://github.com/o1-labs/proof-systems" }

[dev-dependencies]
pretty_assertions = "1.4.0"
test-case = "3.3.1"
rstest = "0.18.2"

[features]
default = []
test-utils = []
98 changes: 98 additions & 0 deletions examples/mnist_inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility};
use std::collections::HashMap;

fn preprocess_image(img_path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
// Load and convert image to grayscale
let img = image::open(img_path)?.into_luma8();

// Ensure image is 28x28
let resized = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Lanczos3);

// Convert to f32 and normalize to [0, 1]
let pixels: Vec<f32> = resized.into_raw().into_iter().map(|x| x as f32).collect();

//Apply normalization
let pixels: Vec<f32> = pixels
.into_iter()
.map(|x| (x / 255.0 - 0.1307) / 0.3081)
.collect();

// Create a batch dimension by wrapping the flattened pixels
let mut input = Vec::with_capacity(28 * 28);
input.extend_from_slice(&pixels);
Ok(input)
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create run args with batch size
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

// Create visibility settings
let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

// Load the MNIST model
println!("Loading MNIST model...");
let model = Model::new("models/mnist_mlp.onnx", &run_args, &visibility).map_err(|e| {
println!("Error loading model: {:?}", e);
e
})?;

// Print model structure
println!("\nModel structure:");
println!("Number of nodes: {}", model.graph.nodes.len());
println!("Input nodes: {:?}", model.graph.inputs);
println!("Output nodes: {:?}", model.graph.outputs);

// Load and preprocess the image
println!("\nLoading and preprocessing image...");
let input = preprocess_image("models/data/1052.png")?;

// Execute the model
println!("\nRunning inference...");
let result = model.graph.execute(&[input])?;

//Result
println!("Result: {:?}", result);

// Print the output probabilities
println!("\nOutput probabilities for digits 0-9:");
if let Some(probabilities) = result.first() {
// The model outputs logits, so we need to apply softmax
let max_logit = probabilities
.iter()
.take(10)
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = probabilities
.iter()
.take(10)
.map(|&x| (x - max_logit).exp())
.sum();

let softmax: Vec<f32> = probabilities
.iter()
.take(10)
.map(|&x| ((x - max_logit).exp()) / exp_sum)
.collect();

for (digit, &prob) in softmax.iter().enumerate() {
println!("Digit {}: {:.4}", digit, prob);
}

// Find the predicted digit
let predicted_digit = softmax
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(digit, _)| digit)
.unwrap();

println!("\nPredicted digit: {}", predicted_digit);
}

Ok(())
}
55 changes: 55 additions & 0 deletions examples/perceptron.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility};
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create run args with batch size
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

// Create visibility settings
let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

// Load the perceptron model
println!("Loading perceptron model...");
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?;

// Print model structure
println!("\nModel structure:");
println!("Number of nodes: {}", model.graph.nodes.len());
println!("Input nodes: {:?}", model.graph.inputs);
println!("Output nodes: {:?}", model.graph.outputs);

// Print node connections
println!("\nNode connections:");
for (id, node) in &model.graph.nodes {
match node {
mina_zkml::graph::model::NodeType::Node(n) => {
println!("Node {}: {:?} inputs: {:?}", id, n.op_type, n.inputs);
println!("Output dimensions: {:?}", n.out_dims);
println!("Weight Tensor: {:?}", n.weights);
println!("Bias Tensor: {:?}", n.bias);
}
mina_zkml::graph::model::NodeType::SubGraph { .. } => {
println!("Node {}: SubGraph", id);
}
}
}

// Create a sample input vector of size 10
let input = vec![1.0, 0.5, -0.3, 0.8, -0.2, 0.7, 0.1, -0.4, 0.9, 0.6];
println!("\nInput vector (size 10):");
println!("{:?}", input);

// Execute the model
let result = model.graph.execute(&[input])?;

// Print the output
println!("\nOutput vector (size 3, after ReLU):");
println!("{:?}", result[0]);

Ok(())
}
64 changes: 64 additions & 0 deletions examples/zk_inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use mina_zkml::{
graph::model::{Model, RunArgs, VarVisibility, Visibility},
zk::proof::ProofSystem,
};
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1. Load the model
println!("Loading model...");
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?;

// 2. Create proof system
println!("Creating proof system...");
let proof_system = ProofSystem::new(&model);

// 3. Create sample input (with proper padding to size 10)
let input = vec![vec![
1.0, 0.5, -0.3, 0.8, -0.2, // Original values
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10
]];

// 4. Generate output and proof
println!("Generating output and proof...");
let prover_output = proof_system.prove(&input)?;
println!("Model output: {:?}", prover_output.output);

// 5. Verify the proof with output and proof
println!("Verifying proof...");
let is_valid = proof_system.verify(&prover_output.output, &prover_output.proof)?;

println!("\nResults:");
println!("Model execution successful: ✓");
println!("Proof creation successful: ✓");
println!(
"Proof verification: {}",
if is_valid { "✓ Valid" } else { "✗ Invalid" }
);

// 6. Demonstrate invalid verification with modified output
println!("\nTesting invalid case with modified output...");
let mut modified_output = prover_output.output.clone();
modified_output[0][0] += 1.0; // Modify first output value

let is_valid_modified = proof_system.verify(&modified_output, &prover_output.proof)?;
println!(
"Modified output verification: {}",
if !is_valid_modified {
"✗ Invalid (Expected)"
} else {
"✓ Valid (Unexpected!)"
}
);

Ok(())
}
57 changes: 57 additions & 0 deletions examples/zk_inference_fail.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use mina_zkml::{
graph::model::{Model, RunArgs, VarVisibility, Visibility},
zk::proof::ProofSystem,
};
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1. Load the model
println!("Loading model...");
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?;

// 2. Create proof system
println!("Creating proof system...");
let proof_system = ProofSystem::new(&model);

// 3. Create sample input (with proper padding to size 10)
let input = vec![vec![
1.0, 0.5, -0.3, 0.8, -0.2, // Original values
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10
]];

// 4. Generate output and proof
println!("Generating output and proof...");
let prover_output = proof_system.prove(&input)?;
println!("Model output: {:?}", prover_output.output);

// 5. Create modified output (simulating malicious behavior)
let mut modified_output = prover_output.output.clone();
modified_output[0][0] += 1.0; // Modify first output value

// 6. Try to verify with modified output (should fail)
println!("Verifying proof with modified output...");
let is_valid = proof_system.verify(&modified_output, &prover_output.proof)?;

println!("\nResults:");
println!("Model execution successful: ✓");
println!("Proof creation successful: ✓");
println!(
"Modified output verification: {}",
if !is_valid {
"✗ Invalid (Expected)"
} else {
"✓ Valid (Unexpected!)"
}
);

Ok(())
}
Loading

0 comments on commit 620c5b7

Please sign in to comment.