From 2728c087fa9e848d1b808314002bfddd2801a03d Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 9 Apr 2024 11:08:05 +0200 Subject: [PATCH 1/3] first working wip --- core/src/ops/konst.rs | 61 +++++++++++++++++++++ libcli/src/terminal.rs | 2 +- nnef/src/ast.rs | 2 + nnef/src/framework.rs | 108 +++++++++++++++++++------------------ nnef/src/ops/nnef/deser.rs | 48 ++++++++++------- nnef/src/resource.rs | 65 ++++++++++++++++++++-- nnef/src/ser.rs | 8 ++- 7 files changed, 214 insertions(+), 80 deletions(-) diff --git a/core/src/ops/konst.rs b/core/src/ops/konst.rs index b5a18ec7ea..c8ba242fe8 100644 --- a/core/src/ops/konst.rs +++ b/core/src/ops/konst.rs @@ -1,3 +1,7 @@ +use std::fmt::Debug; + +use dyn_clone::DynClone; + use crate::internal::*; #[derive(Debug, Clone, new, Hash, Eq, PartialEq)] @@ -72,3 +76,60 @@ impl TypedOp for Const { target.wire_node(&node.name, op, &[]) } } + +#[derive(Debug, Clone, new)] +pub struct LazyConst(pub Arc); + +impl Op for LazyConst { + fn name(&self) -> Cow { + "LazyConst".into() + } + + fn info(&self) -> TractResult> { + Ok(vec!(format!("{:?}", self.0))) + } + + op_as_typed_op!(); +} + +impl EvalOp for LazyConst { + fn is_stateless(&self) -> bool { + false + } + + fn state( + &self, + _session: &mut SessionState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(self.clone()))) + } +} + +impl OpState for LazyConst { + fn eval( + &mut self, + _session: &mut SessionState, + _op: &dyn Op, + _inputs: TVec, + ) -> TractResult> { + Ok(tvec!(self.0.eval()?)) + } +} + +trivial_op_state_freeeze!(LazyConst); + +impl TypedOp for LazyConst { + as_op!(); + + fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult> { + Ok(tvec!(self.0.output_fact()?)) + } +} + +pub trait LazyConstProvider: DynClone + Debug + Send + Sync + 'static { + fn output_fact(&self) -> TractResult; + fn eval(&self) -> TractResult; +} + +dyn_clone::clone_trait_object!(LazyConstProvider); diff --git a/libcli/src/terminal.rs b/libcli/src/terminal.rs index d78ae3b909..d72e0b8d5e 100644 --- a/libcli/src/terminal.rs +++ b/libcli/src/terminal.rs @@ -129,7 +129,7 @@ fn render_node_prefixed( // flops column let mut flops_column = if options.profile && options.cost { - let timing: f64 = tags.profile.as_ref().unwrap().as_secs_f64(); + let timing: f64 = tags.profile.as_ref().map(Duration::as_secs_f64).unwrap_or(0.0); let flops_column_pad = flops_column_pad.clone(); let it = tags.cost.iter().map(move |c| { if c.0.is_compute() { diff --git a/nnef/src/ast.rs b/nnef/src/ast.rs index f7691141d2..4d9c7485f7 100644 --- a/nnef/src/ast.rs +++ b/nnef/src/ast.rs @@ -1,5 +1,6 @@ use crate::internal::*; use tract_itertools::Itertools; +use crate::resource::LazyDat; pub mod dump; pub mod dump_doc; @@ -10,6 +11,7 @@ pub mod quant; pub struct ProtoModel { pub doc: Document, pub tensors: HashMap>, + pub lazy_tensors: HashMap>, pub quantization: Option>, pub resources: HashMap>, } diff --git a/nnef/src/framework.rs b/nnef/src/framework.rs index 096ecc04c8..440895bff5 100644 --- a/nnef/src/framework.rs +++ b/nnef/src/framework.rs @@ -3,17 +3,23 @@ use tract_core::tract_data::itertools::Itertools; use crate::ast::quant::write_quant_format; use crate::ast::{Document, Identifier, ProtoModel, QuantFormat}; +use crate::resource::{LazyDat, LazyDatLoader}; use crate::{internal::*, nnef}; use std::io::Read; #[cfg(target_family = "unix")] use std::os::unix::prelude::OsStrExt; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::str::FromStr; pub fn stdlib() -> Vec { crate::ast::parse::parse_fragments(include_str!("../stdlib.nnef")).unwrap() } +pub enum LazyDataProvider { + None, + File(PathBuf), +} + pub struct Nnef { pub stdlib: Vec, pub registries: Vec, @@ -27,6 +33,7 @@ impl Default for Nnef { stdlib: stdlib(), registries: vec![crate::ops::tract_nnef()], resource_loaders: vec![ + LazyDatLoader.into_boxed(), GraphNnefLoader.into_boxed(), DatLoader.into_boxed(), GraphQuantLoader.into_boxed(), @@ -267,7 +274,13 @@ impl tract_core::prelude::Framework for Nnef { .skip(path.components().count()) .collect::(); let mut stream = std::fs::File::open(entry.path())?; - read_stream(&subpath, &mut stream, &mut resources, self)?; + read_stream( + &subpath, + &LazyDataProvider::File(entry.path().to_owned()), + &mut stream, + &mut resources, + self, + )?; } proto_model_from_resources(resources) } @@ -293,7 +306,7 @@ impl tract_core::prelude::Framework for Nnef { for entry in tar.entries()? { let mut entry = entry?; let path = entry.path()?.to_path_buf(); - read_stream(&path, &mut entry, &mut resources, self)?; + read_stream(&path, &LazyDataProvider::None, &mut entry, &mut resources, self)?; } proto_model_from_resources(resources) } @@ -313,7 +326,7 @@ fn proto_model_from_resources( // Iter resources IDs to detect submodels. Submodels are IDs with // - two path compoents (ex: XXX/file) // - a graph.nnef file as filename - let sub_models = resources + let sub_model_ids = resources .keys() .clone() .filter_map(|id| { @@ -330,8 +343,8 @@ fn proto_model_from_resources( // If there are submodels, we use the associated resources to create a TypedModel resource and add // it as a new resource. - let mut new_resources = if sub_models.len() > 0 { - sub_models.into_iter().try_fold(resources, |r, it| -> TractResult> { + let new_resources = if sub_model_ids.len() > 0 { + sub_model_ids.into_iter().try_fold(resources, |r, it| -> TractResult> { let (submodel_resources, mut resources): (HashMap>, _) = r.into_iter().partition(|(k, _v)| k.starts_with(&it)); let submodel_resources = submodel_resources @@ -347,56 +360,44 @@ fn proto_model_from_resources( resources }; - // NNEF document extraction - let doc = new_resources - .remove(crate::resource::GRAPH_NNEF_FILENAME) - .with_context(|| { - anyhow!("Resource {} was not found in the model", crate::resource::GRAPH_NNEF_FILENAME) - })? - .downcast_arc::() - .map_err(|_| anyhow!("Error while downcasting NNEF document resource"))?; - - let doc = Arc::try_unwrap(doc) - .map_err(|_| anyhow!("Error while extracting NNEF Document from shared reference. Only one reference to the document is expected"))?; - - // Collect all resources that can be downcastable to Arc. - let tensors: HashMap<_, _> = new_resources - .iter() - .filter_map(|(key, resource)| { - Arc::clone(resource) - .downcast_arc::() - .ok() - .map(|r| (Identifier::from(&**key), r)) - }) - .collect(); - // Iterate over tensors keys to remove them from the global resources hash map. - tensors.keys().for_each(|k| { - new_resources.remove(&*k.0); - }); - - // Quantization format resources extraction if present. - let quantization = if let Some(q_r) = - new_resources.remove(crate::resource::GRAPH_QUANT_FILENAME) - { - let Ok(q_r) = q_r.downcast_arc::>() else { - bail!("Error while downcasting quantization format resource") - }; - let Ok(q_r) = Arc::try_unwrap(q_r) else { - bail!("Error while extracting quantization format resource from shared reference. Only one reference to it is expected") - }; - Some(q_r.into_iter().map(|(k, v)| (Identifier(k), v)).collect()) - } else { - None - }; + let mut resources = HashMap::default(); + let mut tensors = HashMap::default(); + let mut lazy_tensors = HashMap::default(); + let mut doc: Option> = None; + let mut quantization = None; + for (k, res) in new_resources { + if let Ok(t) = res.clone().downcast_arc::() { + tensors.insert(Identifier(k), t); + } else if let Ok(t) = res.clone().downcast_arc::() { + lazy_tensors.insert(Identifier(k), t); + } else if k == crate::resource::GRAPH_NNEF_FILENAME { + doc = Some( + res.downcast_arc::() + .map_err(|_| anyhow!("graph.nnef must be a Document"))?, + ); + } else if k == crate::resource::GRAPH_QUANT_FILENAME { + let map = res + .downcast_arc::>() + .map_err(|_| anyhow!("graph.quant must be quantization information"))?; + quantization = + Some(map.iter().map(|(k, v)| (Identifier::from(&**k), v.clone())).collect()) + } else { + resources.insert(k, res); + } + } + + let Some(doc) = doc else { bail!("Could not find graph.nnef") }; + let doc = Arc::try_unwrap(doc).unwrap(); - let proto = ProtoModel { doc, tensors, quantization, resources: new_resources }; + let proto = ProtoModel { doc, tensors, lazy_tensors, quantization, resources }; proto.validate()?; Ok(proto) } -fn read_stream( +fn read_stream( path: &Path, - reader: &mut R, + lazy_data_provider: &LazyDataProvider, + reader: &mut impl std::io::Read, resources: &mut HashMap>, framework: &Nnef, ) -> TractResult<()> { @@ -408,9 +409,10 @@ fn read_stream( let mut last_loader_name; for loader in framework.resource_loaders.iter() { last_loader_name = Some(loader.name()); - let loaded = loader.try_load(path, reader, framework).with_context(|| { - anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path) - })?; + let loaded = + loader.try_load(path, lazy_data_provider, reader, framework).with_context(|| { + anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path) + })?; if let Some((id, resource)) = loaded { ensure!( !resources.contains_key(&id), diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index fc3614ac37..ed1b22c7d9 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -39,42 +39,50 @@ pub fn external(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> pub fn variable(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { let shape: TVec = invocation.named_arg_as(builder, "shape")?; let label = Identifier(invocation.named_arg_as(builder, "label")?); + let sanitized_label = Identifier(label.0.trim_start_matches('/').to_owned()); + let lazy_tensors = &builder.proto_model.lazy_tensors; let tensors = &builder.proto_model.tensors; - let mut tensor = Arc::clone( - tensors - .get(&label) - .or_else(|| tensors.get(&Identifier(label.0.trim_start_matches('/').to_owned()))) - .ok_or_else(|| format_err!("No data for tensor {:?}", label))?, - ); + + let mut wire = if let Some(t) = tensors.get(&label).or_else(|| tensors.get(&sanitized_label)) { + builder.wire_as_outlets(tract_core::ops::konst::Const::new(t.clone()), &[])? + } else if let Some(lt) = lazy_tensors.get(&label).or_else(|| lazy_tensors.get(&sanitized_label)) + { + builder.wire_as_outlets(tract_core::ops::konst::LazyConst::new(lt.clone()), &[])? + } else { + bail!("No data for tensor {:?}", label) + }; + let fact = builder.model.outlet_fact(wire[0])?; + if fact.shape.as_concrete().unwrap() != &*shape { + bail!( + "Wrong shape for tensor: {:?}, tensor file says {:?}, graph files says {:?}", + label, + fact.shape, + shape + ); + } + if let Some(Some(dt)) = invocation.dt_from_quant_file.first() { - if dt.size_of() != tensor.datum_type().size_of() { + if dt.size_of() != fact.datum_type.size_of() { bail!( "Mismatched tensor type for tensor {}: expected {:?}, got {:?}", label.0, *dt, - tensor.datum_type() + fact.datum_type ); } - if *dt != tensor.datum_type() { + if *dt != fact.datum_type { trace!( "Casting tensor {} from {:?} to {:?} when deserializing", label.0, - tensor.datum_type(), + fact.datum_type, *dt ); //FIXME: avoid cast by late-loading tensors ? - tensor = tensor.cast_to_dt(*dt)?.into_owned().into_arc_tensor() + wire = builder.wire_as_outlets(cast(*dt), &wire)?; } } - if tensor.shape() != &*shape { - bail!( - "Wrong shape for tensor: {:?}, tensor file says {:?}, graph files says {:?}", - label, - tensor.shape(), - shape - ); - } - builder.wire(tract_core::ops::konst::Const::new(tensor), &[]) + + Ok(wire.into()) } // fragment reshape( input: tensor, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) diff --git a/nnef/src/resource.rs b/nnef/src/resource.rs index 4ceccee341..a6b93eb628 100644 --- a/nnef/src/resource.rs +++ b/nnef/src/resource.rs @@ -1,8 +1,10 @@ -use std::path::Path; +use std::path::{Path, PathBuf}; use crate::ast::{Document, QuantFormat}; +use crate::framework::LazyDataProvider; use crate::internal::*; use tract_core::downcast_rs::{impl_downcast, DowncastSync}; +use tract_core::ops::konst::LazyConstProvider; pub const GRAPH_NNEF_FILENAME: &str = "graph.nnef"; pub const GRAPH_QUANT_FILENAME: &str = "graph.quant"; @@ -32,6 +34,7 @@ pub trait ResourceLoader: Send + Sync { fn try_load( &self, path: &Path, + lazy_data_provider: &LazyDataProvider, reader: &mut dyn std::io::Read, framework: &Nnef, ) -> TractResult)>>; @@ -57,6 +60,7 @@ impl ResourceLoader for GraphNnefLoader { fn try_load( &self, path: &Path, + _lazy_data_provider: &LazyDataProvider, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { @@ -84,6 +88,7 @@ impl ResourceLoader for DatLoader { fn try_load( &self, path: &Path, + _lazy_data_provider: &LazyDataProvider, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { @@ -97,6 +102,56 @@ impl ResourceLoader for DatLoader { } } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct LazyDat { + path: PathBuf, + fact: TypedFact, +} + +impl Resource for LazyDat {} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)] +pub struct LazyDatLoader; + +impl ResourceLoader for LazyDatLoader { + fn name(&self) -> Cow { + "LazyDatLoader".into() + } + + fn try_load( + &self, + path: &Path, + lazy_data_provider: &LazyDataProvider, + reader: &mut dyn std::io::Read, + _framework: &Nnef, + ) -> TractResult)>> { + let LazyDataProvider::File(f) = lazy_data_provider else { return Ok(None) }; + if path.extension().map(|e| e == "dat").unwrap_or(false) { + let tensor = crate::tensors::read_tensor(reader) + .with_context(|| format!("Error while reading tensor {path:?}"))?; + let lazy_dat = LazyDat { + fact: TypedFact::dt_shape(tensor.datum_type(), tensor.shape()), + path: f.clone(), + }; + Ok(Some((resource_path_to_id(path)?, Arc::new(lazy_dat)))) + } else { + Ok(None) + } + } +} + +impl LazyConstProvider for LazyDat { + fn eval(&self) -> TractResult { + let tensor = crate::tensors::read_tensor(&std::fs::File::open(&self.path)?) + .with_context(|| format!("Error while reading tensor {:?}", self))?; + Ok(tensor.into_tvalue()) + } + + fn output_fact(&self) -> TractResult { + Ok(self.fact.clone()) + } +} + impl Resource for HashMap {} #[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)] @@ -110,6 +165,7 @@ impl ResourceLoader for GraphQuantLoader { fn try_load( &self, path: &Path, + _lazy_data_provider: &LazyDataProvider, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { @@ -144,6 +200,7 @@ impl ResourceLoader for TypedModelLoader { fn try_load( &self, path: &Path, + _lazy_data_provider: &LazyDataProvider, reader: &mut dyn std::io::Read, framework: &Nnef, ) -> TractResult)>> { @@ -158,13 +215,11 @@ impl ResourceLoader for TypedModelLoader { }; let label = if path_str.ends_with(NNEF_TGZ) { - path - .to_str() + path.to_str() .ok_or_else(|| anyhow!("invalid model resource path"))? .trim_end_matches(NNEF_TGZ) } else { - path - .to_str() + path.to_str() .ok_or_else(|| anyhow!("invalid model resource path"))? .trim_end_matches(NNEF_TAR) }; diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index 70d11878df..10a3c226d1 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -224,7 +224,13 @@ impl<'a> IntoAst<'a> { graph_def: GraphDef { id: Identifier("network".into()), parameters, results, body }, }; let quantization = if self.quantization.len() > 0 { Some(self.quantization) } else { None }; - Ok(ProtoModel { doc, tensors, quantization, resources: self.resources }) + Ok(ProtoModel { + doc, + lazy_tensors: Default::default(), + tensors, + quantization, + resources: self.resources, + }) } fn node(&mut self, node: &TypedNode) -> TractResult>> { From 0711913cf7bb5692e868b1efd867df7cc230f381 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 9 Apr 2024 16:53:52 +0200 Subject: [PATCH 2/3] simpler types --- nnef/src/ast.rs | 19 ++++++++++++++++++ nnef/src/framework.rs | 20 ++++++++----------- nnef/src/resource.rs | 45 +++++++++++++++++++++---------------------- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/nnef/src/ast.rs b/nnef/src/ast.rs index 4d9c7485f7..0907eb16ba 100644 --- a/nnef/src/ast.rs +++ b/nnef/src/ast.rs @@ -1,3 +1,7 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + use crate::internal::*; use tract_itertools::Itertools; use crate::resource::LazyDat; @@ -7,6 +11,21 @@ pub mod dump_doc; pub mod parse; pub mod quant; +#[derive(Clone,Debug)] +pub enum LazyReader { + File(PathBuf), +} + +impl LazyReader { + pub fn read(&self) -> TractResult> { + match self { + LazyReader::File(p) => { + Ok(Box::new(File::open(p).with_context(|| format!("Opening {p:?}"))?)) + } + } + } +} + #[derive(Clone, Debug)] pub struct ProtoModel { pub doc: Document, diff --git a/nnef/src/framework.rs b/nnef/src/framework.rs index 440895bff5..e991cc0392 100644 --- a/nnef/src/framework.rs +++ b/nnef/src/framework.rs @@ -2,24 +2,19 @@ use tar::Builder; use tract_core::tract_data::itertools::Itertools; use crate::ast::quant::write_quant_format; -use crate::ast::{Document, Identifier, ProtoModel, QuantFormat}; +use crate::ast::{Document, Identifier, LazyReader, ProtoModel, QuantFormat}; use crate::resource::{LazyDat, LazyDatLoader}; use crate::{internal::*, nnef}; use std::io::Read; #[cfg(target_family = "unix")] use std::os::unix::prelude::OsStrExt; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::str::FromStr; pub fn stdlib() -> Vec { crate::ast::parse::parse_fragments(include_str!("../stdlib.nnef")).unwrap() } -pub enum LazyDataProvider { - None, - File(PathBuf), -} - pub struct Nnef { pub stdlib: Vec, pub registries: Vec, @@ -276,7 +271,7 @@ impl tract_core::prelude::Framework for Nnef { let mut stream = std::fs::File::open(entry.path())?; read_stream( &subpath, - &LazyDataProvider::File(entry.path().to_owned()), + Some(LazyReader::File(entry.path().to_owned())), &mut stream, &mut resources, self, @@ -306,7 +301,7 @@ impl tract_core::prelude::Framework for Nnef { for entry in tar.entries()? { let mut entry = entry?; let path = entry.path()?.to_path_buf(); - read_stream(&path, &LazyDataProvider::None, &mut entry, &mut resources, self)?; + read_stream(&path, None, &mut entry, &mut resources, self)?; } proto_model_from_resources(resources) } @@ -396,7 +391,7 @@ fn proto_model_from_resources( fn read_stream( path: &Path, - lazy_data_provider: &LazyDataProvider, + lazy_data_provider: Option, reader: &mut impl std::io::Read, resources: &mut HashMap>, framework: &Nnef, @@ -409,8 +404,9 @@ fn read_stream( let mut last_loader_name; for loader in framework.resource_loaders.iter() { last_loader_name = Some(loader.name()); - let loaded = - loader.try_load(path, lazy_data_provider, reader, framework).with_context(|| { + let loaded = loader + .try_load(path, lazy_data_provider.clone(), reader, framework) + .with_context(|| { anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path) })?; if let Some((id, resource)) = loaded { diff --git a/nnef/src/resource.rs b/nnef/src/resource.rs index a6b93eb628..fccc1731a1 100644 --- a/nnef/src/resource.rs +++ b/nnef/src/resource.rs @@ -1,7 +1,6 @@ -use std::path::{Path, PathBuf}; +use std::path::Path; -use crate::ast::{Document, QuantFormat}; -use crate::framework::LazyDataProvider; +use crate::ast::{Document, LazyReader, QuantFormat}; use crate::internal::*; use tract_core::downcast_rs::{impl_downcast, DowncastSync}; use tract_core::ops::konst::LazyConstProvider; @@ -34,7 +33,7 @@ pub trait ResourceLoader: Send + Sync { fn try_load( &self, path: &Path, - lazy_data_provider: &LazyDataProvider, + lazy_data_provider: Option, reader: &mut dyn std::io::Read, framework: &Nnef, ) -> TractResult)>>; @@ -60,7 +59,7 @@ impl ResourceLoader for GraphNnefLoader { fn try_load( &self, path: &Path, - _lazy_data_provider: &LazyDataProvider, + _lazy_data_provider: Option, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { @@ -88,7 +87,7 @@ impl ResourceLoader for DatLoader { fn try_load( &self, path: &Path, - _lazy_data_provider: &LazyDataProvider, + _lazy_data_provider: Option, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { @@ -102,9 +101,9 @@ impl ResourceLoader for DatLoader { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug)] pub struct LazyDat { - path: PathBuf, + reader: LazyReader, fact: TypedFact, } @@ -121,28 +120,28 @@ impl ResourceLoader for LazyDatLoader { fn try_load( &self, path: &Path, - lazy_data_provider: &LazyDataProvider, + lazy_data_provider: Option, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { - let LazyDataProvider::File(f) = lazy_data_provider else { return Ok(None) }; - if path.extension().map(|e| e == "dat").unwrap_or(false) { - let tensor = crate::tensors::read_tensor(reader) - .with_context(|| format!("Error while reading tensor {path:?}"))?; - let lazy_dat = LazyDat { - fact: TypedFact::dt_shape(tensor.datum_type(), tensor.shape()), - path: f.clone(), - }; - Ok(Some((resource_path_to_id(path)?, Arc::new(lazy_dat)))) - } else { - Ok(None) + if let Some(lazy) = lazy_data_provider { + if path.extension().map(|e| e == "dat").unwrap_or(false) { + let tensor = crate::tensors::read_tensor(reader) + .with_context(|| format!("Error while reading tensor {path:?}"))?; + let dat = LazyDat { + fact: TypedFact::dt_shape(tensor.datum_type(), tensor.shape()), + reader: lazy, + }; + return Ok(Some((resource_path_to_id(path)?, Arc::new(dat)))); + } } + Ok(None) } } impl LazyConstProvider for LazyDat { fn eval(&self) -> TractResult { - let tensor = crate::tensors::read_tensor(&std::fs::File::open(&self.path)?) + let tensor = crate::tensors::read_tensor(self.reader.read()?) .with_context(|| format!("Error while reading tensor {:?}", self))?; Ok(tensor.into_tvalue()) } @@ -165,7 +164,7 @@ impl ResourceLoader for GraphQuantLoader { fn try_load( &self, path: &Path, - _lazy_data_provider: &LazyDataProvider, + _lazy_data_provider: Option, reader: &mut dyn std::io::Read, _framework: &Nnef, ) -> TractResult)>> { @@ -200,7 +199,7 @@ impl ResourceLoader for TypedModelLoader { fn try_load( &self, path: &Path, - _lazy_data_provider: &LazyDataProvider, + _lazy_data_provider: Option, reader: &mut dyn std::io::Read, framework: &Nnef, ) -> TractResult)>> { From f3d4906687b06c4069bc8257a3424d50642d567b Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 17 Apr 2024 09:15:26 +0200 Subject: [PATCH 3/3] fixes --- nnef/nnef-resources/src/json_loader.rs | 1 + nnef/src/lib.rs | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/nnef/nnef-resources/src/json_loader.rs b/nnef/nnef-resources/src/json_loader.rs index 03520043fb..6a05ece59d 100644 --- a/nnef/nnef-resources/src/json_loader.rs +++ b/nnef/nnef-resources/src/json_loader.rs @@ -25,6 +25,7 @@ impl ResourceLoader for JsonLoader { fn try_load( &self, path: &Path, + _lazy_data_provider: Option, reader: &mut dyn std::io::Read, _framework: &tract_nnef::framework::Nnef, ) -> TractResult)>> { diff --git a/nnef/src/lib.rs b/nnef/src/lib.rs index ad273466da..3f471578a0 100644 --- a/nnef/src/lib.rs +++ b/nnef/src/lib.rs @@ -25,14 +25,15 @@ pub mod prelude { pub mod internal { pub use crate::ast::parse::parse_parameters; pub use crate::ast::{ - param, FragmentDecl, FragmentDef, Identifier, Parameter, RValue, TypeName, + param, FragmentDecl, FragmentDef, Identifier, LazyReader, Parameter, RValue, TypeName, }; pub use crate::deser::{ModelBuilder, ResolvedInvocation, Value}; pub use crate::framework::Nnef; pub use crate::prelude::*; pub use crate::registry::*; pub use crate::resource::{ - DatLoader, GraphNnefLoader, GraphQuantLoader, Resource, ResourceLoader, TypedModelResource, TypedModelLoader, + DatLoader, GraphNnefLoader, GraphQuantLoader, Resource, ResourceLoader, TypedModelLoader, + TypedModelResource, }; pub use crate::ser::{invocation, logical, numeric, string, IntoAst}; pub use std::any::TypeId;