Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/clean synth text #186

Open
wants to merge 9 commits into
base: production
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion validator/control_node/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

async def main() -> None:
nltk.download('punkt_tab')

nltk.download('stopwords')
config = load_config()
await config.psql_db.connect()

Expand Down
20 changes: 10 additions & 10 deletions validator/control_node/src/synthetics/synthetic_generation_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
async def generate_chat_synthetic(model: str, task_config: Any, word_to_token: float = 4) -> payload_models.ChatPayload:
start = time()
synth_corpus = sutils.get_synth_corpus()

try:
total_n_words = sutils.get_random_int_from_dist(size=1, max_value=task_config.orchestrator_server_config.load_model_config['max_model_len']//word_to_token)
if total_n_words.size == 0:
total_n_words = 1000
total_n_words = 1000
else:
total_n_words = int(total_n_words[0])
total_n_words = total_n_words if total_n_words > 0 else 20
Expand Down Expand Up @@ -65,9 +65,9 @@ async def generate_chat_synthetic(model: str, task_config: Any, word_to_token: f
)

logger.debug(f"Generated {total_n_words} words chat synth in {round(time()-start, 3)}s")
logger.debug(f"prompt : {messages}")
logger.debug(f"payload : {payload}")
return payload

except Exception as e:

logger.error("Error in new version of generate_chat_synthetic: %s", e)
Expand All @@ -78,18 +78,18 @@ async def generate_chat_synthetic(model: str, task_config: Any, word_to_token: f
async def generate_chat_comp_synthetic(model: str, task_config: Any, word_to_token: float = 4) -> payload_models.CompletionPayload:
start = time()
synth_corpus = sutils.get_synth_corpus()

try:
total_n_words = sutils.get_random_int_from_dist(size=1, max_value=task_config.orchestrator_server_config.load_model_config['max_model_len']//word_to_token)
if total_n_words.size == 0:
total_n_words = 1000
total_n_words = 1000
else:
total_n_words = int(total_n_words[0])
total_n_words = total_n_words if total_n_words > 0 else 20
logger.debug(f"generating prompt with {total_n_words} words for synth")

message = await sutils.generate_text(synth_corpus, total_n_words)
message = await sutils.generate_text(synth_corpus, total_n_words, completions=True)

payload = payload_models.CompletionPayload(
prompt=message,
temperature=round(random.random(), 1),
Expand All @@ -100,9 +100,9 @@ async def generate_chat_comp_synthetic(model: str, task_config: Any, word_to_tok
)

logger.debug(f"Generated {total_n_words} words chat completion synth in {round(time()-start, 3)}s")
logger.debug(f"prompt : {message}")
logger.debug(f"payload : {payload}")
return payload

except Exception as e:

logger.error("Error in new version of generate_chat_comp_synthetic: %s", e)
Expand Down
84 changes: 58 additions & 26 deletions validator/utils/synthetic/synthetic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from validator.utils.synthetic import synthetic_constants as scst
from redis.asyncio import Redis
from core.models import config_models as cmodels
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import numpy as np
from time import time
import re
Expand All @@ -19,9 +20,7 @@
import diskcache
from PIL import Image
import uuid
import numpy as np
from fiber.logging_utils import get_logger
import random
import json
from functools import lru_cache

Expand All @@ -40,20 +39,53 @@ def get_synth_corpus():
return synth_corpus


def split_sentences(text):
fragments = sent_tokenize(text)
return [frag for frag in fragments if len(frag.split()) > 2]
def clean_text(text):
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\s*([.,!?;:])\s*', r'\1 ', text)
text = re.sub(r'[^\x20-\x7E]', '', text)
text = re.sub(r'([.!?])[.!?]+', r'\1', text)
text = re.sub(r'\.{2}(?!\.)', '.', text)
return text.strip()

def split_sentences_by_stops(text):
stop_words = set(stopwords.words('english'))
rough_sentences = re.split(r'(?<=[.!?])\s+', text)
final_sentences = []

for sentence in rough_sentences:
words = sentence.split()
if len(words) < 15: # don't split short sentences
final_sentences.append(sentence)
continue

current_segment = []
for i, word in enumerate(words):
current_segment.append(word)

# check if current word is a stop word and we have a reasonable segment length
if (word.lower() in stop_words and
len(current_segment) >= 5 and
i != len(words) - 1): # don't split at the last word

final_sentences.append(' '.join(current_segment))
current_segment = []

if current_segment: # add any remaining words
final_sentences.append(' '.join(current_segment))

return [s.strip() for s in final_sentences if s.strip()]

async def get_random_text_from_queue():
async def get_random_text_from_queue():
try:
if not random_text_queue.empty():
return await random_text_queue.get()
text = await random_text_queue.get()
return text
return None
except Exception as e:
logger.error(f"Error retrieving text from queue: {e}")
return None

async def generate_text(corpus, n_words):
async def generate_text(corpus: dict, n_words: int, completions : bool = False):
random.seed(time()%10000)
generated_text_parts = []

Expand All @@ -62,26 +94,26 @@ async def generate_text(corpus, n_words):

while current_word_count < n_words:
random.shuffle(categories)
# randomly select text from random categories, until we reach n_words

for i, category in enumerate(categories):
sentence = random.choice(corpus[category]).strip()
sentences_in_category = split_sentences(sentence)
sentence = clean_text(random.choice(corpus[category]))
sentences_in_category = split_sentences_by_stops(sentence)

if not sentences_in_category:
continue

if i > 0 and i%3 == 0:
sentence_part = await get_random_text_from_queue()
else:
else:
sentence_part = random.choice(sentences_in_category)

if not sentence_part:
continue

sentence_word_count = len(word_tokenize(sentence_part))
if current_word_count + sentence_word_count > n_words:
remaining_words = n_words - current_word_count
truncated_part = ' '.join(sentence_part.split()[:remaining_words])
truncated_part = ' '.join(word_tokenize(sentence_part)[:remaining_words])
generated_text_parts.append(truncated_part)
current_word_count += remaining_words
break
Expand All @@ -94,16 +126,16 @@ async def generate_text(corpus, n_words):

if not generated_text_parts:
raise ValueError("Unable to generate text, problem with corpus?")
merged_text = ' '.join(generated_text_parts).strip()

merged_text = ' '.join(generated_text_parts)
possible_endings = ['.', '!', '?', '...']

if merged_text and merged_text[-1] not in possible_endings:
if merged_text and merged_text[-1] not in possible_endings and not completions:
if random.choice([True, False]):
merged_text += random.choice(possible_endings)

merged_text = re.sub(r'[^\x20-\x7E]', '', merged_text).strip()
return merged_text
return clean_text(merged_text)


def get_random_int_from_dist(size=1, gamma_mean=1000, max_value=8000, gamma_shape=0.5, gaussian_mean=1000, gaussian_weight=0.3, gaussian_std=850):
gamma_scale = gamma_mean / gamma_shape
Expand All @@ -119,7 +151,7 @@ async def fetch_random_text() -> Tuple[str, int, int]:
n_paragraphes = random.randint(2, 4)
n_sentences = random.randint(1, 6)
url = f'http://metaphorpsum.com/paragraphs/{n_paragraphes}/{n_sentences}'

async with aiohttp.ClientSession() as session:
try:
async with session.get(url) as response:
Expand All @@ -142,16 +174,16 @@ async def get_save_random_text() -> None:
while True:
try:
queue_size = random_text_queue.qsize()

if queue_size < scst.RANDOM_TEXT_QUEUE_MAX_SIZE:
text, n_paragraphes, n_sentences = await fetch_random_text()
text, n_paragraphes, n_sentences = await fetch_random_text()
await random_text_queue.put(text)
logger.debug(f"Pushed random metaphorpsum.com text with {n_paragraphes} paragraphs, and {n_sentences} sentences to queue")
else:
logger.debug("Queue is full. Skipping text insertion")
logger.debug("Queue is full. Skipping text insertion")

await asyncio.sleep(2)

except Exception as e:
logger.error(f"Error fetching and saving synthetic data: {e} - sleeping for 60s")
await asyncio.sleep(60)
Expand Down