-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature 4 kimchi #11
Merged
+3,582
−709
Merged
Feature 4 kimchi #11
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
6ffcfdd
feat: simple perceptron works
sshivaditya 7406c61
feat: mnist classifier works
sshivaditya d1d16e7
fix: model inference works
sshivaditya 52624d7
fix: model outputs same as pytorch
sshivaditya 2b88fee
fix: kimchi wiring and proof index working proof generation not working
sshivaditya 28a5b85
fix: proof verification works
sshivaditya 728dd6c
fix: proof verification with output
sshivaditya a7de625
fix: proof verification with mnist
sshivaditya f979a34
fix: remove proof systems
sshivaditya b4f3437
fix: cargo tests and some unit test
sshivaditya d5f4940
fix: formatting
sshivaditya 00a3f5b
fix: clippy warnings
sshivaditya f7132ad
fix: formatting
sshivaditya 15fd55f
fix: clippy warnings
sshivaditya 2ae8e86
fix: clippy warnings
sshivaditya 298fcaf
fix: clippy warnings
sshivaditya 994ebf9
fix: clippy warnings
sshivaditya be71efc
fix: clippy warnings
sshivaditya File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; | ||
use std::collections::HashMap; | ||
|
||
fn preprocess_image(img_path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> { | ||
// Load and convert image to grayscale | ||
let img = image::open(img_path)?.into_luma8(); | ||
|
||
// Ensure image is 28x28 | ||
let resized = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Lanczos3); | ||
|
||
// Convert to f32 and normalize to [0, 1] | ||
let pixels: Vec<f32> = resized.into_raw().into_iter().map(|x| x as f32).collect(); | ||
|
||
//Apply normalization | ||
let pixels: Vec<f32> = pixels | ||
.into_iter() | ||
.map(|x| (x / 255.0 - 0.1307) / 0.3081) | ||
.collect(); | ||
|
||
// Create a batch dimension by wrapping the flattened pixels | ||
let mut input = Vec::with_capacity(28 * 28); | ||
input.extend_from_slice(&pixels); | ||
Ok(input) | ||
} | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// Create run args with batch size | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
// Create visibility settings | ||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
// Load the MNIST model | ||
println!("Loading MNIST model..."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can use the logger for it. |
||
let model = Model::new("models/mnist_mlp.onnx", &run_args, &visibility).map_err(|e| { | ||
println!("Error loading model: {:?}", e); | ||
e | ||
})?; | ||
|
||
// Print model structure | ||
println!("\nModel structure:"); | ||
println!("Number of nodes: {}", model.graph.nodes.len()); | ||
println!("Input nodes: {:?}", model.graph.inputs); | ||
println!("Output nodes: {:?}", model.graph.outputs); | ||
|
||
// Load and preprocess the image | ||
println!("\nLoading and preprocessing image..."); | ||
let input = preprocess_image("models/data/1052.png")?; | ||
|
||
// Execute the model | ||
println!("\nRunning inference..."); | ||
let result = model.graph.execute(&[input])?; | ||
|
||
//Result | ||
println!("Result: {:?}", result); | ||
|
||
// Print the output probabilities | ||
println!("\nOutput probabilities for digits 0-9:"); | ||
if let Some(probabilities) = result.first() { | ||
// The model outputs logits, so we need to apply softmax | ||
let max_logit = probabilities | ||
.iter() | ||
.take(10) | ||
.fold(f32::NEG_INFINITY, |a, &b| a.max(b)); | ||
let exp_sum: f32 = probabilities | ||
.iter() | ||
.take(10) | ||
.map(|&x| (x - max_logit).exp()) | ||
.sum(); | ||
|
||
let softmax: Vec<f32> = probabilities | ||
.iter() | ||
.take(10) | ||
.map(|&x| ((x - max_logit).exp()) / exp_sum) | ||
.collect(); | ||
|
||
for (digit, &prob) in softmax.iter().enumerate() { | ||
println!("Digit {}: {:.4}", digit, prob); | ||
} | ||
|
||
// Find the predicted digit | ||
let predicted_digit = softmax | ||
.iter() | ||
.enumerate() | ||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) | ||
.map(|(digit, _)| digit) | ||
.unwrap(); | ||
|
||
println!("\nPredicted digit: {}", predicted_digit); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; | ||
use std::collections::HashMap; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// Create run args with batch size | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
// Create visibility settings | ||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
// Load the perceptron model | ||
println!("Loading perceptron model..."); | ||
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; | ||
|
||
// Print model structure | ||
println!("\nModel structure:"); | ||
println!("Number of nodes: {}", model.graph.nodes.len()); | ||
println!("Input nodes: {:?}", model.graph.inputs); | ||
println!("Output nodes: {:?}", model.graph.outputs); | ||
|
||
// Print node connections | ||
println!("\nNode connections:"); | ||
for (id, node) in &model.graph.nodes { | ||
match node { | ||
mina_zkml::graph::model::NodeType::Node(n) => { | ||
println!("Node {}: {:?} inputs: {:?}", id, n.op_type, n.inputs); | ||
println!("Output dimensions: {:?}", n.out_dims); | ||
println!("Weight Tensor: {:?}", n.weights); | ||
println!("Bias Tensor: {:?}", n.bias); | ||
} | ||
mina_zkml::graph::model::NodeType::SubGraph { .. } => { | ||
println!("Node {}: SubGraph", id); | ||
} | ||
} | ||
} | ||
|
||
// Create a sample input vector of size 10 | ||
let input = vec![1.0, 0.5, -0.3, 0.8, -0.2, 0.7, 0.1, -0.4, 0.9, 0.6]; | ||
println!("\nInput vector (size 10):"); | ||
println!("{:?}", input); | ||
|
||
// Execute the model | ||
let result = model.graph.execute(&[input])?; | ||
|
||
// Print the output | ||
println!("\nOutput vector (size 3, after ReLU):"); | ||
println!("{:?}", result[0]); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
use mina_zkml::{ | ||
graph::model::{Model, RunArgs, VarVisibility, Visibility}, | ||
zk::proof::ProofSystem, | ||
}; | ||
use std::collections::HashMap; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// 1. Load the model | ||
println!("Loading model..."); | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; | ||
|
||
// 2. Create proof system | ||
println!("Creating proof system..."); | ||
let proof_system = ProofSystem::new(&model); | ||
|
||
// 3. Create sample input (with proper padding to size 10) | ||
let input = vec![vec![ | ||
1.0, 0.5, -0.3, 0.8, -0.2, // Original values | ||
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 | ||
]]; | ||
|
||
// 4. Generate output and proof | ||
println!("Generating output and proof..."); | ||
let prover_output = proof_system.prove(&input)?; | ||
println!("Model output: {:?}", prover_output.output); | ||
|
||
// 5. Verify the proof with output and proof | ||
println!("Verifying proof..."); | ||
let is_valid = proof_system.verify(&prover_output.output, &prover_output.proof)?; | ||
|
||
println!("\nResults:"); | ||
println!("Model execution successful: ✓"); | ||
println!("Proof creation successful: ✓"); | ||
println!( | ||
"Proof verification: {}", | ||
if is_valid { "✓ Valid" } else { "✗ Invalid" } | ||
); | ||
|
||
// 6. Demonstrate invalid verification with modified output | ||
println!("\nTesting invalid case with modified output..."); | ||
let mut modified_output = prover_output.output.clone(); | ||
modified_output[0][0] += 1.0; // Modify first output value | ||
|
||
let is_valid_modified = proof_system.verify(&modified_output, &prover_output.proof)?; | ||
println!( | ||
"Modified output verification: {}", | ||
if !is_valid_modified { | ||
"✗ Invalid (Expected)" | ||
} else { | ||
"✓ Valid (Unexpected!)" | ||
} | ||
); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
use mina_zkml::{ | ||
graph::model::{Model, RunArgs, VarVisibility, Visibility}, | ||
zk::proof::ProofSystem, | ||
}; | ||
use std::collections::HashMap; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// 1. Load the model | ||
println!("Loading model..."); | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; | ||
|
||
// 2. Create proof system | ||
println!("Creating proof system..."); | ||
let proof_system = ProofSystem::new(&model); | ||
|
||
// 3. Create sample input (with proper padding to size 10) | ||
let input = vec![vec![ | ||
1.0, 0.5, -0.3, 0.8, -0.2, // Original values | ||
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 | ||
]]; | ||
|
||
// 4. Generate output and proof | ||
println!("Generating output and proof..."); | ||
let prover_output = proof_system.prove(&input)?; | ||
println!("Model output: {:?}", prover_output.output); | ||
|
||
// 5. Create modified output (simulating malicious behavior) | ||
let mut modified_output = prover_output.output.clone(); | ||
modified_output[0][0] += 1.0; // Modify first output value | ||
|
||
// 6. Try to verify with modified output (should fail) | ||
println!("Verifying proof with modified output..."); | ||
let is_valid = proof_system.verify(&modified_output, &prover_output.proof)?; | ||
|
||
println!("\nResults:"); | ||
println!("Model execution successful: ✓"); | ||
println!("Proof creation successful: ✓"); | ||
println!( | ||
"Modified output verification: {}", | ||
if !is_valid { | ||
"✗ Invalid (Expected)" | ||
} else { | ||
"✓ Valid (Unexpected!)" | ||
} | ||
); | ||
|
||
Ok(()) | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can parameterize this option when we implement CLI.