Skip to content

Commit

Permalink
Conv2d op and Layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 18, 2024
1 parent f3a9456 commit ee65e2f
Show file tree
Hide file tree
Showing 18 changed files with 1,896 additions and 55 deletions.
2 changes: 2 additions & 0 deletions tf_shell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from tf_shell.python.shell_tensor import broadcast_to
from tf_shell.python.shell_tensor import split
from tf_shell.python.shell_tensor import segment_sum
from tf_shell.python.shell_tensor import conv2d
from tf_shell.python.shell_tensor import conv2d_transpose

from tf_shell.python.shell_context import ShellContext64
from tf_shell.python.shell_context import create_context64
Expand Down
661 changes: 661 additions & 0 deletions tf_shell/cc/kernels/conv_kernels.cc

Large diffs are not rendered by default.

193 changes: 192 additions & 1 deletion tf_shell/cc/ops/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,195 @@ Status ShellSegmentReductionWithNumSegmentsShape(InferenceContext* c) {
c->set_output(0, data_out);
c->set_output(1, reduction_counters);
return OkStatus();
}
}

Status ShellConv2dImpl(InferenceContext* c, bool different_num_in_channels) {
// Input shape s_x is {height, width, in_channels}. Output shape is
// {out_height, out_width, out_channels}. The batch size is implicit in the
// ciphertext ring degree and not part of the shape.
ShapeHandle s_x = c->input(1);
ShapeHandle s_filter = c->input(2);
DimensionHandle one = c->MakeDim(1);

// Check that the input tensor has rank 3.
TF_RETURN_IF_ERROR(c->WithRank(s_x, 3, &s_x));
DimensionHandle const height = c->Dim(s_x, 0);
DimensionHandle const width = c->Dim(s_x, 1);
DimensionHandle const in_channels = c->Dim(s_x, 2);

// Check that the filter tensor has rank 4.
TF_RETURN_IF_ERROR(c->WithRank(s_filter, 4, &s_filter));
DimensionHandle const filter_height = c->Dim(s_filter, 0);
DimensionHandle const filter_width = c->Dim(s_filter, 1);
DimensionHandle const filter_in_channels = c->Dim(s_filter, 2);
DimensionHandle const filter_out_channels = c->Dim(s_filter, 3);

// Check the stride.
std::vector<tsl::int32> stride;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &stride));
// DimensionHandle const stride_batch = c->MakeDim(stride[0]);
DimensionHandle const stride_height = c->MakeDim(stride[1]);
DimensionHandle const stride_width = c->MakeDim(stride[2]);
DimensionHandle const stride_in_channels = c->MakeDim(stride[3]);

// Check the padding tensor.
std::vector<tsl::int32> padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
DimensionHandle const padding_top = c->MakeDim(padding[0]);
DimensionHandle const padding_bottom = c->MakeDim(padding[1]);
DimensionHandle const padding_left = c->MakeDim(padding[2]);
DimensionHandle const padding_right = c->MakeDim(padding[3]);

std::vector<tsl::int32> dilations;
TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
// DimensionHandle const dilation_batch = c->MakeDim(dilations[0]);
DimensionHandle const dilation_height = c->MakeDim(dilations[1]);
DimensionHandle const dilation_width = c->MakeDim(dilations[2]);
// DimensionHandle const dilation_channel = c->MakeDim(dilations[3]);

// Prepare dilated filter dimensions.
DimensionHandle filter_dilated_height = filter_height;
TF_RETURN_IF_ERROR(
c->Subtract(filter_dilated_height, one, &filter_dilated_height));
TF_RETURN_IF_ERROR(c->Multiply(filter_dilated_height, dilation_height,
&filter_dilated_height));
TF_RETURN_IF_ERROR(
c->Add(filter_dilated_height, one, &filter_dilated_height));

DimensionHandle filter_dilated_width = filter_width;
TF_RETURN_IF_ERROR(
c->Subtract(filter_dilated_width, one, &filter_dilated_width));
TF_RETURN_IF_ERROR(
c->Multiply(filter_dilated_width, dilation_width, &filter_dilated_width));
TF_RETURN_IF_ERROR(c->Add(filter_dilated_width, one, &filter_dilated_width));

