diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index d4e01d12..3b937aa5 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -19,6 +19,7 @@ "top_k": 50, "max_tokens": 512, "stop_token_ids": [151645], + # "top_logprobs": 10, } arch_guard_model_type = { @@ -34,3 +35,21 @@ prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE]) arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict) +# Patterns for function name and parameter parsing +FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") +FUNC_NAME_END_TOKEN = ('",', "',") + +FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") +PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") +PARAMETER_NAME_START_PATTERN = (',"', ",'") +PARAMETER_VALUE_START_PATTERN = ('":', "':") +PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") + +# Thresholds +HALLUCINATION_THRESHOLD_DICT = { + "t": {"entropy": 0.1, "varentropy": 0.5}, + "v": { + "entropy": 0.5, + "varentropy": 2.5, + }, +} diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py new file mode 100644 index 00000000..eb4a56d1 --- /dev/null +++ b/model_server/app/function_calling/hallucination_handler.py @@ -0,0 +1,240 @@ +import json +import ast +import os +import json +import math +import torch +import random +from typing import Any, Dict, List, Tuple +import app.commons.constants as const +import itertools + + +def check_threshold(entropy, varentropy, thd): + """ + Check if the given entropy or variance of entropy exceeds the specified thresholds. + + Args: + entropy (float): The entropy value to check. + varentropy (float): The variance of entropy value to check. + thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'. + + Returns: + bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise. + """ + return entropy > thd["entropy"] or varentropy > thd["varentropy"] + + +def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: + """ + Calculate the entropy and variance of entropy (varentropy) from log probabilities. + + Args: + log_probs (list of float): A list of log probabilities. + + Returns: + tuple: A tuple containing: + - log_probs (list of float): The input log probabilities as a list. + - entropy (float): The calculated entropy. + - varentropy (float): The calculated variance of entropy. + """ + log_probs = torch.tensor(log_probs) + token_probs = torch.exp(log_probs) + entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e) + varentropy = torch.sum( + token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, + dim=-1, + ) + return entropy.item(), varentropy.item() + + +def check_parameter_property(api_description, parameter_name, property_name): + """ + Check if a parameter in an API description has a specific property. + + Args: + api_description (dict): The API description in JSON format. + parameter_name (str): The name of the parameter to check. + property_name (str): The property to look for (e.g., 'format', 'default'). + + Returns: + bool: True if the parameter has the specified property, False otherwise. + """ + parameters = api_description.get("properties", {}) + parameter_info = parameters.get(parameter_name, {}) + + return property_name in parameter_info + + +class HallucinationStateHandler: + """ + A class to handle the state of hallucination detection in token processing. + + Attributes: + tokens (list): List of tokens processed. + logprobs (list): List of log probabilities for each token. + state (str): Current state of the handler. + mask (list): List of masks indicating the type of each token. + parameter_name_done (bool): Flag indicating if parameter name extraction is done. + hallucination (bool): Flag indicating if a hallucination is detected. + hallucination_message (str): Message describing the hallucination. + parameter_name (list): List of extracted parameter names. + function_description (dict): Description of functions and their parameters. + token_probs_map (list): List mapping tokens to their entropy and variance of entropy. + current_token (str): The current token being processed. + """ + + def __init__(self): + """ + Initializes the HallucinationStateHandler with default values. + """ + self.tokens = [] + self.logprobs = [] + self.state = None + self.mask = [] + self.parameter_name_done = False + self.hallucination = False + self.hallucination_message = "" + self.parameter_name = [] + + self.token_probs_map = [] + self.current_token = None + + def process_function(self, apis): + self.apis = apis + if self.apis is None: + raise ValueError("API descriptions not set.") + parameter_names = {} + for func in self.apis: + func_name = func["name"] + parameters = func["parameters"]["properties"] + parameter_names[func_name] = list(parameters.keys()) + self.function_description = parameter_names + self.function_properties = {x["name"]: x["parameters"] for x in self.apis} + + def process_token(self): + """ + Processes the current token and updates the state and mask accordingly. + Detects hallucinations based on the token type and log probabilities. + """ + content = "".join(self.tokens).replace(" ", "") + if self.current_token == "": + self.mask.append("t") + self.check_logprob() + + # Function name extraction logic + if self.state == "function_name": + if self.current_token not in const.FUNC_NAME_END_TOKEN: + self.mask.append("f") + else: + self.state = None + self.check_function_name() + + if content.endswith(const.FUNC_NAME_START_PATTERN): + print("function name entered") + self.state = "function_name" + + # Parameter name extraction logic + if self.state == "parameter_name" and not content.endswith( + const.PARAMETER_NAME_END_TOKENS + ): + self.mask.append("p") + elif self.state == "parameter_name" and content.endswith( + const.PARAMETER_NAME_END_TOKENS + ): + self.state = None + self.check_parameter_name() + self.parameter_name_done = True + elif self.parameter_name_done and content.endswith( + const.PARAMETER_NAME_START_PATTERN + ): + self.state = "parameter_name" + + if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN): + self.state = "parameter_name" + + # Parameter value extraction logic + if self.state == "parameter_value" and not content.endswith( + const.PARAMETER_VALUE_END_TOKEN + ): + if self.current_token.strip() not in ['"', ""]: + self.mask.append("v") + if ( + len(self.mask) > 1 + and self.mask[-2] != "v" + and not check_parameter_property( + self.function_properties[self.function_name], + self.parameter_name[-1], + "default", + ) + ): + self.check_logprob() + else: + self.mask.append("e") + + elif self.state == "parameter_value" and content.endswith( + const.PARAMETER_VALUE_END_TOKEN + ): + self.state = None + elif self.parameter_name_done and content.endswith( + const.PARAMETER_VALUE_START_PATTERN + ): + self.state = "parameter_value" + + # Maintain consistency between stack and mask + if len(self.mask) != len(self.tokens): + self.mask.append("e") + + def check_logprob(self): + """ + Checks the log probability of the current token and updates the token probability map. + Detects hallucinations based on entropy and variance of entropy. + """ + probs = self.logprobs[-1] + entropy, varentropy = calculate_entropy(probs) + self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) + + if check_threshold( + entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]] + ): + self.hallucination = True + self.hallucination_message = f"Token '{self.current_token}' is uncertain." + + def count_consecutive_token(self, token="v") -> int: + """ + Counts the number of consecutive occurrences of a given token in the mask. + + Args: + token (str): The token to count in the mask. + + Returns: + int: The number of consecutive occurrences of the token. + """ + return ( + len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask)))) + if self.mask and self.mask[-1] == token + else 0 + ) + + def check_function_name(self): + """ + Checks the extracted function name against the function descriptions. + Detects hallucinations if the function name is not found. + """ + f_len = self.count_consecutive_token("f") + self.function_name = "".join(self.tokens[:-1][-f_len:]) + if self.function_name not in self.function_description.keys(): + self.hallucination = True + self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." + + def check_parameter_name(self): + """ + Checks the extracted parameter name against the function descriptions. + Detects hallucinations if the parameter name is not found. + """ + p_len = self.count_consecutive_token("p") + parameter_name = "".join(self.tokens[:-1][-p_len:]) + self.parameter_name.append(parameter_name) + if parameter_name not in self.function_description[self.function_name]: + self.hallucination = True + self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions." diff --git a/model_server/app/function_calling/model_handler.py b/model_server/app/function_calling/model_handler.py index 7b915cd4..e1da914c 100644 --- a/model_server/app/function_calling/model_handler.py +++ b/model_server/app/function_calling/model_handler.py @@ -134,4 +134,4 @@ def fix_json_string(self, json_str: str): fixed_str += opening_bracket[unmatched_opening] # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str + return fixed_str.replace("'", '"') diff --git a/model_server/app/tests/test_cases.json b/model_server/app/tests/test_cases.json new file mode 100644 index 00000000..8fd7ec1e --- /dev/null +++ b/model_server/app/tests/test_cases.json @@ -0,0 +1,794 @@ +[{ + "case": "tool_call_halluciation", + "tokens" : [""], + "expect": 1, + "logprobs": [[-0.3333307206630707, + -1.5310522317886353, + -3.5098977088928223, + -3.9004578590393066, + -5.775152683258057, + -5.814209461212158, + -5.9574151039123535, + -6.0094895362854, + -6.0094895362854, + -6.673445224761963]] +}, +{ + "case" : "parameter_value_hallucination", + "expect" : 0, + "tokens" : ["", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Sea", + ",", + " Australia", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "1", + "'}}\n", + ""], + "logprobs": [[-0.008103232830762863, + -5.085402488708496, + -6.777836799621582, + -7.558959007263184, + -9.850253105163574, + -10.266852378845215, + -10.540244102478027, + -10.722506523132324, + -10.800618171691895, + -10.917786598205566], + [0.0, + -23.25142478942871, + -25.139137268066406, + -26.2847843170166, + -28.992677688598633, + -29.070789337158203, + -29.55248260498047, + -29.91700553894043, + -30.20341682434082, + -30.307567596435547], + [0.0, + -21.66313934326172, + -23.06916046142578, + -23.32953453063965, + -25.65988540649414, + -25.985353469848633, + -26.519121170043945, + -27.07892417907715, + -27.977216720581055, + -28.458908081054688], + [0.0, + -28.094383239746094, + -28.56305694580078, + -29.109844207763672, + -29.44832992553711, + -31.79170036315918, + -32.0, + -32.05207443237305, + -32.31244659423828, + -32.364524841308594], + [0.0, + -30.489830017089844, + -31.140766143798828, + -31.81774139404297, + -34.525634765625, + -35.8275032043457, + -36.504478454589844, + -39.05614471435547, + -40.123680114746094, + -40.696502685546875], + [0.0, + -25.646865844726562, + -26.66232681274414, + -27.781936645507812, + -28.979660034179688, + -31.140764236450195, + -31.92188835144043, + -31.973962783813477, + -33.04149627685547, + -33.58828353881836], + [0.0, + -23.511798858642578, + -24.136695861816406, + -25.230268478393555, + -25.777053833007812, + -25.80309295654297, + -26.45402717590332, + -26.636289596557617, + -26.740440368652344, + -26.896663665771484], + [0.0, + -22.366153717041016, + -24.683483123779297, + -26.610252380371094, + -26.610252380371094, + -27.313264846801758, + -27.67778778076172, + -28.510986328125, + -28.615135192871094, + -29.13588523864746], + [0.0, + -22.52237319946289, + -24.292919158935547, + -24.344993591308594, + -24.39706802368164, + -24.73555564880371, + -29.943042755126953, + -29.969079971313477, + -30.021154403686523, + -30.0341739654541], + [0.0, + -30.17738151550293, + -30.411718368530273, + -30.88039207458496, + -30.984540939331055, + -31.270952224731445, + -31.895851135253906, + -32.46867370605469, + -32.624900817871094, + -33.484134674072266], + [0.0, + -28.146459579467773, + -29.396255493164062, + -30.099267959594727, + -31.127744674682617, + -31.179821014404297, + -32.807159423828125, + -33.7445068359375, + -33.770545959472656, + -34.069976806640625], + [0.0, + -26.323841094970703, + -26.558177947998047, + -30.515867233276367, + -30.932466506958008, + -31.37510108947754, + -31.531326293945312, + -31.70056915283203, + -32.065093994140625, + -32.364524841308594], + [0.0, + -26.922698974609375, + -30.28152847290039, + -31.505287170410156, + -33.30187225341797, + -33.73148727416992, + -34.27827453613281, + -34.33034896850586, + -34.460533142089844, + -34.720909118652344], + [0.0, + -21.532955169677734, + -26.94873809814453, + -29.109848022460938, + -30.80228042602539, + -31.55736541748047, + -33.484134674072266, + -34.681854248046875, + -35.384864807128906, + -35.853538513183594], + [0.0, + -19.502033233642578, + -20.46541976928711, + -24.110658645629883, + -24.501218795776367, + -25.256305694580078, + -25.82912826538086, + -25.881202697753906, + -26.063465118408203, + -26.063465118408203], + [0.0, + -24.37103271484375, + -25.256305694580078, + -25.933277130126953, + -26.714401245117188, + -28.2506103515625, + -31.010576248168945, + -32.07810974121094, + -34.62977981567383, + -35.241661071777344], + [-1.1920922133867862e-06, + -14.398697853088379, + -14.424736976623535, + -17.158666610717773, + -17.41904067993164, + -18.200162887573242, + -18.434499740600586, + -18.66883659362793, + -19.71033477783203, + -19.71033477783203], + [-0.0001445904199499637, + -8.98305892944336, + -11.35246467590332, + -13.1490478515625, + -13.669795989990234, + -14.073375701904297, + -14.516012191772461, + -14.555068969726562, + -15.622602462768555, + -15.635622024536133], + [-0.44747352600097656, + -1.0202960968017578, + -8.467000961303711, + -10.914518356323242, + -11.25300407409668, + -11.435266494750977, + -12.346576690673828, + -13.075624465942383, + -13.12769889831543, + -13.231849670410156], + [-3.123767137527466, + -1.1188862323760986, + -1.639634370803833, + -2.0562336444854736, + -2.8633930683135986, + -2.9675419330596924, + -3.4882919788360596, + -3.69659161567688, + -4.217339515686035, + -4.243376731872559], + [-7.199982064776123e-05, + -9.76410961151123, + -11.144091606140137, + -16.507802963256836, + -17.132701873779297, + -17.44515037536621, + -17.9138240814209, + -18.33042335510254, + -18.9162654876709, + -19.39795684814453], + [0.0, + -22.991050720214844, + -23.824249267578125, + -24.969894409179688, + -25.46460723876953, + -25.829130172729492, + -26.480066299438477, + -26.909683227539062, + -27.33930206298828, + -27.391376495361328], + [-0.21928852796554565, + -1.625309705734253, + -9.775025367736816, + -12.977627754211426, + -16.388530731201172, + -17.091541290283203, + -19.044347763061523, + -19.38283348083496, + -19.460947036743164, + -19.59113311767578], + [0.0, + -24.006507873535156, + -27.443450927734375, + -27.729862213134766, + -28.12042236328125, + -28.276647567749023, + -28.927583694458008, + -30.099267959594727, + -31.479251861572266, + -32.07810974121094], + [0.0, + -18.17412567138672, + -18.772987365722656, + -21.689178466796875, + -21.92351531982422, + -23.7200984954834, + -23.79821014404297, + -23.79821014404297, + -24.032546997070312, + -25.308382034301758], + [-0.12947827577590942, + -2.1083219051361084, + -12.419143676757812, + -15.23118782043457, + -15.595710754394531, + -15.830047607421875, + -17.001731872558594, + -17.60059356689453, + -18.121341705322266, + -18.251529693603516], + [0.0, + -19.449962615966797, + -24.371034622192383, + -24.917821884155273, + -25.529701232910156, + -25.85516929626465, + -26.037429809570312, + -26.115543365478516, + -26.623271942138672, + -26.649309158325195], + [-0.03332124650478363, + -3.4181859493255615, + -15.759925842285156, + -15.812002182006836, + -16.593124389648438, + -17.894996643066406, + -18.09027671813965, + -18.79328727722168, + -19.144792556762695, + -20.147233963012695], + [0.0, + -21.142393112182617, + -22.157852172851562, + -23.511798858642578, + -24.657445907592773, + -25.021968841552734, + -25.5427188873291, + -25.59479331970215, + -25.75101661682129, + -25.95931625366211], + [0.0, + -23.04312515258789, + -24.94385528564453, + -26.323841094970703, + -27.54759979248047, + -28.563060760498047, + -29.786819458007812, + -30.620018005371094, + -30.69812774658203, + -31.08869171142578], + [0.0, + -26.167617797851562, + -28.771360397338867, + -29.55248260498047, + -30.906429290771484, + -31.114728927612305, + -31.414159774780273, + -31.622459411621094, + -31.713590621948242, + -31.726608276367188], + [-0.05012698099017143, + -3.018392562866211, + -11.740934371948242, + -13.146955490112305, + -13.797887802124023, + -14.943536758422852, + -16.037107467651367, + -16.375595092773438, + -16.714080810546875, + -17.36501693725586], + [-0.9704352021217346, + -0.7360983490943909, + -2.1941938400268555, + -4.225115776062012, + -5.0062360763549805, + -5.2666120529174805, + -5.839434623718262, + -7.2714948654174805, + -8.33902645111084, + -8.495253562927246], + [-0.014467108063399792, + -4.258565902709961, + -8.789079666137695, + -10.429437637329102, + -10.793962478637695, + -11.835458755493164, + -11.939607620239258, + -13.31959342956543, + -13.866378784179688, + -15.038063049316406], + [0.0, + -20.08787727355957, + -21.350692749023438, + -21.415786743164062, + -21.50691795349121, + -21.50691795349121, + -22.7176570892334, + -24.13669776916504, + -24.188772201538086, + -24.34499740600586]] +}, +{ + "case": "fail_case", + "expect" : 0, + "tokens" : ["", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Seattle", + ",", + " WA", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "7", + "'}}\n", + ""], + "logprobs":[[-0.00013815402053296566, + -9.113236427307129, + -10.571331977844238, + -14.099404335021973, + -14.28166675567627, + -15.583537101745605, + -15.81787395477295, + -16.143341064453125, + -16.143341064453125, + -16.260509490966797], + [0.0, + -26.896663665771484, + -27.32628059387207, + -27.41741180419922, + -32.07810974121094, + -32.07810974121094, + -32.28641128540039, + -32.29943084716797, + -32.44263458251953, + -32.520748138427734], + [0.0, + -22.444263458251953, + -24.527257919311523, + -27.15703773498535, + -28.016273498535156, + -28.2506103515625, + -28.693246841430664, + -29.070789337158203, + -29.565500259399414, + -29.812854766845703], + [0.0, + -27.860050201416016, + -28.641170501708984, + -29.448333740234375, + -30.932466506958008, + -31.63547706604004, + -32.33848571777344, + -32.85923767089844, + -33.17168426513672, + -33.45809555053711], + [0.0, + -31.81774139404297, + -31.895854949951172, + -32.05207824707031, + -35.43694305419922, + -36.3482551574707, + -38.61351013183594, + -39.26444625854492, + -40.61839294433594, + -41.71196365356445], + [0.0, + -27.33930206298828, + -27.834014892578125, + -28.849472045898438, + -30.567943572998047, + -32.98942565917969, + -33.067535400390625, + -33.067535400390625, + -35.67127990722656, + -35.69731903076172], + [0.0, + -25.33441925048828, + -26.063465118408203, + -26.219690322875977, + -26.2457275390625, + -26.53213882446289, + -27.365337371826172, + -28.354759216308594, + -28.667207717895508, + -28.74532127380371], + [0.0, + -24.423107147216797, + -24.579330444335938, + -26.81855010986328, + -28.12042236328125, + -28.32872200012207, + -28.61513328552246, + -29.16191864013672, + -29.187957763671875, + -29.240032196044922], + [0.0, + -22.027664184570312, + -23.850284576416016, + -23.980472564697266, + -24.292922973632812, + -24.787633895874023, + -29.279088973999023, + -29.55248260498047, + -29.903987884521484, + -30.190399169921875], + [0.0, + -31.609439849853516, + -31.817739486694336, + -32.54678726196289, + -32.676971435546875, + -32.781124114990234, + -32.98942565917969, + -33.106590270996094, + -33.57526397705078, + -34.369407653808594], + [0.0, + -29.34418296813965, + -29.63059425354004, + -30.021156311035156, + -30.984540939331055, + -33.21073913574219, + -34.30431365966797, + -34.56468963623047, + -34.70789337158203, + -34.79902648925781], + [0.0, + -25.438566207885742, + -25.69894027709961, + -30.190397262573242, + -30.802276611328125, + -31.58340072631836, + -31.609437942504883, + -31.64849281311035, + -31.973960876464844, + -32.29943084716797], + [0.0, + -27.157039642333984, + -32.104148864746094, + -32.33848571777344, + -34.04393768310547, + -34.12205505371094, + -34.40846252441406, + -34.42148208618164, + -34.772987365722656, + -34.87713623046875], + [0.0, + -24.813671112060547, + -26.974777221679688, + -31.010578155517578, + -31.08869171142578, + -32.1822624206543, + -35.33279037475586, + -35.489013671875, + -36.999183654785156, + -37.88446044921875], + [0.0, + -20.46541976928711, + -20.647682189941406, + -23.069164276123047, + -24.136699676513672, + -25.438570022583008, + -25.646869659423828, + -26.193655014038086, + -26.297805786132812, + -26.506103515625], + [0.0, + -27.18307113647461, + -28.30268096923828, + -28.56305694580078, + -29.526439666748047, + -32.416595458984375, + -35.202598571777344, + -36.426361083984375, + -39.31651306152344, + -39.38160705566406], + [0.0, + -18.7469482421875, + -20.100894927978516, + -21.402767181396484, + -21.428804397583008, + -22.20992660522461, + -22.34011459350586, + -22.730674743652344, + -23.069162368774414, + -23.980472564697266], + [-3.576278118089249e-07, + -15.2579345703125, + -16.481693267822266, + -17.991863250732422, + -19.215621948242188, + -20.25712013244629, + -21.350692749023438, + -22.314077377319336, + -22.496337890625, + -22.938974380493164], + [-0.08506780862808228, + -2.506549835205078, + -14.848289489746094, + -15.473188400268555, + -16.33242416381836, + -16.358461380004883, + -16.566761016845703, + -17.03543472290039, + -17.686370849609375, + -17.816556930541992], + [-0.0194891095161438, + -4.445854187011719, + -5.591499328613281, + -5.956024169921875, + -6.685070037841797, + -13.142353057861328, + -13.558952331542969, + -15.173273086547852, + -15.303461074829102, + -15.85024642944336], + [-0.0005990855861455202, + -7.4212646484375, + -15.675132751464844, + -15.72720718383789, + -16.76870346069336, + -16.76870346069336, + -17.706050872802734, + -18.669435501098633, + -19.398483276367188, + -19.658857345581055], + [0.0, + -24.110658645629883, + -25.829130172729492, + -26.011390686035156, + -26.011390686035156, + -26.532140731811523, + -26.58421516418457, + -27.651750564575195, + -27.75589942932129, + -28.055330276489258], + [-1.1408883333206177, + -0.38580334186553955, + -7.494022369384766, + -12.519245147705078, + -14.576202392578125, + -16.034297943115234, + -16.945608139038086, + -17.908992767333984, + -18.664077758789062, + -19.34105110168457], + [0.0, + -26.688365936279297, + -29.83889389038086, + -30.177383422851562, + -30.64605712890625, + -31.244916915893555, + -31.270954132080078, + -32.83319854736328, + -34.655818939208984, + -34.89015579223633], + [0.0, + -18.929210662841797, + -19.16354751586914, + -23.589908599853516, + -24.683481216430664, + -24.995929718017578, + -25.516677856445312, + -25.542715072631836, + -25.77705192565918, + -26.063465118408203], + [-0.2519786059856415, + -1.5017764568328857, + -12.437495231628418, + -15.457839012145996, + -15.744250297546387, + -16.837820053100586, + -17.41064453125, + -17.56686782836914, + -17.61894416809082, + -18.035541534423828], + [0.0, + -20.517494201660156, + -24.683483123779297, + -25.67290496826172, + -26.58421516418457, + -27.651750564575195, + -27.781936645507812, + -27.912124633789062, + -28.09438705444336, + -28.445892333984375], + [-3.40932747349143e-05, + -10.284820556640625, + -18.252273559570312, + -20.17904281616211, + -21.663175582885742, + -22.027700424194336, + -22.288074493408203, + -22.704673767089844, + -23.12127113342285, + -23.277496337890625], + [0.0, + -22.60049057006836, + -25.46460723876953, + -25.829130172729492, + -26.063467025756836, + -27.287227630615234, + -27.391376495361328, + -27.4694881439209, + -27.67778778076172, + -28.055330276489258], + [0.0, + -23.902362823486328, + -28.823436737060547, + -29.240036010742188, + -29.31814956665039, + -29.917007446289062, + -30.021160125732422, + -31.21887969970703, + -32.416603088378906, + -32.416603088378906], + [0.0, + -28.641170501708984, + -31.947925567626953, + -32.59886169433594, + -33.848655700683594, + -34.109031677246094, + -34.73393249511719, + -35.02033996582031, + -35.02033996582031, + -36.074859619140625], + [-0.013183215633034706, + -4.335395336151123, + -19.619365692138672, + -20.035964965820312, + -20.244266510009766, + -21.311800003051758, + -21.441987991333008, + -22.561595916748047, + -23.108383178710938, + -23.264606475830078], + [-8.344646857949556e-07, + -14.190400123596191, + -15.9088716506958, + -18.17412567138672, + -18.46053695678711, + -18.46053695678711, + -18.512611389160156, + -18.90317153930664, + -19.059398651123047, + -19.085433959960938], + [0.0, + -17.70545196533203, + -18.903175354003906, + -20.829944610595703, + -22.574451446533203, + -22.860862731933594, + -23.069162368774414, + -23.32953643798828, + -23.694061279296875, + -24.188772201538086], + [0.0, + -20.022781372070312, + -21.038240432739258, + -21.220502853393555, + -22.496337890625, + -22.769729614257812, + -23.589908599853516, + -23.65500259399414, + -23.94141387939453, + -24.266881942749023]] +} +] diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py new file mode 100644 index 00000000..25ad3303 --- /dev/null +++ b/model_server/app/tests/test_hallucination.py @@ -0,0 +1,60 @@ +import json +from app.function_calling.hallucination_handler import HallucinationStateHandler +import pytest +import os + +# Get the directory of the current file +current_dir = os.path.dirname(__file__) + +# Construct the full path to the JSON file +json_file_path = os.path.join(current_dir, "test_cases.json") + +with open(json_file_path) as f: + test_cases = json.load(f) + +get_weather_api = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "str", + "description": "The location to get the weather for", + "format": "City, State", + }, + "unit": { + "type": "str", + "description": "The unit to return the weather in.", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + }, + "days": { + "type": "str", + "description": "the number of days for the request.", + }, + }, + "required": ["location", "days"], + }, + }, +} +function_description = get_weather_api["function"] +if type(function_description) != list: + function_description = [get_weather_api["function"]] + + +@pytest.mark.parametrize("case", test_cases) +def test_hallucination(case): + state = HallucinationStateHandler() + state.process_function(function_description) + for token, logprob in zip(case["tokens"], case["logprobs"]): + if token != "": + state.current_token = token + state.tokens.append(token) + state.logprobs.append(logprob) + state.process_token() + if state.hallucination: + break + assert state.hallucination == case["expect"]