Skip to content

Commit

Permalink
Support multiple autocontexts per graph.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 12, 2024
1 parent 1dfc02f commit 0dd1d9a
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 69 deletions.
192 changes: 131 additions & 61 deletions tf_shell/cc/optimizers/moduli_autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,45 +120,37 @@ Status AddScalarConstNode(T value, utils::Mutation* mutation,
return status;
}

StatusOr<utils::MutableNodeView*> GetAutoShellContextNode(
utils::MutableNodeView* GetNextAutoShellContextNode(
utils::MutableGraphView& graph_view) {
utils::MutableNodeView* found = nullptr;
for (int i = 0; i < graph_view.NumNodes(); ++i) {
auto const* node_view = graph_view.GetNode(i);
auto const* node_def = node_view->node();
if (node_def->op() == kShellAutoContext) {
if (found != nullptr) {
return errors::InvalidArgument(
"Multiple AutoShellContext nodes found in the graph.");
}
found = graph_view.GetNode(i);
return graph_view.GetNode(i);
}
}
if (found == nullptr) {
return errors::NotFound("AutoShellContext node not found.");
}
return found;
return nullptr;
}

Status GetAutoShellContextParams(utils::MutableNodeView* autocontext_node_view,
Status GetAutoShellContextParams(utils::MutableNodeView* autocontext,
ShellAutoParams& params) {
auto const* cleartext_bits_node =
autocontext_node_view->GetRegularFanin(0).node_view()->node();
autocontext->GetRegularFanin(0).node_view()->node();
TF_RETURN_IF_ERROR(GetScalarConstValue<uint64_t, DT_UINT64>(
*cleartext_bits_node, &params.cleartext_bits));

auto const* scaling_factor_node =
autocontext_node_view->GetRegularFanin(1).node_view()->node();
autocontext->GetRegularFanin(1).node_view()->node();
TF_RETURN_IF_ERROR(GetScalarConstValue<uint64_t, DT_UINT64>(
*scaling_factor_node, &params.scaling_factor));

auto const* noise_offset_node =
autocontext_node_view->GetRegularFanin(2).node_view()->node();
autocontext->GetRegularFanin(2).node_view()->node();
TF_RETURN_IF_ERROR(GetScalarConstValue<uint64_t, DT_UINT64>(
*noise_offset_node, &params.noise_offset_bits));

auto const* noise_variance_node =
autocontext_node_view->GetRegularFanin(3).node_view()->node();
autocontext->GetRegularFanin(3).node_view()->node();
TF_RETURN_IF_ERROR(GetScalarConstValue<uint64_t, DT_UINT64>(
*noise_variance_node, &params.noise_variance));

Expand All @@ -172,12 +164,59 @@ Status GetAutoShellContextParams(utils::MutableNodeView* autocontext_node_view,
return OkStatus();
}

int GetMulDepth(utils::MutableGraphView& graph_view) {
StatusOr<bool> DecryptUsesSameContext(utils::MutableNodeView const* node_view,
utils::MutableNodeView const* context) {
utils::MutableNodeView const* trace = node_view;
if (trace == nullptr || !IsDecrypt(*trace->node())) {
return errors::InvalidArgument(
"Expected the node to be a decrypt node, but found ", trace->GetOp());
}
trace = trace->GetRegularFanin(0).node_view();

// The next op should be a strided slice.
if (trace->GetOp() != "StridedSlice") {
return errors::InvalidArgument(
"Traceback to context expected the first op to be a strided slice, "
"but found ",
trace->GetOp());
}
trace = trace->GetRegularFanin(0).node_view();

// The next op should be a tensor list gather. This is how the context was
// created.
if (trace->GetOp() != "TensorListGather") {
return errors::InvalidArgument(
"Traceback to context expected the second op to be a tensor list "
"gather, but found ",
trace->GetOp());
}

// Tracing further back in the graph is difficult because of how TensorFlow
// decides to optimize the graph before this optimizer is run. It is
// difficult because the context may be cached and read from disk.
// Instead of handling all possible cases, take advantage of the name scope.
// The context part of the name scopes of the TensorListGather should
// match that of the context node. This is more fragile on the tf-shell
// side, but will not break if the TensorFlow graph optimizers change.
std::string actx_name = context->GetName();
int actx_ns_start = actx_name.find("create_autocontext64");
int actx_ns_end = actx_name.find("/", actx_ns_start);
std::string actx_ns =
actx_name.substr(actx_ns_start, actx_ns_end - actx_ns_start);

if (trace->GetName().find(actx_ns) == std::string::npos) {
return false;
}
return true;
}

StatusOr<int> GetMulDepth(utils::MutableGraphView& graph_view,
utils::MutableNodeView const* autocontext) {
// Traverse the graph and return the maximum multiplicative depth.
int const num_nodes = graph_view.NumNodes();
std::vector<uint64_t> node_mul_depth(num_nodes);

uint64_t max_noise = 0;
uint64_t max_depth = 0;
for (int i = 0; i < num_nodes; ++i) {
auto const* this_node_view = graph_view.GetNode(i);
auto const* this_node_def = this_node_view->node();
Expand Down Expand Up @@ -221,12 +260,20 @@ int GetMulDepth(utils::MutableGraphView& graph_view) {
int const fanin_a_index = this_node_view->GetRegularFanin(0).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
} else if (IsDecrypt(*this_node_def)) {
// Decryption is where the maximum multiplicative depth is reached.
int const fanin_a_index = this_node_view->GetRegularFanin(2).node_index();
node_mul_depth[i] = node_mul_depth[fanin_a_index];
max_noise = std::max(max_noise, node_mul_depth[i]);

// Ensure the decrypt op uses the same autocontext node as the argument
// (for the case where there are multiple autocontext nodes in the graph).
TF_ASSIGN_OR_RETURN(bool is_same_autocontext,
DecryptUsesSameContext(this_node_view, autocontext));
if (is_same_autocontext) {
max_depth = std::max(max_depth, node_mul_depth[i]);
}
}
}
return max_noise;
return max_depth;
}

// Function for modular exponentiation
Expand Down Expand Up @@ -616,6 +663,7 @@ Status EstimateNodeNoise(

template <typename T>
Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view,
utils::MutableNodeView const* autocontext,
ShellParams const& params,
uint64_t const noise_varaince, uint64_t* log_noise) {
// Estimate the ciphertext noise growth by traversing the graph.
Expand Down Expand Up @@ -654,8 +702,18 @@ Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view,
TF_RETURN_IF_ERROR(
EstimateNodeNoise<T>(graph_view, i, node_noise, params, error_params));

// Update the maximum noise budget.
log_max_noise = std::max(log_max_noise, node_noise[i]);
// If this is a decryption node, update the maximum node budget. Ensure the
// decryption node uses the same autocontext node as the argument (for the
// case when there are multiple).
utils::MutableNodeView const* node_view = graph_view.GetNode(i);
if (IsDecrypt(*node_view->node())) {
TF_ASSIGN_OR_RETURN(bool is_same_autocontext,
DecryptUsesSameContext(node_view, autocontext));
if (is_same_autocontext) {
// Update the maximum noise budget.
log_max_noise = std::max(log_max_noise, node_noise[i]);
}
}
}

if constexpr (debug_moduli) {
Expand All @@ -670,15 +728,14 @@ Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view,
// decode(encode(a)) where a is a cleartext (tf datatype) and marks nodes to
// delete accordingly.
Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
utils::MutableNodeView* autocontext,
ShellParams const& params,
ShellAutoParams const& auto_params) {
TF_ASSIGN_OR_RETURN(auto* autocontext_node_view,
GetAutoShellContextNode(graph_view));
int autocontext_node_index = autocontext_node_view->node_index();
int autocontext_node_index = autocontext->node_index();

if constexpr (debug_graph) {
std::cout << "Removing AutoShellContext node: "
<< autocontext_node_view->node()->DebugString() << std::endl;
<< autocontext->node()->DebugString() << std::endl;
}

// Create the new inputs for the ShellContext node.
Expand All @@ -690,7 +747,7 @@ Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
std::string seed_str_name = "ContextImport64/seed";

utils::Mutation* mutation = graph_view.GetMutationBuilder();
std::string device = autocontext_node_view->GetDevice();
std::string device = autocontext->GetDevice();
TF_RETURN_IF_ERROR(
AddScalarConstNode<uint64_t>(params.log_n, mutation, log_n_name, device));
TF_RETURN_IF_ERROR(AddScalarConstNode<std::vector<uint64_t>>(
Expand All @@ -707,8 +764,9 @@ Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
// Replace the AutoShellContext node with ShellContextImport64.
NodeDef shell_context_import_node;
shell_context_import_node.set_op(kShellContext);
// shell_context_import_node.set_name(autocontext_node_view->GetName());
shell_context_import_node.set_name(kShellContext);
std::string new_name = autocontext->GetName();
new_name = new_name.insert(new_name.find_last_of('/') + 1, "Optimized");
shell_context_import_node.set_name(new_name);
shell_context_import_node.set_device(device);
shell_context_import_node.add_input(log_n_name);
shell_context_import_node.add_input(qs_name);
Expand All @@ -726,12 +784,12 @@ Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,

// Update all fanouts of the AutoShellContext node to point to the new
// ShellContext node.
auto const& all_fanouts = autocontext_node_view->GetRegularFanouts();
auto const& all_fanouts = autocontext->GetRegularFanouts();
for (int i = 0; i < static_cast<int>(all_fanouts.size()); ++i) {
auto const& fanouts_by_port = all_fanouts[i];
for (auto const& fanout : fanouts_by_port) {
mutation->AddOrUpdateRegularFanin(fanout.node_view(), fanout.index(),
{kShellContext, i});
{new_name, i});
if constexpr (debug_graph) {
std::cout << "Updating fanout: " << fanout.node_view()->node()->name()
<< " index: " << fanout.index()
Expand All @@ -741,49 +799,31 @@ Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view,
}

mutation->RemoveNode(graph_view.GetNode(autocontext_node_index));
for (auto const& fanin : autocontext_node_view->GetRegularFanins()) {
mutation->RemoveNode(fanin.node_view());
for (auto const& fanin : autocontext->GetRegularFanins()) {
// When there are multiple autocontexts, the fanins may be shared. Only
// remove the fanin if it is not shared.
if (fanin.node_view()->NumRegularFanouts() == 1) {
mutation->RemoveNode(fanin.node_view());
}
}

TF_RETURN_IF_ERROR(mutation->Apply());

return OkStatus();
}

} // namespace

ModuliAutotuneOptimizer::ModuliAutotuneOptimizer() {}

Status ModuliAutotuneOptimizer::Init(
tensorflow::RewriterConfig_CustomGraphOptimizer const* config) {
return OkStatus();
}

Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
GrapplerItem const& item,
GraphDef* optimized_graph) {
GrapplerItem mutable_item(item);
Status status;
utils::MutableGraphView graph_view(&mutable_item.graph, &status);
TF_RETURN_IF_ERROR(status);

// See if an autocontext node exists in the graph. If not, there is nothing
// to do.
TF_ASSIGN_OR_RETURN(auto* autocontext, GetAutoShellContextNode(graph_view));

Status OptimizeAutocontext(utils::MutableGraphView& graph_view,
utils::MutableNodeView* autocontext) {
// Use GetScalarConstValue to get value of plaintext modulus,
// etc.
ShellAutoParams auto_params;
TF_RETURN_IF_ERROR(GetAutoShellContextParams(autocontext, auto_params));

// Topological sort so all subsequent traversals are in order.
TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));

// Find the maximum multiplicative depth of the graph and use this to set
// the plaintext modulus t, based on the scaling factor and depth.
// Note the mul_depth + 1 accounts for the first multiplication by the
// scaling factor during encoding.
int mul_depth = GetMulDepth(graph_view);
TF_ASSIGN_OR_RETURN(int mul_depth, GetMulDepth(graph_view, autocontext));
uint64_t total_plaintext_bits =
auto_params.cleartext_bits +
std::ceil(std::log2(
Expand Down Expand Up @@ -827,7 +867,8 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
}

TF_RETURN_IF_ERROR(EstimateNoiseGrowth<uint64_t>(
graph_view, params, auto_params.noise_variance, &log_max_noise));
graph_view, autocontext, params, auto_params.noise_variance,
&log_max_noise));

uint64_t total_ct_bits =
BitWidth(params.t) + log_max_noise + auto_params.noise_offset_bits;
Expand Down Expand Up @@ -885,8 +926,37 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
<< std::endl;
}

TF_RETURN_IF_ERROR(
ReplaceAutoparamWithContext(graph_view, params, auto_params));
TF_RETURN_IF_ERROR(ReplaceAutoparamWithContext(graph_view, autocontext,
params, auto_params));
return OkStatus();
}

} // namespace

ModuliAutotuneOptimizer::ModuliAutotuneOptimizer() {}

Status ModuliAutotuneOptimizer::Init(
tensorflow::RewriterConfig_CustomGraphOptimizer const* config) {
return OkStatus();
}

Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster,
GrapplerItem const& item,
GraphDef* optimized_graph) {
GrapplerItem mutable_item(item);
Status status;
utils::MutableGraphView graph_view(&mutable_item.graph, &status);
TF_RETURN_IF_ERROR(status);

// Topological sort so all subsequent traversals are in order.
TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));

