Skip to content

Commit

Permalink
Cast tokens to int64 before passing to ONNX to fix Windows issue
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Aug 27, 2021
1 parent db4e6dc commit 30bb656
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/python/testonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def testClassification(self):
options = SessionOptions()
session = InferenceSession(model, options)

# Tokenize
# Tokenize and cast to int64 to support all platforms
tokens = tokenizer(["cat"], return_tensors="np")
tokens = {x: tokens[x].astype(np.int64) for x in tokens}

# Run inference and validate
outputs = session.run(None, dict(tokens))
Expand All @@ -84,8 +85,9 @@ def testPooling(self):
options = SessionOptions()
session = InferenceSession(model, options)

# Tokenize
# Tokenize and cast to int64 to support all platforms
tokens = tokenizer(["cat"], return_tensors="np")
tokens = {x: tokens[x].astype(np.int64) for x in tokens}

# Run inference and validate
outputs = session.run(None, dict(tokens))
Expand Down

0 comments on commit 30bb656

Please sign in to comment.