diff --git a/src/synthia/validator/text_validator.py b/src/synthia/validator/text_validator.py index 3b759e4..9c73ee5 100644 --- a/src/synthia/validator/text_validator.py +++ b/src/synthia/validator/text_validator.py @@ -91,7 +91,8 @@ def cut_to_max_allowed_weights( max_allowed_weights = settings.max_allowed_weights # sort the score by highest to lowest - sorted_scores = sorted(score_dict.items(), key=lambda x: x[1], reverse=True) + sorted_scores = sorted( + score_dict.items(), key=lambda x: x[1], reverse=True) # cut to max_allowed_weights cut_scores = sorted_scores[:max_allowed_weights] @@ -126,7 +127,8 @@ def get_synthia_netuid(clinet: CommuneClient, subnet_name: str = "synthia"): def get_ip_port(modules_adresses: dict[int, str]): - filtered_addr = {id: extract_address(addr) for id, addr in modules_adresses.items()} + filtered_addr = {id: extract_address(addr) + for id, addr in modules_adresses.items()} ip_port = { id: x.group(0).split(":") for id, x in filtered_addr.items() if x is not None } @@ -264,31 +266,38 @@ async def _get_miner_prediction( val_info: ValidationDataset, miner_info: tuple[list[str], Ss58Address], ) -> tuple[str | None, ValidationDataset]: - connection, miner_key = miner_info - module_ip, module_port = connection + miner_answer = None - question = get_miner_prompt( - val_info.criteria, val_info.chosen_subject, len(val_info.val_answer) - ) - client = ModuleClient(module_ip, int(module_port), self.key) try: - miner_answer = await client.call( - "generate", miner_key, {"prompt": question}, timeout=self.call_timeout + connection, miner_key = miner_info + module_ip, module_port = connection + + question = get_miner_prompt( + val_info.criteria, val_info.chosen_subject, len( + val_info.val_answer) ) + client = ModuleClient(module_ip, int(module_port), self.key) - miner_answer = miner_answer["answer"] + try: + response = await client.call( + "generate", miner_key, {"prompt": question}, timeout=self.call_timeout + ) + miner_answer = response.get("answer") + except Exception as e: + log(f"Miner {module_ip}:{module_port} failed to generate an answer: {e}") + # This is needed, so truly nothing can get propagated through the call stack except Exception as e: - log(f"Miner {module_ip}:{module_port} failed to generate an answer") - print(e) - miner_answer = None + log(f"An unexpected error occurred in _get_miner_prediction: {e}") - return miner_answer, val_info + finally: + return miner_answer, val_info def _get_unit_euclid_distance( self, embedded_miner_answer: list[float], embbeded_val_answer: list[float] ): - distance = euclidean_distance(embedded_miner_answer, embbeded_val_answer) + distance = euclidean_distance( + embedded_miner_answer, embbeded_val_answer) miner_norm = np.linalg.norm(embedded_miner_answer) val_norm = np.linalg.norm(embbeded_val_answer) normalized_distance = distance / (miner_norm + val_norm) @@ -308,7 +317,7 @@ def _score_miner( def _split_val_subject(self, val_answer: str): end_of_subject = val_answer.find("\n") subject = val_answer[:end_of_subject] - val_answer = val_answer[end_of_subject + 1 :] + val_answer = val_answer[end_of_subject + 1:] return subject, val_answer def _test_score(self, text_a: str, text_b: str): @@ -353,7 +362,8 @@ async def validate_step( modules_keys = self.client.query_map_key(syntia_netuid) val_ss58 = self.key.ss58_address if val_ss58 not in modules_keys.values(): - raise RuntimeError(f"validator key {val_ss58} is not registered in subnet") + raise RuntimeError( + f"validator key {val_ss58} is not registered in subnet") modules_info: dict[int, ModuleInfo] = {} modules_filtered_address = get_ip_port(modules_adresses) @@ -369,14 +379,16 @@ async def validate_step( score_dict: dict[int, float] = {} hf_data_list: list[dict[str, str]] = [] # == Validation loop / Scoring == - val_dataset = self._get_validation_dataset(settings, NUM_QUESTIONS_PER_CYCLE) + val_dataset = self._get_validation_dataset( + settings, NUM_QUESTIONS_PER_CYCLE) log(f"Selected the following miners: {modules_info.keys()}") futures: list[asyncio.Task[tuple[str | None, ValidationDataset]]] = [] for mod_info in modules_info.values(): val_info = random.choice(val_dataset) future = asyncio.create_task( - self._get_miner_prediction(val_info, (mod_info.address, mod_info.key)) + self._get_miner_prediction( + val_info, (mod_info.address, mod_info.key)) ) futures.append(future) miner_answers = await asyncio.gather(*futures, return_exceptions=True) @@ -391,10 +403,10 @@ async def validate_step( if not miner_answer: log(f"Skipping miner {uid} that didn't answer") continue - score = self._score_miner(miner_answer, val_info.embedded_val_answer) + score = self._score_miner( + miner_answer, val_info.embedded_val_answer) for answer in response_cache: similarity = fuzz.ratio(answer, miner_answer) # type: ignore - log(f"similarity: {similarity}") response_cache.append(miner_answer) # score has to be lower or eq to 1, as one is the best score