-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from chris-chris/feature-4-kimchi
Feature 4 kimchi
- Loading branch information
Showing
23 changed files
with
3,582 additions
and
709 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; | ||
use std::collections::HashMap; | ||
|
||
fn preprocess_image(img_path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> { | ||
// Load and convert image to grayscale | ||
let img = image::open(img_path)?.into_luma8(); | ||
|
||
// Ensure image is 28x28 | ||
let resized = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Lanczos3); | ||
|
||
// Convert to f32 and normalize to [0, 1] | ||
let pixels: Vec<f32> = resized.into_raw().into_iter().map(|x| x as f32).collect(); | ||
|
||
//Apply normalization | ||
let pixels: Vec<f32> = pixels | ||
.into_iter() | ||
.map(|x| (x / 255.0 - 0.1307) / 0.3081) | ||
.collect(); | ||
|
||
// Create a batch dimension by wrapping the flattened pixels | ||
let mut input = Vec::with_capacity(28 * 28); | ||
input.extend_from_slice(&pixels); | ||
Ok(input) | ||
} | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// Create run args with batch size | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
// Create visibility settings | ||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
// Load the MNIST model | ||
println!("Loading MNIST model..."); | ||
let model = Model::new("models/mnist_mlp.onnx", &run_args, &visibility).map_err(|e| { | ||
println!("Error loading model: {:?}", e); | ||
e | ||
})?; | ||
|
||
// Print model structure | ||
println!("\nModel structure:"); | ||
println!("Number of nodes: {}", model.graph.nodes.len()); | ||
println!("Input nodes: {:?}", model.graph.inputs); | ||
println!("Output nodes: {:?}", model.graph.outputs); | ||
|
||
// Load and preprocess the image | ||
println!("\nLoading and preprocessing image..."); | ||
let input = preprocess_image("models/data/1052.png")?; | ||
|
||
// Execute the model | ||
println!("\nRunning inference..."); | ||
let result = model.graph.execute(&[input])?; | ||
|
||
//Result | ||
println!("Result: {:?}", result); | ||
|
||
// Print the output probabilities | ||
println!("\nOutput probabilities for digits 0-9:"); | ||
if let Some(probabilities) = result.first() { | ||
// The model outputs logits, so we need to apply softmax | ||
let max_logit = probabilities | ||
.iter() | ||
.take(10) | ||
.fold(f32::NEG_INFINITY, |a, &b| a.max(b)); | ||
let exp_sum: f32 = probabilities | ||
.iter() | ||
.take(10) | ||
.map(|&x| (x - max_logit).exp()) | ||
.sum(); | ||
|
||
let softmax: Vec<f32> = probabilities | ||
.iter() | ||
.take(10) | ||
.map(|&x| ((x - max_logit).exp()) / exp_sum) | ||
.collect(); | ||
|
||
for (digit, &prob) in softmax.iter().enumerate() { | ||
println!("Digit {}: {:.4}", digit, prob); | ||
} | ||
|
||
// Find the predicted digit | ||
let predicted_digit = softmax | ||
.iter() | ||
.enumerate() | ||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) | ||
.map(|(digit, _)| digit) | ||
.unwrap(); | ||
|
||
println!("\nPredicted digit: {}", predicted_digit); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; | ||
use std::collections::HashMap; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// Create run args with batch size | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
// Create visibility settings | ||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
// Load the perceptron model | ||
println!("Loading perceptron model..."); | ||
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; | ||
|
||
// Print model structure | ||
println!("\nModel structure:"); | ||
println!("Number of nodes: {}", model.graph.nodes.len()); | ||
println!("Input nodes: {:?}", model.graph.inputs); | ||
println!("Output nodes: {:?}", model.graph.outputs); | ||
|
||
// Print node connections | ||
println!("\nNode connections:"); | ||
for (id, node) in &model.graph.nodes { | ||
match node { | ||
mina_zkml::graph::model::NodeType::Node(n) => { | ||
println!("Node {}: {:?} inputs: {:?}", id, n.op_type, n.inputs); | ||
println!("Output dimensions: {:?}", n.out_dims); | ||
println!("Weight Tensor: {:?}", n.weights); | ||
println!("Bias Tensor: {:?}", n.bias); | ||
} | ||
mina_zkml::graph::model::NodeType::SubGraph { .. } => { | ||
println!("Node {}: SubGraph", id); | ||
} | ||
} | ||
} | ||
|
||
// Create a sample input vector of size 10 | ||
let input = vec![1.0, 0.5, -0.3, 0.8, -0.2, 0.7, 0.1, -0.4, 0.9, 0.6]; | ||
println!("\nInput vector (size 10):"); | ||
println!("{:?}", input); | ||
|
||
// Execute the model | ||
let result = model.graph.execute(&[input])?; | ||
|
||
// Print the output | ||
println!("\nOutput vector (size 3, after ReLU):"); | ||
println!("{:?}", result[0]); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
use mina_zkml::{ | ||
graph::model::{Model, RunArgs, VarVisibility, Visibility}, | ||
zk::proof::ProofSystem, | ||
}; | ||
use std::collections::HashMap; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// 1. Load the model | ||
println!("Loading model..."); | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; | ||
|
||
// 2. Create proof system | ||
println!("Creating proof system..."); | ||
let proof_system = ProofSystem::new(&model); | ||
|
||
// 3. Create sample input (with proper padding to size 10) | ||
let input = vec![vec![ | ||
1.0, 0.5, -0.3, 0.8, -0.2, // Original values | ||
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 | ||
]]; | ||
|
||
// 4. Generate output and proof | ||
println!("Generating output and proof..."); | ||
let prover_output = proof_system.prove(&input)?; | ||
println!("Model output: {:?}", prover_output.output); | ||
|
||
// 5. Verify the proof with output and proof | ||
println!("Verifying proof..."); | ||
let is_valid = proof_system.verify(&prover_output.output, &prover_output.proof)?; | ||
|
||
println!("\nResults:"); | ||
println!("Model execution successful: ✓"); | ||
println!("Proof creation successful: ✓"); | ||
println!( | ||
"Proof verification: {}", | ||
if is_valid { "✓ Valid" } else { "✗ Invalid" } | ||
); | ||
|
||
// 6. Demonstrate invalid verification with modified output | ||
println!("\nTesting invalid case with modified output..."); | ||
let mut modified_output = prover_output.output.clone(); | ||
modified_output[0][0] += 1.0; // Modify first output value | ||
|
||
let is_valid_modified = proof_system.verify(&modified_output, &prover_output.proof)?; | ||
println!( | ||
"Modified output verification: {}", | ||
if !is_valid_modified { | ||
"✗ Invalid (Expected)" | ||
} else { | ||
"✓ Valid (Unexpected!)" | ||
} | ||
); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
use mina_zkml::{ | ||
graph::model::{Model, RunArgs, VarVisibility, Visibility}, | ||
zk::proof::ProofSystem, | ||
}; | ||
use std::collections::HashMap; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
// 1. Load the model | ||
println!("Loading model..."); | ||
let mut variables = HashMap::new(); | ||
variables.insert("batch_size".to_string(), 1); | ||
let run_args = RunArgs { variables }; | ||
|
||
let visibility = VarVisibility { | ||
input: Visibility::Public, | ||
output: Visibility::Public, | ||
}; | ||
|
||
let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; | ||
|
||
// 2. Create proof system | ||
println!("Creating proof system..."); | ||
let proof_system = ProofSystem::new(&model); | ||
|
||
// 3. Create sample input (with proper padding to size 10) | ||
let input = vec![vec![ | ||
1.0, 0.5, -0.3, 0.8, -0.2, // Original values | ||
0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 | ||
]]; | ||
|
||
// 4. Generate output and proof | ||
println!("Generating output and proof..."); | ||
let prover_output = proof_system.prove(&input)?; | ||
println!("Model output: {:?}", prover_output.output); | ||
|
||
// 5. Create modified output (simulating malicious behavior) | ||
let mut modified_output = prover_output.output.clone(); | ||
modified_output[0][0] += 1.0; // Modify first output value | ||
|
||
// 6. Try to verify with modified output (should fail) | ||
println!("Verifying proof with modified output..."); | ||
let is_valid = proof_system.verify(&modified_output, &prover_output.proof)?; | ||
|
||
println!("\nResults:"); | ||
println!("Model execution successful: ✓"); | ||
println!("Proof creation successful: ✓"); | ||
println!( | ||
"Modified output verification: {}", | ||
if !is_valid { | ||
"✗ Invalid (Expected)" | ||
} else { | ||
"✓ Valid (Unexpected!)" | ||
} | ||
); | ||
|
||
Ok(()) | ||
} |
Oops, something went wrong.