Skip to content

Commit

Permalink
Fix shape issues with convolutional layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Nov 4, 2024
1 parent 2237f5e commit 5a9de54
Show file tree
Hide file tree
Showing 13 changed files with 314 additions and 140 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ the labels.
2. Run the tests.

```bash
bazel test ...
bazel test //tf_shell/...
bazel test //tf_shell_ml/... # Large tests, requires 128GB of memory.
```

3. Build the code.
Expand Down
84 changes: 71 additions & 13 deletions tf_shell/cc/kernels/conv_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Conv2dOp : public OpKernel {
std::vector<tsl::int32> stride;
std::vector<tsl::int32> padding;
std::vector<tsl::int32> dilation;
std::vector<tsl::int32> output_shape;

public:
explicit Conv2dOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {
Expand All @@ -135,6 +136,11 @@ class Conv2dOp : public OpKernel {
OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("dilations", &dilation));
OP_REQUIRES(op_ctx, dilation.size() == 4,
InvalidArgument("dilations must have 4 elements."));

OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("output_shape", &output_shape));
OP_REQUIRES(
op_ctx, output_shape.size() == 0 || output_shape.size() == 5,
InvalidArgument("output_shape must have 5 elements if provided."));
}

void Compute(OpKernelContext* op_ctx) override {
Expand Down Expand Up @@ -222,25 +228,39 @@ class Conv2dOp : public OpKernel {
(filter_width - 1) * dilation_width + 1;

int64_t const h_start = -padding_top;
int64_t const h_end = height + padding_bottom - filter_dilated_height;
int64_t h_end = height + padding_bottom - filter_dilated_height;
int64_t const w_start = -padding_left;
int64_t const w_end = width + padding_right - filter_dilated_width;
int64_t w_end = width + padding_right - filter_dilated_width;
int64_t const c_start = 0;
int64_t const c_end = in_channels - filter_in_channels;

// Allocate output with shape
// Compute the output shape.
// [batch size (implicit), out_height, out_width, out_channels]
int64_t const out_height = (h_end - h_start) / stride_height + 1;
int64_t const out_width = (w_end - w_start) / stride_width + 1;
int64_t out_height = (h_end - h_start) / stride_height + 1;
int64_t out_width = (w_end - w_start) / stride_width + 1;
int64_t const out_channels = (c_end - c_start) / stride_in_channels + 1;
if (output_shape.size() > 0) {
if (output_shape[1] != out_height) {
h_end += output_shape[1] - out_height;
out_height = output_shape[1];
}
if (output_shape[2] != out_width) {
w_end += output_shape[2] - out_width;
out_width = output_shape[2];
}
}

// Allocate the output tensor.
Tensor* output;
TensorShape output_shape;
TensorShape allocated_output_shape;
if constexpr (AllowDifferentNumInChannels) {
output_shape = {out_height, out_width, out_channels, filter_out_channels};
allocated_output_shape = {out_height, out_width, out_channels,
filter_out_channels};
} else {
output_shape = {out_height, out_width, filter_out_channels};
allocated_output_shape = {out_height, out_width, filter_out_channels};
}
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output));
OP_REQUIRES_OK(op_ctx,
op_ctx->allocate_output(0, allocated_output_shape, &output));
auto shaped_output = output->shaped<Variant, 4>(
{out_height, out_width, out_channels, filter_out_channels});

Expand Down Expand Up @@ -361,6 +381,7 @@ class Conv2dTransposeOp : public OpKernel {
std::vector<tsl::int32> stride;
std::vector<tsl::int32> padding;
std::vector<tsl::int32> dilation;
std::vector<tsl::int32> output_shape;

public:
explicit Conv2dTransposeOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {
Expand All @@ -385,6 +406,11 @@ class Conv2dTransposeOp : public OpKernel {
dilation_batch == 1 && dilation_height == 1 &&
dilation_width == 1 && dilation_in_channels == 1,
InvalidArgument("All dilations must be 1."));

OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("output_shape", &output_shape));
OP_REQUIRES(
op_ctx, output_shape.size() == 0 || output_shape.size() == 4,
InvalidArgument("output_shape must have 4 elements if provided."));
}