// Add the padding to the height and width.
DimensionHandle out_height = height;
TF_RETURN_IF_ERROR(c->Add(out_height, padding_bottom, &out_height));
TF_RETURN_IF_ERROR(c->Add(out_height, padding_top, &out_height));
TF_RETURN_IF_ERROR(
c->Subtract(out_height, filter_dilated_height, &out_height));
TF_RETURN_IF_ERROR(c->Divide(out_height, stride_height, false, &out_height));
TF_RETURN_IF_ERROR(c->Add(out_height, one, &out_height));

DimensionHandle out_width = width;
TF_RETURN_IF_ERROR(c->Add(out_width, padding_right, &out_width));
TF_RETURN_IF_ERROR(c->Add(out_width, padding_left, &out_width));
TF_RETURN_IF_ERROR(c->Subtract(out_width, filter_dilated_width, &out_width));
TF_RETURN_IF_ERROR(c->Divide(out_width, stride_width, false, &out_width));
TF_RETURN_IF_ERROR(c->Add(out_width, one, &out_width));

DimensionHandle channels = in_channels;
TF_RETURN_IF_ERROR(c->Subtract(channels, filter_in_channels, &channels));
TF_RETURN_IF_ERROR(c->Divide(channels, stride_in_channels, true, &channels));
TF_RETURN_IF_ERROR(c->Add(channels, one, &channels));

ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(out_height), c->Vector(out_width),
&output_shape));
if (different_num_in_channels) {
// If the number of input channels in the input image and the filter can be
// different, the output shape must include the channel dimension, i.e.,
// {out_height, out_width, channels, out_channels}.
// Otherwise, channels is excluded, i.e.
// {out_height, out_width, out_channels}.
TF_RETURN_IF_ERROR(
c->Concatenate(output_shape, c->Vector(channels), &output_shape));
}
TF_RETURN_IF_ERROR(c->Concatenate(
output_shape, c->Vector(filter_out_channels), &output_shape));

c->set_output(0, output_shape);
return OkStatus();
}

Status ShellConv2d(InferenceContext* c) { return ShellConv2dImpl(c, false); }

Status ShellConv2dWithChan(InferenceContext* c) {
return ShellConv2dImpl(c, true);
}

Status ShellConv2dTransposeImpl(InferenceContext* c,
bool different_num_in_channels) {
// Input shape s_x is {height, width, in_channels}. Output shape is
// {out_height, out_width, out_channels}. The batch size is implicit in the
// ciphertext ring degree and not part of the shape.
ShapeHandle s_x = c->input(1);
ShapeHandle s_filter = c->input(2);
DimensionHandle one = c->MakeDim(1);

// Check that the input tensor has rank 3.
TF_RETURN_IF_ERROR(c->WithRank(s_x, 3, &s_x));
DimensionHandle const height = c->Dim(s_x, 0);
DimensionHandle const width = c->Dim(s_x, 1);
DimensionHandle const in_channels = c->Dim(s_x, 2);

// Check that the filter tensor has rank 4.
TF_RETURN_IF_ERROR(c->WithRank(s_filter, 4, &s_filter));
DimensionHandle const filter_height = c->Dim(s_filter, 0);
DimensionHandle const filter_width = c->Dim(s_filter, 1);
DimensionHandle const filter_out_channels = c->Dim(s_filter, 2);
DimensionHandle const filter_in_channels = c->Dim(s_filter, 3);

// Check the stride.
std::vector<tsl::int32> stride;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &stride));
// DimensionHandle const stride_batch = c->MakeDim(stride[0]);
DimensionHandle const stride_height = c->MakeDim(stride[1]);
DimensionHandle const stride_width = c->MakeDim(stride[2]);
DimensionHandle const stride_in_channels = c->MakeDim(stride[3]);

// Check the padding tensor.
std::vector<tsl::int32> padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
DimensionHandle const padding_top = c->MakeDim(padding[0]);
DimensionHandle const padding_bottom = c->MakeDim(padding[1]);
DimensionHandle const padding_left = c->MakeDim(padding[2]);
DimensionHandle const padding_right = c->MakeDim(padding[3]);

