Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can you share the code about how to sample the support set? #14

Open
ToneLi opened this issue Mar 7, 2023 · 1 comment
Open

can you share the code about how to sample the support set? #14

ToneLi opened this issue Mar 7, 2023 · 1 comment

Comments

@ToneLi
Copy link

ToneLi commented Mar 7, 2023

No description provided.

@thuan00
Copy link

thuan00 commented May 17, 2024

hi, i've implemented the sampling method in the paper for personal purpose
hope its helpful for u too

import random

def sample_ner_data_struct_shot(samples, count_fn, k=1, random_state=None):
    """ sample or select a subset of samples with k
        using the sampling method from https://arxiv.org/abs/2010.02405
    Args:
        samples: list
        count_fn: input a sample, return a dict of {entity_type: count}
        k: number of entity instances for each entity type
    Returns:
        indices of the selected samples
        entity count of the selected samples
    """
    # count entities
    count = {} # total count
    samples_count = [] # count for each sample
    for sample in samples:
        sample_count = count_fn(sample)
        samples_count.append(sample_count)
        for e_type, e_count in sample_count.items():
            count[e_type] = count.get(e_type, 0) + e_count

    # sort by entity count, iterate from the infrequent entity to the frequent and sample
    entity_types = sorted(count.keys(), key=lambda k: count[k])
    selected_ids = set()
    selected_count = {k:0 for k in entity_types}
    random.seed(random_state)
    for entity_type in entity_types:
        while selected_count[entity_type] < k:
            samples_with_e = [i for i in range(len(samples)) if entity_type in samples_count[i] and i not in selected_ids]
            sample_id = random.choice(samples_with_e)
            selected_ids.add(sample_id)
            # update selected_count
            for e_type, e_count in samples_count[sample_id].items():
                selected_count[e_type] += e_count

    return list(selected_ids), selected_count



from collections import Counter

def count_entity_(sample):
    return Counter([slot['label'] for slot in sample['slots']])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants