-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
587 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
--- | ||
dataset_info: | ||
features: | ||
- name: conversation | ||
dtype: string | ||
splits: | ||
- name: train | ||
num_bytes: 87758119 | ||
num_examples: 1322 | ||
- name: validation | ||
num_bytes: 7731418 | ||
num_examples: 111 | ||
- name: test | ||
num_bytes: 27041394 | ||
num_examples: 331 | ||
download_size: 63044464 | ||
dataset_size: 122530931 | ||
--- | ||
|
||
# Dataset Card for "poetry-instructions" | ||
|
||
A dataset of user-assistant dialogue instructions for guided poetry creation. | ||
Poems used were taken from | ||
[merve/poetry](https://huggingface.co/datasets/merve/poetry) and | ||
[matthh/gutenberg-poetry-corpus](https://huggingface.co/datasets/matthh/gutenberg-poetry-corpus). | ||
|
||
The dataset contains dialogues in the following formats: | ||
|
||
- Poetry Completion: | ||
|
||
``` | ||
User: Can you continue this poem for me? <poem_start> | ||
Assistant: Sure, a continuation for this poem could be: <poem end> | ||
``` | ||
|
||
- Create poem in style of (?): | ||
|
||
``` | ||
User: Can you write a poem for me in the style of <author>? | ||
Assistant: Sure, here's a poem in the style of <author>: <poem> | ||
``` | ||
|
||
- Creat poem about (?): | ||
|
||
``` | ||
User: Can you write me a poem about <keywords (extracted using keyphrase model)>? | ||
Assistant: Sure, here's a poem about <keywords>: <poem> | ||
``` | ||
|
||
- Create poem about (?) in the style of (?): | ||
|
||
``` | ||
User: Can you write me a poem about <keywords> in the style of <author>? | ||
Assistant: Sure, here's a poem about <keywords> in the style of <author>: <poem> | ||
``` |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
import torch | ||
from transformers import AutoModelForTokenClassification, AutoTokenizer, TokenClassificationPipeline | ||
from transformers.pipelines import AggregationStrategy | ||
|
||
|
||
def extract_keywords(poem: str, num_keywords: int) -> List[str]: | ||
model_out = MODEL_STORE.keyword_extractor(poem, num_keywords=num_keywords) | ||
# print(model_out) | ||
return [e["word"] for e in model_out] | ||
|
||
|
||
class FixedERPipeline(TokenClassificationPipeline): | ||
"""Pipeline for Entity Recognition, modified to allow specification | ||
of the number of entities to extract | ||
""" | ||
|
||
def __init__(self, model, *args, **kwargs): | ||
super().__init__( | ||
model=AutoModelForTokenClassification.from_pretrained(model), | ||
tokenizer=AutoTokenizer.from_pretrained(model), | ||
*args, | ||
**kwargs, | ||
) | ||
self._num_keywords = 1 | ||
self.non_entity_label = "O" | ||
self.key_label = "KEY" | ||
(self.b_key_label_idx,) = [ | ||
idx for idx, lbl in self.model.config.id2label.items() if lbl == "B-KEY" | ||
] # index for the B-KEY entity | ||
|
||
def __call__(self, *args, num_keywords: int = 1, **kwargs): | ||
self._num_keywords = num_keywords | ||
return super().__call__(*args, **kwargs) | ||
|
||
"""Taken from | ||
https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py#L341 | ||
Modified to fix the total number of keyphrases returned | ||
""" | ||
|
||
def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: | ||
aggregation_strategy = AggregationStrategy.SIMPLE | ||
sorted_entities = sorted( | ||
[e for e in pre_entities], key=lambda e: e["scores"][self.b_key_label_idx], reverse=True | ||
) | ||
extracted_b_key_words = set() | ||
extracted_b_key_idxs = set() | ||
# extract top-n tokens with highest B-KEY score, skipping duplicates | ||
for e in sorted_entities: | ||
if e["word"] not in extracted_b_key_words: | ||
extracted_b_key_words.add(e["word"]) | ||
extracted_b_key_idxs.add(e["index"]) | ||
if len(extracted_b_key_idxs) >= self._num_keywords: | ||
break | ||
|
||
if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}: | ||
entities = [] | ||
for pre_entity in pre_entities: | ||
# if entity is one of our extracted B-KEYs, fix prediction to indicate new keyphrase | ||
if pre_entity["index"] in extracted_b_key_idxs: | ||
entity_idx = self.b_key_label_idx | ||
else: | ||
pre_entity["scores"][self.b_key_label_idx] = 0 | ||
entity_idx = pre_entity["scores"].argmax() | ||
|
||
score = pre_entity["scores"][entity_idx] | ||
entity = { | ||
"entity": self.model.config.id2label[entity_idx], | ||
"score": score, | ||
"index": pre_entity["index"], | ||
"word": pre_entity["word"], | ||
"start": pre_entity["start"], | ||
"end": pre_entity["end"], | ||
} | ||
entities.append(entity) | ||
else: | ||
entities = self.aggregate_words(pre_entities, aggregation_strategy) | ||
|
||
if aggregation_strategy == AggregationStrategy.NONE: | ||
return entities | ||
return self.group_entities(entities) | ||
|
||
"""Taken from | ||
https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py#L420 | ||
Modified to prevent lone I-KEYs from being extracted | ||
""" | ||
|
||
def group_sub_entities(self, entities: List[dict]) -> dict: | ||
""" | ||
Group together the adjacent tokens with the same entity predicted. | ||
Args: | ||
entities (`dict`): The entities predicted by the pipeline. | ||
""" | ||
# Get the first entity in the entity group | ||
|
||
# modify to set as non-entity if no B-KEY exists in group | ||
entity = self.non_entity_label | ||
score = np.nanmean([entity["score"] for entity in entities]) | ||
for e in entities: | ||
bi, tag = self.get_tag(e["entity"]) | ||
if bi == "B" and tag == self.key_label: | ||
entity = self.key_label | ||
score = e["score"] | ||
break | ||
|
||
tokens = [entity["word"] for entity in entities] | ||
|
||
entity_group = { | ||
"entity_group": entity, | ||
"score": score, | ||
"word": self.tokenizer.convert_tokens_to_string(tokens), | ||
"start": entities[0]["start"], | ||
"end": entities[-1]["end"], | ||
} | ||
return entity_group | ||
|
||
|
||
class ModelStore: | ||
def __init__(self): | ||
self.keyword_extractor = None | ||
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
|
||
def load(self): | ||
self.keyword_extractor = FixedERPipeline( | ||
model="yanekyuk/bert-uncased-keyword-extractor", | ||
device=self.device, | ||
) | ||
|
||
|
||
MODEL_STORE = ModelStore() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
import numpy as np | ||
|
||
from .augmentation import extract_keywords | ||
from .template import ( | ||
CONTINUE_POEM_TEMPLATE, | ||
NEW_POEM_ABOUT_IN_STYLE_OF_TEMPLATE, | ||
NEW_POEM_ABOUT_TEMPLATE, | ||
NEW_POEM_IN_STYLE_OF_TEMPLATE, | ||
) | ||
|
||
|
||
@dataclass | ||
class PoetryRecord: | ||
poem: str | ||
title: str | ||
author: str | ||
theme: str | ||
time_period: str | ||
|
||
|
||
class PoetryDialogueTask(Enum): | ||
CONTINUE = "make_continue_poem_dialogue" | ||
IN_STYLE = "make_poem_in_style_dialogue" | ||
ABOUT_KEYWORDS = "make_poem_about_keywords_dialogue" | ||
ABOUT_KEYWORDS_IN_STYLE = "make_poem_about_keywords_in_style_dialogue" | ||
|
||
@staticmethod | ||
def make_continue_poem_dialogue(record: "PoetryRecord") -> str: | ||
line_splits = record.poem.split("\n") | ||
if len(line_splits) <= 1: | ||
return PoetryDialogueTask.random_task_excluding(PoetryDialogueTask.CONTINUE).prepare_dialogue(record) | ||
line_split_idx = np.random.randint(1, len(line_splits)) | ||
return CONTINUE_POEM_TEMPLATE.format( | ||
poem_start="\n".join(line_splits[:line_split_idx]), | ||
poem_end="\n".join(line_splits[line_split_idx:]), | ||
) | ||
|
||
@staticmethod | ||
def make_poem_in_style_dialogue(record: "PoetryRecord") -> str: | ||
return NEW_POEM_IN_STYLE_OF_TEMPLATE.format( | ||
author=record.author, | ||
poem=record.poem, | ||
) | ||
|
||
@staticmethod | ||
def make_poem_about_keywords_dialogue(record: "PoetryRecord") -> str: | ||
keywords = extract_keywords(record.poem, num_keywords=np.random.randint(1, 4)) | ||
keyword_string = keywords[-1] | ||
if len(keywords) > 1: | ||
keyword_string = ", ".join(keywords[:-1]) + " and " + keyword_string | ||
return NEW_POEM_ABOUT_TEMPLATE.format( | ||
about=keyword_string, | ||
poem=record.poem, | ||
) | ||
|
||
@staticmethod | ||
def make_poem_about_keywords_in_style_dialogue(record: "PoetryRecord") -> str: | ||
keywords = extract_keywords(record.poem, num_keywords=np.random.randint(1, 4)) | ||
keyword_string = keywords[-1] | ||
if len(keywords) > 1: | ||
keyword_string = ", ".join(keywords[:-1]) + " and " + keyword_string | ||
return NEW_POEM_ABOUT_IN_STYLE_OF_TEMPLATE.format( | ||
about=keyword_string, | ||
author=record.author, | ||
poem=record.poem, | ||
) | ||
|
||
@staticmethod | ||
def random_task(): | ||
return np.random.choice(PoetryDialogueTask) | ||
|
||
@staticmethod | ||
def random_task_excluding(*exclude_tasks: "PoetryDialogueTask"): | ||
return np.random.choice([t for t in PoetryDialogueTask if t not in exclude_tasks]) | ||
|
||
def prepare_dialogue(self, record: PoetryRecord) -> str: | ||
return getattr(PoetryDialogueTask, self.value)(record) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
CONTINUE_POEM_TEMPLATE = """User: Can you continue this poem for me? | ||
{poem_start} | ||
Assistant: Sure, a continuation for this poem could be: | ||
{poem_end}""" | ||
|
||
NEW_POEM_IN_STYLE_OF_TEMPLATE = """User: Can you write a poem for me in the style of {author}? | ||
Assistant: Sure, here's a poem in the style of {author}: | ||
{poem}""" | ||
|
||
NEW_POEM_ABOUT_TEMPLATE = """User: Can you write me a poem about {about}? | ||
Assistant: Sure, here's a poem about {about}: | ||
{poem}""" | ||
|
||
NEW_POEM_ABOUT_IN_STYLE_OF_TEMPLATE = """User: Can you write me a poem about {about} in the style of {author}? | ||
Assistant: Sure, here's a poem about {about} in the style of {author}: | ||
{poem}""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from dataclasses import dataclass | ||
|
||
import datasets | ||
|
||
|
||
@dataclass | ||
class OpenAssistantConfig(datasets.BuilderConfig): | ||
"""BuilderConfig for OpenAssistant datasets.""" | ||
|
||
name: str = None | ||
version: datasets.Version = None | ||
description: str = None | ||
schema: str = None | ||
subset_id: str = None | ||
|
||
|
||
features = datasets.Features( | ||
{ | ||
"conversation": datasets.Value("string"), | ||
} | ||
) |
Oops, something went wrong.