You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def export_decoder_to_onnx(sam, args, batch_size=4):
sam_decoder = SamCoreMLModel(
model=sam,
use_stability_score=args.use_stability_score
)
sam_decoder.eval()
if args.gelu_approximate:
for n, m in sam.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
image_embeddings = torch.randn(batch_size, embed_dim, *embed_size, dtype=torch.float)
point_coords = torch.randint(low=0, high=1024, size=(batch_size, 5, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=4, size=(batch_size, 5), dtype=torch.float)
# Define the input names and output names
input_names = ["image_embeddings", "point_coords", "point_labels"]
output_names = ["scores", "masks"]
# Export the decoder model to ONNX format
onnx_decoder_filename = args.checkpoint.replace('.pth', '_decoder.onnx')
torch.onnx.export(
sam_decoder,
(image_embeddings, point_coords, point_labels),
onnx_decoder_filename,
input_names=input_names,
output_names=output_names,
opset_version=13, # Use an appropriate ONNX opset version
dynamic_axes={
"image_embeddings": {0: "batch_size"},
"point_coords": {0: "batch_size", 1: "num_points"},
"point_labels": {0: "batch_size", 1: "num_points"}
},
verbose=False
)
print(f"Exported ONNX decoder model to {onnx_decoder_filename}")
there is a error like:
File "f:\ai_code\edgesam\edge_sam\modeling\transformer.py", line 165, in forward
k = keys + key_pe
RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0
what to do can i solve it?
The text was updated successfully, but these errors were encountered:
when i change the code as:
there is a error like:
what to do can i solve it?
The text was updated successfully, but these errors were encountered: