Skip to content

Commit

Permalink
[Feat] Add qwen loglikelihood (EvolvingLMMs-Lab#43)
Browse files Browse the repository at this point in the history
* Add qwen loglikelihood

* Revise the pyproject dependency. Move tiktoken out from optional-dependencies

* Add ferret-bench

* Add seedbench 2, test on llava
  • Loading branch information
kcz358 authored Feb 10, 2024
1 parent 2bb8fd6 commit 801829a
Show file tree
Hide file tree
Showing 9 changed files with 834 additions and 1 deletion.
Empty file.
419 changes: 419 additions & 0 deletions lmms_eval/models/model_utils/qwen/qwen_generate_utils.py

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions lmms_eval/models/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.qwen.qwen_generate_utils import make_context
from accelerate import Accelerator, DistributedType
from typing import List, Optional, Union, Tuple
import uuid
Expand Down Expand Up @@ -115,6 +116,76 @@ def world_size(self):
return self._world_size

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
if type(doc_to_target) == str:
continuation = doc_to_target
else:
continuation = doc_to_target(self.task_dict[task][split][doc_id])
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
query = []
visual_paths = []
for visual in visuals:
name = uuid.uuid4().hex.upper()[0:6]
visual.save(f"/tmp/{name}.png")
visual_paths.append(f"/tmp/{name}.png")
query.append({
'image' : f"/tmp/{name}.png"
})

# Make a copy for query to save context (text that needs to be masked)
context_query = [ _ for _ in query]
context_query.append({'text' : contexts})
query.append({'text' : contexts + continuation})

context_query = self.tokenizer.from_list_format(context_query)
query = self.tokenizer.from_list_format(query)

raw_contxt_text, context_tokens = make_context(
self.tokenizer,
context_query,
history=None,
system='You are a helpful assistant',
max_window_size=self.model.generation_config.max_window_size,
chat_format=self.model.generation_config.chat_format
)
context_tokens = torch.tensor([context_tokens])

raw_continuation_text, continuation_tokens = make_context(
self.tokenizer,
query,
history=None,
system='You are a helpful assistant',
max_window_size=self.model.generation_config.max_window_size,
chat_format=self.model.generation_config.chat_format
)
continuation_tokens = torch.tensor([continuation_tokens]).to(self.model.device)
attn_mask = torch.ones_like(continuation_tokens).to(self.model.device)
labels = continuation_tokens.clone().to(self.model.device)
labels[:, :context_tokens.shape[1]] = -100
with torch.inference_mode():
outputs = self.model(
input_ids = continuation_tokens,
labels=labels,
attention_mask = attn_mask
)
loss = outputs.loss
logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = continuation_tokens[:, context_tokens.shape[1] :]
greedy_tokens = greedy_tokens[:, context_tokens.shape[1] : continuation_tokens.shape[1]] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal)))
pbar.update(1)

pbar.close()
return res


assert False, "We have not implemented this function for Qwen VL yet"

def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
Expand Down
35 changes: 35 additions & 0 deletions lmms_eval/tasks/ferret/ferret.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
dataset_path: lmms-lab/Ferret-Bench
dataset_kwargs:
token: True
task: "ferret"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.ferret_doc_to_visual
doc_to_text: !function utils.ferret_doc_to_text
doc_to_target: "gpt_answer"
generation_kwargs:
until:
- "ASSISTANT:"
image_aspect_ratio: original
max_new_tokens: 1024
temperature: 0
top_p: 0
num_beams: 1
do_sample: false
process_results: !function utils.ferret_process_results
metric_list:
- metric: gpt_eval_ferret_all
aggregation: !function utils.ferret_all_aggregation
higher_is_better: true
- metric: gpt_eval_ferret_refer_desc
aggregation: !function utils.ferret_refer_desc_aggregation
higher_is_better: true
- metric: gpt_eval_ferret_refer_reason
aggregation: !function utils.ferret_refer_reason_aggregation
higher_is_better: true
- metric: gpt_eval_ferret_ground_conv
aggregation: !function utils.ferret_ground_conv_aggregation
higher_is_better: true
metadata:
version: 0.0
gpt_eval_model_name: "gpt-4-0314"
5 changes: 5 additions & 0 deletions lmms_eval/tasks/ferret/rule.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"refer_desc": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question about specific region of an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. Also, the relationships between pairs of objects are provided, in the format of object -> relationship -> subject, where the object/subject are indexed by object id from previous object lists as well as the object names. Also, several region description are given, each describing a box region of image, with detailed coordinates. \nPlease rate the spatial correspondence, helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
"refer_reason": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question about specific region of an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. Also, the relationships between pairs of objects are provided, in the format of object -> relationship -> subject, where the object/subject are indexed by object id from previous object lists as well as the object names. Also, several region description are given, each describing a box region of image, with detailed coordinates. \nPlease rate the spatial correspondence, helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
"ground_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question that requires model to predict the coordinates of relevant object. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. Also, the relationships between pairs of objects are provided, in the format of object -> relationship -> subject, where the object/subject are indexed by object id from previous object lists as well as the object names. Also, several region description are given, each describing a box region of image, with detailed coordinates. \nPlease rate the predicted coordinates, helpfulness, relevance, accuracy, level of details of their responses. Specifically, pay your attention to the precision of the coordinates and whether it matches the object. Small deviation (<20% of ground-truth box width or height) of coordinates is allowed and shouldn't be punished. More than that, the degree of deviation should be reflected in scoring too. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
}
202 changes: 202 additions & 0 deletions lmms_eval/tasks/ferret/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import json
import logging
import os
import requests
import numpy as np
import openai
from openai import OpenAI
import time
import yaml
from pathlib import Path
from copy import deepcopy

eval_logger = logging.getLogger("lmms-eval")
NUM_SECONDS_TO_SLEEP = 0.5

FERRET_W_METRICS = ["gpt_eval_ferret_refer_desc", "gpt_eval_ferret_refer_reason", "gpt_eval_ferret_ground_conv"]

rule_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "rule.json"), "r"))

with open(Path(__file__).parent / "ferret.yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

config = yaml.safe_load("".join(safe_data))

GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]

API_TYPE = os.getenv("API_TYPE", "openai")

if API_TYPE == "openai":
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
}
elif API_TYPE == "azure":
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
headers = {
"api-key": API_KEY,
"Content-Type": "application/json",
}


def get_eval(content: str, max_tokens: int, retries: int = 3):
global headers

messages = [
{
"role": "system",
"content": "You are a helpful and precise assistant for checking the quality of the answer.",
},
{"role": "user", "content": content},
]

payload = {
"model": GPT_EVAL_MODEL_NAME,
"messages": messages,
"temperature": 0.2,
"max_tokens": max_tokens,
}

for attempt in range(retries):
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()

content = response_data["choices"][0]["message"]["content"].strip()
if content != "":
return content, response_data["model"]
break # If successful, break out of the loop

except Exception as e:
eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < retries - 1: # If we have retries left, sleep and then continue to next attempt
time.sleep(NUM_SECONDS_TO_SLEEP)
else: # If this was the last attempt, log and return empty
eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
return "", ""
return "", ""


def parse_score(review):
try:
score_pair = review.split("\n")[0]
score_pair = score_pair.replace(",", " ")
sp = score_pair.split(" ")
if len(sp) == 2:
return [float(sp[0]), float(sp[1])]
else:
eval_logger.debug(f"Can not split: {review}. Returning [-1, -1]")
return [-1, -1]
except Exception as e:
eval_logger.debug(f"Error: {e}. Returning [-1, -1]")
return [-1, -1]


def ferret_doc_to_visual(doc):
return [doc["image"].convert("RGB")]


def ferret_doc_to_text(doc):
question = doc["question"]
return question


