diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index b6b8ddd165921..b8cfe92baeff0 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -643,6 +643,8 @@ class ScatterAssignOp : public Operator { &ScatterAssignOp::DoRun}, {{TensorProto_DataType_INT32, TensorProto_DataType_INT64}, &ScatterAssignOp::DoRun}, + {{TensorProto_DataType_INT32, TensorProto_DataType_DOUBLE}, + &ScatterAssignOp::DoRun}, {{TensorProto_DataType_INT64, TensorProto_DataType_FLOAT}, &ScatterAssignOp::DoRun}, {{TensorProto_DataType_INT64, TensorProto_DataType_FLOAT16}, @@ -652,7 +654,9 @@ class ScatterAssignOp : public Operator { {{TensorProto_DataType_INT64, TensorProto_DataType_INT32}, &ScatterAssignOp::DoRun}, {{TensorProto_DataType_INT64, TensorProto_DataType_INT64}, - &ScatterAssignOp::DoRun}}) {} + &ScatterAssignOp::DoRun}, + {{TensorProto_DataType_INT64, TensorProto_DataType_DOUBLE}, + &ScatterAssignOp::DoRun}}) {} bool RunOnDevice() override { const auto& data = Input(DATA); diff --git a/caffe2/python/operator_test/rowwise_counter_test.py b/caffe2/python/operator_test/rowwise_counter_test.py index e8719348a1752..a00dd24b3f2c9 100644 --- a/caffe2/python/operator_test/rowwise_counter_test.py +++ b/caffe2/python/operator_test/rowwise_counter_test.py @@ -27,7 +27,7 @@ def test_rowwise_counter(self): n = 5 curr_iter = np.array([100], dtype=np.int64) - update_counter = np.random.randint(99, size=h).astype(np.float32) + update_counter = np.random.randint(99, size=h).astype(np.float64) prev_iter = np.random.rand(h, 1).astype(np.int64) indices = np.unique(np.random.randint(0, h, size=n)) indices.sort(axis=0) diff --git a/caffe2/python/optimizer.py b/caffe2/python/optimizer.py index 5ff90461e2e55..0bfb07562441a 100644 --- a/caffe2/python/optimizer.py +++ b/caffe2/python/optimizer.py @@ -750,6 +750,7 @@ def _run(self, net, param_init_net, param_info): str(param) + "_update_counter", input_as_shape=1, value=0.0, + dtype=core.DataType.DOUBLE, ) prev_update_iter = param_init_net.ConstantFill( num_rows, @@ -764,6 +765,7 @@ def _run(self, net, param_init_net, param_info): str(param) + "_update_counter", shape=[shapes[str(param)][0]], value=0.0, + dtype=core.DataType.DOUBLE, ) prev_update_iter = param_init_net.ConstantFill( [], diff --git a/caffe2/sgd/rowwise_counter.cc b/caffe2/sgd/rowwise_counter.cc index 8c5da359e8675..a9afaadd21325 100644 --- a/caffe2/sgd/rowwise_counter.cc +++ b/caffe2/sgd/rowwise_counter.cc @@ -2,7 +2,7 @@ namespace caffe2 { -REGISTER_CPU_OPERATOR(RowWiseCounter, RowWiseCounterOp); +REGISTER_CPU_OPERATOR(RowWiseCounter, RowWiseCounterOp); OPERATOR_SCHEMA(RowWiseCounter) .NumInputs(4) .NumOutputs(2) diff --git a/caffe2/sgd/rowwise_counter.h b/caffe2/sgd/rowwise_counter.h index af38d2a656c6b..fb0647d596f26 100644 --- a/caffe2/sgd/rowwise_counter.h +++ b/caffe2/sgd/rowwise_counter.h @@ -4,7 +4,6 @@ namespace caffe2 { -template class RowWiseCounterOp final : public Operator { public: RowWiseCounterOp(const OperatorDef& operator_def, Workspace* ws) @@ -28,7 +27,7 @@ class RowWiseCounterOp final : public Operator { bool DoRunWithType() { auto* prev_iter = Output(OUTPUT_PREV_ITER)->template mutable_data(); - auto* counter = Output(OUTPUT_COUNTER)->template mutable_data(); + auto* counter = Output(OUTPUT_COUNTER)->template mutable_data(); const int64_t curr_iter = Input(ITER).template data()[0]; const auto* indices = Input(INDICES).template data();