diff --git a/README.md b/README.md old mode 100755 new mode 100644 index d31451eb..e48f69b1 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ --- ## Annoucement +- [2024-09] 🎉🎉 We welcome the new task [MMSearch](https://mmsearch.github.io/). - [2024-09] 🎉🎉 We welcome the new task [MME-RealWorld](https://mme-realworld.github.io/) for inference acceleration - [2024-09] ⚙️️⚙️️️️ We upgrade `lmms-eval` to `0.2.3` with more tasks and features. We support a compact set of language tasks evaluations (code credit to [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)), and we remove the registration logic at start (for all models and tasks) to reduce the overhead. Now `lmms-eval` only launches necessary tasks/models. Please check the [release notes](https://github.com/EvolvingLMMs-Lab/lmms-eval/releases/tag/v0.2.3) for more details. - [2024-08] 🎉🎉 We welcome the new model [LLaVA-OneVision](https://huggingface.co/papers/2408.03326), [Mantis](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/162), new tasks [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench), [LongVideoBench](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/117), [MMStar](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/158). We provide new feature of SGlang Runtime API for llava-onevision model, please refer the [doc](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/main/docs/commands.md) for inference acceleration diff --git a/docs/task_guide.md b/docs/task_guide.md old mode 100755 new mode 100644 index 1e7d3a9d..1fc6a4ba --- a/docs/task_guide.md +++ b/docs/task_guide.md @@ -75,6 +75,44 @@ metadata: - version: 0.0 ``` +Multi-round-generation-based tasks: + +- MMSearch(`lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml`) + +```yaml +dataset_path: CaraJ/MMSearch +dataset_name: end2end +dataset_kwargs: + token: False +task: "mmsearch_end2end" +test_split: end2end +output_type: generate_until_multi_round # Note that here we use the new output_type here for multi-round generation. It basicly follows generate_until but incorporate multi-round inference +doc_to_visual: !function lmms_eval_utils.mmsearch_end2end_doc_to_visual +doc_to_text: !function lmms_eval_utils.mmsearch_end2end_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" + max_new_tokens: 512 + temperature: 0 + top_p: 0 + num_beams: 1 + do_sample: false +process_results: !function lmms_eval_utils.mmsearch_end2end_process_results +metric_list: + - metric: end2end_f1_score + aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_f1_score + higher_is_better: true + - metric: requery_score + aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_req_score + higher_is_better: true +lmms_eval_specific_kwargs: # Note that here we cache the result of every sample whenever the it is inferenced + middle_resules_dir: /data1/zrr/jdz/mmsearch/mmsearch_middile_results + result_cache_dir: /data1/zrr/jdz/mmsearch/mmsearch_result_cache_dir + +``` + + ## Configurations Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task. @@ -96,8 +134,9 @@ Dataset configuration options: - **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template. Prompting / in-context formatting options: -- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model -- **doc_to_visial** (`Union[Callable, str]`, *optional*) — Function to process a sample into the appropriate input images for the model. +- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model. + + For multi-round generation, (e.g., MMSearch), the function accepts additional parameters about the round index, previous round information and previous model output. It should return the input image for the next round, input text for the next round, a boolean indicating if round inference should terminate, model outputs from all rounds, and extra information from previous rounds. - **doc_to_target** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks. diff --git a/lmms_eval/api/instance.py b/lmms_eval/api/instance.py index 56285ec0..18cfb739 100755 --- a/lmms_eval/api/instance.py +++ b/lmms_eval/api/instance.py @@ -4,7 +4,7 @@ @dataclass class Instance: - request_type: Literal["loglikelihood", "generate_until"] + request_type: Literal["loglikelihood", "generate_until", "generate_until_multi_round"] arguments: tuple idx: int metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here diff --git a/lmms_eval/api/metrics.py b/lmms_eval/api/metrics.py index 6de0d771..8a29833e 100755 --- a/lmms_eval/api/metrics.py +++ b/lmms_eval/api/metrics.py @@ -338,7 +338,7 @@ def mean_stderr(arr): @register_metric( metric="bypass", higher_is_better=True, - output_type=["loglikelihood", "multiple_choice", "generate_until"], + output_type=["loglikelihood", "multiple_choice", "generate_until", "generate_until_multi_round"], aggregation="bypass", ) def bypass(items): @@ -368,7 +368,7 @@ def f1_fn(items): # This is a passthrough function @register_metric( metric="bleu", higher_is_better=True, - output_type="generate_until", + output_type=["generate_until", "generate_until_multi_round"], aggregation="bleu", ) def bleu_fn(items): # This is a passthrough function @@ -378,7 +378,7 @@ def bleu_fn(items): # This is a passthrough function @register_metric( metric="chrf", higher_is_better=True, - output_type="generate_until", + output_type=["generate_until", "generate_until_multi_round"], aggregation="chrf", ) def chrf_fn(items): # This is a passthrough function @@ -388,7 +388,7 @@ def chrf_fn(items): # This is a passthrough function @register_metric( metric="ter", higher_is_better=True, - output_type="generate_until", + output_type=["generate_until", "generate_until_multi_round"], aggregation="ter", ) def ter_fn(items): # This is a passthrough function diff --git a/lmms_eval/api/model.py b/lmms_eval/api/model.py index ab9dbd52..d99ba6ee 100755 --- a/lmms_eval/api/model.py +++ b/lmms_eval/api/model.py @@ -73,6 +73,25 @@ def generate_until(self, requests) -> List[str]: """ pass + @abc.abstractmethod + def generate_until_multi_round(self, requests) -> List[str]: + """Generate greedily until a stopping sequence + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, until). + context: str + Context string + generation_kwargs: dict + Generation Kwargs + 'visual_list: list[dict]' + Visual input to the model. Can be None. + :return: list[str] + A list of strings continuation + continuation: str + The generated continuation. + """ + pass + @classmethod def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T: """ @@ -160,7 +179,7 @@ def fn(requests): eval_logger.info(f"Loading '{attr}' responses from cache '{self.cache_db}' where possible...") for req in tqdm(requests): hsh = hash_args(attr, req.args) - if attr == "generate_until" and req.args[1].get("do_sample", False): + if attr in ["generate_until", "generate_until_multi_round"] and req.args[1].get("do_sample", False): # when we are doing non-greedy generation, don't use the cache # (else every "randomly sampled" generation would be identical for repeats > 1). if not warned: diff --git a/lmms_eval/api/registry.py b/lmms_eval/api/registry.py index 630fd427..b037ae8f 100755 --- a/lmms_eval/api/registry.py +++ b/lmms_eval/api/registry.py @@ -76,6 +76,7 @@ def decorate(fn): ], "multiple_choice": ["acc", "acc_norm"], "generate_until": ["exact_match"], + "generate_until_multi_round": ["exact_match"], } diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 085adba0..95304ca4 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -11,6 +11,7 @@ import subprocess from collections.abc import Callable from dataclasses import asdict, dataclass, field +from functools import partial from glob import glob from typing import ( Any, @@ -59,6 +60,7 @@ "loglikelihood", "multiple_choice", "generate_until", + "generate_until_multi_round", ] @@ -130,9 +132,9 @@ def __post_init__(self) -> None: raise ValueError("Got both a `group` and `tag` entry within a TaskConfig. Please use one or the other--`group` values will be deprecated in v0.4.4.") if self.generation_kwargs is not None: - if self.output_type != "generate_until": + if "generate_until" not in self.output_type: eval_logger.warning(f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!") - assert self.output_type != "generate_until" + assert "generate_until" not in self.output_type if "temperature" in self.generation_kwargs: self.generation_kwargs["temperature"] = float(self.generation_kwargs["temperature"]) @@ -140,7 +142,7 @@ def __post_init__(self) -> None: if "until" not in self.generation_kwargs: self.generation_kwargs["until"] = [self.fewshot_delimiter] else: - if self.output_type == "generate_until": + if "generate_until" in self.output_type: # ensure that we greedily generate in absence of explicit arguments otherwise self.generation_kwargs = { "until": None if self.fewshot_delimiter is None else [self.fewshot_delimiter], @@ -1380,6 +1382,8 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst elif self.OUTPUT_TYPE == "generate_until": arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, doc_id, self.config.task, split) + elif self.OUTPUT_TYPE == "generate_until_multi_round": + arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, partial(self.config.doc_to_text, lmms_eval_specific_kwargs=self.lmms_eval_specific_kwargs), doc_id, self.config.task, split) return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs) # TODO: we add a full_docs interface here for some evaluations that needs to access the full datasets during process_results function. we may have better ways to handle this. @@ -1466,7 +1470,7 @@ def process_results(self, doc, results, full_docs=None): acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 result_dict["acc_mutual_info"] = acc_mutual_info - elif self.OUTPUT_TYPE == "generate_until": + elif "generate_until" in self.OUTPUT_TYPE: gold = self.doc_to_target(doc) result = results[0] if self.config.doc_to_choice is not None: @@ -1524,7 +1528,7 @@ def process_results(self, doc, results, full_docs=None): else: raise ValueError( f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", - "'loglikelihood','generate_until' or 'multiple_choice'", + "'loglikelihood','generate_until', 'generate_until_multi_round', or 'multiple_choice'", ) return result_dict diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 88149746..9230c827 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -197,7 +197,7 @@ def _adjust_config(task_dict): if task_obj is None: continue lm.task_dict[task_name] = task_obj.dataset - if task_obj.get_config("output_type") == "generate_until": + if "generate_until" in task_obj.get_config("output_type"): if gen_kwargs is not None: task_obj.set_config(key="generation_kwargs", value=gen_kwargs, update=True) diff --git a/lmms_eval/logging_utils.py b/lmms_eval/logging_utils.py index bccd0322..27f25513 100755 --- a/lmms_eval/logging_utils.py +++ b/lmms_eval/logging_utils.py @@ -278,7 +278,7 @@ def _generate_dataset(self, data: List[Dict[str, Any]], config: Dict[str, Any]) choices = ["\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) for x in data] resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data] filtered_resps = [np.argmax([n[0] for n in x["filtered_resps"]]) for x in data] - elif config["output_type"] == "generate_until": + elif "generate_until" in config["output_type"]: instance = [x["arguments"][0][0] for x in data] resps = [x["resps"][0][0] for x in data] filtered_resps = [x["filtered_resps"][0] for x in data] diff --git a/lmms_eval/models/llava_onevision.py b/lmms_eval/models/llava_onevision.py index ce2c9c2d..a2a81079 100644 --- a/lmms_eval/models/llava_onevision.py +++ b/lmms_eval/models/llava_onevision.py @@ -407,7 +407,7 @@ def _collate(x): gen_kwargs.pop("until") question_input = [] - + # import ipdb; ipdb.set_trace() for visual, context in zip(batched_visuals, batched_contexts): if visual is None or visual == []: # for text-only tasks. visual = None @@ -551,3 +551,211 @@ def _collate(x): pbar.close() return res + + def generate_until_multi_round(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + metadata = requests[0].metadata + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_to_text, batched_doc_id, batched_task, batched_split = zip(*chunk) + task = batched_task[0] + split = batched_split[0] + batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] # [B, N] + assert len(batched_visuals) == 1 + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + if "until" in gen_kwargs: + gen_kwargs.pop("until") + + # multi round inference: terminate when receiving signal from the doc_to_text + round_idx = 0 + batched_round_res = [] + batched_previous_round_info = None + while True: + question_input = [] + + if round_idx != 0: # get current round visual and context from doc_to_text function + batched_visuals, batched_contexts, batched_terminal_singal, batched_round_res, batched_previous_round_info = list( + zip( + *[ + batched_doc_to_text[0]( + self.task_dict[task][split][ids], + previous_output=[round_res[ids_idx] for round_res in batched_round_res], + round_idx=round_idx, + previous_round_info=batched_previous_round_info[ids_idx] if batched_previous_round_info is not None else None, + ) + for ids_idx, ids in enumerate(batched_doc_id) + ] + ) + ) + # import ipdb; ipdb.set_trace() + batched_round_res = list(zip(*batched_round_res)) # [(r1_1, r1_2), (r2_1, r2_2), ...] + if batched_terminal_singal[0]: # terminal signal from doc_to_text function + break + + for visual, context in zip(batched_visuals, batched_contexts): + if visual is None or visual == []: # for text-only tasks. + visual = None + task_type = "text" + placeholder_count = 0 + image_tensor = None + else: + if len(visual) > 1 or "image_aspect_ratio" not in self._config.__dict__: # for multi image case, we treat per image aspect ratio as "pad" by default. + self._config.image_aspect_ratio = getattr(gen_kwargs, "image_aspect_ratio", "pad") + eval_logger.info(f"In Multi-Image setting, image aspect ratio: {self._config.image_aspect_ratio}") + + if "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: # overwrite logic for video task with multiple static image frames + assert type(visual) == list, "sample_frames must be specified for video task" + sample_indices = np.linspace(0, len(visual) - 1, metadata["sample_frames"], dtype=int) + visual = [visual[i] for i in sample_indices] + assert len(visual) == metadata["sample_frames"] + + image_tensor = process_images(visual, self._image_processor, self._config) + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + + task_type = "video" + placeholder_count = 1 + + elif type(visual[0]) == PIL.Image.Image: # For image, multi-image tasks + image_tensor = process_images(visual, self._image_processor, self._config) + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + + task_type = "image" + placeholder_count = len(visual) if isinstance(visual, list) else 1 + + elif type(visual[0]) == str: # For video task + image_tensor = [] + try: + if self.video_decode_backend == "decord": + frames = self.load_video(visual, self.max_frames_num) + elif self.video_decode_backend == "pyav": + frames = read_video_pyav(visual[0], num_frm=self.max_frames_num) + frames = self._image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda() + image_tensor.append(frames) + except Exception as e: + eval_logger.error(f"Error {e} in loading video") + image_tensor = None + + task_type = "video" + placeholder_count = len(frames) if self.token_strategy == "multiple" else 1 + + if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: + """ + Three senarios: + 1. No image, and there for, no image token should be added. + 2. image token is already specified in the context, so we don't need to add it. + 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. + 4. For video tasks, we could add a token or multiple tokens for each frame in the context. This depends on the training strategy and should balance in test to decide which is better + """ + # if task_type == "image": # indeed in multi-image case, not the video in frames. + # image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN] + # elif task_type == "video": + # image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count if self.token_strategy == "multiple" else [DEFAULT_IMAGE_TOKEN] + image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count + image_tokens = " ".join(image_tokens) + question = image_tokens + "\n" + context + else: + question = context + + # This is much safer for llama3, as we now have some object type in it + if "llama_3" in self.conv_template: + conv = copy.deepcopy(conv_templates[self.conv_template]) + else: + conv = conv_templates[self.conv_template].copy() + + if utils.is_json(question): # conversational question input + question = json.loads(question) + for idx, item in enumerate(question): + role = conv.roles[idx % 2] + message = item["value"] + conv.append_message(role, message) + + assert len(conv.messages) % 2 == 1 + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + else: # only simple string for question + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + + # preconfigure gen_kwargs with defaults + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "do_sample" not in gen_kwargs: + gen_kwargs["do_sample"] = False + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input] + pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) + attention_masks = input_ids.ne(pad_token_ids).to(self.device) + + if task_type == "image": + gen_kwargs["image_sizes"] = [batched_visuals[0][idx].size for idx in range(len(batched_visuals[0]))] + elif task_type == "video": + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + gen_kwargs["modalities"] = ["video"] + gen_kwargs["stopping_criteria"] = [stopping_criteria] + self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride + self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode + + # These steps are not in LLaVA's original code, but are necessary for generation to work + # TODO: attention to this major generation step... + if "image_aspect_ratio" in gen_kwargs.keys(): + gen_kwargs.pop("image_aspect_ratio") + try: + with torch.inference_mode(): + cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) + # cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) + + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) + except Exception as e: + raise e + + text_outputs = [response.strip() for response in text_outputs] + batched_round_res.append(text_outputs) + + round_idx += 1 + + res.extend(list(zip(*batched_round_res))) + self.cache_hook.add_partial("generate_until_multi_round", (context, gen_kwargs), batched_round_res) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res diff --git a/lmms_eval/tasks/mmsearch/constants.py b/lmms_eval/tasks/mmsearch/constants.py new file mode 100644 index 00000000..c0bbcd30 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/constants.py @@ -0,0 +1,12 @@ +# Configuration + +FULLPAGE_SPLIT_DICT = {"slice_height": 512, "max_slices": 10} + +MAX_QUERY_LENGTH = 300 +USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36" +IMAGE_SEARCH_RESULT = 10 +DEFAULT_IMAGE_TOKEN = "" +# Time for loading the website +BRIEF_TIMEOUT = 5000 # ms. Should be changed according to the network speed +FULLPAGE_TIMEOUT = 2500 # ms. Should be changed according to the network speed +FULLPAGE_CONTENT_TIMEOUT = 5000 # ms. Should be changed according to the network speed diff --git a/lmms_eval/tasks/mmsearch/get_final_scores.py b/lmms_eval/tasks/mmsearch/get_final_scores.py new file mode 100644 index 00000000..b0c5e3a7 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/get_final_scores.py @@ -0,0 +1,64 @@ +import argparse +import json +import os + +import datasets + +from lmms_eval.tasks.mmsearch.score.result_summary import get_result_summary + +task_path_dict = dict( + end2end="end2end_path", + requery="requery_path", + rerank="rerank_path", + summarization="summarization_path", +) + +task_key_dict = dict( + end2end="f1_score", + requery="req_score", + rerank="rer_score", + summarization="f1_score", +) + +task_ratio_dict = dict( + end2end=0.75, + requery=0.05, + rerank=0.1, + summarization=0.1, +) + + +def parse_args(): + argparser = argparse.ArgumentParser() + argparser.add_argument("--save_path", default="result_summary_final.json", type=str) + argparser.add_argument("--end2end_path", type=str, help="should be like xxx/submission/mmsearch_end2end_f1_score.json") + argparser.add_argument("--requery_path", type=str, help="should be like xxx/submission/mmsearch_end2end_requery_score.json") + argparser.add_argument("--rerank_path", type=str, help="should be like xxx/submission/mmsearch_end2end_rerank_score.json") + argparser.add_argument("--summarization_path", type=str, help="should be like xxx/submission/mmsearch_end2end_f1_score.json") + return argparser.parse_args() + + +args = parse_args() + +anno = datasets.load_dataset("CaraJ/MMSearch", name="end2end", split="end2end") + +all_task_result_summary = dict() +for task, attr in task_path_dict.items(): + key = task_key_dict[task] + all_task_result_summary[task] = json.load(open(getattr(args, attr)))[key] + +# total dict +final_result_summary = dict() +final_result_summary["total_dict"] = dict() +final_result_summary["total_dict"]["average"] = sum([ratio * all_task_result_summary[task]["total_dict"]["average"] for task, ratio in task_ratio_dict.items()]) +# area dict +final_result_summary["area_dict"] = dict() +for area in all_task_result_summary["end2end"]["area_dict"]: + final_result_summary["area_dict"][area] = sum([ratio * all_task_result_summary[task]["area_dict"][area]["average"] for task, ratio in task_ratio_dict.items()]) +# subfield dict +final_result_summary["subfield_dict"] = dict() +for subfield in all_task_result_summary["end2end"]["subfield_dict"]: + final_result_summary["subfield_dict"][subfield] = sum([ratio * all_task_result_summary[task]["subfield_dict"][subfield]["average"] for task, ratio in task_ratio_dict.items()]) + +print(f"Average final score: {final_result_summary['total_dict']['average']}") +json.dump(final_result_summary, open(args.save_path, "w"), indent=4) diff --git a/lmms_eval/tasks/mmsearch/lmms_eval_utils.py b/lmms_eval/tasks/mmsearch/lmms_eval_utils.py new file mode 100644 index 00000000..7a256469 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/lmms_eval_utils.py @@ -0,0 +1,430 @@ +import json +import os +from pathlib import Path + +import ipdb +import pandas as pd +import yaml +from loguru import logger as eval_logger +from PIL import Image + +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file +from lmms_eval.tasks.mmsearch.constants import * +from lmms_eval.tasks.mmsearch.prompts.prompt import * +from lmms_eval.tasks.mmsearch.prompts.prompt_w_imagesearch import * +from lmms_eval.tasks.mmsearch.retrieve_content.retriever import Content_Retriever +from lmms_eval.tasks.mmsearch.score.f1_score import get_f1_score +from lmms_eval.tasks.mmsearch.score.req_score import get_requery_score +from lmms_eval.tasks.mmsearch.score.result_summary import get_result_summary +from lmms_eval.tasks.mmsearch.utils.image_utils import pil_image_to_bytes +from lmms_eval.tasks.mmsearch.utils.lmms_eval_utils import * +from lmms_eval.tasks.mmsearch.utils.prompt_utils import * +from lmms_eval.tasks.mmsearch.utils.utils import * + +with open(Path(__file__).parent / "mmsearch.yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +# constants +brief_result_num = 8 +fullpage_num = 1 +content_retriever = Content_Retriever() + + +def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_output=None, round_idx=None, previous_round_info=None): + """ + Returns: + visuals (for next round) + contexts (for next round) + terminal_signal + round_result + previous_round_info + """ + # prepare save dir + middle_result_dir = lmms_eval_specific_kwargs["middle_resules_dir"] if lmms_eval_specific_kwargs is not None and "middle_resules_dir" in lmms_eval_specific_kwargs else "mmsearch_middile_results" + result_cache_dir = lmms_eval_specific_kwargs["result_cache_dir"] if lmms_eval_specific_kwargs is not None and "result_cache_dir" in lmms_eval_specific_kwargs else "mmsearch_result_cache_dir" + os.makedirs(middle_result_dir, exist_ok=True) + os.makedirs(result_cache_dir, exist_ok=True) + # prepare query information + if doc["query_image"] is None: + query_has_image = False + prompt_template_dict = text_query_dict + else: # query with image + query_has_image = True + prompt_template_dict = image_search_text_query_dict + query = doc["query"] + + # initial round: round_idx is None. This remains the same output format as other benchmark + eval_logger.info("----------------Round1: Requery----------------") + if round_idx is None: + prompt_template = prompt_template_dict["stage1"] + if not query_has_image: + text_query = prompt_template.format(question=query) + else: + text_query = prompt_template.format(question=DEFAULT_IMAGE_TOKEN + query, image_search_result=DEFAULT_IMAGE_TOKEN) + return text_query + # round2: search result + rerank + if round_idx == 1: + # if exist, return. This check has to be done here to avoid many + cache_path = os.path.join(result_cache_dir, f"{doc['sample_id']}.json") + if os.path.exists(cache_path): + eval_logger.info(f"{doc['sample_id']} already exists. Load the cache result.") + round_res = json.load(open(cache_path))["round_res"] + return None, None, True, round_res, None + eval_logger.info("----------------Round2: Rerank----------------") + # prepare + requery = previous_output[-1] + stage1_screenshot_dir = os.path.join(middle_result_dir, doc["sample_id"], "stage1") + + # search result + result_brief = search_text_brief_result(query=requery, max_result_num=brief_result_num, screenshot_dir=stage1_screenshot_dir) # relative path # [{'title', 'text','screenshot_path', 'url'}] + + if result_brief is None: # the search engine returns None to the requery + round_res = [requery, None, None] + save_result_to_cache(doc, round_res, dict(), result_cache_dir) + return None, None, True, round_res, None + + website_information, input_image_list = get_website_information(result_brief) + input_image_list = [Image.open(f).convert("RGB") for f in input_image_list] + + prompt_template = prompt_template_dict["stage2"] + if not query_has_image: + image_files = input_image_list + text_query = prompt_template.format(brief_result_num=brief_result_num, rerank_num=fullpage_num, question=query, website_information=website_information, incontext_example=get_rerank_incontext_example(fullpage_num)) + else: + image_files = [doc["query_image"].convert("RGB"), doc["image_search_result"].convert("RGB"), *input_image_list] + text_query = prompt_template.format( + brief_result_num=brief_result_num, + rerank_num=fullpage_num, + question=DEFAULT_IMAGE_TOKEN + query, + image_search_result=DEFAULT_IMAGE_TOKEN, + website_information=website_information, + incontext_example=get_rerank_incontext_example(fullpage_num), + ) + + image_files[0] = image_files[0].copy() + return image_files, text_query, False, previous_output, dict(result_brief=result_brief) + # round3: get full page + summarization + if round_idx == 2: + eval_logger.info("----------------Round3: Summarization----------------") + # prepare + stage3_screenshot_dir = os.path.join(middle_result_dir, doc["sample_id"], "stage3") + requery = previous_output[0] + rerank = previous_output[1] + result_brief = previous_round_info["result_brief"] + + # postprocess the rerank result + selected_index, _ = postprocess_rerank(rerank, fullpage_num) + selected_website = [result_brief[i] for i in selected_index] + result_full = search_url_full_result(urls=[web["url"] for web in selected_website], screenshot_dir=stage3_screenshot_dir) # relative path # [{'content', 'screenshot_fullpage_path'}] + + # add title and snippet + for full_idx, brief_idx in enumerate(selected_index): + result_full[full_idx]["title"] = result_brief[brief_idx]["title"] + result_full[full_idx]["snippet"] = result_brief[brief_idx]["snippet"] + + # conduct content retrieval + for idx, inst_full in enumerate(result_full): + if inst_full["content"] is None: # in case cannot get web content + inst_full["content"] = "" + if inst_full["content"].strip() != "": # some web do not contain language content + result_full[idx]["content"] = content_retriever.get_retrieved_content(requery, inst_full["content"]) + + website_full_information, input_image_list = get_full_website_information(result_full=result_full, image_dir=stage3_screenshot_dir, fullpage_split_dict=FULLPAGE_SPLIT_DICT) + + input_image_list = [Image.open(f).convert("RGB") for f in input_image_list] + # text_query and input_image_list + prompt_template = prompt_template_dict["stage3"] + if not query_has_image: + image_files = input_image_list + text_query = prompt_template.format( + rerank_num=fullpage_num, + website_information=website_full_information, + question=query, + ) + else: + image_files = [*input_image_list, doc["image_search_result"].convert("RGB"), doc["query_image"].convert("RGB")] + # assume only 1 image in the query + text_query = prompt_template.format(rerank_num=fullpage_num, website_information=website_full_information, image_search_result=DEFAULT_IMAGE_TOKEN, question=DEFAULT_IMAGE_TOKEN + query) + + image_files[0] = image_files[0].copy() + return image_files, text_query, False, previous_output, dict(result_brief=result_brief, website_full_information=website_full_information) + # the process should terminate + if round_idx == 3: + save_result_to_cache(doc, previous_output, previous_round_info, result_cache_dir) + return None, None, True, previous_output, None + + +def mmsearch_end2end_doc_to_visual(doc): + if doc["query_image"] is None: + return [] + return [doc["query_image"].convert("RGB").copy(), doc["image_search_result"].convert("RGB")] # .copy is a workround of the type judgement in llava-ov + + +def mmsearch_rerank_doc_to_visual(doc): + image_list = [] + # query image + if doc["query_image"] is not None: + image_list.extend([doc["query_image"].convert("RGB"), doc["image_search_result"].convert("RGB")]) + # website screenshot + image_list.extend(doc[f"website{idx}_head_screenshot"].convert("RGB") for idx in range(brief_result_num)) # there are 8 webpages in the dataset + + # a workround to pass the type judgement in llava-ov + image_list[0] = image_list[0].copy() + return image_list + + +def mmsearch_rerank_doc_to_text(doc, lmms_eval_specific_kwargs=None): + if doc["query_image"] is None: + query_has_image = False + prompt_template_dict = text_query_dict + else: + query_has_image = True + prompt_template_dict = image_search_text_query_dict + + result_brief = [dict(**doc[f"website{i}_info"], screenshot_path=doc[f"website{i}_head_screenshot"]) for i in range(brief_result_num)] # [{'title', 'text','screenshot_path', 'd'}] + query = doc["query"] + + website_information, _ = get_website_information(result_brief) + + # add query image + prompt_template = prompt_template_dict["stage2"] + if not query_has_image: + text_query = prompt_template.format(brief_result_num=brief_result_num, rerank_num=fullpage_num, question=query, website_information=website_information, incontext_example=get_rerank_incontext_example(fullpage_num)) + else: + text_query = prompt_template.format( + brief_result_num=brief_result_num, + rerank_num=fullpage_num, + question=DEFAULT_IMAGE_TOKEN + query, + image_search_result=DEFAULT_IMAGE_TOKEN, + website_information=website_information, + incontext_example=get_rerank_incontext_example(fullpage_num), + ) + return text_query + + +def mmsearch_summarization_doc_to_visual(doc): + # from https://github.com/CaraJ7/MMSearch/blob/main/eval_summarization.py + # set up prompt + if doc["query_image"] is None: + query_has_image = False + prompt_template_dict = text_query_dict + else: + query_has_image = True + prompt_template_dict = image_search_text_query_dict + + result_full = [ + dict( + title=doc["website_title"], + snippet=doc["website_snippet"], + content=doc["website_retrieved_content"], + slimmed_website_fullpage_screenshot=pil_image_to_bytes(doc["website_fullpage_screenshot"]), + ) + ] # the screenshot from the dataset has already been slimmed + _, input_image_list = get_full_website_information(result_full=result_full, fullpage_split_dict=FULLPAGE_SPLIT_DICT) + + # add query image in the input image files + if not query_has_image: + image_files = [Image.open(f).convert("RGB") for f in input_image_list] + else: + image_files = [*[Image.open(f).convert("RGB") for f in input_image_list], doc["image_search_result"].convert("RGB"), doc["query_image"].convert("RGB")] + + # a workround to pass the type judgement in llava-ov + image_files[0] = image_files[0].copy() + return image_files + + +def mmsearch_summarization_doc_to_text(doc, lmms_eval_specific_kwargs=None): + # from https://github.com/CaraJ7/MMSearch/blob/main/eval_summarization.py + # set up prompt + if doc["query_image"] is None: + query_has_image = False + prompt_template_dict = text_query_dict + else: + query_has_image = True + prompt_template_dict = image_search_text_query_dict + + result_full = [ + dict( + title=doc["website_title"], + snippet=doc["website_snippet"], + content=doc["website_retrieved_content"], + slimmed_website_fullpage_screenshot=pil_image_to_bytes(doc["website_fullpage_screenshot"]), + ) + ] # the screenshot from the dataset has already been slimmed + website_full_information, input_image_list = get_full_website_information(result_full=result_full, fullpage_split_dict=FULLPAGE_SPLIT_DICT) + query = doc["query"] + + # add query image in the input image files + prompt_template = prompt_template_dict["stage3"] + if not query_has_image: + text_query = prompt_template.format( + rerank_num=fullpage_num, + website_information=website_full_information, + question=query, + ) + else: + # assume only 1 image in the query + text_query = prompt_template.format(rerank_num=fullpage_num, website_information=website_full_information, image_search_result=DEFAULT_IMAGE_TOKEN, question=DEFAULT_IMAGE_TOKEN + query) + return text_query + + +def mmsearch_end2end_process_results(doc, results): + round_res = results[0] + result = { + "sample_id": doc["sample_id"], + "query": doc["query"], + "timestamp": doc["timestamp"], + "area": doc["area"], + "subfield": doc["subfield"], + "gt_answer": doc["gt_answer"], + "gt_requery": doc["gt_requery"], + "alternative_gt_answers": doc["alternative_gt_answers"], + "requery_prediction": round_res[0], + "answer_prediction": round_res[2], + } + + return { + "end2end_f1_score": result, + "requery_score": result, + } + + +def mmsearch_rerank_process_results(doc, results): + prediction = results[0].strip() + + result = { + "sample_id": doc["sample_id"], + "query": doc["query"], + "timestamp": doc["timestamp"], + "area": doc["area"], + "subfield": doc["subfield"], + "gt_answer": doc["gt_answer"], + "rerank_prediction": prediction, + "valid": doc["valid"], + "not_sure": doc["not_sure"], + "invalid": doc["invalid"], + } + + return { + "rek_score": result, + } + + +def mmsearch_summarization_process_results(doc, results): + prediction = results[0].strip() + + result = { + "sample_id": doc["sample_id"], + "query": doc["query"], + "timestamp": doc["timestamp"], + "area": doc["area"], + "subfield": doc["subfield"], + "gt_answer": doc["gt_answer"], + "alternative_gt_answers": doc["alternative_gt_answers"], + "answer_prediction": prediction, + } + + return { + "summarization_f1_score": result, + } + + +def mmsearch_aggregate_results_f1_score(results, args, *, calculate_gain=False, random_scores=None): + result_list = [] + for inst in results: + prediction = inst["answer_prediction"] + gt_answer = inst["gt_answer"] + f1_score = get_f1_score(prediction, gt_answer) + for gt_alternative_answer in inst["alternative_gt_answers"]: + alternative_f1_score = get_f1_score(prediction, gt_alternative_answer) + if alternative_f1_score > f1_score: + f1_score = alternative_f1_score + inst.update(dict(f1_score=f1_score)) + result_list.append(inst) + + # assert len(result_list) == 300 # assert to be the benchmark length, or the get_result_summary function will not work + # save results + path = generate_submission_file(f"{args.tasks}_f1_results.json", args) + with open(path, "w") as f: + json.dump(result_list, f, indent=4) + # save scores + result_summary = get_result_summary(result_list, result_list, summary_key="f1_score") + path = generate_submission_file(f"{args.tasks}_f1_score.json", args) + with open(path, "w") as f: + json.dump(result_summary, f, indent=4) + avg_f1_score = result_summary["f1_score"]["total_dict"]["average"] + return avg_f1_score + + +def mmsearch_aggregate_results_req_score(results, args, *, calculate_gain=False, random_scores=None): + result_list = [] + for inst in results: + requery = inst["requery_prediction"] + gt_requery = inst["gt_requery"] + req_score = get_requery_score(requery, gt_requery) + inst.update( + dict( + req_score=req_score["score"], + req_score_dict=req_score, + ) + ) + result_list.append(inst) + + assert len(result_list) == 300 # assert to be the benchmark length, or the get_result_summary function will not work + # save results + path = generate_submission_file(f"{args.tasks}_requery_results.json", args) + with open(path, "w") as f: + json.dump(result_list, f, indent=4) + # save scores + result_summary = get_result_summary(result_list, result_list, summary_key="req_score") + path = generate_submission_file(f"{args.tasks}_requery_score.json", args) + with open(path, "w") as f: + json.dump(result_summary, f, indent=4) + avg_req_score = result_summary["req_score"]["total_dict"]["average"] + return avg_req_score + + +def mmsearch_aggregate_results_rek_score(results, args, *, calculate_gain=False, random_scores=None): + result_list = [] + for inst in results: + rerank = inst["rerank_prediction"] + selected_index, valid = postprocess_rerank(rerank, fullpage_num) + selected_index = selected_index[0] # only take the first one + + if not valid: + score = 0 + elif selected_index in inst["valid"]: + score = 1 + elif selected_index in inst["not_sure"]: + score = 0.5 + else: + score = 0 + + inst.update( + dict( + model_output_valid=valid, + parsed_answer_rank=selected_index, + rer_score=score, + ) + ) + result_list.append(inst) + assert len(result_list) == 300 # assert to be the benchmark length, or the get_result_summary function will not work + + # save results + path = generate_submission_file(f"{args.tasks}_rerank_results.json", args) + with open(path, "w") as f: + json.dump(result_list, f, indent=4) + # save score + result_summary = get_result_summary(result_list, result_list, summary_key="rer_score") + path = generate_submission_file(f"{args.tasks}_rerank_score.json", args) + with open(path, "w") as f: + json.dump(result_summary, f, indent=4) + avg_rerank_score = result_summary["rer_score"]["total_dict"]["average"] + return avg_rerank_score diff --git a/lmms_eval/tasks/mmsearch/mmsearch.yaml b/lmms_eval/tasks/mmsearch/mmsearch.yaml new file mode 100644 index 00000000..71c2a763 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/mmsearch.yaml @@ -0,0 +1,9 @@ +# the final score is computed using the script lmms_eval/tasks/mmsearch/get_final_scores.py +# the score file in the submission folder of the three task (end2end, rerank, summarization) should be input as args +group: mathverse +task: + - mmsearch_end2end + - mmsearch_rerank + - mmsearch_summarization +metadata: + version: 0.0 \ No newline at end of file diff --git a/lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml b/lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml new file mode 100644 index 00000000..aa70ea54 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml @@ -0,0 +1,31 @@ +# the final score is computed using the script lmms_eval/tasks/mmsearch/get_final_scores.py +# the score file in the submission folder of the three task (end2end, rerank, summarization) should be input as args +dataset_path: CaraJ/MMSearch +dataset_name: end2end +dataset_kwargs: + token: False +task: "mmsearch_end2end" +test_split: end2end +output_type: generate_until_multi_round +doc_to_visual: !function lmms_eval_utils.mmsearch_end2end_doc_to_visual +doc_to_text: !function lmms_eval_utils.mmsearch_end2end_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" + max_new_tokens: 512 + temperature: 0 + top_p: 0 + num_beams: 1 + do_sample: false +process_results: !function lmms_eval_utils.mmsearch_end2end_process_results +metric_list: + - metric: end2end_f1_score + aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_f1_score + higher_is_better: true + - metric: requery_score + aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_req_score + higher_is_better: true +lmms_eval_specific_kwargs: # whenever a sample is infered, save it + middle_resules_dir: /data1/zrr/jdz/mmsearch/mmsearch_middile_results + result_cache_dir: /data1/zrr/jdz/mmsearch/mmsearch_result_cache_dir diff --git a/lmms_eval/tasks/mmsearch/mmsearch_rerank.yaml b/lmms_eval/tasks/mmsearch/mmsearch_rerank.yaml new file mode 100644 index 00000000..7fc97d02 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/mmsearch_rerank.yaml @@ -0,0 +1,35 @@ +# the final score is computed using the script lmms_eval/tasks/mmsearch/get_final_scores.py +# the score file in the submission folder of the three task (end2end, rerank, summarization) should be input as args +dataset_path: CaraJ/MMSearch +dataset_name: rerank +dataset_kwargs: + token: False +task: "mmsearch_rerank" +test_split: rerank +output_type: generate_until +doc_to_visual: !function lmms_eval_utils.mmsearch_rerank_doc_to_visual +doc_to_text: !function lmms_eval_utils.mmsearch_rerank_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" + max_new_tokens: 1024 + temperature: 0 + top_p: 0 + num_beams: 1 + do_sample: false +process_results: !function lmms_eval_utils.mmsearch_rerank_process_results +metric_list: + - metric: rek_score + aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_rek_score + higher_is_better: true + +lmms_eval_specific_kwargs: + default: + shot_type: "format-prompt" # can also be "custom-prompt" + query_type: "query_wo" # now only support query_wo +model_specific_generation_kwargs: + llava: + image_aspect_ratio: original + llava_onevision: + image_aspect_ratio: original \ No newline at end of file diff --git a/lmms_eval/tasks/mmsearch/mmsearch_summarization.yaml b/lmms_eval/tasks/mmsearch/mmsearch_summarization.yaml new file mode 100644 index 00000000..fdac99da --- /dev/null +++ b/lmms_eval/tasks/mmsearch/mmsearch_summarization.yaml @@ -0,0 +1,33 @@ +# the final score is computed using the script lmms_eval/tasks/mmsearch/get_final_scores.py +# the score file in the submission folder of the three task (end2end, rerank, summarization) should be input as args +dataset_path: CaraJ/MMSearch +dataset_name: summarization +dataset_kwargs: + token: False +task: "mmsearch_summarization" +test_split: summarization +output_type: generate_until +doc_to_visual: !function lmms_eval_utils.mmsearch_summarization_doc_to_visual +doc_to_text: !function lmms_eval_utils.mmsearch_summarization_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" + max_new_tokens: 1024 + temperature: 0 + top_p: 0 + num_beams: 1 + do_sample: false +process_results: !function lmms_eval_utils.mmsearch_summarization_process_results +metric_list: + - metric: summarization_f1_score + aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_f1_score + higher_is_better: true + +lmms_eval_specific_kwargs: + default: + shot_type: "format-prompt" # can also be "custom-prompt" + query_type: "query_wo" # now only support query_wo +model_specific_generation_kwargs: + llava: + image_aspect_ratio: original \ No newline at end of file diff --git a/lmms_eval/tasks/mmsearch/prompts/prompt.py b/lmms_eval/tasks/mmsearch/prompts/prompt.py new file mode 100644 index 00000000..08cd4514 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/prompts/prompt.py @@ -0,0 +1,43 @@ +stage1_text_requery_prompt = """You are a helpful assistant. I am giving you a question, which cannot be solved without external knowledge. +Assume you have access to a text-only search engine (e.g., google). Please raise a query to the search engine to search for what is useful for you to answer the question correctly. Your query needs to consider the attribute of the query to search engine. +Here are 3 examples: +Question: Did Zheng Xiuwen wear a knee pad in the women's singles tennis final in 2024 Paris Olympics? +Query to the search engine: Images of Zheng Xiuwen in the women's singles tennis final in 2024 Paris Olympics + +Question: When will Apple release iPhone16? +Query to the search engine: iPhone 16 release date + +Question: Who will sing a French song at the Olympic Games closing ceremony? +Query to the search engine: Singers at the Olympic Games closing ceremony, French song + +Question: {question} +Query to the search engine (do not involve any explanation): """ + +stage2_text_requery_prompt = """You are a helpful assistant. I am giving you a question and {brief_result_num} website information related to the question (including the screenshot, snippet and title). +You should now read the screenshots, snippets and titles. Select {rerank_num} website that are the most helpful for you to answer the question. Once you select it, the detailed content of them will be provided to help you correctly answer the question. +The question is: {question} +The website informations is: +{website_information} + +You should directly output {rerank_num} website's index that can help you most, separated with ',', and enclose each website in angle brackets. The output format should be: . +An example of the output is: {incontext_example} +Your answer: """ + + +stage3_text_requery_prompt = """You are a helpful assistant. I am giving you a question and {rerank_num} website information related to the question. +Please follow these guidelines when formulating your answer: +1. If the question contains a false premise or assumption, answer "invalid question". +2. When answering questions about dates, use the yyyy-mm-dd format. +3. Answer the question with as few words as you can. + +You should now read the information of the website and answer the question. +The website informations is {website_information} +The question is: {question}. +Please directly output the answer without any explanation: """ + + +text_query_dict = { + "stage1": stage1_text_requery_prompt, + "stage2": stage2_text_requery_prompt, + "stage3": stage3_text_requery_prompt, +} diff --git a/lmms_eval/tasks/mmsearch/prompts/prompt_w_imagesearch.py b/lmms_eval/tasks/mmsearch/prompts/prompt_w_imagesearch.py new file mode 100644 index 00000000..9d36f293 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/prompts/prompt_w_imagesearch.py @@ -0,0 +1,46 @@ +stage1_image_search_text_requery_prompt = """You are a helpful assistant. I am giving you a question including an image, which cannot be solved without external knowledge. +Assume you have access to a search engine (e.g., google). Please raise a query to the search engine to search for what is useful for you to answer the question correctly. You need to consider the characteristics of asking questions to search engines when formulating your questions. +You are also provided with the search result of the image in the question. You should leverage the image search result to raise the text query. +Here are 3 examples: +Question: Did Zheng Xiuwen wear a knee pad in the women's singles tennis final in 2024 Paris Olympics? +Query to the search engine: Images of Zheng Xiuwen in the women's singles tennis final in 2024 Paris Olympics + +Question: When will Apple release iPhone16? +Query to the search engine: iPhone 16 release date + +Question: Who will sing a French song at the Olympic Games closing ceremony? +Query to the search engine: Singers at the Olympic Games closing ceremony, French song + +Question: {question} +The image search result is: {image_search_result} +Query to the search engine (do not involve any explanation): """ + +stage2_image_search_text_requery_prompt = """You are a helpful assistant. I am giving you a question including an image. You are provided with the search result of the image in the question. And you are provided with {brief_result_num} website information related to the question (including the screenshot, snippet and title). +You should now read the screenshots, snippets and titles of these websites. Select {rerank_num} website that are the most helpful for you to answer the question. Once you select it, the detailed content of them will be provided to help you correctly answer the question. +The question is: {question} +The image search result is: {image_search_result} +The website informations is: +{website_information} + +You should directly output {rerank_num} website's index that can help you most, separated with ',', and enclose each website in angle brackets. The output format should be: . +An example of the output is: {incontext_example} +Your answer: """ + +stage3_image_search_text_requery_prompt = """You are a helpful assistant. I am giving you a question including an image. You are provided with the search result of the image in the question. And you are provided with {rerank_num} website information related to the question. +Please follow these guidelines when formulating your answer: +1. If the question contains a false premise or assumption, answer "invalid question". +2. When answering questions about dates, use the yyyy-mm-dd format. +3. Answer the question with as few words as you can. + +You should now read the information of the website and answer the question. +The website informations is {website_information} +The image search result is: {image_search_result} +The question is: {question}. +Please directly output the answer without any explanation: """ + + +image_search_text_query_dict = { + "stage1": stage1_image_search_text_requery_prompt, + "stage2": stage2_image_search_text_requery_prompt, + "stage3": stage3_image_search_text_requery_prompt, +} diff --git a/lmms_eval/tasks/mmsearch/retrieve_content/retriever.py b/lmms_eval/tasks/mmsearch/retrieve_content/retriever.py new file mode 100644 index 00000000..49014a17 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/retrieve_content/retriever.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +from FlagEmbedding import BGEM3FlagModel +from tqdm import tqdm + +from lmms_eval.tasks.mmsearch.retrieve_content.tokenization.tokenizers import ( + LexicalAnalyzer, +) + + +@dataclass +class Config: + chunk_length: int = 200 + slidew: bool = False + sentb: bool = False + TopK: int = 8 + + +class Content_Retriever: + def __init__(self): + # define tokenizer + self.tokenizer = LexicalAnalyzer() + self.tokenizer_offsets = LexicalAnalyzer(do_char_positions=True) + + self.config = Config() + + self.tokenizer_offsets.settings["do_sliding_window_passages"] = self.config.slidew + self.tokenizer_offsets.settings["respect_sent_boundaries"] = self.config.sentb + # define retrieval model + self.model = BGEM3FlagModel("BAAI/bge-m3", device="cpu", use_fp16=False) # Setting use_fp16 to True speeds up computation with a slight performance degradation + + def split_doc_into_passages(self, doc): + text = doc + passages = [] + + passages_tokens = self.tokenizer.analyze_excerpts(text) + for _, passage_tokens in enumerate(passages_tokens): + if self.tokenizer.settings["respect_sent_boundaries"]: + tokens = [] + for psg in passage_tokens: + tokens.extend(psg) + passage_tokens = tokens + if len(passage_tokens) == 0: + continue + + passage_text = " ".join(passage_tokens) + passages.append(passage_text) + + return passages + + def get_retrieved_content(self, requery, content): + docs = [content] + all_chucks = self.split_doc_into_passages(content) + # encode + output_1 = self.model.encode([requery], return_dense=True, return_sparse=True, return_colbert_vecs=True, batch_size=12, max_length=self.config.chunk_length) + output_2 = self.model.encode(all_chucks, return_dense=True, return_sparse=True, return_colbert_vecs=True, batch_size=12, max_length=self.config.chunk_length) + scores = [] + for i in range(len(output_2["colbert_vecs"])): + scores.append(self.model.colbert_score(output_1["colbert_vecs"][0], output_2["colbert_vecs"][i]).item()) + + sorted_pairs = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) + sorted_values, original_indices = zip(*sorted_pairs) + return "\n".join([all_chucks[idx] for idx in sorted_values[: self.config.TopK]]) diff --git a/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/__init__.py b/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/__init__.py new file mode 100644 index 00000000..97fa08f7 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/__init__.py @@ -0,0 +1 @@ +# Implement your code here. diff --git a/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/tokenizers.py b/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/tokenizers.py new file mode 100644 index 00000000..78161ecb --- /dev/null +++ b/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/tokenizers.py @@ -0,0 +1,272 @@ +import logging +import re +import sys +import unicodedata +from typing import Any, Dict, List + +import nltk +from nltk.corpus import stopwords +from nltk.stem import WordNetLemmatizer +from nltk.stem.porter import PorterStemmer +from nltk.tokenize import sent_tokenize as sent_tok + +from lmms_eval.tasks.mmsearch.retrieve_content.tokenization.utils import PickleWriteable + +QUOTES = re.compile("(\"|``|'')") + + +class LexemeWithPositions(dict): + def __init__(self, lexeme: str, begin_pos: int, end_pos: int): + if begin_pos > end_pos: + raise ValueError("begin position cannot be greater than end position for a lexeme") + super().__init__(_lexeme=lexeme, _begin_pos=begin_pos, _end_pos=end_pos) + + def fetch_lexeme(self) -> str: + return self["_lexeme"] + + def update_lexeme(self, lex: str) -> None: + self["_lexeme"] = lex + + def fetch_begin_pos(self) -> int: + return self["_begin_pos"] + + def fetch_end_pos(self) -> int: + return self["_end_pos"] + + +class LexicalAnalyzer(PickleWriteable): + PUNCTUATION_TABLE = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")) + PUNCTUATION_CHARS = set(chr(i) for i in PUNCTUATION_TABLE.keys()) + + DEFAULT_SETTINGS = { + "min_lexeme_len": 1, + "max_lexeme_len": 25, + "space_chars": [], + "do_stop_words": False, + "excerpt_len": 200, + "max_sent_len": 35, + "respect_sent_boundaries": False, + "do_sliding_window_excerpts": False, + "do_char_positions": False, + "strip_inlexeme_punctuation": False, + "remove_numerics": False, + "do_stem": False, + "do_lemma": False, + "do_punct_removal": False, + "do_lower": False, + "lexeme_splitter": nltk.tokenize.word_tokenize, + "reject_threshold": -1, + } + + def __init__(self, **kwargs): + self.settings = self.DEFAULT_SETTINGS.copy() + self.settings.update(kwargs) + + if self.settings["do_stop_words"]: + self.stop_words = set(self.settings.get("stop_words", stopwords.words("english"))) + elif "stop_words" in kwargs: + raise ValueError("please set do_stop_words to True when stop_words are provided.") + + if self.settings["do_stem"] and self.settings["do_lemma"]: + raise ValueError("stemming and lemmatization cannot be turned on simultaneously") + + self.stemmer = PorterStemmer() if self.settings["do_stem"] else None + self.lemmatizer = WordNetLemmatizer() if self.settings["do_lemma"] else None + + def analyze_excerpts(self, text: str) -> List[List[Any]]: + if self.settings["respect_sent_boundaries"]: + return self._analyze_excerpts_with_sentence_boundaries(text) + return self._analyze_excerpts_without_sentence_boundaries(text) + + def _analyze_excerpts_with_sentence_boundaries(self, text: str) -> List[List[List[LexemeWithPositions]]]: + analyzed_sentences = self.analyze_sentences(text) + analyzed_sentences = self._split_long_sentences(analyzed_sentences) + + if not self.settings["do_sliding_window_excerpts"]: + return self._generate_nonoverlapping_excerpts(analyzed_sentences) + + nonoverlapping_excerpts = self._generate_nonoverlapping_excerpts(analyzed_sentences) + if len(nonoverlapping_excerpts) <= 1: + return nonoverlapping_excerpts + + return self._generate_overlapping_excerpts(nonoverlapping_excerpts) + + def _analyze_excerpts_without_sentence_boundaries(self, text: str) -> List[List[LexemeWithPositions]]: + analyzed_lexemes = self.analyze(text) + if self.settings["do_sliding_window_excerpts"]: + return self._generate_sliding_window_excerpts(analyzed_lexemes) + return _split_analyzed_text_to_nonoverlapping_excerpts(analyzed_lexemes, self.settings["excerpt_len"]) + + def _generate_nonoverlapping_excerpts(self, analyzed_sentences: List[List[LexemeWithPositions]]) -> List[List[List[LexemeWithPositions]]]: + excerpts = [] + current_excerpt = [] + current_excerpt_len = 0 + sent_lens = [len(sent) for sent in analyzed_sentences] + + for i, sent in enumerate(analyzed_sentences): + current_excerpt.append(sent) + if i < len(analyzed_sentences) - 1 and sum(sent_lens[i + 1 :]) <= self.settings["excerpt_len"] / 2: + current_excerpt.extend(analyzed_sentences[i + 1 :]) + excerpts.append(current_excerpt) + break + current_excerpt_len += len(sent) + if current_excerpt_len >= self.settings["excerpt_len"] or i == len(analyzed_sentences) - 1: + excerpts.append(current_excerpt) + current_excerpt = [] + current_excerpt_len = 0 + return excerpts + + def _generate_overlapping_excerpts(self, nonoverlapping_excerpts: List[List[List[LexemeWithPositions]]]) -> List[List[List[LexemeWithPositions]]]: + overlapping_excerpts = [] + for i in range(len(nonoverlapping_excerpts) - 1): + left_excerpt_start_index = len(nonoverlapping_excerpts[i]) // 2 + right_excerpt_end_index = max(1, len(nonoverlapping_excerpts[i + 1]) // 2) + overlapping_excerpts.append(nonoverlapping_excerpts[i][left_excerpt_start_index:] + nonoverlapping_excerpts[i + 1][:right_excerpt_end_index]) + + excerpts = [] + for i in range(len(nonoverlapping_excerpts) - 1): + excerpts.append(nonoverlapping_excerpts[i]) + excerpts.append(overlapping_excerpts[i]) + excerpts.append(nonoverlapping_excerpts[-1]) + return excerpts + + def _generate_sliding_window_excerpts(self, analyzed_lexemes: List[LexemeWithPositions]) -> List[List[LexemeWithPositions]]: + excerpts = [] + for i in range(0, len(analyzed_lexemes), self.settings["excerpt_len"] // 2): + excerpts.append(analyzed_lexemes[i : i + self.settings["excerpt_len"]]) + if i + self.settings["excerpt_len"] >= len(analyzed_lexemes): + break + return excerpts + + def analyze_sentences(self, text: str) -> List[List[LexemeWithPositions]]: + analyzed_sentences = [] + for sent_with_positions in self._analyze_text(text, sent_tok): + sent_text = sent_with_positions.fetch_lexeme() + sent_start = sent_with_positions.fetch_begin_pos() + analyzed_sentences.append(self.analyze(sent_text, sent_start)) + return analyzed_sentences + + def analyze(self, text: str, offset: int = 0) -> List[LexemeWithPositions]: + analyzed_lexemes = self._analyze_text(text, self.settings["lexeme_splitter"], offset) + transformed_lexemes = self._filter_and_transform(analyzed_lexemes) + if not self.settings["do_char_positions"]: + return [lexeme_with_positions.fetch_lexeme() for lexeme_with_positions in transformed_lexemes] + return transformed_lexemes + + def _split_long_sentences(self, analyzed_sentences: List[List[LexemeWithPositions]]) -> List[List[LexemeWithPositions]]: + split_sentences = [] + for analyzed_sent in analyzed_sentences: + if len(analyzed_sent) > self.settings["max_sent_len"]: + sent_excerpts = _split_analyzed_text_to_nonoverlapping_excerpts(analyzed_sent, self.settings["max_sent_len"]) + split_sentences.extend(sent_excerpts) + else: + split_sentences.append(analyzed_sent) + return split_sentences + + def _analyze_text(self, text: str, analyzer: callable, offset: int = 0) -> List[LexemeWithPositions]: + if not isinstance(text, str): + raise ValueError(f"text type is invalid: {type(text)}") + + for space_char in self.settings["space_chars"]: + text = text.replace(space_char, " ") + + text = self._omit_long_lexemes(text, self.settings["reject_threshold"]) + segments = analyzer(text) + + if analyzer is nltk.word_tokenize: + segments = self._divide_long_lexemes(segments) + + analyzed_segments = [] + start = 0 + for segment in segments: + segment_start = text.find(segment, start) + if analyzer is nltk.word_tokenize and segment in ["``", "''"]: + quotes_match = QUOTES.search(text, start) + if quotes_match: + segment = quotes_match.group(0) + segment_start = quotes_match.start() + else: + segment_start = -1 + + if segment_start < 0: + raise ValueError(f"cannot find the segment {segment} in the text {text}") + + end = segment_start + len(segment) + analyzed_segments.append(LexemeWithPositions(segment, offset + segment_start, offset + end)) + start = end + + return analyzed_segments + + def _divide_long_lexemes(self, segments: List[str]) -> List[str]: + divided_segments = [] + for seg in segments: + if len(seg) <= self.settings["max_lexeme_len"]: + divided_segments.append(seg) + else: + divided_segments.extend([seg[i : i + self.settings["max_lexeme_len"]] for i in range(0, len(seg), self.settings["max_lexeme_len"])]) + return divided_segments + + def _omit_long_lexemes(self, text: str, reject_threshold: int) -> str: + if reject_threshold < 0: + return text + + verified_sentence = [] + omitted_count = 0 + for lexeme in text.split(): + if len(lexeme) < reject_threshold: + verified_sentence.append(lexeme) + else: + omitted_count += 1 + logging.error(f"omitting long lexeme with length: {len(lexeme)}") + + logging.info(f"total omitted: {omitted_count}, total retained: {len(verified_sentence)}") + return " ".join(verified_sentence) + + def _filter_and_transform(self, lexemes_with_positions: List[LexemeWithPositions]) -> List[LexemeWithPositions]: + transformed_lexemes = [] + + for lexeme_w_positions in lexemes_with_positions: + lexeme = lexeme_w_positions.fetch_lexeme() + + if self._should_omit_lexeme(lexeme): + continue + + lexeme = self._transform_lexeme(lexeme) + + if lexeme is None: + logging.error("analysis produced None as a lexeme") + continue + + lexeme_w_positions.update_lexeme(lexeme) + transformed_lexemes.append(lexeme_w_positions) + + return transformed_lexemes + + def _should_omit_lexeme(self, lexeme: str) -> bool: + return ( + (self.settings["do_punct_removal"] and lexeme in self.PUNCTUATION_CHARS) + or (self.settings["do_stop_words"] and lexeme.lower() in self.stop_words) + or (len(lexeme) < self.settings["min_lexeme_len"]) + or (self.settings["remove_numerics"] and lexeme.isnumeric()) + ) + + def _transform_lexeme(self, lexeme: str) -> str: + if self.settings["do_lower"]: + lexeme = lexeme.lower() + if self.settings["strip_inlexeme_punctuation"]: + lexeme = lexeme.translate(self.PUNCTUATION_TABLE) + if self.settings["do_stem"]: + lexeme = self.stemmer.stem(lexeme) + elif self.settings["do_lemma"]: + lexeme = self.lemmatizer.lemmatize(lexeme) + return lexeme + + +def _split_analyzed_text_to_nonoverlapping_excerpts(analyzed_lexemes: List[LexemeWithPositions], excerpt_len: int) -> List[List[LexemeWithPositions]]: + excerpts = [] + for i in range(0, len(analyzed_lexemes), excerpt_len): + if len(analyzed_lexemes) - (i + excerpt_len) <= excerpt_len / 2: + excerpts.append(analyzed_lexemes[i:]) + break + excerpts.append(analyzed_lexemes[i : i + excerpt_len]) + return excerpts diff --git a/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/utils.py b/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/utils.py new file mode 100644 index 00000000..abc882e9 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/retrieve_content/tokenization/utils.py @@ -0,0 +1,21 @@ +import json +import pickle + + +class PickleWriteable: + """Mixin for persisting an instance with pickle.""" + + def save(self, path): + try: + with open(path, "wb") as f: + pickle.dump(self, f) + except (pickle.PickleError, OSError) as e: + raise IOError("Unable to save {} to path: {}".format(self.__class__.__name__, path)) from e + + @classmethod + def load(cls, path): + try: + with open(path, "rb") as f: + return pickle.load(f) + except (pickle.PickleError, OSError) as e: + raise IOError("Unable to load {} from path: {}".format(cls.__name__, path)) from e diff --git a/lmms_eval/tasks/mmsearch/score/f1_score.py b/lmms_eval/tasks/mmsearch/score/f1_score.py new file mode 100644 index 00000000..d06e24bf --- /dev/null +++ b/lmms_eval/tasks/mmsearch/score/f1_score.py @@ -0,0 +1,47 @@ +import collections +import re +import string + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s): + if not s: + return [] + return normalize_answer(s).split() + + +# qa-f1 +# from https://github.com/chanchimin/RQ-RAG/blob/96b4ec981d4a4399e8402da1b75e16f7812aedfe/retrieval_lm/output/sample_from_tree.py#L181 +def get_f1_score(a_pred, a_gold): + gold_toks = get_tokens(a_gold) + pred_toks = get_tokens(a_pred) + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(gold_toks == pred_toks) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_toks) + recall = 1.0 * num_same / len(gold_toks) + f1 = (2 * precision * recall) / (precision + recall) + return f1 diff --git a/lmms_eval/tasks/mmsearch/score/req_score.py b/lmms_eval/tasks/mmsearch/score/req_score.py new file mode 100644 index 00000000..bef00751 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/score/req_score.py @@ -0,0 +1,23 @@ +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu +from rouge import Rouge + + +def get_requery_score(prediction, gt): + score_dict = dict() + + # 计算BLEUBLEU分数 + smoothing_function = SmoothingFunction().method1 # * used to deal with non-overlap n-gram + + # calculate BLEU-1 score with smoothing function + bleu_score = sentence_bleu([gt.split()], prediction.split(), weights=(1, 0, 0, 0), smoothing_function=smoothing_function) + + # ROUGE + rouge = Rouge() + rouge_scores = rouge.get_scores(prediction, gt)[0] + rouge_l_f1 = rouge_scores["rouge-l"]["f"] + + score_dict["bleu"] = bleu_score + score_dict["rouge_l"] = rouge_l_f1 + score_dict["score"] = (bleu_score + rouge_l_f1) / 2 + + return score_dict diff --git a/lmms_eval/tasks/mmsearch/score/result_summary.py b/lmms_eval/tasks/mmsearch/score/result_summary.py new file mode 100644 index 00000000..f66d9381 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/score/result_summary.py @@ -0,0 +1,53 @@ +def get_area_score(prediction_summary, key): + area_pre_dict = {"news": [], "knowledge": []} + for inst in prediction_summary: + area = inst["area"] + area_pre_dict[area].append(inst[key]) + area_dict = dict() + for k, v in area_pre_dict.items(): + area_dict[k] = dict(length=len(v), average=sum(v) / len(v)) + return area_dict + + +def get_subfield_score(prediction_summary, key, all_subfield): + area_pre_dict = {t: [] for t in all_subfield} + for inst in prediction_summary: + area = inst["subfield"] + area_pre_dict[area].append(inst[key]) + area_dict = dict() + for k, v in area_pre_dict.items(): + area_dict[k] = dict(length=len(v), average=sum(v) / len(v)) + return area_dict + + +def get_result_summary(anno, result_list, summary_key): + if isinstance(summary_key, str): + summary_key = [summary_key] + # change result_list to dict + result_dict = {inst["sample_id"]: inst for inst in result_list} + + all_subfield = [] + # add missing samples to zero + for inst in anno: + if inst["sample_id"] not in result_dict: + dummy_result = dict(sample_id=inst["sample_id"], area=inst["area"], subfield=inst["subfield"], **{k: 0 for k in summary_key}) + result_list.append(dummy_result) + print(f"Missing sample: {inst['sample_id']}") + all_subfield.append(inst["subfield"]) + all_subfield = list(set(all_subfield)) + + return_dict = dict() + for key in summary_key: + try: + return_dict[key] = dict( + total_dict=dict(total_length=len(result_list), average=sum([inst[key] for inst in result_list]) / len(result_list)), + area_dict=get_area_score(result_list, key), + subfield_dict=get_subfield_score(result_list, key, all_subfield), + ) + except: + import pdb + + pdb.set_trace() + print(result_dict) + + return return_dict diff --git a/lmms_eval/tasks/mmsearch/utils/image_utils.py b/lmms_eval/tasks/mmsearch/utils/image_utils.py new file mode 100644 index 00000000..16cc3492 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/utils/image_utils.py @@ -0,0 +1,206 @@ +import os +import shutil +from io import BytesIO + +import cv2 +import numpy as np +from PIL import Image + +Image.MAX_IMAGE_PIXELS = 1000000000 +from loguru import logger as eval_logger + + +# slim fullpage screenshot +def slim_image_and_save(image_path, save_path): + result, is_gray = adaptive_pixel_slimming(image_path) + try: + if is_gray: + save_result = cv2.imwrite(save_path, result.reshape(-1, 1)) + else: + save_result = cv2.imwrite(save_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) + + except Exception as e: + eval_logger.info(f"Slim fullpage screenshot failed: {e}") + shutil.copy(image_path, save_path) + return + + if not save_result: + eval_logger.info("Save slimmed fullpage screenshot failed") + shutil.copy(image_path, save_path) + + +def adaptive_pixel_slimming(image_path, RESIZE_W=1024, RESIZE_H=5120, thresh_gradmap=200, thresh_gradsum=50, thresh_length=15): + # Read the source document image + ori_website = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) + + if ori_website is None: + eval_logger.warning("cv2 load failed. Using PIL to load") + ori_website_rgb = Image.open(image_path).convert("RGB") + ori_website_rgb = np.asanyarray(ori_website_rgb) + ori_website = ori_website_rgb[:, :, [2, 1, 0]] + + # Check if the image is grayscale or color + is_gray = len(ori_website.shape) == 2 or (len(ori_website.shape) == 3 and ori_website.shape[2] == 1) + + if is_gray: + ori_website = ori_website.squeeze() if len(ori_website.shape) == 3 else ori_website + else: + ori_website = cv2.cvtColor(ori_website, cv2.COLOR_BGR2RGB) + + H, W = ori_website.shape[:2] + + # Compute Sobel gradients + if is_gray: + sobel_x = np.abs(cv2.Sobel(ori_website, cv2.CV_64F, 1, 0, ksize=3)) + sobel_y = np.abs(cv2.Sobel(ori_website, cv2.CV_64F, 0, 1, ksize=3)) + else: + sobel_x = np.abs(cv2.Sobel(ori_website, cv2.CV_64F, 1, 0, ksize=3)).max(axis=2) + sobel_y = np.abs(cv2.Sobel(ori_website, cv2.CV_64F, 0, 1, ksize=3)).max(axis=2) + + # Compute gradient map + ori_website_gradient_map = np.maximum(sobel_x, sobel_y) + # Resize to apply the threshold + ori_website_gradient_map = cv2.resize(ori_website_gradient_map, (RESIZE_W, RESIZE_H)) + ori_website_gradient_map[ori_website_gradient_map < thresh_gradmap] = 0 + + # Find blank area in y direction + sum_grad_y = np.sum(ori_website_gradient_map, axis=0) + blank_blocks_y = find_blank_block(sum_grad_y, thresh_gradsum, thresh_length) + blank_blocks_y = resize2predefined(blank_blocks_y, W / RESIZE_W) + + # Find blank area in x direction + sum_grad_x = np.sum(ori_website_gradient_map, axis=1) + blank_blocks_x = find_blank_block(sum_grad_x, thresh_gradsum, thresh_length) + blank_blocks_x = resize2predefined(blank_blocks_x, H / RESIZE_H) + # Remove blank blocks + slimmed_website = remove_blocks(blank_blocks_y, blank_blocks_x, ori_website) + + return slimmed_website, is_gray + + +def find_blank_block(arr, thresh_gradsum, thresh_length): + mask = (arr > thresh_gradsum).astype(int) + diff = np.diff(np.concatenate(([1], mask, [1]))) + end_indices = np.where(diff == 1)[0] + start_indices = np.where(diff == -1)[0] + lengths = end_indices - start_indices + valid_blocks = lengths >= thresh_length + return list(zip(start_indices[valid_blocks], end_indices[valid_blocks])) + + +def resize2predefined(blocks, scale): + return [(int(start * scale), int(end * scale)) for start, end in blocks] + + +def remove_blocks(blank_blocks_y, blank_blocks_x, image): + mask = np.ones(image.shape[:2], dtype=bool) + for start, end in blank_blocks_y: + mask[:, start:end] = False + for start, end in blank_blocks_x: + mask[start:end, :] = False + + # Extract non-black regions + if len(image.shape) == 2: # Grayscale + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + return image[rows][:, cols] + else: # Color + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + return image[rows][:, cols] + + +# crop & split fullpage screenshot to multiple small images +def crop_and_split(fullpage_path, fullpage_split_dict, save_slice_path=None): + slice_height = fullpage_split_dict["slice_height"] + max_slices = fullpage_split_dict["max_slices"] + + return_list = [] + # Open the image + with Image.open(fullpage_path) as img: + width, height = img.size + + # Calculate the number of slices needed + num_slices = min(max_slices, (height + slice_height - 1) // slice_height) + + # Slice and save the image + for i in range(num_slices): + top = i * slice_height + bottom = min((i + 1) * slice_height, height) + + slice = img.crop((0, top, width, bottom)) + + # Save the slice + if save_slice_path is not None: + output_path = os.path.join(save_slice_path, f"slice_{i}.jpg") + slice.save(output_path) + else: + output_path = pil_image_to_bytes(slice) + return_list.append(output_path) + + return return_list + + +# crop image search result +def crop_image_search_results(image_path, save_path): + image = cv2.imread(image_path) + if image is None: + logger.warning("cv2 load failed. Using PIL to load") + print(f"image_path: {image_path}; exist: {os.path.exists(image_path)}") + image_rgb = Image.open(image_path) + image_rgb = np.asanyarray(image_rgb) + image = image_rgb[:, :, [2, 1, 0]] + + # Convert to grayscale + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Apply vertical Sobel operator + sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=5) + + # Convert to 8-bit unsigned integer + sobelx = np.uint8(np.absolute(sobelx)) + + # Apply thresholding + _, thresh = cv2.threshold(sobelx, 50, 255, cv2.THRESH_BINARY) + + # Detect lines using Hough transform + lines = cv2.HoughLines(thresh, 1, np.pi / 180, 600) + + max_x = 0 + + # Check detected lines + if lines is not None: + for rho, theta in lines[:, 0]: + # Only consider nearly vertical lines (theta close to 0 or pi) + if theta < 0.1 or theta > np.pi - 0.1: + a = np.cos(theta) + b = np.sin(theta) + x0 = a * rho + # Calculate two endpoints of the line + x1 = int(x0 + 1000 * (-b)) + x2 = int(x0 - 1000 * (-b)) + # Update maximum x value + max_x = max(max_x, x1, x2) + + # Ensure max_x does not exceed image width + max_x = min(max_x, image.shape[1] - 1) + + # Crop the image, keeping only the part to the right of max_x + cropped_image = image[:, max_x:] + cv2.imwrite(save_path, cropped_image) + + +# convert pil images to bytes to unify the loading method +# the object returned by the function can be loaded with Image.open function +def pil_image_to_bytes(pil_image, format="PNG"): + img_byte_arr = BytesIO() + pil_image.save(img_byte_arr, format=format) + return BytesIO(img_byte_arr.getvalue()) + + +if __name__ == "__main__": + image_path = "temp_files/20240824_163026/21/stage3/0/fullpage.png" + if cv2.imread(image_path, cv2.IMREAD_UNCHANGED) is None: + print("Wrong") + else: + print(cv2.imread(image_path, cv2.IMREAD_UNCHANGED).shape) diff --git a/lmms_eval/tasks/mmsearch/utils/lmms_eval_utils.py b/lmms_eval/tasks/mmsearch/utils/lmms_eval_utils.py new file mode 100644 index 00000000..de5dc41e --- /dev/null +++ b/lmms_eval/tasks/mmsearch/utils/lmms_eval_utils.py @@ -0,0 +1,12 @@ +import json +import os + + +def save_result_to_cache(doc, round_res, previous_round_info, save_dir): + save_dict = dict( + sample_id=doc["sample_id"], + query=doc["query"], + round_res=round_res, + ) + save_dict.update(previous_round_info) + json.dump(save_dict, open(os.path.join(save_dir, f"{save_dict['sample_id']}.json"), "w"), indent=4) diff --git a/lmms_eval/tasks/mmsearch/utils/prompt_utils.py b/lmms_eval/tasks/mmsearch/utils/prompt_utils.py new file mode 100644 index 00000000..55af6902 --- /dev/null +++ b/lmms_eval/tasks/mmsearch/utils/prompt_utils.py @@ -0,0 +1,91 @@ +import os +import re + +from lmms_eval.tasks.mmsearch.utils.image_utils import ( + crop_and_split, + slim_image_and_save, +) + +DEFAULT_IMAGE_TOKEN = "" + + +def get_website_information(result_brief): + """ + result_brief: [{'title', 'text','screenshot_path'}] + """ + website_information, input_image_list = [], [] + for idx, inst in enumerate(result_brief): + template = f"Website {idx+1} Title: {inst['title']};\nWebsite {idx+1} snippet: {inst['snippet']};\nWebsite {idx+1} Screenshot: {DEFAULT_IMAGE_TOKEN}" + website_information.append(template) + input_image_list.append(inst["screenshot_path"]) + + return "\n\n".join(website_information), input_image_list + + +def get_rerank_incontext_example(rerank_num): + l = [f"" for i in range(rerank_num)] + return ",".join(l) + + +def get_full_website_information(result_full, image_dir="", fullpage_split_dict=None, save_slim_dir=None): + """ + result_full: [{'title', 'snippet', 'content','screenshot_path'}] + """ + if save_slim_dir is None: + save_slim_dir = image_dir + + input_image_list = [] + inst = result_full[0] # assert only 1 fullpage content + + template = f"Website Title: {inst['title']};\n Website Snippet: {inst['snippet']};\n" + + # add content + template += f"Website Content: {inst['content']};\n" + + ## slim image to be tense + if "screenshot_fullpage_path" in inst: + split_list = inst["screenshot_fullpage_path"].split("/")[-1].split(".") + save_name = ".".join(split_list[:-1]) + f"_slim.{split_list[-1]}" + save_path = os.path.join(save_slim_dir, save_name) + + slim_image_and_save(image_path=os.path.join(image_dir, inst["screenshot_fullpage_path"]), save_path=save_path) + save_slice_path = os.path.join(save_slim_dir, "slices") + os.makedirs(save_slice_path, exist_ok=True) + elif "slimmed_website_fullpage_screenshot" in inst: # the screenshot is already slimmed + save_path = inst["slimmed_website_fullpage_screenshot"] + save_slice_path = None # do not save the slices + else: + raise ValueError("seems that the inst variable does not contain relevant key") + + # here, we split the fullpage to maximum 10 images, each with 512 height (the width depends on the website itself) + screenshot_fullpage_split_list = crop_and_split(fullpage_path=save_path, fullpage_split_dict=fullpage_split_dict, save_slice_path=save_slice_path) + template += f"Website Screenshot: {DEFAULT_IMAGE_TOKEN*len(screenshot_fullpage_split_list)};\n" + input_image_list.extend(screenshot_fullpage_split_list) + + website_information = template + + return website_information, input_image_list + + +def postprocess_rerank(rerank, rerank_num): + pattern = r"" + matches = re.findall(pattern, rerank) + output_index = [int(x) - 1 for x in matches] + if len(output_index) > rerank_num: + print(f"More index than rerank number: {rerank}") + output_index = output_index[:rerank_num] + valid = False + elif len(output_index) < rerank_num: + print(f"Less index than rerank number: {rerank}") + if len(output_index) == 0: + print("No valid output for rereank") + output_index = [i for i in range(rerank_num)] + valid = False + elif not all([[x < 0 for x in output_index]]): + print(f"Some index is less than 1: {rerank}") + output_index = [i for i in range(rerank_num)] + valid = False + else: + valid = True + + return output_index, valid diff --git a/lmms_eval/tasks/mmsearch/utils/utils.py b/lmms_eval/tasks/mmsearch/utils/utils.py new file mode 100644 index 00000000..6773c46d --- /dev/null +++ b/lmms_eval/tasks/mmsearch/utils/utils.py @@ -0,0 +1,422 @@ +import asyncio +import logging +import math +import os +import random +import tempfile +import time +from typing import Any, Dict, List, Optional + +import matplotlib.pyplot as plt +import requests + +# get rank id for random seed +from accelerate import Accelerator +from duckduckgo_search import DDGS +from langchain_community.document_loaders import UnstructuredHTMLLoader +from loguru import logger as eval_logger +from playwright.async_api import TimeoutError as PlaywrightTimeoutError +from playwright.async_api import async_playwright +from requests.exceptions import RequestException + +from lmms_eval.tasks.mmsearch.constants import * +from lmms_eval.tasks.mmsearch.utils.web_content_utils import * + +accelerator = Accelerator() +WORLD_SIZE = accelerator.num_processes +RANK = accelerator.process_index +random.seed(RANK) + + +### Proxy setting +def get_proxy_settings(): + http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy") + https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy") + + proxies = {} + if http_proxy: + proxies["http"] = http_proxy + if https_proxy: + proxies["https"] = https_proxy + + # Try to obtain environ proxies + if not proxies: + try: + system_proxies = requests.utils.get_environ_proxies("") + if system_proxies: + proxies = system_proxies + except Exception as e: + eval_logger.warning(f"Cannot obtain environ proxies: {e}") + + return proxies + + +PROXY = get_proxy_settings() # get proxy if exist + +### Brief Results + + +def search_text_brief_result(query, max_result_num, screenshot_dir): + os.makedirs(screenshot_dir, exist_ok=True) + return asyncio.run(run_query(query, screenshot_dir, max_result_num)) + + +async def run_query(query: str, screenshot_dir_path: str, max_result_num: int): + engine = DDGSQueryRun(max_results=max_result_num) + results = await engine(query, screenshot_dir_path) + return results + + +## Search Engine API +class RapidAPI: + def __init__(self, rapidapi_name): + self.rapidapi_name = rapidapi_name + self.ddgs = DDGS(proxy=PROXY["https"], timeout=50) if len(PROXY) != 0 else DDGS(timeout=50) + + def query(self, text: str, max_results: int) -> List[Dict[str, Any]]: + initial_delay = 1 + max_retries = 3 + + for attempt in range(max_retries): + try: + time.sleep(random.choice([i for i in range(5, 10 + 20 * WORLD_SIZE, 5)])) # Avoid frequent requests and multiple rank query at the same time + response = list(self.ddgs.text(" ".join(text.strip("'").split(" ")[:100]), max_results=max_results)) + return response[:max_results] + except Exception as e: + error_message = str(e) + if "202" in error_message or "Accepted" in error_message: + delay = initial_delay * (2**attempt) + random.uniform(0, 1) + print(f"Received 202 status code, waiting {delay:.2f} seconds before retrying... (Attempt {attempt + 1}/{max_retries})") + time.sleep(delay) + elif isinstance(e, RequestException): + print(f"Network error: {e}") + time.sleep(random.uniform(1, 3)) + else: + print(f"Unknown error: {e}") + raise ValueError + + +## API: Search Engine Retrieval + Screenshot of top section +class DDGSQueryRun: + name = "duckduckgo_search" + signature = f"{name}(query: str) -> str" + + def __init__(self, max_results: int, rapidapi_name: str = "one"): + self.max_results = max_results + self.api_wrapper = RapidAPI(rapidapi_name) + + async def __call__(self, query: str, screenshot_dir_path: str) -> List[Dict[str, Any]]: + try: + output = self.api_wrapper.query(query, max_results=self.max_results + 20) # account for error website + except Exception as e: + eval_logger.error(f"DDGSQueryRun call failed:") + eval_logger.error(f"{e}") + output = [] + + evidences = [] + for idx, result in enumerate(output): + evidence = {"title": result["title"], "snippet": result.get("description", result.get("body", "")), "url": result["href"], "screenshot_path": os.path.join(screenshot_dir_path, f"{idx}.jpg")} + success = await take_screenshot_async(evidence["url"], os.path.join(screenshot_dir_path, f"{idx}.jpg")) + if success: + evidences.append(evidence) + if len(evidences) == self.max_results: + break + + if not evidences: + evidences = None + + return evidences + + +## Screenshot of top section. Set the size to be 1024*1024 +async def take_screenshot_async(url: str, screenshot_path: str, timeout: int = BRIEF_TIMEOUT): + async with async_playwright() as p: + if len(PROXY) != 0: + browser = await p.chromium.launch(headless=True, proxy={"server": PROXY["https"]}) + context = await browser.new_context(user_agent=USER_AGENT, proxy={"server": PROXY["https"]}, viewport={"width": 1024, "height": 1024}) + else: + browser = await p.chromium.launch(headless=True) + context = await browser.new_context(user_agent=USER_AGENT, viewport={"width": 1024, "height": 1024}) + + page = await context.new_page() + try: + await page.goto(url, wait_until="load", timeout=timeout) + await page.screenshot(path=screenshot_path) + eval_logger.info(f"Successfully taking screenshot of current state: {url}") + except PlaywrightTimeoutError: + eval_logger.info(f"Timeout occurred while loading {url}. Taking screenshot of current state.") + try: + await page.screenshot(path=screenshot_path) + except Exception as e: + eval_logger.info(f"An error occurred while taking screenshot of {url}: {str(e)}") + await context.close() + await browser.close() + return False + except Exception as e: + eval_logger.info(f"An error occurred while taking screenshot of {url}: {str(e)}") + await context.close() + await browser.close() + return False + finally: + await context.close() + await browser.close() + return True + + +### Full page content +def search_url_full_result(urls, screenshot_dir): + results = [] + for idx, url in enumerate(urls): + save_dir_path = os.path.join(screenshot_dir, str(idx)) + os.makedirs(save_dir_path, exist_ok=True) + fullpage_success = take_fullpage_screenshot(url, f"{save_dir_path}/fullpage.png") + if not fullpage_success: + eval_logger.info(f"take_fullpage_screenshot failed. Save a blank image") + # Create a 512x512 pixel blank image + fig, ax = plt.subplots(figsize=(512 / 100, 512 / 100), dpi=100) + # Remove coordinate axes + ax.axis("off") + # Add text + ax.text(0.5, 0.5, "No Content", fontsize=50, ha="center", va="center") + # Adjust layout and set image boundaries + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + # Save image + plt.savefig(f"{save_dir_path}/fullpage.png", dpi=100, pad_inches=0) + + results.append( + dict(content=get_fullpage_content(url), screenshot_fullpage_path=f"{str(idx)}/fullpage.png"), + ) + return results + + +## Fullpage screenshot +def take_fullpage_screenshot(url: str, screenshot_path: str, timeout: int = FULLPAGE_TIMEOUT): + return asyncio.run(_take_fullpage_screenshot(url, screenshot_path)) + + +async def _take_fullpage_screenshot(url: str, screenshot_path: str, timeout: int = FULLPAGE_TIMEOUT): + async with async_playwright() as p: + if len(PROXY) != 0: + browser = await p.chromium.launch(headless=True, proxy={"server": PROXY["https"]}) + context = await browser.new_context(user_agent=USER_AGENT, proxy={"server": PROXY["https"]}, viewport={"width": 512, "height": 512}, is_mobile=True) + else: + browser = await p.chromium.launch(headless=True) + context = await browser.new_context( + user_agent=USER_AGENT, + viewport={"width": 512, "height": 512}, + is_mobile=True, + ) + + page = await context.new_page() + try: + await page.goto(url, wait_until="networkidle", timeout=timeout) + # Scroll the full page for all image to be visible + await scroll_full_page(page) + await page.wait_for_timeout(2000) + await page.screenshot(path=screenshot_path, full_page=True) + eval_logger.info(f"Successfully took full page screenshot: {url}") + return True + except PlaywrightTimeoutError: + eval_logger.info(f"Timeout occurred while loading {url}. Taking screenshot of current state.") + try: + await scroll_full_page(page) + await page.wait_for_timeout(2000) + await page.screenshot(path=screenshot_path, full_page=True) + return True + except Exception as e: + eval_logger.error(f"An error occurred while taking full page screenshot of {url}: {str(e)}") + return False + except Exception as e: + eval_logger.error(f"An error occurred while taking full page screenshot of {url}: {str(e)}") + return False + finally: + await context.close() + await browser.close() + + +## Fullpage textual content +def get_fullpage_content(url: str, timeout: int = FULLPAGE_TIMEOUT) -> Optional[str]: + return asyncio.run(_get_fullpage_content(url, timeout)) + + +async def _get_fullpage_content(url: str, timeout: int = FULLPAGE_CONTENT_TIMEOUT) -> Optional[str]: + async with async_playwright() as p: + if len(PROXY) != 0: + browser = await p.chromium.launch(headless=True, proxy={"server": PROXY["https"]}) + context = await browser.new_context( + user_agent=USER_AGENT, + proxy={"server": PROXY["https"]}, + ) + else: + browser = await p.chromium.launch(headless=True) + context = await browser.new_context( + user_agent=USER_AGENT, + ) + + page = await context.new_page() + + try: + # Set navigation timeout (milliseconds) + page.set_default_navigation_timeout(timeout) + + # Navigate to the specified URL + await page.goto(url, wait_until="load", timeout=timeout) + + html_content = await page.content() + + # use UnstructuredHTMLLoader to extract main content + # setup a temporary file + with tempfile.NamedTemporaryFile(mode="w+", suffix=".html", delete=False) as temp_file: + temp_file.write(html_content) + temp_file_path = temp_file.name + + loader = UnstructuredHTMLLoader(temp_file_path) + data = loader.load() + # delete the temporary file + os.unlink(temp_file_path) + main_text = data[0].page_content + + eval_logger.info(f"Successfully scraping content of current state: {url}") + + return main_text + + except PlaywrightTimeoutError: + eval_logger.info(f"Timeout occurred while loading {url}. Scraping content of current state.") + try: + html_content = await page.content() + main_text = extract_main_content(html_content) + return main_text + except Exception as e: + eval_logger.info(f"An error occurred while processing content of {url}: {str(e)}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + + finally: + await browser.close() + + +### Utils for screenshot +async def scroll_full_page(page, max_height=10000): + return await page.evaluate( + f""" + async () => {{ + const js_height = () => {{ + try {{ + return Math.min(document.body.clientHeight, {max_height}); + }} catch (error) {{ + console.warn("Unable to get clientHeight, using max_height:", error); + return {max_height}; + }} + }}; + + let height = js_height(); + let k = 1; + const scrollStep = 300; // Scroll step length + const pauseDuration = 1000; // Pause duration after each scroll (milliseconds) + const maxHeight = {max_height}; // Maximum scroll height + + while (true) {{ + if (k * scrollStep < height && k * scrollStep < maxHeight) {{ + window.scrollTo(0, k * scrollStep); + await new Promise(resolve => setTimeout(resolve, pauseDuration)); + height = js_height(); + k += 1; + }} else {{ + break; + }} + }} + + // Scroll back to top + window.scrollTo(0, 0); + await new Promise(resolve => setTimeout(resolve, pauseDuration)); + }} + """ + ) + + +async def load_all_images(page): + # Save current scroll position + original_position = await page.evaluate("() => ({ x: window.scrollX, y: window.scrollY })") + + # Find all image elements + locators = page.locator("//img") + + # Create an array of Promises, each corresponding to the loading of an image + promises = await locators.evaluate_all( + """ + elements => elements.map(img => { + if (img.complete) return Promise.resolve(); + return new Promise(resolve => { + img.onload = resolve; + img.onerror = resolve; // Also handle loading failure + // If the image doesn't have a src, it might be a lazy-loaded image + if (!img.src && img.dataset.src) { + img.src = img.dataset.src; + } + }); + }) + """ + ) + + # Wait for all images to finish loading + await page.evaluate("promises => Promise.all(promises)", promises) + + # Restore original scroll position + await page.evaluate("position => window.scrollTo(position.x, position.y)", original_position) + + # Give the page some time to stabilize + await page.wait_for_timeout(1000) + + +### Search image for google lens. Only will be used for new queries to MMSearch-Engine. Can only be used with English Browers. +def search_by_image(url, screenshot_path): + return asyncio.run(_search_by_image(url, screenshot_path)) + + +async def _search_by_image(image_url, screenshot_path="search_results.png", delay=5.0, headless=True): + results = [] + async with async_playwright() as p: + browser = await p.chromium.launch(headless=headless, args=["--lang=en-US"]) + context = await browser.new_context(locale="en-US", viewport={"width": 1280, "height": 800}) + page = await context.new_page() + + await page.goto("https://images.google.com") + await page.wait_for_selector('div[role="button"][aria-label="Search by image"]', state="visible") + await page.click('div[role="button"][aria-label="Search by image"]') + await page.wait_for_selector('input[placeholder="Paste image link"]', state="visible") + await page.fill('input[placeholder="Paste image link"]', image_url) + await page.wait_for_selector('div[jsname="ZtOxCb"]', state="visible") + await page.click('div[jsname="ZtOxCb"]') + + await page.wait_for_selector("img", state="visible") + await load_all_images(page) + await asyncio.sleep(delay) + + # Extract search results + result_cards = await page.query_selector_all(".Vd9M6") + count = 0 + for card in result_cards: + image_element = await card.query_selector("img.wETe9b") + snippet_element = await card.query_selector(".UAiK1e") + a_element = await card.query_selector("a.GZrdsf") + + if image_element and snippet_element: + image_url = await image_element.get_attribute("src") + snippet = await snippet_element.inner_text() + web_url = await a_element.get_attribute("href") + + if image_url.startswith("dat:image"): + print(image_url) + continue + + results.append({"image_url": image_url, "snippet": snippet, "web_url": web_url}) + count += 1 + if count == IMAGE_SEARCH_RESULT: + break + + await page.screenshot(path=screenshot_path, full_page=True) + await browser.close() + + return results diff --git a/lmms_eval/tasks/mmsearch/utils/web_content_utils.py b/lmms_eval/tasks/mmsearch/utils/web_content_utils.py new file mode 100644 index 00000000..510f07ab --- /dev/null +++ b/lmms_eval/tasks/mmsearch/utils/web_content_utils.py @@ -0,0 +1,30 @@ +import re + +from bs4 import BeautifulSoup + + +def clean_text(text): + # Remove excess whitespace characters + text = re.sub(r"\s+", " ", text).strip() + # Remove excess newline characters + text = re.sub(r"\n+", "\n", text) + return text + + +def extract_main_content(html): + soup = BeautifulSoup(html, "html.parser") + + # Remove scripts, styles, navigation, and footer elements + for element in soup(["script", "style", "nav", "footer", "header"]): + element.decompose() + + # Try to find the main content area (assuming it uses
tag or id/class containing "content") + main_content = soup.find("main") or soup.find(id=re.compile("content", re.I)) or soup.find(class_=re.compile("content", re.I)) + + if main_content: + text = main_content.get_text(separator="\n", strip=True) + else: + # If no clear main content area is found, use the content of + text = soup.body.get_text(separator="\n", strip=True) + + return clean_text(text) diff --git a/pyproject.toml b/pyproject.toml index f2e7d5bc..f997bb45 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,12 +95,24 @@ qwen = [ "decord", "qwen_vl_utils", ] +mmsearch = [ + "playwright", + "requests", + "matplotlib", + "duckduckgo_search", + "langchain", + "langchain-community", + "beautifulsoup4", + "FlagEmbedding", + "rouge", +] all = [ "vila", "gemini", "reka", "metrics", "qwen", + "mmsearch" ] [tool.setuptools.packages.find]