Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Multi-query processing #511

Open
AlekseySh opened this issue Mar 24, 2024 · 5 comments
Open

Implement Multi-query processing #511

AlekseySh opened this issue Mar 24, 2024 · 5 comments
Labels
documentation Improvements or additions to documentation

Comments

@AlekseySh
Copy link
Contributor

AlekseySh commented Mar 24, 2024

The concept is that a query involves multiple objects instead of just one. We aim to retrieve results for all these objects simultaneously. A straightforward approach is to use frequency voting:

  • First, we obtain N results for each sub-query, yielding N*X results.
  • Then, we retain only the N most frequent results.

As the result, we should have an example similar to "Using a trained model for retrieval" (https://github.com/OML-Team/open-metric-learning?tab=readme-ov-file#examples)

@AlekseySh AlekseySh added the good first issue Good for newcomers label Mar 24, 2024
@AlekseySh AlekseySh added the documentation Improvements or additions to documentation label Mar 24, 2024
@VSXV

This comment was marked as outdated.

@AlekseySh

This comment was marked as outdated.

@AlekseySh

This comment was marked as outdated.

@AlekseySh
Copy link
Contributor Author

AlekseySh commented Apr 12, 2024

EXAMPLE

multi query

@AlekseySh
Copy link
Contributor Author

AlekseySh commented May 22, 2024

DRAFT

from collections import defaultdict

import numpy as np
import torch
from torch import FloatTensor, LongTensor

from oml.retrieval import RetrievalResults

rr = RetrievalResults(
    distances=[
        FloatTensor([0.1, 0.3, 0.6, 0.9]),
        FloatTensor([0.5, 0.8]),
        FloatTensor([0.1, 0.2]),
        FloatTensor([]),
    ],
    retrieved_ids=[
        LongTensor([0, 1, 2, 3]),
        LongTensor([4, 2]),
        LongTensor([10, 20]),
        LongTensor([])
    ],
    gt_ids=[
        LongTensor([0, 2, 50]),
        LongTensor([0, 2, 50]),  # todo: it may be not consisted
        LongTensor([10, 30]),
        LongTensor([50])
    ]
)

query_groups = [[0, 1], [2], [3]]

rr_expected = RetrievalResults(
    distances=[
        FloatTensor([0.1, 0.3, 0.5, 0.7, 0.9]),
        FloatTensor([0.1, 0.3, 0.5, 0.7, 0.9]),
        FloatTensor([0.1, 0.2]),
        FloatTensor([]),
    ],
    retrieved_ids=[
        LongTensor([0, 1, 4, 2, 3]),
        LongTensor([0, 1, 4, 2, 3]),
        LongTensor([10, 20]),
        LongTensor([])
    ],
    gt_ids=[
        LongTensor([0, 2, 50]),
        LongTensor([0, 2, 50]),
        LongTensor([10, 30]),
        LongTensor([50])
    ]
)

distances_upd, retrieved_ids_upd = dict(), dict()
for group in query_groups:
    group_lens = [len(rr.retrieved_ids[ig]) for ig in group]
    if set(group_lens) == {0}:
        for ig in group:
            distances_upd[ig] = FloatTensor([])
            retrieved_ids_upd[ig] = LongTensor([])

    else:
        dist_group = torch.concat([rr.distances[ig] for ig in group])
        ri_group = torch.concat([rr.retrieved_ids[ig] for ig in group])
        gt_ids = torch.concat([rr.gt_ids[ig] for ig in group])

        ri2dist = defaultdict(list)
        for d, ri in zip(dist_group, ri_group):
            ri2dist[int(ri)].append(float(d))

        ri_dist = [(ri, float(np.mean(d))) for ri, d in ri2dist.items()]
        ri_dist = sorted(ri_dist, key=lambda x: x[1], reverse=False)
        ri_upd, dist_upd = zip(*ri_dist)

        for ig in group:
            distances_upd[ig] = FloatTensor(dist_upd)
            retrieved_ids_upd[ig] = LongTensor(ri_upd)

distances_upd_final = []
retrieved_ids_upd_final = []
for iq in range(len(rr.retrieved_ids)):
    distances_upd_final.append(distances_upd[iq])
    retrieved_ids_upd_final.append(retrieved_ids_upd[iq])

rr_produced = RetrievalResults(distances=distances_upd_final, retrieved_ids=retrieved_ids_upd_final, gt_ids=rr.gt_ids)

print(rr_expected)
print(rr_produced)

@AlekseySh AlekseySh changed the title Add an example of using multi query Implement Multi-query processing Jun 8, 2024
@AlekseySh AlekseySh removed the good first issue Good for newcomers label Jun 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
Status: To do
Development

No branches or pull requests

2 participants