Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RerankSorer implementation based on KNN #50

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from

Conversation

Dmitryv-2024
Copy link
Collaborator

Реализовал RerankScorer. В реализации я попробовал немного порефакторить код поэтому класс наследуется от KNNScorer.
После того как сделал рефакторинг, я понял что не знаю как проверить методы load/dump.
Если вам не понравится мой подход, я могу легко за_copy-paste этот класс как сделано во всех других.

Еще я сначала решил реализовать возможность выбрасывать соседей для которых CrossEncoder дает уверенность меньше заданного порога (rank_threshold_cutoff) но потом осознал:

  • количество кандидатов для разных utterance может быть разным, поэтому _count_scores из KNNScorer не будет работать и нужна другая реализация
  • может получится, что для какого-то utterance вообще нет вариантов. По идее, такой вариант подходит под класс OOS, но количество классов нужно расширять, а это должно быть согласовано во всей pipeline.

Тест получился, практически полной копией теста для KNNScorer, тоже кандидат для рефакторинга.

self.scorer_name = scorer_name
self.m = k if m is None else m
self.rank_threshold_cutoff = rank_threshold_cutoff
self._scorer = CrossEncoder(self.scorer_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

это надо переместить в fit(), слишком тяжелая операция

у нас философия такая же как sklearn: в конструкторе мы только сохраняем и валидируем аургменты конструктора

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

в связи с этим надо будет реализовать загрузку кросс энкодера в методе load()

embedder_name: str,
k: int,
weights: WEIGHT_TYPES,
scorer_name: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

я бы дал название этому аргументу cross_encoder_name

@voorhs
Copy link
Collaborator

voorhs commented Nov 25, 2024

еще надо добавить этот модуль в серч спейс тестов с оптимизацией и инференсом

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants