-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hallucination with log probs #281
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,362 @@ | ||
import torch | ||
import numpy as np | ||
import math | ||
import app.commons.constants as const | ||
import random | ||
from typing import List, Dict, Any, Tuple | ||
import json | ||
|
||
|
||
def filter_tokens_and_probs( | ||
tokens: List[str], probs: List[float] | ||
) -> Tuple[List[str], List[float]]: | ||
""" | ||
Filters out special tokens from the list of tokens and their corresponding probabilities. | ||
|
||
Args: | ||
tokens (list): List of tokens. | ||
probs (list): List of probabilities corresponding to the tokens. | ||
|
||
Returns: | ||
tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities. | ||
""" | ||
# Use regex to identify tokens without special characters | ||
special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n'] | ||
filtered_tokens = [token for token in tokens if token not in special_tokens] | ||
filtered_probs = [ | ||
prob for token, prob in zip(tokens, probs) if token not in special_tokens | ||
] | ||
Comment on lines
+25
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could zip them in the beginning and then use zip(*) to unzip |
||
return filtered_tokens, filtered_probs | ||
|
||
|
||
def get_all_parameter_values( | ||
tokens: List[str], probs: List[float], parameter_names: Dict[str, Any] | ||
) -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
""" | ||
Extracts parameter values and their corresponding probabilities from the tokens. | ||
|
||
Args: | ||
tokens (list): List of tokens. | ||
probs (list): List of probabilities corresponding to the tokens. | ||
parameter_names (dict): Dictionary of parameter names for each function. | ||
|
||
Returns: | ||
tuple: A tuple containing two dictionaries - parameter values and their corresponding probabilities. | ||
""" | ||
parameter_values = {} | ||
probs_values = {} | ||
i = 0 | ||
|
||
while i < len(tokens): | ||
# Try to form parameter names by combining tokens | ||
combined_token = "" | ||
start = i | ||
found_param = False | ||
|
||
# Incrementally combine tokens to find a full match with any parameter name | ||
while i < len(tokens): | ||
if combined_token: | ||
combined_token += tokens[ | ||
i | ||
] # Append next token to the current combination | ||
Comment on lines
+59
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. first param match will contain everything token from start to the matched token? |
||
else: | ||
combined_token = tokens[i] # Start a new combination | ||
|
||
# Check if the combined token matches any parameter name | ||
for func, params in parameter_names.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. func => _ |
||
if combined_token in params: | ||
# Collect values associated with this parameter | ||
values = [] | ||
prob_values = [] | ||
i += 1 # Move past the parameter name | ||
|
||
# Collect tokens as values until the next parameter or end marker | ||
while ( | ||
i < len(tokens) | ||
and tokens[i] not in params | ||
and tokens[i] != "</tool_call>" | ||
): | ||
values.append(tokens[i]) | ||
prob_values.append(probs[i]) | ||
i += 1 | ||
|
||
# Store the parameter values and probabilities | ||
parameter_values[combined_token] = values | ||
probs_values[combined_token] = prob_values | ||
|
||
found_param = True | ||
break # Stop combining further once a parameter is matched | ||
|
||
if found_param: | ||
break # Exit the outer loop if parameter was matched | ||
i += 1 # Move to the next token if no match was found yet | ||
Comment on lines
+90
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so even if one parameter is matched we will exit the loop? |
||
|
||
# Reset to the next token if no parameter match was found | ||
if not found_param: | ||
i = start + 1 | ||
|
||
return parameter_values, probs_values | ||
|
||
|
||
def calculate_stats( | ||
data: Dict[str, Any], function_description: Dict[str, Any] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. data => param_probs |
||
) -> Dict[str, Any]: | ||
""" | ||
Calculates statistical metrics for the given data. | ||
|
||
Args: | ||
data (dict): Dictionary containing parameter values and their corresponding probabilities. | ||
function_description (dict): Description of the function containing parameter properties. | ||
|
||
Returns: | ||
dict: Dictionary containing statistical metrics for each parameter. | ||
""" | ||
stats = {} | ||
try: | ||
for key, values in data.items(): | ||
if len(data[key]) >= 1: | ||
first = values[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is first? is that sorted? |
||
max_value = max(values) | ||
min_value = min(values) | ||
avg_value = sum(values) / len(values) | ||
has_format = check_parameter_property( | ||
function_description, key, "format" | ||
) | ||
has_default = check_parameter_property( | ||
function_description, key, "default" | ||
) | ||
stats[key] = { | ||
"first": first, | ||
"max": max_value, | ||
"min": min_value, | ||
"avg": avg_value, | ||
"has_format": has_format, | ||
"has_default": has_default, | ||
} | ||
except Exception as e: | ||
print(data) | ||
return stats | ||
|
||
|
||
def check_parameter_property( | ||
api_description: Dict[str, Any], parameter_name: str, property_name: str | ||
) -> bool: | ||
Comment on lines
+141
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check_parameter_property => is_param_property_set |
||
""" | ||
Check if a parameter in an API description has a specific property. | ||
|
||
Args: | ||
api_description (dict): The API description in JSON format. | ||
parameter_name (str): The name of the parameter to check. | ||
property_name (str): The property to look for (e.g., 'format', 'default'). | ||
|
||
Returns: | ||
bool: True if the parameter has the specified property, False otherwise. | ||
""" | ||
parameters = api_description.get("parameters", {}).get("properties", {}) | ||
parameter_info = parameters.get(parameter_name, {}) | ||
|
||
return property_name in parameter_info | ||
|
||
|
||
def calculate_entropy(log_probs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add param type |
||
""" | ||
Calculate the entropy and variance of entropy (varentropy) from log probabilities. | ||
|
||
Args: | ||
log_probs (list of float): A list of log probabilities. | ||
|
||
Returns: | ||
tuple: A tuple containing: | ||
- log_probs (list of float): The input log probabilities as a list. | ||
- entropy (float): The calculated entropy. | ||
- varentropy (float): The calculated variance of entropy. | ||
""" | ||
log_probs = torch.tensor(log_probs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't shadow variable |
||
token_probs = torch.exp(log_probs) | ||
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e) | ||
varentropy = torch.sum( | ||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, | ||
dim=-1, | ||
) | ||
return log_probs.tolist(), entropy.item(), varentropy.item() | ||
|
||
|
||
def hallucination_detect( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this current_state could be captured better if you declare rewrite it as class and move the state as class variables. I see that you didn't want to use global vars here and hence used state dict. But I think using class would be better here. |
||
token: str, | ||
log_probs: List[float], | ||
current_state: Dict[str, Any], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does the caller need to see |
||
entropy_thd: float = 0.7, | ||
varentropy_thd: float = 4.0, | ||
) -> bool: | ||
Comment on lines
+188
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. take default values from consts like |
||
""" | ||
Detects hallucinations in the token sequence based on entropy and varentropy thresholds. | ||
|
||
Args: | ||
token (str): The current token. | ||
log_probs (list): List of log probabilities for the current token. | ||
current_state (dict): The current state of the detection process. | ||
entropy_thd (float): Entropy threshold for detecting hallucinations. | ||
varentropy_thd (float): Variance of entropy threshold for detecting hallucinations. | ||
|
||
Returns: | ||
bool: True if a hallucination is detected, False otherwise. | ||
""" | ||
|
||
if token: | ||
# check if there is content in token | ||
current_state["tokens"].append(token) | ||
current_state["content"] += token | ||
current_state["logprobs"].append(log_probs) | ||
# keep track of entropy and varentropy | ||
_, entropy, varentropy = calculate_entropy(log_probs) | ||
current_state["entropy"].append(entropy) | ||
current_state["varentropy"].append(varentropy) | ||
# first check if tool call token is certain | ||
if token == "<tool_call>": | ||
if entropy > entropy_thd or varentropy > varentropy_thd: | ||
current_state["hallucination"] = True | ||
current_state[ | ||
"hallucination_message" | ||
] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}" | ||
return True | ||
elif token == "</tool_call>": | ||
current_state["state"] = "tool_call_end" | ||
# try to extract tool call, else raise error | ||
try: | ||
current_state[ | ||
"tool_call" | ||
] = const.arch_function_hanlder.extract_tool_calls( | ||
current_state["content"] | ||
)[ | ||
0 | ||
] | ||
current_state["tool_call_process"] = True | ||
except: | ||
current_state["tool_call_process"] = False | ||
print(f"cant process tool") | ||
return True | ||
# check if function name is valid | ||
if ( | ||
current_state["tool_call"]["function"]["name"] | ||
not in current_state["parameter_names"].keys() | ||
): | ||
current_state["hallucination"] = True | ||
current_state[ | ||
"hallucination_message" | ||
] = f"function name {current_state['tool_call']['name']} not found" | ||
return True | ||
|
||
# check if parameter names are from the given function tools | ||
current_parameter_names = current_state["tool_call"]["function"][ | ||
"arguments" | ||
].keys() | ||
given_parameter_names = current_state["parameter_names"][ | ||
current_state["tool_call"]["function"]["name"] | ||
] | ||
if not set(current_parameter_names).issubset(given_parameter_names): | ||
missing_keys = set(current_parameter_names) - set(given_parameter_names) | ||
|
||
current_state["hallucination"] = True | ||
current_state[ | ||
"hallucination_message" | ||
] = f"parameter names {missing_keys} not found" | ||
return True | ||
|
||
# filtered special tokens that are not needed in the hallucination check for parameter values | ||
( | ||
current_state["filtered_tokens"], | ||
current_state["filtered_entropy"], | ||
) = filter_tokens_and_probs( | ||
current_state["tokens"], current_state["entropy"] | ||
) | ||
( | ||
current_state["filtered_tokens"], | ||
current_state["filtered_varentropy"], | ||
) = filter_tokens_and_probs( | ||
current_state["tokens"], current_state["varentropy"] | ||
) | ||
parameter_values, entropy_values = get_all_parameter_values( | ||
current_state["filtered_tokens"], | ||
current_state["filtered_entropy"], | ||
current_state["parameter_names"], | ||
) | ||
parameter_values, varentropy_values = get_all_parameter_values( | ||
current_state["filtered_tokens"], | ||
current_state["filtered_varentropy"], | ||
current_state["parameter_names"], | ||
) | ||
|
||
current_state["parameter_values"] = parameter_values | ||
current_state["parameter_values_entropy"] = entropy_values | ||
current_state["parameter_values_varentropy"] = varentropy_values | ||
# calculate the max, first, avg of sub tokens for parameter value | ||
current_state["parameter_value_entropy_stat"] = calculate_stats( | ||
current_state["parameter_values_entropy"], | ||
current_state["function_description"][0], | ||
) | ||
current_state["parameter_value_varentropy_stat"] = calculate_stats( | ||
current_state["parameter_values_varentropy"], | ||
current_state["function_description"][0], | ||
) | ||
# get map for debugging | ||
current_state["token_entropy_map"] = { | ||
x: y for x, y in zip(current_state["tokens"], current_state["entropy"]) | ||
} | ||
current_state["token_varentropy_map"] = { | ||
x: y | ||
for x, y in zip(current_state["tokens"], current_state["varentropy"]) | ||
} | ||
|
||
# checking hallucination for parameter value | ||
current_state["parameter_value_check"] = { | ||
x: {"hallucination": False, "message": ""} | ||
for x in current_state["parameter_values"].keys() | ||
} | ||
for key in current_state["parameter_value_check"].keys(): | ||
# if parameter is given a format, check the first token | ||
if current_state["parameter_value_entropy_stat"][key]["has_format"]: | ||
if ( | ||
current_state["parameter_value_entropy_stat"][key]["first"] | ||
> entropy_thd | ||
or current_state["parameter_value_varentropy_stat"][key][ | ||
"first" | ||
] | ||
> varentropy_thd | ||
): | ||
current_state["parameter_value_check"][key][ | ||
"hallucination" | ||
] = True | ||
current_state["hallucination"] = True | ||
current_state["parameter_value_check"][key][ | ||
"message" | ||
] = f"parameter {key} with formatting doesn't pass threshold" | ||
# if parameter gis given a default value, we can always use default | ||
elif current_state["parameter_value_entropy_stat"][key]["has_default"]: | ||
current_state["parameter_value_check"][key]["hallucination"] = False | ||
current_state["parameter_value_check"][key][ | ||
"message" | ||
] = f"parameter {key} with default" | ||
# check if max sub token is > thresholds | ||
else: | ||
if ( | ||
current_state["parameter_value_entropy_stat"][key]["max"] | ||
> entropy_thd | ||
or current_state["parameter_value_varentropy_stat"][key]["max"] | ||
> varentropy_thd | ||
): | ||
current_state["parameter_value_check"][key][ | ||
"hallucination" | ||
] = True | ||
current_state["parameter_value_check"][key][ | ||
"message" | ||
] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold" | ||
current_state["hallucination"] = True | ||
if current_state["hallucination"] == True: | ||
current_state["hallucination_message"] = "\n".join( | ||
[ | ||
current_state["parameter_value_check"][key]["message"] | ||
for key in current_state["parameter_value_check"].keys() | ||
] | ||
) | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that exhaustive list, are there more to add? and how did you come up with this list?