// Optimize each autocontext op in the graph.
utils::MutableNodeView* autocontext = GetNextAutoShellContextNode(graph_view);
while (autocontext != nullptr) {
TF_RETURN_IF_ERROR(OptimizeAutocontext(graph_view, autocontext));
autocontext = GetNextAutoShellContextNode(graph_view);
}

if constexpr (debug_graph) {
std::cout << "Optimized graph: " << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions tf_shell/cc/optimizers/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ bool IsEncode(NodeDef const& node) { return node.op() == kEncode; }
bool IsDecode(NodeDef const& node) { return node.op() == kDecode; }
bool IsEncrypt(NodeDef const& node) { return node.op() == kEncrypt; }
bool IsPlainDerypt(NodeDef const& node) { return node.op() == kDecrypt; }
bool IsFastDerypt(NodeDef const& node) { return node.op() == kFastDecrypt; }
bool IsFastDecrypt(NodeDef const& node) { return node.op() == kFastDecrypt; }
bool IsDecrypt(NodeDef const& node) {
return IsPlainDerypt(node) || IsFastDerypt(node);
return IsPlainDerypt(node) || IsFastDecrypt(node);
}

bool IsAddCtCt(NodeDef const& node) { return node.op() == kAddCtCt; }
Expand Down
6 changes: 3 additions & 3 deletions tf_shell/cc/optimizers/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ constexpr char kEncode[] = "PolynomialImport64";
constexpr char kDecode[] = "PolynomialExport64";
constexpr char kEncrypt[] = "Encrypt64";
constexpr char kDecrypt[] = "Decrypt64";
constexpr char kFastDecrypt[] = "FastDecrypt64";
constexpr char kFastDecrypt[] = "DecryptFastRotated64";

constexpr char kAddCtCt[] = "AddCtCt64";
constexpr char kSubCtCt[] = "SubCtCt64";
Expand All @@ -45,7 +45,7 @@ constexpr char kMatMulPtCt[] = "MatMulPtCt64";
constexpr char kFastMatMulPtCt[] = "FastMatMulPtCt64";

constexpr char kRoll[] = "Roll64";
constexpr char kReduceSumByRotation[] = "ReduceSumByRotation64";
constexpr char kReduceSumByRotation[] = "ReduceSumByRotationCt64";
constexpr char kFastReduceSumByRotation[] = "FastReduceSumByRotationCt64";
constexpr char kReduceSum[] = "ReduceSumCt64";

Expand All @@ -64,7 +64,7 @@ bool IsEncode(NodeDef const& node);
bool IsDecode(NodeDef const& node);
bool IsEncrypt(NodeDef const& node);
bool IsPlainDerypt(NodeDef const& node);
bool IsFastDerypt(NodeDef const& node);
bool IsFastDecrypt(NodeDef const& node);
bool IsDecrypt(NodeDef const& node);

bool IsAddCtCt(NodeDef const& node);
Expand Down
Loading

0 comments on commit 0dd1d9a

Please sign in to comment.