Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature 4 kimchi #11

Merged
merged 18 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
1,114 changes: 780 additions & 334 deletions Cargo.lock

Large diffs are not rendered by default.

37 changes: 34 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
[package]
name = "onnx-parser"
name = "mina-zkml"
version = "0.1.0"
edition = "2021"

[lib]
name = "kimchi"
name = "mina_zkml"
crate-type = ["cdylib", "rlib"]

[[example]]
name = "perceptron"
path = "examples/perceptron.rs"

[[example]]
name = "mnist_inference"
path = "examples/mnist_inference.rs"

[[example]]
name = "zk_inference"
path = "examples/zk_inference.rs"

[dependencies]
anyhow = "1.0.90"
bincode = "1.3"
image = "0.25.4"
image = "0.24.7"
instant = "0.1.13"
log = "0.4.22"
ndarray = "0.15.4"
Expand All @@ -19,3 +31,22 @@ serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0.64"
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false }
kimchi = { git = "https://github.com/o1-labs/proof-systems", package = "kimchi" }
ark-ff = "0.4.0"
ark-poly = "0.4.0"
ark-ec = "0.4.0"
mina-curves = { git = "https://github.com/o1-labs/proof-systems" }
chrono = "0.4.38"
rand = "0.8.5"
groupmap = { git = "https://github.com/o1-labs/proof-systems" }
poly-commitment = { git = "https://github.com/o1-labs/proof-systems" }
mina-poseidon = { git = "https://github.com/o1-labs/proof-systems" }

[dev-dependencies]
pretty_assertions = "1.4.0"
test-case = "3.3.1"
rstest = "0.18.2"

[features]
default = []
test-utils = []
98 changes: 98 additions & 0 deletions examples/mnist_inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility};
use std::collections::HashMap;

fn preprocess_image(img_path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
// Load and convert image to grayscale
let img = image::open(img_path)?.into_luma8();

// Ensure image is 28x28
let resized = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Lanczos3);

// Convert to f32 and normalize to [0, 1]
let pixels: Vec<f32> = resized.into_raw().into_iter().map(|x| x as f32).collect();

//Apply normalization
let pixels: Vec<f32> = pixels
.into_iter()
.map(|x| (x / 255.0 - 0.1307) / 0.3081)
.collect();

// Create a batch dimension by wrapping the flattened pixels
let mut input = Vec::with_capacity(28 * 28);
input.extend_from_slice(&pixels);
Ok(input)
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create run args with batch size
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

// Create visibility settings
let visibility = VarVisibility {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can parameterize this option when we implement CLI.

input: Visibility::Public,
output: Visibility::Public,
};

// Load the MNIST model
println!("Loading MNIST model...");
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use the logger for it.

let model = Model::new("models/mnist_mlp.onnx", &run_args, &visibility).map_err(|e| {
println!("Error loading model: {:?}", e);
e
})?;

// Print model structure
println!("\nModel structure:");
println!("Number of nodes: {}", model.graph.nodes.len());
println!("Input nodes: {:?}", model.graph.inputs);
println!("Output nodes: {:?}", model.graph.outputs);

// Load and preprocess the image
println!("\nLoading and preprocessing image...");
let input = preprocess_image("models/data/1052.png")?;

// Execute the model
println!("\nRunning inference...");
let result = model.graph.execute(&[input])?;

//Result
println!("Result: {:?}", result);

// Print the output probabilities
println!("\nOutput probabilities for digits 0-9:");
if let Some(probabilities) = result.first() {
// The model outputs logits, so we need to apply softmax
let max_logit = probabilities
.iter()
.take(10)
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = probabilities
.iter()
.take(10)
.map(|&x| (x - max_logit).exp())
.sum();

let softmax: Vec<f32> = probabilities
.iter()
.take(10)
.map(|&x| ((x - max_logit).exp()) / exp_sum)
.collect();

for (digit, &prob) in softmax.iter().enumerate() {
println!("Digit {}: {:.4}", digit, prob);
}

// Find the predicted digit
let predicted_digit = softmax
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(digit, _)| digit)
.unwrap();

println!("\nPredicted digit: {}", predicted_digit);
}

Ok(())
}
55 changes: 55 additions & 0 deletions examples/perceptron.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility};
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create run args with batch size
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

// Create visibility settings
let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

// Load the perceptron model
println!("Loading perceptron model...");
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?;

// Print model structure
println!("\nModel structure:");
println!("Number of nodes: {}", model.graph.nodes.len());
println!("Input nodes: {:?}", model.graph.inputs);
println!("Output nodes: {:?}", model.graph.outputs);

// Print node connections
println!("\nNode connections:");
for (id, node) in &model.graph.nodes {
match node {
mina_zkml::graph::model::NodeType::Node(n) => {
println!("Node {}: {:?} inputs: {:?}", id, n.op_type, n.inputs);
println!("Output dimensions: {:?}", n.out_dims);
println!("Weight Tensor: {:?}", n.weights);
println!("Bias Tensor: {:?}", n.bias);
}
mina_zkml::graph::model::NodeType::SubGraph { .. } => {
println!("Node {}: SubGraph", id);
}
}
}

// Create a sample input vector of size 10
let input = vec![1.0, 0.5, -0.3, 0.8, -0.2, 0.7, 0.1, -0.4, 0.9, 0.6];
println!("\nInput vector (size 10):");
println!("{:?}", input);

// Execute the model
let result = model.graph.execute(&[input])?;

// Print the output
println!("\nOutput vector (size 3, after ReLU):");
println!("{:?}", result[0]);

Ok(())
}
64 changes: 64 additions & 0 deletions examples/zk_inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use mina_zkml::{
graph::model::{Model, RunArgs, VarVisibility, Visibility},
zk::proof::ProofSystem,
};
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1. Load the model
println!("Loading model...");
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?;

// 2. Create proof system
println!("Creating proof system...");
let proof_system = ProofSystem::new(&model);

// 3. Create sample input (with proper padding to size 10)
let input = vec![vec![
1.0, 0.5, -0.3, 0.8, -0.2, // Original values
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10
]];

// 4. Generate output and proof
println!("Generating output and proof...");
let prover_output = proof_system.prove(&input)?;
println!("Model output: {:?}", prover_output.output);

// 5. Verify the proof with output and proof
println!("Verifying proof...");
let is_valid = proof_system.verify(&prover_output.output, &prover_output.proof)?;

println!("\nResults:");
println!("Model execution successful: ✓");
println!("Proof creation successful: ✓");
println!(
"Proof verification: {}",
if is_valid { "✓ Valid" } else { "✗ Invalid" }
);

// 6. Demonstrate invalid verification with modified output
println!("\nTesting invalid case with modified output...");
let mut modified_output = prover_output.output.clone();
modified_output[0][0] += 1.0; // Modify first output value

let is_valid_modified = proof_system.verify(&modified_output, &prover_output.proof)?;
println!(
"Modified output verification: {}",
if !is_valid_modified {
"✗ Invalid (Expected)"
} else {
"✓ Valid (Unexpected!)"
}
);

Ok(())
}
57 changes: 57 additions & 0 deletions examples/zk_inference_fail.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use mina_zkml::{
graph::model::{Model, RunArgs, VarVisibility, Visibility},
zk::proof::ProofSystem,
};
use std::collections::HashMap;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1. Load the model
println!("Loading model...");
let mut variables = HashMap::new();
variables.insert("batch_size".to_string(), 1);
let run_args = RunArgs { variables };

let visibility = VarVisibility {
input: Visibility::Public,
output: Visibility::Public,
};

let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?;

// 2. Create proof system
println!("Creating proof system...");
let proof_system = ProofSystem::new(&model);

// 3. Create sample input (with proper padding to size 10)
let input = vec![vec![
1.0, 0.5, -0.3, 0.8, -0.2, // Original values
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10
]];

// 4. Generate output and proof
println!("Generating output and proof...");
let prover_output = proof_system.prove(&input)?;
println!("Model output: {:?}", prover_output.output);

// 5. Create modified output (simulating malicious behavior)
let mut modified_output = prover_output.output.clone();
modified_output[0][0] += 1.0; // Modify first output value

// 6. Try to verify with modified output (should fail)
println!("Verifying proof with modified output...");
let is_valid = proof_system.verify(&modified_output, &prover_output.proof)?;

println!("\nResults:");
println!("Model execution successful: ✓");
println!("Proof creation successful: ✓");
println!(
"Modified output verification: {}",
if !is_valid {
"✗ Invalid (Expected)"
} else {
"✓ Valid (Unexpected!)"
}
);

Ok(())
}
Loading