From 23f5db94da56acc87ee1de9aa451645bea2d21ae Mon Sep 17 00:00:00 2001 From: Enio Borges <63541694+borgesenioc@users.noreply.github.com> Date: Wed, 20 Nov 2024 08:23:15 -0300 Subject: [PATCH] Update test_generation.py Refactor test_generation.py: Parameterize text model tests and add edge case - Used @pytest.mark.parametrize to simplify and reuse dialog tests in TestTextModelInference. - Added a new test case for handling empty dialog inputs to improve edge case coverage. - Maintained the vision model tests in skipped state with @unittest.skip and @pytest.mark.skip. - Improved overall readability and modularity of the test file. --- models/llama3/tests/api/test_generation.py | 38 +++++++++++++--------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/models/llama3/tests/api/test_generation.py b/models/llama3/tests/api/test_generation.py index a71738ba..43c4988b 100644 --- a/models/llama3/tests/api/test_generation.py +++ b/models/llama3/tests/api/test_generation.py @@ -7,13 +7,10 @@ import os import unittest - from pathlib import Path - import numpy as np import pytest from llama_models.llama3.api.datatypes import ImageMedia, SystemMessage, UserMessage - from llama_models.llama3.reference_impl.generation import Llama from PIL import Image as PIL_Image @@ -42,33 +39,44 @@ class TestTextModelInference(unittest.TestCase): def setUpClass(cls): cls.generator = build_generator("TEXT_MODEL_CHECKPOINT_DIR") - def test_run_generation(self): - dialogs = [ + @pytest.mark.parametrize( + "dialogs", + [ [ - SystemMessage(content="Always answer with Haiku"), - UserMessage(content="I am going to Paris, what should I see?"), + [ + SystemMessage(content="Always answer with Haiku"), + UserMessage(content="I am going to Paris, what should I see?"), + ], + [ + SystemMessage(content="Always answer with emojis"), + UserMessage(content="How to go from Beijing to NY?"), + ], ], [ - SystemMessage( - content="Always answer with emojis", - ), - UserMessage(content="How to go from Beijing to NY?"), + [ + SystemMessage(content="Always answer in riddles"), + UserMessage(content="What has keys but can't open locks?"), + ] ], - ] + ], + ) + def test_run_generation(self, dialogs): for dialog in dialogs: result = self.__class__.generator.chat_completion( dialog, temperature=0, logprobs=True, ) - out_message = result.generation self.assertTrue(len(out_message.content) > 0) shape = np.array(result.logprobs).shape - # assert at least 10 tokens self.assertTrue(shape[0] > 10) self.assertEqual(shape[1], 1) + def test_empty_dialog(self): + with self.assertRaises(ValueError): + self.__class__.generator.chat_completion([], temperature=0) + class TestVisionModelInference(unittest.TestCase): @@ -106,10 +114,8 @@ def test_run_generation(self): temperature=0, logprobs=True, ) - out_message = result.generation self.assertTrue(len(out_message.content) > 0) shape = np.array(result.logprobs).shape - # assert at least 10 tokens self.assertTrue(shape[0] > 10) self.assertEqual(shape[1], 1)