def ferret_process_results(doc, result):
"""
Args:
doc: a instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name (in this case coco_bleu), value: metric value
"""
try:
question = doc.get("question", "")
ans1 = doc.get("gpt_answer", "")
ans2 = result[0] if result else ""
context = doc.get("context", [])
context = "\n".join(context) if isinstance(context, list) else context
category = doc.get("category", "")
rule = rule_dict.get(category, {})
prompt = rule.get("prompt", "")
role = rule.get("role", "user")
content = f"[Context]\n{context}\n\n" f"[Question]\n{question}\n\n" f"[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n" f"[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n" f"[System]\n{prompt}\n\n"
review, model_name = get_eval(content, 1024)
scores = parse_score(review)
except Exception as e:
eval_logger.error(f"Error for Question ID: {doc.get('question_id', 'Unknown')}: {e}")
review = "Failed to Get a Proper Review."
model_name = "Failed Request"
scores = [-1, -1]

metric = f"gpt_eval_ferret_{doc.get('category', 'all')}"
category_review_dict = {
"question": question,
"ans1": ans1,
"ans2": ans2,
"context": context,
"category": category,
"review": review,
"scores": scores,
"eval_model": model_name,
}

non_category_review_dict = deepcopy(category_review_dict)
non_category_review_dict["scores"] = [-999, -999]

data_dict = {}
for m in FERRET_W_METRICS:
if m == metric:
data_dict[m] = category_review_dict
else:
data_dict[m] = non_category_review_dict
data_dict["gpt_eval_ferret_all"] = category_review_dict

# return {"gpt_eval_ferret_all": review_dict}
return data_dict


def ferret_refer_desc_aggregation(results):
return ferret_aggregation(results, "refer_desc")


def ferret_refer_reason_aggregation(results):
return ferret_aggregation(results, "refer_reason")


def ferret_ground_conv_aggregation(results):
return ferret_aggregation(results, "ground_conv")


def ferret_all_aggregation(results):
return ferret_aggregation(results, "all")


def ferret_aggregation(results, category):
try:
scores = []
for result in results:
if -999 in result["scores"]:
continue
scores.append(result["scores"])

stats = np.asarray(scores).mean(0).tolist()
stats = [round(x, 3) for x in stats]
# gpt4_score_percentage = stats[0] * 10
# model_score_percentage = stats[1] * 10
# eval_logger.info(f"Category: {category}")
# eval_logger.info(f"GPT4 Score: {gpt4_score_percentage:.1f}%")
# eval_logger.info(f"Model Score: {model_score_percentage:.1f}%")
# eval_logger.info("=========================")
return round(stats[1] / stats[0] * 100, 1)
except Exception as e:
eval_logger.info(f"Error in ferret_aggregation: {e}, and in category: {category}")
return None
46 changes: 46 additions & 0 deletions lmms_eval/tasks/seedbench_2/seedbench_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
dataset_path: lmms-lab/SEED-Bench-2
dataset_kwargs:
token: True
task: "seedbench-2"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.seed_doc_to_visual
doc_to_text: !function utils.seed_doc_to_text
doc_to_target: "answer"
generation_kwargs:
until:
- "ASSISTANT:"
max_new_tokens: 16
image_aspect_ratio: original
# The return value of process_results will be used by metrics
process_results: !function utils.seed_process_result
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
metric_list:
- metric: seed_Video
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
- metric: seed_Multiple_Images
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
- metric: seed_Image_&_Text_Generation
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
- metric: seed_Single_Image
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
- metric: seed_Image_Generation
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
- metric: seed_Interleaved_Image
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
- metric: seed_all
aggregation: !function utils.seed_aggregation_result
higher_is_better: true
metadata:
- version: 0.0

model_specific_prompt_kwargs:
llava :
img_token : <image>
post_prompt : "Answer with the option's letter from the given choices directly."
Loading

0 comments on commit 801829a

Please sign in to comment.