Skip to content

Commit

Permalink
[SR] assignStorageToManagedTensors returns a vector (pytorch#69568)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#69568

Non-empty vectors should never be passed to `assignStorageToManagedTensors` and `assignStorageToManagedOutputTensors`. Presumably, this out-variant convention was adopted to avoid move-assigning the corresponding attribtues in `MemoryPlanner`. But the cost of a vector move-assign is not high, and this function type signature is safer.

Test Plan: `buck test caffe2/bechmarks/static_runtime:static_runtime_cpptest`

Reviewed By: donaldong

Differential Revision: D32729289

fbshipit-source-id: 88f19de8eb89d8a4f1dd8bbd4d9e7f686e41888b
  • Loading branch information
Mike Iovine authored and facebook-github-bot committed Dec 10, 2021
1 parent 9aa1b3e commit f87f1d0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
5 changes: 2 additions & 3 deletions benchmarks/static_runtime/test_static_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1230,9 +1230,8 @@ void testAssignStorageToManagedTensors(
ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());

auto ranges = ManagedTensorRanges(graph, managed_tensor_values);
std::vector<StorageGroup> groups;
assignStorageToManagedTensors(
graph->block()->nodes(), ranges, tensor_value_to_tensor, groups);
auto groups = assignStorageToManagedTensors(
graph->block()->nodes(), ranges, tensor_value_to_tensor);

checkStorageGroups(
groups, ranges, tensor_value_to_tensor, min_reused_tensors);
Expand Down
25 changes: 12 additions & 13 deletions torch/csrc/jit/runtime/static/memory_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ FastMap<const Value*, at::Tensor*> tensorValueToTensor(

} // namespace

void assignStorageToManagedTensors(
std::vector<StorageGroup> assignStorageToManagedTensors(
graph_node_list nodes,
const ManagedTensorRanges& ranges,
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor,
std::vector<StorageGroup>& managed_tensor_groups) {
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor) {
std::vector<StorageGroup> managed_tensor_groups;
// This set maps each Value* to its assigned storage group.
FastMap<const Value*, size_t> storage_group_mapping;
// On each iteration, this vector stores the set of storage groups that
Expand Down Expand Up @@ -119,6 +119,7 @@ void assignStorageToManagedTensors(
}
}
}
return managed_tensor_groups;
}

namespace {
Expand All @@ -127,10 +128,10 @@ bool setIncludes(const FastSet<const Value*>& set, const Value* v) {
return set.find(v) != set.end();
}

void assignStorageToOutputTensors(
std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
StaticRuntime* runtime,
const FastSet<const Value*>& managed_output_tensor_values,
std::vector<std::pair<size_t, at::Tensor*>>& managed_output_tensors) {
const FastSet<const Value*>& managed_output_tensor_values) {
std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
for (auto& pnode : runtime->nodes()) {
for (const auto i : c10::irange(pnode.outputs().size())) {
auto& ival = pnode.Output(i);
Expand All @@ -144,6 +145,7 @@ void assignStorageToOutputTensors(
managed_output_tensors.emplace_back(0, tensor);
}
}
return managed_output_tensors;
}

} // namespace
Expand Down Expand Up @@ -213,11 +215,8 @@ MemoryPlanner::MemoryPlanner(
const auto tensor_value_to_tensor =
tensorValueToTensor(runtime->nodes(), managed_tensor_values);
if (optimize_memory) {
::torch::jit::assignStorageToManagedTensors(
runtime->node_ptrs(),
ranges,
tensor_value_to_tensor,
managed_tensors_);
managed_tensors_ = assignStorageToManagedTensors(
runtime->node_ptrs(), ranges, tensor_value_to_tensor);
} else {
for (auto& tensor : tensor_value_to_tensor) {
managed_tensors_.emplace_back(tensor.second);
Expand All @@ -226,8 +225,8 @@ MemoryPlanner::MemoryPlanner(
}

if (enable_out_variant && manage_output_tensors) {
::torch::jit::assignStorageToOutputTensors(
runtime, managed_output_tensor_values, managed_output_tensors_);
managed_output_tensors_ =
assignStorageToOutputTensors(runtime, managed_output_tensor_values);
}

num_managed_tensors_ = 0;
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/runtime/static/memory_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ class StorageGroup {
std::vector<at::Tensor*> group_{};
};

TORCH_API void assignStorageToManagedTensors(
TORCH_API std::vector<StorageGroup> assignStorageToManagedTensors(
graph_node_list nodes,
const ManagedTensorRanges& ranges,
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor,
std::vector<StorageGroup>& managed_tensor_groups);
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor);

// There are three types of ops in a processed graph in Static Runtime:
// 1. op with _out variant
Expand Down

0 comments on commit f87f1d0

Please sign in to comment.