Skip to content

Commit

Permalink
fix(abstractions): data transformation with identical source and dest…
Browse files Browse the repository at this point in the history
…ination
  • Loading branch information
TianyiQ committed Nov 30, 2024
1 parent 418255f commit cf09339
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 91 deletions.
6 changes: 3 additions & 3 deletions build_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.path import root
import src.utils.text_utils as tw
import src.utils.text_utils as tu
import src.cleanser.rule_based_cleanser as rb
import src.cleanser.localllm_cleanser as llm_cleanser
import src.model_training.train_hislm as hislm
Expand Down Expand Up @@ -53,7 +53,7 @@ def build_pile_of_law():


if __name__ == "__main__":
tw.write_log(f"\n\n\n\n\n\n=========== NEW RUN ============\n\n")
tu.write_log(f"\n\n\n\n\n\n=========== NEW RUN ============\n\n")
print(
"This script is NOT meant to be run as part of the benchmarking process. Unless you would like to replicate the dataset building & model training process, you could directly run `run_benchmark.py` instead, which will automatically download the pre-built dataset and/or models on demand."
)
Expand Down Expand Up @@ -82,7 +82,7 @@ def build_pile_of_law():
max_hours=10
) # takes ~100h, but if max_hours is supplied then stops after this many hours (won't affect data integrity)
# finishing up
tw.seal_all_files()
tu.seal_all_files()
print("Finished building entire dataset. Proceed to data cleansing.")

