Skip to content

Commit

Permalink
Unify boxed function signature between jit and c10 (pytorch#37034)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#37034

c10 takes a Stack* in boxed functions while JIT took Stack&.
c10 doesn't return anything while JIT returns an int which is always zero.

This changes JIT to follow the c10 behavior.
ghstack-source-id: 106834069

Test Plan: unit tests

Differential Revision: D20567950

fbshipit-source-id: 1a7aea291023afc52ae706957e9a5ca576fbb53b
  • Loading branch information
smessmer authored and facebook-github-bot committed Jun 30, 2020
1 parent 320164f commit 53af9df
Show file tree
Hide file tree
Showing 33 changed files with 361 additions and 578 deletions.
29 changes: 28 additions & 1 deletion aten/src/ATen/core/stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace jit {

using c10::IValue;
using Stack = std::vector<IValue>;
using Operation = std::function<int(Stack&)>;
using Operation = std::function<void(Stack*)>;

// An operation with N inputs and M outputs pops the last N inputs off
// the stack and pushes its M inputs onto the stack
Expand All @@ -29,9 +29,15 @@ using Operation = std::function<int(Stack&)>;
static inline IValue& peek(Stack& stack, size_t i, size_t N) {
return *(stack.end() - N + i);
}
static inline IValue& peek(Stack* stack, size_t i, size_t N) {
return peek(*stack, i, N);
}
static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
return *(stack.end() - N + i);
}
static inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
return peek(*stack, i, N);
}
// treat the last N elements of the stack as a list, looking up the
// slice starting at index i and having length len
static inline at::ArrayRef<IValue> peekSlice(
Expand All @@ -44,14 +50,23 @@ static inline at::ArrayRef<IValue> peekSlice(
static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
return peekSlice(stack, 0, N, N);
}
static inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
return last(*stack, N);
}
static inline void drop(Stack& stack, size_t n) {
stack.erase(stack.end() - n, stack.end());
}
static inline void drop(Stack* stack, size_t n) {
drop(*stack, n);
}
static inline IValue pop(Stack& stack) {
auto r = std::move(stack.back());
stack.pop_back();
return r;
}
static inline IValue pop(Stack* stack) {
return pop(*stack);
}
static inline std::vector<IValue> pop(Stack& stack, size_t n) {
std::vector<IValue> result;
result.reserve(n);
Expand All @@ -76,6 +91,10 @@ static inline void pop(Stack& stack, Types&... args) {
(args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
drop(stack, N);
}
template <typename... Types>
static inline void pop(Stack* stack, Types&... args) {
pop(*stack, args...);
}
template <typename Type>
static inline void push_one(Stack& stack, Type&& arg) {
stack.emplace_back(std::forward<Type>(arg));
Expand All @@ -92,6 +111,10 @@ template <typename... Types>
static inline void push(Stack& stack, Types&&... args) {
(void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
}
template <typename... Types>
static inline void push(Stack* stack, Types&&... args) {
return push(*stack, std::forward<Types>(args)...);
}
template <class T>
static inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
for (T elem : elements) {
Expand All @@ -107,6 +130,10 @@ template <typename T>
inline void pack(Stack& stack, T&& v) {
stack.emplace_back(std::forward<T>(v));
}
template <typename T>
inline void pack(Stack* stack, T&& v) {
pack(*stack, std::forward<T>(v));
}

template <std::size_t remaining, typename... Args>
struct TuplePacker {
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/record_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ struct TORCH_API RecordFunction {
RecordFunction(
RecordScope scope = RecordScope::FUNCTION);

template <typename F>
void before(
F fn,
const std::vector<c10::IValue>* args,
int64_t current_sequence_nr = -1) {
inputs_ = *args;
before(fn, current_sequence_nr);
}

// Destructor calls end callbacks
virtual ~RecordFunction();

Expand Down
6 changes: 3 additions & 3 deletions test/cpp/jit/test_alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ void testAliasAnalysis() {
void testWriteTracking() {
RegisterOperators reg({Operator(
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
[](Stack& s) { return 0; },
[](Stack* s) {},
aliasAnalysisFromSchema())});
const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
{
Expand Down Expand Up @@ -920,11 +920,11 @@ graph():
void testWildcards() {
RegisterOperators reg({Operator(
"prim::returns_wildcard(Tensor a) -> Tensor(*)",
[](Stack& stack) { return 0; },
[](Stack* stack) {},
aliasAnalysisFromSchema()),
Operator(
"prim::writes(Tensor(z!) a) -> Tensor(a)",
[](Stack& stack) { return 0; },
[](Stack* stack) {},
aliasAnalysisFromSchema())});
const auto returns_wildcard =
Symbol::fromQualString("prim::returns_wildcard");
Expand Down
5 changes: 1 addition & 4 deletions test/cpp/jit/test_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ RegisterOperators reg({
// because it always produces empty Tensors.
Operator(
"prim::MakeTestTensor() -> Tensor",
[](Stack& stack) {
push(stack, at::Tensor());
return 0;
},
[](Stack* stack) { push(stack, at::Tensor()); },
aliasAnalysisFromSchema()),
});
}
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/jit/test_custom_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void testCustomOperators() {

Stack stack;
push(stack, 2.0f, at::ones(5));
op->getOperation()(stack);
op->getOperation()(&stack);
at::Tensor output;
pop(stack, output);

Expand Down Expand Up @@ -59,7 +59,7 @@ void testCustomOperators() {

Stack stack;
push(stack, 2.0f, at::ones(5));
op->getOperation()(stack);
op->getOperation()(&stack);
at::Tensor output;
pop(stack, output);

Expand Down Expand Up @@ -98,7 +98,7 @@ void testCustomOperators() {
push(stack, c10::List<int64_t>({1, 2}));
push(stack, c10::List<double>({1.0, 2.0}));
push(stack, c10::List<at::Tensor>({at::ones(5)}));
op->getOperation()(stack);
op->getOperation()(&stack);
c10::List<double> output;
pop(stack, output);

Expand Down Expand Up @@ -128,7 +128,7 @@ void testCustomOperators() {

Stack stack;
push(stack, c10::List<at::Tensor>({at::ones(5)}));
op->getOperation()(stack);
op->getOperation()(&stack);
c10::List<at::Tensor> output;
pop(stack, output);

Expand Down
8 changes: 2 additions & 6 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1234,21 +1234,17 @@ void testNoneSchemaMatch() {
RegisterOperators reg({
Operator(
"prim::test_none() -> int?",
[](Stack& stack) {
push(stack, IValue());
return 0;
},
[](Stack* stack) { push(stack, IValue()); },
aliasAnalysisFromSchema()),
Operator(
"prim::is_none(int? a) -> bool",
[](Stack& stack) {
[](Stack* stack) {
IValue a = pop(stack);
if (a.isNone()) {
push(stack, true);
} else {
push(stack, false);
}
return 0;
},
aliasAnalysisFromSchema()),
});
Expand Down
6 changes: 2 additions & 4 deletions test/cpp/jit/test_schema_matching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ void testSchemaMatching() {
RegisterOperators reg({
Operator(
"aten::test_vartype(t[] a, t b) -> (t)",
[](Stack& stack) {
[](Stack* stack) {
c10::List<double> list;
double a;
pop(stack, list, a);
push(stack, a);
return 0;
},
c10::AliasAnalysisKind::FROM_SCHEMA),
});
Expand Down Expand Up @@ -53,12 +52,11 @@ void testSchemaMatching() {
RegisterOperators reg({
Operator(
"aten::test_vartype2(t a, t[] b) -> (t[])",
[](Stack& stack) {
[](Stack* stack) {
double a;
c10::List<double> list;
pop(stack, a, list);
push(stack, a);
return 0;
},
AliasAnalysisKind::FROM_SCHEMA),
});
Expand Down
2 changes: 1 addition & 1 deletion test/custom_operator/test_custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Result get_operator_from_registry_and_execute(const char* op_name, Args&&... arg

torch::jit::Stack stack;
torch::jit::push(stack, std::forward<Args>(args)...);
op->getOperation()(stack);
op->getOperation()(&stack);

TORCH_INTERNAL_ASSERT(1 == stack.size());
return torch::jit::pop(stack).to<Result>();
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/autograd/record_function_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
jit::RegisterOperators reg_fut_ops({
jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
[](jit::Stack& stack) {
[](jit::Stack* stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
auto profiledFut = _call_end_callbacks_on_fut(tensor, fut);
// return future that completes when profiling callbacks have run.
jit::push(stack, std::move(profiledFut));
return 0;
},
aliasAnalysisFromSchema()),
});
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/rpc/request_callback_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void RequestCallbackImpl::processRpc(
// scriptCall is only alive within this block, use reference to avoid copy
auto& stack = scriptCall.stackRef();
if (scriptCall.hasOp()) {
scriptCall.op()->getOperation()(stack);
scriptCall.op()->getOperation()(&stack);
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
"Return value of a builtin operator or a "
Expand Down Expand Up @@ -340,7 +340,7 @@ void RequestCallbackImpl::processRpc(
auto& stack = scriptRemoteCall.stackRef();
if (scriptRemoteCall.hasOp()) {
try {
scriptRemoteCall.op()->getOperation()(stack);
scriptRemoteCall.op()->getOperation()(&stack);
} catch (const std::exception& e) {
// Don't throw in this call, but rather transfer the exception
// to the rref.
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/OVERVIEW.md
Original file line number Diff line number Diff line change
Expand Up @@ -761,10 +761,10 @@ All builtin operators are represented using a stack machine concept. An operator

```cpp
using Stack = std::vector<IValue>;
using Operation = std::function<int(Stack&)>;
using Operation = std::function<void(Stack*)>;

// schema: example_add(Tensor a, Tensor b) -> Tensor
int example_add(Stack& stack) {
void example_add(Stack* stack) {
Tensor a, b;
// stack before: ? ? ? a b <- back
pop(stack, a, b); //Templated helper function
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/codegen/cuda/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ RegisterOperators reg({
Operator(
prim::CudaFusionGroup,
[](const Node* node) -> Operation {
return [node](Stack& stack) {
fuser::cuda::runFusionGroup(node, stack);
return 0;
return [node](Stack* stack) {
fuser::cuda::runFusionGroup(node, *stack);
};
},
c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/fuser/fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ RegisterOperators reg_fused_operators({Operator(
[](const Node* node) -> Operation {
int64_t dim = node->i(attr::dim);
int64_t num_inputs = node->inputs().size();
return [dim, num_inputs](Stack& stack) {
return [dim, num_inputs](Stack* stack) {
auto result = at::cat(
fmap(
last(stack, num_inputs),
[](const IValue& i) { return i.toTensor(); }),
dim);
drop(stack, num_inputs);
pack(stack, std::move(result));
return 0;
};
},
aliasAnalysisIsSpecialCase())});
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ bool Function::append_operator(

auto jit_op = findOperatorFor(opname);
if (jit_op) {
fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
} else {
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
if (op.has_value()) {
Expand Down
19 changes: 8 additions & 11 deletions torch/csrc/jit/passes/batch_mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {

RegisterOperators mm_tree_reduction_reg({Operator(
"prim::MMTreeReduce(...) -> Tensor",
[](Stack& stack) {
[](Stack* stack) {
auto num_inputs = pop(stack).toInt();
std::vector<at::Tensor> inputs;
inputs.reserve(num_inputs);
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
for (auto it = stack->end() - num_inputs; it != stack->end(); ++it) {
inputs.push_back(std::move(*it).toTensor());
}
drop(stack, num_inputs);
Expand Down Expand Up @@ -156,7 +156,6 @@ RegisterOperators mm_tree_reduction_reg({Operator(
}
push(stack, std::move(acc));
}
return 0;
},
aliasAnalysisIsSpecialCase())});

Expand Down Expand Up @@ -320,11 +319,11 @@ RegisterOperators mm_batch_side_reg({Operator(
[](const Node* node) -> Operation {
size_t num_other_side_inputs = node->inputs().size() - 1;
Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
return [num_other_side_inputs, single_side](Stack& stack) {
return [num_other_side_inputs, single_side](Stack* stack) {
at::Tensor side_input;
std::vector<at::Tensor> other_side_inputs;
other_side_inputs.reserve(num_other_side_inputs);
for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
for (auto it = stack->end() - num_other_side_inputs; it != stack->end();
++it) {
other_side_inputs.push_back(std::move(*it).toTensor());
}
Expand All @@ -343,23 +342,21 @@ RegisterOperators mm_batch_side_reg({Operator(
mm_out,
num_other_side_inputs,
/*dim=*/single_side == Side::LHS ? 1 : 0);
stack.insert(
stack.end(),
stack->insert(
stack->end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
} else {
if (single_side == Side::LHS) {
for (at::Tensor& other : other_side_inputs) {
stack.emplace_back(side_input.mm(other));
stack->emplace_back(side_input.mm(other));
}
} else {
for (at::Tensor& other : other_side_inputs) {
stack.emplace_back(other.mm(side_input));
stack->emplace_back(other.mm(side_input));
}
}
}

return 0;
};
},
aliasAnalysisIsSpecialCase())});
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/constant_propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
default: {
auto op = n->getOperation();
try {
op(stack);
op(&stack);
} catch (...) {
return c10::nullopt;
}
Expand Down
Loading

0 comments on commit 53af9df

Please sign in to comment.