void Compute(OpKernelContext* op_ctx) override {
Expand Down Expand Up @@ -467,10 +493,11 @@ class Conv2dTransposeOp : public OpKernel {
dilation_width == 1 && dilation_in_channels == 1,
InvalidArgument("Dilation is not yet supported."));

// Compute the output shape.
int64_t const h_start = -filter_height + 1 + padding_top;
int64_t const h_end = ((height - 1) * stride_height) - padding_bottom;
int64_t h_end = ((height - 1) * stride_height) - padding_bottom;
int64_t const w_start = -filter_width + 1 + padding_left;
int64_t const w_end = ((width - 1) * stride_width) - padding_right;
int64_t w_end = ((width - 1) * stride_width) - padding_right;
int64_t c_start = -filter_in_channels + 1;
int64_t c_end = ((in_channels - 1) * stride_in_channels);
if constexpr (!AllowDifferentNumInChannels) {
Expand All @@ -480,9 +507,22 @@ class Conv2dTransposeOp : public OpKernel {

// Allocate output with shape
// [batch size (implicit), out_height, out_width, out_channels]
int64_t const out_height = h_end - h_start + 1;
int64_t const out_width = w_end - w_start + 1;
int64_t out_height = h_end - h_start + 1;
int64_t out_width = w_end - w_start + 1;
int64_t const out_channels = c_end - c_start + 1;
if (output_shape.size() > 0) {
if (output_shape[1] != out_height) {
h_end += output_shape[1] - out_height;
out_height = output_shape[1];
}
if (output_shape[2] != out_width) {
w_end += output_shape[2] - out_width;
out_width = output_shape[2];
}
}

// Allocate the output tensor.

Tensor* output;
TensorShape output_shape;
if constexpr (AllowDifferentNumInChannels) {
Expand Down Expand Up @@ -567,6 +607,24 @@ class Conv2dTransposeOp : public OpKernel {
}
}
} // End matrix multiplication

// For some inputs which use the output_shape attribute, elements
// of the output may have no inputs associated with them. In this
// case, the dot product will be nullptr. A valid zero must be
// inserted. Since there is no way to create a zero ciphertext
// without the key, subtract one of the ciphertext inputs from
// itself to get a zero.
if (dot_product == nullptr) {
x_val = shaped_x(0, 0, 0).get<InputCtOrPoly>();
filter_val = shaped_filter(0, 0, 0, 0).get<FilterCtOrPoly>();
if constexpr (std::is_same<InputCtOrPoly,
SymmetricCtVariant<T>>::value) {
dot_product = new SymmetricCt(x_val->ct); // copy
} else {
dot_product = new SymmetricCt(filter_val->ct); // copy
}
OP_REQUIRES_OK(op_ctx, dot_product->SubInPlace(*dot_product));
}
OP_REQUIRES(op_ctx, dot_product != nullptr,
Internal("Internal error, dot product is NULL."));
OP_REQUIRES(op_ctx, x_val != nullptr,
Expand Down
34 changes: 32 additions & 2 deletions tf_shell/cc/ops/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,21 @@ Status ShellSegmentReductionWithNumSegmentsShape(InferenceContext* c) {
}

Status ShellConv2dImpl(InferenceContext* c, bool different_num_in_channels) {
// If the output shape is provided, return that.
std::vector<tsl::int32> requested_output_shape;
TF_RETURN_IF_ERROR(c->GetAttr("output_shape", &requested_output_shape));
if (!requested_output_shape.empty()) {
ShapeHandle output_shape_handle = c->Scalar();
// Skip the batching dimension.
for (size_t i = 1; i < requested_output_shape.size(); i++) {
DimensionHandle h = c->MakeDim(requested_output_shape[i]);
TF_RETURN_IF_ERROR(c->Concatenate(output_shape_handle, c->Vector(h),
&output_shape_handle));
}
c->set_output(0, output_shape_handle);
return OkStatus();
}

// Input shape s_x is {height, width, in_channels}. Output shape is
// {out_height, out_width, out_channels}. The batch size is implicit in the
// ciphertext ring degree and not part of the shape.
Expand Down Expand Up @@ -259,7 +274,7 @@ Status ShellConv2dImpl(InferenceContext* c, bool different_num_in_channels) {
TF_RETURN_IF_ERROR(c->Divide(channels, stride_in_channels, true, &channels));
TF_RETURN_IF_ERROR(c->Add(channels, one, &channels));

ShapeHandle output_shape;
ShapeHandle output_shape = c->Scalar();
TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(out_height), c->Vector(out_width),
&output_shape));
if (different_num_in_channels) {
Expand All @@ -286,6 +301,21 @@ Status ShellConv2dWithChan(InferenceContext* c) {

Status ShellConv2dTransposeImpl(InferenceContext* c,
bool different_num_in_channels) {
// If the output shape is provided, return that.
std::vector<tsl::int32> requested_output_shape;
TF_RETURN_IF_ERROR(c->GetAttr("output_shape", &requested_output_shape));
if (!requested_output_shape.empty()) {
ShapeHandle output_shape_handle = c->Scalar();
// Skip the batching dimension.
for (size_t i = 1; i < requested_output_shape.size(); i++) {
DimensionHandle h = c->MakeDim(requested_output_shape[i]);
TF_RETURN_IF_ERROR(c->Concatenate(output_shape_handle, c->Vector(h),
&output_shape_handle));
}
c->set_output(0, output_shape_handle);
return OkStatus();
}

// Input shape s_x is {height, width, in_channels}. Output shape is
// {out_height, out_width, out_channels}. The batch size is implicit in the
// ciphertext ring degree and not part of the shape.
Expand Down Expand Up @@ -342,7 +372,7 @@ Status ShellConv2dTransposeImpl(InferenceContext* c,
TF_RETURN_IF_ERROR(c->Multiply(channels, stride_in_channels, &channels));
TF_RETURN_IF_ERROR(c->Add(channels, filter_in_channels, &channels));

ShapeHandle output_shape;
ShapeHandle output_shape = c->Scalar();
TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(out_height), c->Vector(out_width),
&output_shape));
if (different_num_in_channels) {
Expand Down
12 changes: 12 additions & 0 deletions tf_shell/cc/ops/shell_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ REGISTER_OP("Conv2dPtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2d);

Expand All @@ -469,6 +470,7 @@ REGISTER_OP("Conv2dCtPt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2d);

Expand All @@ -480,6 +482,7 @@ REGISTER_OP("Conv2dCtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2d);

Expand All @@ -491,6 +494,7 @@ REGISTER_OP("Conv2dWithChanPtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dWithChan);

Expand All @@ -502,6 +506,7 @@ REGISTER_OP("Conv2dWithChanCtPt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dWithChan);

Expand All @@ -513,6 +518,7 @@ REGISTER_OP("Conv2dWithChanCtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dWithChan);

Expand All @@ -524,6 +530,7 @@ REGISTER_OP("Conv2dTransposePtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dTranspose);

Expand All @@ -535,6 +542,7 @@ REGISTER_OP("Conv2dTransposeCtPt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dTranspose);

Expand All @@ -546,6 +554,7 @@ REGISTER_OP("Conv2dTransposeCtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dTranspose);

Expand All @@ -557,6 +566,7 @@ REGISTER_OP("Conv2dTransposeWithChanPtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dTransposeWithChan);

Expand All @@ -568,6 +578,7 @@ REGISTER_OP("Conv2dTransposeWithChanCtPt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dTransposeWithChan);

Expand All @@ -579,6 +590,7 @@ REGISTER_OP("Conv2dTransposeWithChanCtCt64")
.Attr("padding: list(int)")
.Attr("dilations: list(int)")
.Attr("filter_num_elements: int")
.Attr("output_shape: list(int)")
.Output("output: variant")
.SetShapeFn(ShellConv2dTransposeWithChan);

Expand Down
Loading

0 comments on commit 5a9de54

Please sign in to comment.