From c7ebdced62082bc07263b3a8c2374400e0ff1ced Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 20 Sep 2024 21:31:54 +0100 Subject: [PATCH] Add `--no-optimize` flag to rten CLI Add a flag that turns off graph optimizations when loading the model, thus providing a way to test their impact. --- rten-cli/src/main.rs | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/rten-cli/src/main.rs b/rten-cli/src/main.rs index d05cd1a0..ad32b342 100644 --- a/rten-cli/src/main.rs +++ b/rten-cli/src/main.rs @@ -2,7 +2,9 @@ use std::collections::VecDeque; use std::error::Error; use std::time::Instant; -use rten::{Dimension, InputOrOutput, Model, ModelMetadata, NodeId, Output, RunOptions}; +use rten::{ + Dimension, InputOrOutput, Model, ModelMetadata, ModelOptions, NodeId, Output, RunOptions, +}; use rten_tensor::prelude::*; use rten_tensor::Tensor; @@ -10,6 +12,9 @@ struct Args { /// Model file to load. model: String, + /// Whether to enable graph optimizations + optimize: bool, + /// Run model and don't produce other output quiet: bool, @@ -108,6 +113,7 @@ fn parse_args() -> Result { let mut timing = false; let mut verbose = false; let mut input_sizes = Vec::new(); + let mut optimize = true; let mut parser = lexopt::Parser::from_env(); while let Some(arg) = parser.next()? { @@ -120,6 +126,7 @@ fn parse_args() -> Result { .parse() .map_err(|_| "Unable to parse `n_iters`".to_string())?; } + Long("no-optimize") => optimize = false, Short('q') | Long("quiet") => quiet = true, Short('v') | Long("verbose") => verbose = true, Short('V') | Long("version") => { @@ -151,6 +158,8 @@ Options: -n, --n_iters Number of times to evaluate model + --no-optimize Disable graph optimizations + -q, --quiet Run model and don't produce other output -t, --timing Output timing info @@ -176,6 +185,7 @@ Options: model, n_iters, mmap, + optimize, quiet, timing, verbose, @@ -404,10 +414,14 @@ fn print_input_output_list(model: &Model, node_ids: &[NodeId]) { /// running. See `docs/profiling.md`. fn main() -> Result<(), Box> { let args = parse_args()?; + + let mut model_opts = ModelOptions::with_all_ops(); + model_opts.enable_optimization(args.optimize); + let model = if args.mmap { - unsafe { Model::load_mmap(args.model)? } + unsafe { model_opts.load_mmap(args.model)? } } else { - Model::load_file(args.model)? + model_opts.load_file(args.model)? }; if !args.quiet {