// Add the padding to the height and width.
DimensionHandle out_height = height;
TF_RETURN_IF_ERROR(c->Subtract(out_height, one, &out_height));
TF_RETURN_IF_ERROR(c->Multiply(out_height, stride_height, &out_height));
TF_RETURN_IF_ERROR(c->Add(out_height, filter_height, &out_height));
TF_RETURN_IF_ERROR(c->Subtract(out_height, padding_bottom, &out_height));
TF_RETURN_IF_ERROR(c->Subtract(out_height, padding_top, &out_height));

DimensionHandle out_width = width;
TF_RETURN_IF_ERROR(c->Subtract(out_width, one, &out_width));
TF_RETURN_IF_ERROR(c->Multiply(out_width, stride_width, &out_width));
TF_RETURN_IF_ERROR(c->Add(out_width, filter_width, &out_width));
TF_RETURN_IF_ERROR(c->Subtract(out_width, padding_right, &out_width));
TF_RETURN_IF_ERROR(c->Subtract(out_width, padding_left, &out_width));

DimensionHandle channels = in_channels;
TF_RETURN_IF_ERROR(c->Subtract(channels, one, &channels));
TF_RETURN_IF_ERROR(c->Multiply(channels, stride_in_channels, &channels));
TF_RETURN_IF_ERROR(c->Add(channels, filter_in_channels, &channels));

ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(out_height), c->Vector(out_width),
&output_shape));
if (different_num_in_channels) {
// If the number of input channels in the input image and the filter can be
// different, the output shape must include the channel dimension, i.e.,
// {out_height, out_width, channels, out_channels}.
// Otherwise, channels is excluded, i.e.
// {out_height, out_width, out_channels}.
TF_RETURN_IF_ERROR(
c->Concatenate(output_shape, c->Vector(channels), &output_shape));
}
TF_RETURN_IF_ERROR(c->Concatenate(
output_shape, c->Vector(filter_out_channels), &output_shape));

c->set_output(0, output_shape);
return OkStatus();
}

Status ShellConv2dTranspose(InferenceContext* c) {
return ShellConv2dTransposeImpl(c, false);
}

Status ShellConv2dTransposeWithChan(InferenceContext* c) {
return ShellConv2dTransposeImpl(c, true);
}
10 changes: 9 additions & 1 deletion tf_shell/cc/ops/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,12 @@ Status ShellMatMulCtPtShape(InferenceContext* c);

Status ShellMatMulPtCtShape(InferenceContext* c);

Status ShellSegmentReductionWithNumSegmentsShape(InferenceContext* c);
Status ShellSegmentReductionWithNumSegmentsShape(InferenceContext* c);

Status ShellConv2d(InferenceContext* c);

Status ShellConv2dWithChan(InferenceContext* c);

Status ShellConv2dTranspose(InferenceContext* c);

Status ShellConv2dTransposeWithChan(InferenceContext* c);
133 changes: 133 additions & 0 deletions tf_shell/cc/ops/shell_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,139 @@ REGISTER_OP("UnsortedCtSegmentSum")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(ShellSegmentReductionWithNumSegmentsShape);

// Convolutions.
REGISTER_OP("Conv2dPtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2d);

REGISTER_OP("Conv2dCtPt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2d);

REGISTER_OP("Conv2dCtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2d);

REGISTER_OP("Conv2dWithChanPtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dWithChan);

REGISTER_OP("Conv2dWithChanCtPt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dWithChan);

REGISTER_OP("Conv2dWithChanCtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dWithChan);

REGISTER_OP("Conv2dTransposePtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dTranspose);

REGISTER_OP("Conv2dTransposeCtPt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dTranspose);

REGISTER_OP("Conv2dTransposeCtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dTranspose);

REGISTER_OP("Conv2dTransposeWithChanPtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dTransposeWithChan);

REGISTER_OP("Conv2dTransposeWithChanCtPt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dTransposeWithChan);

REGISTER_OP("Conv2dTransposeWithChanCtCt64")
.Input("shell_context: variant")
.Input("x: variant")
.Input("filter: variant")
.Attr("strides: list(int)")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Output("output: variant")
.SetShapeFn(ShellConv2dTransposeWithChan);

// MPC-based kernels.
REGISTER_OP("ClipAndNoiseFeaturesParty")
.Attr("Dtype: {int32, int64}")
Expand Down
Loading

0 comments on commit ee65e2f

Please sign in to comment.