if (
Expand Down
1 change: 1 addition & 0 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def sglang_process_batch(
)
assert len(output) == len(sample_dicts)

count = 0
for _ in range(20):
bad_indices = [
k
Expand Down
116 changes: 73 additions & 43 deletions src/abstractions/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import json
import warnings
import src.utils.text_utils as tw
import src.utils.text_utils as tu
from tqdm import tqdm
from src.abstractions.configs.templates_configs import *

Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(
data_path = f"{root}/output/datasets/{data_name}.json"
Data.ask_and_remove_if_exists(data_path, forced_rewrite=True)

with tw.JsonListWriter(data_path) as json_writer:
with tu.JsonListWriter(data_path) as json_writer:
for element in data_content:
json_writer.append(element)

Expand Down Expand Up @@ -169,9 +169,19 @@ def __init__(
else:
Data.name2data[data_name] = [self]

def copy(self) -> "Data":
"""Returns a shallow copy of the current Data instance."""
cp = Data(self.data_name, self.data_type, self.data_path)
def copy(self, data_name: str = None) -> "Data":
"""
Returns a copy of the current Data instance.
Shallow copy if data_name is not provided or identical to the current data_name; deep copy otherwise.
"""
if data_name and data_name != self.data_name:
new_data_path = f"{root}/output/datasets/{data_name}.json"
Data.ask_and_remove_if_exists(new_data_path, forced_rewrite=True)
execute(f"cp {escape(self.data_path)} {escape(new_data_path)}")
cp = Data(data_name, self.data_type, new_data_path)
else:
cp = Data(self.data_name, self.data_type, self.data_path)

cp.key_fields = self.key_fields.copy()
return cp

Expand Down Expand Up @@ -213,6 +223,19 @@ def transform(
:rtype: Data.
"""
out_path = f"{root}/output/datasets/{result_data_name}.json"
if self.data_name == result_data_name or self.data_path == out_path:
warnings.warn(
f"Data name {result_data_name} is the same as the current data name. The old instance will be invalidated."
)
return self.copy("temp_transform_artifact").transform(
transformation,
result_data_name,
forced_rewrite,
max_batch_size,
keep_key_fields,
map_key_fields,
)

Data.ask_and_remove_if_exists(out_path, forced_rewrite)

def write_dict(sample_dict: Dict):
Expand All @@ -225,7 +248,7 @@ def write_dict(sample_dict: Dict):
def map_key_fields_fn(sample_dict: Dict) -> Dict:
nonlocal self
for k, v in self.default_key_fields.items():
if k in self.key_fields and self.key_fields[k] != v:
if k in self.key_fields and self.key_fields.get(k, v) != v and self.key_fields[k] in sample_dict:
sample_dict[v] = sample_dict[self.key_fields[k]]
del sample_dict[self.key_fields[k]]

Expand All @@ -234,7 +257,7 @@ def map_key_fields_fn(sample_dict: Dict) -> Dict:
def inv_map_key_fields_fn(sample_dict: Dict) -> Dict:
nonlocal self
for k, v in self.default_key_fields.items():
if v in sample_dict and self.key_fields[k] != v:
if v in sample_dict and self.key_fields.get(k, v) != v:
sample_dict[self.key_fields[k]] = sample_dict[v]
del sample_dict[v]

Expand All @@ -245,27 +268,29 @@ def inv_map_key_fields_fn(sample_dict: Dict) -> Dict:
is_first = True

if max_batch_size == 1:
for element in tw.read_json_memory_efficient(self.data_path):
if map_key_fields:
element = map_key_fields_fn(element)

transformed = transformation(element)
if transformed is not None:
write_dict(transformed if not map_key_fields else inv_map_key_fields_fn(transformed))
with tu.JsonListReader(self.data_path) as reader:
for element in reader:
if map_key_fields:
element = map_key_fields_fn(element)

transformed = transformation(element)
if transformed is not None:
write_dict(transformed if not map_key_fields else inv_map_key_fields_fn(transformed))

else:
buffer = []

for element in tw.read_json_memory_efficient(self.data_path):
if map_key_fields:
element = map_key_fields_fn(element)

buffer.append(element)
if len(buffer) == max_batch_size:
for e in transformation(buffer):
write_dict(e if not map_key_fields else inv_map_key_fields_fn(e))
buffer = []
out_file.flush()
with tu.JsonListReader(self.data_path) as reader:
for element in reader:
if map_key_fields:
element = map_key_fields_fn(element)

buffer.append(element)
if len(buffer) == max_batch_size:
for e in transformation(buffer):
write_dict(e if not map_key_fields else inv_map_key_fields_fn(e))
buffer = []
out_file.flush()

if buffer:
for e in transformation(buffer):
Expand Down Expand Up @@ -566,8 +591,9 @@ def all_passages(self) -> Iterable[Dict[Hashable, Any]]:
"""
Returns an iterator of all passages (json dicts) in this dataset.
"""
for element in tw.read_json_memory_efficient(self.data_path):
yield element
with tu.JsonListReader(self.data_path) as reader:
for element in reader:
yield element


class DataFileCollection:
Expand Down Expand Up @@ -704,8 +730,9 @@ def all_passages(self) -> Iterable[Dict[Hashable, Any]]:
list(self.all_files())
): # remove list() if it turns out that the file count is super huge
assert in_path[: len(self.collection_path)] == self.collection_path
for element in tw.read_json_memory_efficient(in_path):
yield element
with tu.JsonListReader(in_path) as reader:
for element in reader:
yield element

def transform(
self,
Expand Down Expand Up @@ -766,21 +793,23 @@ def write_dict(sample_dict: Dict):
is_first = True

if max_batch_size == 1:
for element in tw.read_json_memory_efficient(in_path):
transformed = transformation(element)
if transformed is not None:
write_dict(transformed)
with tu.JsonListReader(in_path) as reader:
for element in reader:
transformed = transformation(element)
if transformed is not None:
write_dict(transformed)

else:
buffer = []

for element in tw.read_json_memory_efficient(in_path):
buffer.append(element)
if len(buffer) == max_batch_size:
for e in transformation(buffer):
write_dict(e)
buffer = []
out_file.flush()
with tu.JsonListReader(in_path) as reader:
for element in reader:
buffer.append(element)
if len(buffer) == max_batch_size:
for e in transformation(buffer):
write_dict(e)
buffer = []
out_file.flush()

if buffer:
for e in transformation(buffer):
Expand Down Expand Up @@ -843,10 +872,11 @@ def convert_to_Data(
for in_path in tqdm(
list(self.all_files())
): # remove list() if it turns out that the file count is super huge
for element in tw.read_json_memory_efficient(in_path):
out_file.write("\n" if is_first else ",\n")
is_first = False
out_file.write(json.dumps(clean_dict(element, filter_fields)))
with tu.JsonListReader(in_path) as reader:
for element in reader:
out_file.write("\n" if is_first else ",\n")
is_first = False
out_file.write(json.dumps(clean_dict(element, filter_fields)))

out_file.write("\n]")

Expand Down
14 changes: 7 additions & 7 deletions src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import torch
import warnings
import src.utils.text_utils as tw
import src.utils.text_utils as tu
import random
import numpy as np
from transformers import (
Expand Down Expand Up @@ -702,7 +702,7 @@ def inference(
:code:`serial` - Serial inference.
"""

tw.write_log(
tu.write_log(
f"Inference start, with result_data_name = {result_data_name} and backend = {backend}."
)
input_is_data = isinstance(data, Data)
Expand Down Expand Up @@ -773,7 +773,7 @@ def inference(
f'Backend {backend} not recognized. Options are "sglang", "vllm", "deepspeed", and "serial".'
)

tw.write_log(
tu.write_log(
f"Inference finished, with result_data_name = {result_data_name} and backend = {backend}."
)

Expand Down Expand Up @@ -908,7 +908,7 @@ def __inference_parallel_deepspeed(
result_data_path, f"{escape(result_data_name)}.json"
)

with tw.JsonListWriter(final_file_path) as writer:
with tu.JsonListWriter(final_file_path) as writer:
with open(initial_file_path, "r") as results_file:
for i, (input_dict, result) in enumerate(
zip(data.all_passages(), results_file)
Expand Down Expand Up @@ -960,11 +960,11 @@ def __inference_serial(
root, "output", "inference_results", "inf", data_name + ".json"
)

with tw.JsonListWriter(
with tu.JsonListWriter(
data_path
) as writer: # memory-efficient: no need to place all answers in memory
if isinstance(input_data, Data):
eles = tw.read_json_memory_efficient(input_data.data_path)
eles = tu.read_json_memory_efficient(input_data.data_path)
else:
eles = input_data

Expand All @@ -988,7 +988,7 @@ def __inference_serial(

# display the first element to showcase results
if writer.is_first:
tw.write_log(
tu.write_log(
f"Inference sample: {ele}. Raw response: {repr(response)}."
)

Expand Down
8 changes: 4 additions & 4 deletions src/cleanser/rule_based_cleanser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from src.path import root
import os, json
import re, unicodedata
import src.utils.text_utils as tw
import src.utils.text_utils as tu
from tqdm import tqdm

last_cleaned = ""
Expand Down Expand Up @@ -63,13 +63,13 @@ def cleanse_text(text):
def cleanse_dir(dirct, to_dir):
os.makedirs(to_dir)
for year in tqdm(os.listdir(dirct), desc=dirct.split("/")[-1]):
generator = tw.read_json_memory_efficient(os.path.join(dirct, year))
generator = tu.read_json_memory_efficient(os.path.join(dirct, year))
out = []
for boi in generator:
orig_len = len(boi["content"])
boi["content"] = cleanse_text(boi["content"])

tw.write_log(
tu.write_log(
"cleansed an object in "
+ year
+ ", length reduced from "
Expand All @@ -83,7 +83,7 @@ def cleanse_dir(dirct, to_dir):
if len(boi["content"]) > 200:
out.append(boi)
else:
tw.write_log(f"Ignoring {repr(boi['content'])}.")
tu.write_log(f"Ignoring {repr(boi['content'])}.")

with open(os.path.join(to_dir, year), "w") as file:
json.dump(out, file)
Expand Down
8 changes: 4 additions & 4 deletions src/eebo/process_eebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from tqdm import tqdm
import json
import src.utils.text_utils as tw
import src.utils.text_utils as tu


# a utility function called by build_eebo_dataset, to read the contents in an eebo xml file in a suitable manner
Expand Down Expand Up @@ -106,10 +106,10 @@ def build_eebo_dataset(eebo_path: str = f"{root}/dataset/raw_downloads/EEBO/"):
1000 < year_earliest <= year <= year_latest < 2025
and year_latest - year_earliest < 50
):
tw.write_single_entry(json_dict=json_element)
tu.write_single_entry(json_dict=json_element)
else:
del json_element["creation_year"]
tw.write_log(
tu.write_log(
f"EEBO: Uncertainty too large, saving to undated.json: {line.strip()}"
)
tw.report_undated_entry(json_dict=json_element)
tu.report_undated_entry(json_dict=json_element)
12 changes: 6 additions & 6 deletions src/gutenberg/get_meta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.path import root
import src.utils.text_utils as tw
import src.utils.text_utils as tu
import os, json
import csv
from tqdm import tqdm
Expand Down Expand Up @@ -56,7 +56,7 @@ def gather_meta(raw_dir, record):
break

assert add["content"]
# add['creation_year'] = tw.decode_year_num(add["created_timestamp"], 1100, 2024)
# add['creation_year'] = tu.decode_year_num(add["created_timestamp"], 1100, 2024)
"""
Taking average from the author's y.o.b & y.o.d
"""
Expand All @@ -81,20 +81,20 @@ def gather_meta(raw_dir, record):
add["creation_year"] = None
break
if add["creation_year"] is not None:
tw.write_single_entry(json_dict=add)
tu.write_single_entry(json_dict=add)
else:
tw.report_undated_entry(add)
tu.report_undated_entry(add)
gutenberg_failure_counter += 1
if (
gutenberg_failure_counter <= 100
or gutenberg_failure_counter % 100 == 0
):
tw.write_log(
tu.write_log(
f'Gutenberg: {gutenberg_failure_counter}-th time, saving to undated.json: created_timestamp={add["created_timestamp"]},{full_timestamp}'
)

except Exception as e:
gutenberg_failure_counter += 1
tw.write_log(
tu.write_log(
f"Gutenberg: {gutenberg_failure_counter}-th time, exception {type(e)} {e}"
)
Loading

0 comments on commit cf09339

Please sign in to comment.