Skip to content

Commit

Permalink
Merge pull request #57 from EvolvingLMMs-Lab/pufanyi/realworldqa
Browse files Browse the repository at this point in the history
[Benchmarks] RealWorldQA
  • Loading branch information
Luodian authored Apr 18, 2024
2 parents 1141765 + 6ecb844 commit 9f2d625
Show file tree
Hide file tree
Showing 11 changed files with 435 additions and 83 deletions.
18 changes: 18 additions & 0 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,22 @@ def _prepare_metric_and_aggregation(self):

@retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
def download(self, dataset_kwargs=None) -> None:
# If the dataset is a video dataset,
# Recursively search whether their is a zip and unzip it to the huggingface home
if dataset_kwargs is not None and "video" in dataset_kwargs and dataset_kwargs["video"]:
hf_home = os.environ["HF_HOME"]
cache_dir = dataset_kwargs["cache_dir"]
dataset_kwargs.pop("cache_dir")
cache_dir = os.path.join(hf_home, cache_dir)
cache_path = snapshot_download(repo_id=self.DATASET_PATH, repo_type="dataset")
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
if not os.path.exists(cache_dir):
for zip_file in zip_files:
shutil.unpack_archive(zip_file, cache_dir)
builder_script = dataset_kwargs["builder_script"]
self.DATASET_PATH = os.path.join(cache_path, builder_script)
dataset_kwargs.pop("video")
dataset_kwargs.pop("builder_script")
download_config = DownloadConfig()
download_config.max_retries = dataset_kwargs.get("max_retries", 3) if dataset_kwargs is not None else 3
download_config.num_proc = dataset_kwargs.get("num_proc", 8) if dataset_kwargs is not None else 8
Expand Down Expand Up @@ -970,6 +986,8 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)

def process_results(self, doc, results):
if self.OUTPUT_TYPE == "generate_until":
results[0] = results[0].strip()
if callable(self.config.process_results):
return self.config.process_results(doc, results)

Expand Down
3 changes: 2 additions & 1 deletion lmms_eval/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from lmms_eval.api.filter import FilterEnsemble
from lmms_eval.api.filter import FilterEnsemble, Filter
from . import selection
from . import extraction
from . import transformation
Expand All @@ -13,6 +13,7 @@
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
Expand Down
186 changes: 170 additions & 16 deletions lmms_eval/filters/extraction.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import re

import sys
import unicodedata
from lmms_eval.api.filter import Filter


class WhitespaceFilter(Filter):
""" """

def __init__(self) -> None:
pass

def apply(self, resps, docs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]

filtered_resp.append(resp)

return filtered_resp

filtered_resps = [filter_set(resp) for resp in resps]

return filtered_resps


class RegexFilter(Filter):
""" """

def __init__(self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]") -> None:
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback

def apply(self, resps, docs):
Expand All @@ -23,9 +52,12 @@ def apply(self, resps, docs):
def filter_set(inst):
filtered = []
for resp in inst:
match = self.regex.search(resp)
match = self.regex.findall(resp)
if match:
match = match.group(1).strip()
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
else:
match = self.fallback
filtered.append(match)
Expand All @@ -38,23 +70,145 @@ def filter_set(inst):
return filtered_resps


class WhitespaceFilter(Filter):
""" """
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""

def __init__(self) -> None:
pass
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore

def apply(self, resps, docs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)

filtered_resp.append(resp)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match

return filtered_resp
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))

filtered_resps = [filter_set(resp) for resp in resps]
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)

if self.ignore_case:
st = st.lower()

if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st

filtered_resps = []

for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"

without_paren_fallback_regexes = []
without_paren_to_target = {}

choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"

without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"

next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")

filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
if not match:
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)

return filtered_resps


class ExtendedRegexFilter(RegexFilter):
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))

def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore

def filter_ignores(self, st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)

if self.ignore_case:
st = st.lower()

if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(self.punct_tbl)
return st

def find_match(self, regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
13 changes: 9 additions & 4 deletions lmms_eval/models/gpt4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,24 @@
}


@register_model("gpt4V")
@register_model("gpt4v")
class GPT4V(lmms):
def __init__(self, **kwargs) -> None:
def __init__(
self,
model_version: str = "gpt-4-vision-preview",
**kwargs,
) -> None:
super().__init__()
# Manually set a image token for GPT4V so that we can search for it
# and split the text and image
# Here we just use the same token as llava for convenient
self.model_version = model_version
self.image_token = "<image>"

# Function to encode the image
def encode_image(self, image: Image):
output_buffer = BytesIO()
image.save(output_buffer, format="JPEG")
image.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
return base64_str
Expand All @@ -72,7 +77,7 @@ def generate_until(self, requests) -> List[str]:
img = self.encode_image(visual)
imgs.append(img)

payload = {"model": "gpt-4-vision-preview", "messages": []}
payload = {"model": self.model_version, "messages": []}
response_json = {"role": "user", "content": []}
# When there is no image token in the context, append the image to the text
if self.image_token not in contexts:
Expand Down
14 changes: 13 additions & 1 deletion lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
unset_hf_deepspeed_config,
)

if torch.__version__ > "2.1.2":
best_fit_attn_implementation = "sdpa"
else:
best_fit_attn_implementation = "eager"


@register_model("llava")
class Llava(lmms):
Expand All @@ -52,11 +57,13 @@ def __init__(
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = False,
revision=None,
attn_implementation=best_fit_attn_implementation,
use_flash_attention_2=True,
device_map="auto",
conv_template="vicuna_v1",
use_cache=True,
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
customized_config=None,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -72,7 +79,12 @@ def __init__(
self._device = torch.device(device)
self.device_map = device_map

self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2)
llava_model_args = {}
llava_model_args["attn_implementation"] = attn_implementation
if customized_config:
llava_model_args["customized_config"] = customized_config
llava_model_args["use_flash_attention_2"] = False
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, **llava_model_args)
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
Expand Down
Loading

0 comments on commit 9f2d625

Please sign in to comment.