Skip to content

Commit

Permalink
Merge pull request #320 from robertknight/convert-captured-values-to-…
Browse files Browse the repository at this point in the history
…const

Add an optimization pass to convert captured values to constants
  • Loading branch information
robertknight authored Aug 24, 2024
2 parents fbe2efd + 07ab5cd commit 8bd8147
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 25 deletions.
12 changes: 12 additions & 0 deletions src/constant_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ pub struct ArcSlice<T> {
phantom: PhantomData<T>,
}

// Manual implementation of `ArcSlice<T>` avoids adding a `T: Clone` bound.
impl<T> Clone for ArcSlice<T> {
fn clone(&self) -> ArcSlice<T> {
ArcSlice {
storage: self.storage.clone(),
byte_offset: self.byte_offset,
len: self.len,
phantom: PhantomData,
}
}
}

impl<T> ArcSlice<T> {
/// Return an ArcSlice which references the subslice of `storage` specified
/// by `data`.
Expand Down
69 changes: 60 additions & 9 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ pub enum ConstantNodeData<T> {
Arc(ArcTensorView<T>),
}

impl<T> ConstantNodeData<T> {
fn clone_ref(&self) -> Option<ConstantNodeData<T>> {
match self {
ConstantNodeData::Owned(_) => None,
ConstantNodeData::Arc(view) => Some(ConstantNodeData::Arc(view.clone())),
}
}
}

impl<T> From<Tensor<T>> for ConstantNodeData<T> {
fn from(val: Tensor<T>) -> ConstantNodeData<T> {
ConstantNodeData::Owned(val)
Expand All @@ -125,6 +134,14 @@ impl<T> ConstantNode<T> {
}
}

fn clone_ref(&self) -> Option<ConstantNode<T>> {
let data = self.data.clone_ref()?;
Some(ConstantNode {
name: self.name.clone(),
data,
})
}

fn layout(&self) -> &DynLayout {
match &self.data {
ConstantNodeData::Owned(data) => data.layout(),
Expand All @@ -146,6 +163,15 @@ impl Constant {
}
}

/// Clone this constant, but only if it can be done so cheaply by
/// incrementing a ref count on the underlying data.
pub fn clone_ref(&self) -> Option<Constant> {
match self {
Constant::Float(f) => f.clone_ref().map(Constant::Float),
Constant::Int(i) => i.clone_ref().map(Constant::Int),
}
}

fn layout(&self) -> &DynLayout {
match self {
Constant::Float(f) => f.layout(),
Expand Down Expand Up @@ -424,18 +450,25 @@ pub struct CaptureEnv<'a> {
graph: &'a Graph,

// Values passed as inputs to the graph run.
inputs: &'a FxHashMap<NodeId, InputOrOutput<'a>>,
inputs: Option<&'a FxHashMap<NodeId, InputOrOutput<'a>>>,

// Temporary values computed during the graph run.
temp_values: &'a FxHashMap<NodeId, Output>,
temp_values: Option<&'a FxHashMap<NodeId, Output>>,
}

impl<'a> CaptureEnv<'a> {
fn new(
/// Create a new capture environment.
///
/// Lookups will first match nodes in `graph` and then try the `parent`
/// environment if that fails. Lookups that match constant nodes will be
/// resolved from the node directly. Lookups that match value nodes will
/// be resolved from `temp_values` first and then `inputs` if there is no
/// match there.
pub fn new(
parent: Option<&'a CaptureEnv<'a>>,
graph: &'a Graph,
inputs: &'a FxHashMap<NodeId, InputOrOutput<'a>>,
temp_values: &'a FxHashMap<NodeId, Output>,
inputs: Option<&'a FxHashMap<NodeId, InputOrOutput<'a>>>,
temp_values: Option<&'a FxHashMap<NodeId, Output>>,
) -> CaptureEnv<'a> {
CaptureEnv {
parent,
Expand All @@ -445,6 +478,19 @@ impl<'a> CaptureEnv<'a> {
}
}

/// Look up a node by name in this environment.
pub fn get_node(&self, name: &str) -> Option<&'a Node> {
if let Some(node_id) = self.graph.get_node_id(name) {
// If a node by this name exists in this graph, but is a placeholder
// for a value captured from a parent graph, then ignore it.
if !self.graph.captures().contains(&node_id) {
return self.graph.get_node(node_id);
}
}

self.parent.and_then(|parent| parent.get_node(name))
}

/// Look up an operator input value by name in this environment.
pub fn get_input(&self, name: &str) -> Option<Input<'a>> {
if let Some(node_id) = self.graph.get_node_id(name) {
Expand All @@ -456,9 +502,13 @@ impl<'a> CaptureEnv<'a> {
Some(Node::Constant(c)) => Some(c.as_input()),
Some(Node::Value(_)) => self
.temp_values
.get(&node_id)
.and_then(|tv| tv.get(&node_id))
.map(|i| i.as_input())
.or_else(|| self.inputs.get(&node_id).map(|i| i.as_input())),
.or_else(|| {
self.inputs
.and_then(|i| i.get(&node_id))
.map(|i| i.as_input())
}),
_ => None,
};
}
Expand Down Expand Up @@ -609,7 +659,7 @@ impl Graph {
captures
}

fn add_node(&mut self, node: Node) -> NodeId {
pub fn add_node(&mut self, node: Node) -> NodeId {
self.nodes.push(node);
let node_id = self.nodes.len() - 1;

Expand Down Expand Up @@ -1024,7 +1074,8 @@ impl Graph {
.map(|out| [out].into())
.map_err(op_error_to_run_error)
} else if op_node.operator.has_subgraph() {
let capture_env = CaptureEnv::new(captures, self, &inputs_by_id, &temp_values);
let capture_env =
CaptureEnv::new(captures, self, Some(&inputs_by_id), Some(&temp_values));
let result = op_node.operator.run_subgraph(
pool,
InputList::from_optional(&op_inputs),
Expand Down
59 changes: 50 additions & 9 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ use rten_tensor::Tensor;

use crate::constant_storage::{ArcSlice, ArcTensorView, ConstantStorage};
use crate::env::str_as_bool;
use crate::graph::{ConstantNodeData, Dimension, Graph, Node, NodeId, RunError, RunOptions};
use crate::graph::{
CaptureEnv, ConstantNodeData, Dimension, Graph, Node, NodeId, RunError, RunOptions,
};
use crate::header::{Header, HeaderError};
use crate::model_metadata::ModelMetadata;
use crate::number::{LeBytes, Pod};
Expand Down Expand Up @@ -147,6 +149,22 @@ fn parse_timing_config(config: &str, opts: &mut RunOptions) {
}
}

/// Configuration for loading subgraphs.
struct SubgraphOptions<'a> {
/// Tensor data storage
storage: Arc<ConstantStorage>,

/// Offset of tensor data within the storage.
tensor_data_offset: Option<u64>,

/// Whether to apply optimizations when loading the subgraph.
optimize: bool,

/// Provides access to info about nodes captured from parent graphs.
/// This is needed for some optimization passes.
capture_env: Option<&'a CaptureEnv<'a>>,
}

/// Options which customize how a model is loaded.
///
/// This enables more advanced use cases such as loading a model with only
Expand Down Expand Up @@ -311,6 +329,7 @@ impl Model {
storage.clone(),
tensor_data_offset,
options.optimize,
None, /* capture_env */
)?;

let metadata = model
Expand All @@ -328,6 +347,7 @@ impl Model {
storage: Arc<ConstantStorage>,
tensor_data_offset: Option<u64>,
optimize: bool,
capture_env: Option<&CaptureEnv>,
) -> Result<Graph, ModelLoadError> {
let node_count = serialized_graph.nodes().map(|ns| ns.len()).unwrap_or(0);

Expand All @@ -353,10 +373,6 @@ impl Model {
graph.set_captures(&captures);
}

let load_subgraph = |g: sg::Graph| -> Result<Graph, ModelLoadError> {
Self::load_graph(g, registry, storage.clone(), tensor_data_offset, optimize)
};

if let Some(nodes) = serialized_graph.nodes() {
for (node_index, node) in nodes.iter().enumerate() {
let graph_node = if let Some(operator) = node.data_as_operator_node() {
Expand All @@ -366,7 +382,12 @@ impl Model {
operator,
registry,
&node_id_from_index,
&load_subgraph,
SubgraphOptions {
storage: storage.clone(),
tensor_data_offset,
optimize,
capture_env,
},
)?
} else if let Some(value) = node.data_as_value_node() {
Self::add_graph_value(&mut graph, node.name(), value)?
Expand All @@ -388,7 +409,7 @@ impl Model {
if optimize {
let optimizer = GraphOptimizer::new();
optimizer
.optimize(graph)
.optimize(graph, capture_env)
.map_err(|err| ModelLoadError::OptimizeError(Box::new(err)))
} else {
Ok(graph)
Expand All @@ -401,8 +422,26 @@ impl Model {
operator: sg::OperatorNode,
registry: &OpRegistry,
node_id_from_index: &HashMap<usize, NodeId>,
load_graph: &dyn Fn(sg::Graph) -> Result<Graph, ModelLoadError>,
subgraph_opts: SubgraphOptions,
) -> Result<NodeId, ModelLoadError> {
let load_subgraph = |g: sg::Graph| -> Result<Graph, ModelLoadError> {
let SubgraphOptions {
storage,
tensor_data_offset,
optimize,
capture_env,
} = &subgraph_opts;
let capture_env = CaptureEnv::new(*capture_env, graph, None, None);
Self::load_graph(
g,
registry,
storage.clone(),
*tensor_data_offset,
*optimize,
Some(&capture_env),
)
};

struct LoadContext<'a> {
load_graph: &'a dyn Fn(sg::Graph) -> Result<Graph, ModelLoadError>,
}
Expand All @@ -413,7 +452,9 @@ impl Model {
}
}

let ctx = LoadContext { load_graph };
let ctx = LoadContext {
load_graph: &load_subgraph,
};
let op = registry
.read_op(&operator, &ctx)
.map_err(ModelLoadError::OperatorInvalid)?;
Expand Down
Loading

0 comments on commit 8bd8147

Please sign in to comment.