diff --git a/op_map.cc b/op_map.cc index b5e1d0b..7fdc32f 100644 --- a/op_map.cc +++ b/op_map.cc @@ -871,8 +871,8 @@ struct FullyConnectedMapper if (outputs[0]->GetShape().size() > 2) { std::vector real_output_shape = { weight_tensor->GetShape()[1], temp_batch}; - tim::vx::TensorSpec real_output_spec(inputs[0]->GetDataType(), - real_output_shape, tim::vx::TensorAttribute::TRANSIENT); + tim::vx::TensorSpec real_output_spec(outputs[0]->GetSpec()); + real_output_spec.SetShape(real_output_shape); auto real_output = delegate->GetGraph()->CreateTensor(real_output_spec); (*op).BindOutput(real_output);