Skip to content

Commit

Permalink
Simplify grappler optimizers.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 2, 2024
1 parent 321a163 commit 84958c4
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 95 deletions.
45 changes: 23 additions & 22 deletions tf_shell/cc/optimizers/ct_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,16 @@ struct ReorderArith {
int inner_pt_node_index;
};

void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) {
auto const* outer_node =
ctx.graph_view.GetNode(reorder.outer_node_index)->node();
auto const* inner_node =
ctx.graph_view.GetNode(reorder.inner_node_index)->node();
void PrintReorderArith(utils::MutableGraphView& graph_view,
ReorderArith const& reorder) {
auto const* outer_node = graph_view.GetNode(reorder.outer_node_index)->node();
auto const* inner_node = graph_view.GetNode(reorder.inner_node_index)->node();
auto const* outer_pt_node =
ctx.graph_view.GetNode(reorder.outer_pt_node_index)->node();
graph_view.GetNode(reorder.outer_pt_node_index)->node();
auto const* inner_ct_node =
ctx.graph_view.GetNode(reorder.inner_ct_node_index)->node();
graph_view.GetNode(reorder.inner_ct_node_index)->node();
auto const* inner_pt_node =
ctx.graph_view.GetNode(reorder.inner_pt_node_index)->node();
graph_view.GetNode(reorder.inner_pt_node_index)->node();

std::cout << outer_node->name() << " ( " << inner_node->name() << " ( "
<< inner_ct_node->name() << " , " << inner_pt_node->name() << " ), "
Expand All @@ -60,9 +59,10 @@ void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) {
// If the outer_op is add or sub, the inner_op must be add or sub.
// If instead the outer_op is mul, the inner_op must be mul.
// If the inner op is used elsewhere (has fanout>1), the pattern is not matched.
bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) {
bool FindAddOrSub(utils::MutableGraphView& graph_view, int node_index,
ReorderArith* reorder) {
// Check given node is op(ct, pt).
auto const* outer_node_view = ctx.graph_view.GetNode(node_index);
auto const* outer_node_view = graph_view.GetNode(node_index);
auto const* outer_node_def = outer_node_view->node();

if (!IsAddCtPt(*outer_node_def) && !IsSubCtPt(*outer_node_def) &&
Expand Down Expand Up @@ -123,7 +123,7 @@ bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) {

if constexpr (debug) {
std::cout << "Found pattern:";
PrintReorderArith(ctx, new_reorder);
PrintReorderArith(graph_view, new_reorder);
}

*reorder = new_reorder;
Expand All @@ -133,10 +133,11 @@ bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) {

// This function replaces the pattern outer_op(inner_op(ct, pt), pt) with
// outer_op(ct, inner_op(pt, pt)).
Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder,
Status ApplyReorderArith(utils::MutableGraphView& graph_view,
ReorderArith const& reorder,
std::vector<bool>* nodes_to_delete) {
GraphDef const* graph = ctx->graph_view.graph();
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
GraphDef const* graph = graph_view.graph();
utils::Mutation* mutation = graph_view.GetMutationBuilder();
Status status;

// First replace the inner node with a pt pt node.
Expand Down Expand Up @@ -210,14 +211,13 @@ Status CtPtOptimizer::Init(

Status CtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item,
GraphDef* optimized_graph) {
GrapplerItem mutable_item = item;
GrapplerItem mutable_item(item);
Status status;
RemapperContext ctx(&mutable_item, &status);
utils::MutableGraphView graph_view(&mutable_item.graph, &status);
TF_RETURN_IF_ERROR(status);

// Topological sort and process the nodes in reverse.
TF_RETURN_IF_ERROR(
ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));

bool finished = false;
while (!finished) {
Expand All @@ -232,17 +232,18 @@ Status CtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item,

// Remap op( op(ct, pt), pt) to op(ct, op(pt, pt)).
ReorderArith reorder;
if (FindAddOrSub(ctx, i, &reorder)) {
TF_RETURN_IF_ERROR(ApplyReorderArith(&ctx, reorder, &nodes_to_delete));
if (FindAddOrSub(graph_view, i, &reorder)) {
TF_RETURN_IF_ERROR(
ApplyReorderArith(graph_view, reorder, &nodes_to_delete));
finished = false;
}
}

// Remove nodes.
utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder();
utils::Mutation* mutation = graph_view.GetMutationBuilder();
for (int i = 0; i < num_nodes; ++i) {
if (nodes_to_delete[i]) {
mutation->RemoveNode(ctx.graph_view.GetNode(i));
mutation->RemoveNode(graph_view.GetNode(i));
}
}
TF_RETURN_IF_ERROR(mutation->Apply());
Expand Down
56 changes: 28 additions & 28 deletions tf_shell/cc/optimizers/moduli_autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ Status AddScalarConstNode(T value, utils::Mutation* mutation,
return status;
}

Status GetAutoShellContextParams(RemapperContext& ctx,
Status GetAutoShellContextParams(utils::MutableGraphView& graph_view,
ShellAutoParams& params) {
// Get the plaintext modulus t.
auto const* autocontext_node_view = ctx.graph_view.GetNode(kShellAutoContext);
auto const* autocontext_node_view = graph_view.GetNode(kShellAutoContext);
// auto const* autocontext_node_def = autocontext_node_view->node();

auto const* cleartext_bits_node =
Expand Down Expand Up @@ -156,14 +156,14 @@ Status GetAutoShellContextParams(RemapperContext& ctx,
return OkStatus();
}

int GetMulDepth(RemapperContext& ctx) {
int GetMulDepth(utils::MutableGraphView& graph_view) {
// Traverse the graph and return the maximum multiplicative depth.
int const num_nodes = ctx.graph_view.NumNodes();
int const num_nodes = graph_view.NumNodes();
std::vector<uint64_t> node_mul_depth(num_nodes);

uint64_t max_noise = 0;
for (int i = 0; i < num_nodes; ++i) {
auto const* this_node_view = ctx.graph_view.GetNode(i);
auto const* this_node_view = graph_view.GetNode(i);
auto const* this_node_def = this_node_view->node();

if (IsArithmetic(*this_node_def) || IsMatMul(*this_node_def) ||
Expand Down Expand Up @@ -369,11 +369,11 @@ Status ChooseShellParams(ShellParams& params, uint64_t const total_pt_bits,
// Returns the noise budget of the current node.
template <typename T>
Status EstimateNodeNoise(
RemapperContext& ctx, int node_index, std::vector<uint64_t>& node_noise,
ShellParams const& params,
utils::MutableGraphView& graph_view, int node_index,
std::vector<uint64_t>& node_noise, ShellParams const& params,
rlwe::RnsErrorParams<rlwe::MontgomeryInt<T>>& error_params) {
// Get the current node and its fanin nodes.
auto const* node_view = ctx.graph_view.GetNode(node_index);
auto const* node_view = graph_view.GetNode(node_index);
auto const* node_def = node_view->node();

uint64_t* this_noise = &node_noise[node_index];
Expand Down Expand Up @@ -491,10 +491,11 @@ Status EstimateNodeNoise(
}

template <typename T>
Status EstimateNoiseGrowth(RemapperContext& ctx, ShellParams const& params,
Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view,
ShellParams const& params,
uint64_t const noise_varaince, uint64_t* log_noise) {
// Estimate the ciphertext noise growth by traversing the graph.
int const num_nodes = ctx.graph_view.NumNodes();
int const num_nodes = graph_view.NumNodes();
std::vector<uint64_t> node_noise(num_nodes);

// Create RnsErrorParams.
Expand Down Expand Up @@ -527,7 +528,7 @@ Status EstimateNoiseGrowth(RemapperContext& ctx, ShellParams const& params,
for (int i = 0; i < num_nodes; ++i) {
// Estimate the noise budget of this node.
TF_RETURN_IF_ERROR(
EstimateNodeNoise<T>(ctx, i, node_noise, params, error_params));
EstimateNodeNoise<T>(graph_view, i, node_noise, params, error_params));

// Update the maximum noise budget.
log_max_noise = std::max(log_max_noise, node_noise[i]);
Expand All @@ -544,11 +545,11 @@ Status EstimateNoiseGrowth(RemapperContext& ctx, ShellParams const& params,
// Returns true if the node_index points to the outermost op of the pattern
// decode(encode(a)) where a is a cleartext (tf datatype) and marks nodes to
// delete accordingly.
Status ReplaceAutoparamWithContext(RemapperContext& ctx,
Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
ShellParams const& params,
ShellAutoParams const& auto_params) {
utils::MutableNodeView* autocontext_node_view =
ctx.graph_view.GetNode(kShellAutoContext);
graph_view.GetNode(kShellAutoContext);
int autocontext_node_index = autocontext_node_view->node_index();

if constexpr (debug_graph) {
Expand All @@ -564,7 +565,7 @@ Status ReplaceAutoparamWithContext(RemapperContext& ctx,
std::string noise_var_name = "ContextImport64/noise_variance";
std::string seed_str_name = "ContextImport64/seed";

utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder();
utils::Mutation* mutation = graph_view.GetMutationBuilder();
std::string device = autocontext_node_view->GetDevice();
TF_RETURN_IF_ERROR(
AddScalarConstNode<uint64_t>(params.log_n, mutation, log_n_name, device));
Expand Down Expand Up @@ -615,7 +616,7 @@ Status ReplaceAutoparamWithContext(RemapperContext& ctx,
}
}

mutation->RemoveNode(ctx.graph_view.GetNode(autocontext_node_index));
mutation->RemoveNode(graph_view.GetNode(autocontext_node_index));
for (auto const& fanin : autocontext_node_view->GetRegularFanins()) {
mutation->RemoveNode(fanin.node_view());
}
Expand All @@ -637,14 +638,14 @@ Status ModuliAutotuneOptimizer::Init(
Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
GrapplerItem const& item,
GraphDef* optimized_graph) {
GrapplerItem mutable_item = item;
GrapplerItem mutable_item(item);
Status status;
RemapperContext ctx(&mutable_item, &status);
utils::MutableGraphView graph_view(&mutable_item.graph, &status);
TF_RETURN_IF_ERROR(status);

// See if an autocontext node exists in the graph. If not, there is nothing
// to do.
auto const* autocontext_view = ctx.graph_view.GetNode(kShellAutoContext);
auto const* autocontext_view = graph_view.GetNode(kShellAutoContext);
if (autocontext_view == nullptr) {
*optimized_graph = std::move(mutable_item.graph);
return OkStatus();
Expand All @@ -654,25 +655,24 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
std::string duplicate_autocontext = kShellAutoContext;
duplicate_autocontext += "_1";
auto const* duplicate_autocontext_view =
ctx.graph_view.GetNode(duplicate_autocontext);
graph_view.GetNode(duplicate_autocontext);
if (duplicate_autocontext_view != nullptr) {
return errors::FailedPrecondition("Multiple autocontext nodes found.");
}

// Use GetScalarConstValue to get value of plaintext modulus,
// etc.
ShellAutoParams auto_params;
TF_RETURN_IF_ERROR(GetAutoShellContextParams(ctx, auto_params));
TF_RETURN_IF_ERROR(GetAutoShellContextParams(graph_view, auto_params));

// Topological sort so all subsequent traversals are in order.
TF_RETURN_IF_ERROR(
ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));

// Find the maximum multiplicative depth of the graph and use this to set
// the plaintext modulus t, based on the scaling factor and depth.
// Note the mul_depth + 1 accounts for the first multiplication by the
// scaling factor during encoding.
int mul_depth = GetMulDepth(ctx);
int mul_depth = GetMulDepth(graph_view);
uint64_t total_plaintext_bits =
auto_params.cleartext_bits +
std::ceil(std::log2(
Expand Down Expand Up @@ -713,7 +713,7 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,

uint64_t log_max_noise = 0;
TF_RETURN_IF_ERROR(EstimateNoiseGrowth<uint64_t>(
ctx, params, auto_params.noise_variance, &log_max_noise));
graph_view, params, auto_params.noise_variance, &log_max_noise));

if (log_max_noise == 0) {
// No encryption in this graph. Smallest parameters will do.
Expand Down Expand Up @@ -756,14 +756,14 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
std::cout << std::endl;
}

TF_RETURN_IF_ERROR(ReplaceAutoparamWithContext(ctx, params, auto_params));
TF_RETURN_IF_ERROR(
ReplaceAutoparamWithContext(graph_view, params, auto_params));

if constexpr (debug_graph) {
std::cout << "Optimized graph: " << std::endl;
int const num_nodes = ctx.graph_view.NumNodes();
int const num_nodes = graph_view.NumNodes();
for (int i = 0; i < num_nodes; ++i) {
std::cout << ctx.graph_view.GetNode(i)->node()->DebugString()
<< std::endl;
std::cout << graph_view.GetNode(i)->node()->DebugString() << std::endl;
}
}

Expand Down
Loading

0 comments on commit 84958c4

Please sign in to comment.