Skip to content

Commit

Permalink
fix: validator propagating error through the call stack
Browse files Browse the repository at this point in the history
  • Loading branch information
PsicoThePato committed May 22, 2024
1 parent 2933d6e commit 6862919
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/synthia/validator/text_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def cut_to_max_allowed_weights(
max_allowed_weights = settings.max_allowed_weights

# sort the score by highest to lowest
sorted_scores = sorted(score_dict.items(), key=lambda x: x[1], reverse=True)
sorted_scores = sorted(
score_dict.items(), key=lambda x: x[1], reverse=True)

# cut to max_allowed_weights
cut_scores = sorted_scores[:max_allowed_weights]
Expand Down Expand Up @@ -126,7 +127,8 @@ def get_synthia_netuid(clinet: CommuneClient, subnet_name: str = "synthia"):


def get_ip_port(modules_adresses: dict[int, str]):
filtered_addr = {id: extract_address(addr) for id, addr in modules_adresses.items()}
filtered_addr = {id: extract_address(addr)
for id, addr in modules_adresses.items()}
ip_port = {
id: x.group(0).split(":") for id, x in filtered_addr.items() if x is not None
}
Expand Down Expand Up @@ -264,31 +266,38 @@ async def _get_miner_prediction(
val_info: ValidationDataset,
miner_info: tuple[list[str], Ss58Address],
) -> tuple[str | None, ValidationDataset]:
connection, miner_key = miner_info
module_ip, module_port = connection
miner_answer = None

question = get_miner_prompt(
val_info.criteria, val_info.chosen_subject, len(val_info.val_answer)
)
client = ModuleClient(module_ip, int(module_port), self.key)
try:
miner_answer = await client.call(
"generate", miner_key, {"prompt": question}, timeout=self.call_timeout
connection, miner_key = miner_info
module_ip, module_port = connection

question = get_miner_prompt(
val_info.criteria, val_info.chosen_subject, len(
val_info.val_answer)
)
client = ModuleClient(module_ip, int(module_port), self.key)

miner_answer = miner_answer["answer"]
try:
response = await client.call(
"generate", miner_key, {"prompt": question}, timeout=self.call_timeout
)
miner_answer = response.get("answer")
except Exception as e:
log(f"Miner {module_ip}:{module_port} failed to generate an answer: {e}")

# This is needed, so truly nothing can get propagated through the call stack
except Exception as e:
log(f"Miner {module_ip}:{module_port} failed to generate an answer")
print(e)
miner_answer = None
log(f"An unexpected error occurred in _get_miner_prediction: {e}")

return miner_answer, val_info
finally:
return miner_answer, val_info

def _get_unit_euclid_distance(
self, embedded_miner_answer: list[float], embbeded_val_answer: list[float]
):
distance = euclidean_distance(embedded_miner_answer, embbeded_val_answer)
distance = euclidean_distance(
embedded_miner_answer, embbeded_val_answer)
miner_norm = np.linalg.norm(embedded_miner_answer)
val_norm = np.linalg.norm(embbeded_val_answer)
normalized_distance = distance / (miner_norm + val_norm)
Expand All @@ -308,7 +317,7 @@ def _score_miner(
def _split_val_subject(self, val_answer: str):
end_of_subject = val_answer.find("\n")
subject = val_answer[:end_of_subject]
val_answer = val_answer[end_of_subject + 1 :]
val_answer = val_answer[end_of_subject + 1:]
return subject, val_answer

def _test_score(self, text_a: str, text_b: str):
Expand Down Expand Up @@ -353,7 +362,8 @@ async def validate_step(
modules_keys = self.client.query_map_key(syntia_netuid)
val_ss58 = self.key.ss58_address
if val_ss58 not in modules_keys.values():
raise RuntimeError(f"validator key {val_ss58} is not registered in subnet")
raise RuntimeError(
f"validator key {val_ss58} is not registered in subnet")
modules_info: dict[int, ModuleInfo] = {}

modules_filtered_address = get_ip_port(modules_adresses)
Expand All @@ -369,14 +379,16 @@ async def validate_step(
score_dict: dict[int, float] = {}
hf_data_list: list[dict[str, str]] = []
# == Validation loop / Scoring ==
val_dataset = self._get_validation_dataset(settings, NUM_QUESTIONS_PER_CYCLE)
val_dataset = self._get_validation_dataset(
settings, NUM_QUESTIONS_PER_CYCLE)

log(f"Selected the following miners: {modules_info.keys()}")
futures: list[asyncio.Task[tuple[str | None, ValidationDataset]]] = []
for mod_info in modules_info.values():
val_info = random.choice(val_dataset)
future = asyncio.create_task(
self._get_miner_prediction(val_info, (mod_info.address, mod_info.key))
self._get_miner_prediction(
val_info, (mod_info.address, mod_info.key))
)
futures.append(future)
miner_answers = await asyncio.gather(*futures, return_exceptions=True)
Expand All @@ -391,10 +403,10 @@ async def validate_step(
if not miner_answer:
log(f"Skipping miner {uid} that didn't answer")
continue
score = self._score_miner(miner_answer, val_info.embedded_val_answer)
score = self._score_miner(
miner_answer, val_info.embedded_val_answer)
for answer in response_cache:
similarity = fuzz.ratio(answer, miner_answer) # type: ignore
log(f"similarity: {similarity}")
response_cache.append(miner_answer)

# score has to be lower or eq to 1, as one is the best score
Expand Down

0 comments on commit 6862919

Please sign in to comment.