Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 26, 2023
1 parent 907871f commit f7141a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
7 changes: 4 additions & 3 deletions tests/neural_compressor/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@
class OptimizationTest(INCTestMixin):
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
("text-classification", "hf-internal-testing/tiny-random-BertForSequenceClassification", 21),
# ("text-generation", "hf-internal-testing/tiny-random-BloomForCausalLM", 1), # TODO : enable causal lm task once INC ONNX export fixed
# ("text-generation", "hf-internal-testing/tiny-random-BloomForCausalLM", 21), # TODO : enable causal lm task once INC ONNX export fixed
)

SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + (
("fill-mask", "hf-internal-testing/tiny-random-DistilBertForMaskedLM", 22),
("fill-mask", "hf-internal-testing/tiny-random-BertForMaskedLM", 22),
("token-classification", "hf-internal-testing/tiny-random-AlbertForTokenClassification", 26),
)

Expand All @@ -88,12 +88,13 @@ def test_dynamic_quantization(self, task, model_name, expected_quantized_matmuls
tokenizer = AutoTokenizer.from_pretrained(model_name)
save_onnx_model = False
quantized_model = None
model_kwargs = {"use_cache" : False, "use_io_binding": False} if task == "text-generation" else {}
with tempfile.TemporaryDirectory() as tmp_dir:
for backend in ["torch", "ort"]:
if backend == "torch":
model = model_class.auto_model_class.from_pretrained(model_name)
else:
model = model_class.from_pretrained(model_name, export=True)
model = model_class.from_pretrained(model_name, export=True, **model_kwargs)

quantizer = INCQuantizer.from_pretrained(model, task=task)
quantizer.quantize(
Expand Down
10 changes: 6 additions & 4 deletions tests/neural_compressor/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ def check_model_outputs(
load_onnx_model=True,
load_inc_model=True,
num_samples=None,
file_name=ONNX_WEIGHTS_NAME,
file_name=None,
):
tokens = tokenizer("This is a sample input", return_tensors="pt")
file_name = ONNX_WEIGHTS_NAME if task!="text-generation" else "decoder_model.onnx"

model_kwargs = (
{"decoder_file_name": file_name, "use_cache": False}
{"decoder_file_name": file_name, "use_cache": False, "use_io_binding": False}
if task == "text-generation"
else {"file_name": file_name}
)
Expand All @@ -113,7 +114,7 @@ def check_model_outputs(
if load_inc_model:
inc_model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(save_directory)
inc_model_outputs = inc_model(**tokens)
self.assertTrue(torch.equal(outputs, inc_model_outputs["logits"]))
self.assertTrue(torch.allclose(inc_model_outputs["logits"], outputs, atol=1e-3))
# self.assertEqual(inc_config.save_onnx_model, load_onnx_model)

if load_onnx_model:
Expand All @@ -127,7 +128,8 @@ def check_model_outputs(
ort_model = ORT_SUPPORTED_TASKS[task]["class"][0].from_pretrained(save_directory, **model_kwargs)
ort_outputs = ort_model(**tokens)
self.assertTrue("logits" in ort_outputs)
# self.assertTrue(torch.allclose(ort_outputs.logits, outputs, atol=1e-3))
if task != "fill-mask":
self.assertTrue(torch.allclose(ort_outputs.logits, outputs, atol=1e-3))

@staticmethod
def get_trainer(
Expand Down

0 comments on commit f7141a6

Please sign in to comment.