Skip to content

Commit

Permalink
[JIT] Clone runOptimizations and similar functions for profiling exec…
Browse files Browse the repository at this point in the history
…utor. (pytorch#42656)

Summary:
Pull Request resolved: pytorch#42656

Thing change will allow us to more freely experiment with pass pipelines
in the profiling executor without affecting passes in the legacy
executor. Also, it somewhat helps to keep all passes in one place to be
able to tell what's going on.

Currently this change should not affect any behavior as I copied the
passes exactly as they've been invoked before, but we will probably want
to change these pipelines in a near future.

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D22971050

Pulled By: ZolotukhinM

fbshipit-source-id: f5bb60783a553c7b51c5343eec7f8fe40037ff99
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Aug 6, 2020
1 parent a4dbc64 commit 57854e7
Showing 1 changed file with 134 additions and 19 deletions.
153 changes: 134 additions & 19 deletions torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/passes/batch_mm.h>
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/passes/clear_undefinedness.h>
Expand All @@ -18,11 +19,14 @@
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_grad_of.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>

C10_DECLARE_bool();

Expand Down Expand Up @@ -79,25 +83,134 @@ static bool needsGradientInProfilingMode(Block* b) {
return false;
}

void runNooptPassPipeline(std::shared_ptr<Graph>& graph) {
LowerGradOf(*graph);
RemoveExpands(graph);
CanonicalizeOps(graph);
EliminateDeadCode(graph);
}

void runPreAutodiffPassPipeline(std::shared_ptr<Graph>& graph) {
InsertGuards(graph);
LowerGradOf(*graph);
EliminateRedundantGuards(graph);
InsertBailOuts(graph);
specializeAutogradZero(*graph);
// runRequiredPasses
{
RemoveExpands(graph);
CanonicalizeOps(graph);
EliminateDeadCode(graph);
}
PeepholeOptimize(graph);
ConstantPropagation(graph);

// runOptimization:
{
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);

PeepholeOptimize(graph);
ConstantPropagation(graph);
ConstantPooling(graph);

UnrollLoops(graph);
// run again with unrolled loops
RemoveListMutation(graph);
PeepholeOptimize(graph);
ConstantPropagation(graph);

EliminateCommonSubexpression(graph);

CheckInplace(graph);
}
}

void runDiffGraphPasses(std::shared_ptr<Graph>& graph) {
// runOptimization:
{
// Basic graph preprocessing to eliminate noise.
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);

PeepholeOptimize(graph);
ConstantPropagation(graph);
ConstantPooling(graph);

UnrollLoops(graph);
// run again with unrolled loops
RemoveListMutation(graph);
PeepholeOptimize(graph);
ConstantPropagation(graph);

EliminateCommonSubexpression(graph);

CheckInplace(graph);
}

// runNondiffOptimization
{
// Run custom passes that different backends can register.
for (const auto& passPair : getCustomPrePasses()) {
passPair.first(graph);
}

// TupleConstruct / TupleUnpack pairs can still be present at this point
// and must be removed for fusion.
LowerSimpleTuples(graph);

// Rewrite subgraphs with many MMs into expressions that batch them.
BatchMM(graph);

if (tensorExprFuserEnabled()) {
FuseTensorExprs(graph);
} else {
FuseGraph(graph, true);
}

// Run custom post-fusion passes
for (const auto& passPair : getCustomPostPasses()) {
passPair.first(graph);
}
}
}

void runNoGradOptimizations(std::shared_ptr<Graph>& graph) {
// runNondiffOptimization
{
// Run custom passes that different backends can register.
for (const auto& passPair : getCustomPrePasses()) {
passPair.first(graph);
}

// TupleConstruct / TupleUnpack pairs can still be present at this point
// and must be removed for fusion.
LowerSimpleTuples(graph);

// Rewrite subgraphs with many MMs into expressions that batch them.
BatchMM(graph);

if (tensorExprFuserEnabled()) {
FuseTensorExprs(graph);
} else {
FuseGraph(graph, true);
}

// Run custom post-fusion passes
for (const auto& passPair : getCustomPostPasses()) {
passPair.first(graph);
}
}
}

void ProfilingGraphExecutorImpl::runProfilingOptimizations(
std::shared_ptr<Graph>& copy) {
if (!getGraphExecutorOptimize()) {
LowerGradOf(*copy);
runRequiredPasses(copy);
runNooptPassPipeline(copy);
return;
}

InsertGuards(copy);
LowerGradOf(*copy);
EliminateRedundantGuards(copy);
InsertBailOuts(copy);
GRAPH_DUMP("After InsertBailOuts: ", copy);
specializeAutogradZero(*copy);

runRequiredPasses(copy);
PeepholeOptimize(copy);
ConstantPropagation(copy);
runOptimization(copy);
runPreAutodiffPassPipeline(copy);

if (needsGradientInProfilingMode(copy->block())) {
auto diff_nodes = CreateAutodiffSubgraphs(
Expand All @@ -106,17 +219,14 @@ void ProfilingGraphExecutorImpl::runProfilingOptimizations(
for (Node* dnode : diff_nodes) {
auto diff_graph = std::move(dnode->g(attr::Subgraph));
Gradient gradient = differentiate(diff_graph);
runOptimization(gradient.f);
// run non diff optimization on the forward graph
runNondiffOptimization(gradient.f, true);
runDiffGraphPasses(gradient.f);
packGradient(gradient, dnode);
}
InlineAutodiffSubgraphs(
copy,
getAutodiffSubgraphInlining() ? autodiffSubgraphInlineThreshold : 1);

} else {
runNondiffOptimization(copy, true);
runNoGradOptimizations(copy);
}
EliminateDeadCode(copy);
GRAPH_DUMP("Optimized Graph : ", copy);
Expand All @@ -132,7 +242,12 @@ void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations(
// may carry over undefinedness
// from profiled backward graphs
ClearUndefinedness(copy);
runRequiredPasses(copy);
// runRequiredPasses
{
RemoveExpands(copy);
CanonicalizeOps(copy);
EliminateDeadCode(copy);
}
if (!getGraphExecutorOptimize()) {
return;
}
Expand Down

0 comments on commit 57854e7

Please sign in to comment.