Skip to content

Commit

Permalink
add poetry-instructions dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacRe committed Feb 26, 2023
1 parent 9abd816 commit 1e8a3ac
Show file tree
Hide file tree
Showing 9 changed files with 587 additions and 0 deletions.
55 changes: 55 additions & 0 deletions data/datasets/poetry_instructions/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
---
dataset_info:
features:
- name: conversation
dtype: string
splits:
- name: train
num_bytes: 87758119
num_examples: 1322
- name: validation
num_bytes: 7731418
num_examples: 111
- name: test
num_bytes: 27041394
num_examples: 331
download_size: 63044464
dataset_size: 122530931
---

# Dataset Card for "poetry-instructions"

A dataset of user-assistant dialogue instructions for guided poetry creation.
Poems used were taken from
[merve/poetry](https://huggingface.co/datasets/merve/poetry) and
[matthh/gutenberg-poetry-corpus](https://huggingface.co/datasets/matthh/gutenberg-poetry-corpus).

The dataset contains dialogues in the following formats:

- Poetry Completion:

```
User: Can you continue this poem for me? <poem_start>
Assistant: Sure, a continuation for this poem could be: <poem end>
```

- Create poem in style of (?):

```
User: Can you write a poem for me in the style of <author>?
Assistant: Sure, here's a poem in the style of <author>: <poem>
```

- Creat poem about (?):

```
User: Can you write me a poem about <keywords (extracted using keyphrase model)>?
Assistant: Sure, here's a poem about <keywords>: <poem>
```

- Create poem about (?) in the style of (?):

```
User: Can you write me a poem about <keywords> in the style of <author>?
Assistant: Sure, here's a poem about <keywords> in the style of <author>: <poem>
```
Empty file.
Empty file.
134 changes: 134 additions & 0 deletions data/datasets/poetry_instructions/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import List

import numpy as np
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, TokenClassificationPipeline
from transformers.pipelines import AggregationStrategy


def extract_keywords(poem: str, num_keywords: int) -> List[str]:
model_out = MODEL_STORE.keyword_extractor(poem, num_keywords=num_keywords)
# print(model_out)
return [e["word"] for e in model_out]


class FixedERPipeline(TokenClassificationPipeline):
"""Pipeline for Entity Recognition, modified to allow specification
of the number of entities to extract
"""

def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs,
)
self._num_keywords = 1
self.non_entity_label = "O"
self.key_label = "KEY"
(self.b_key_label_idx,) = [
idx for idx, lbl in self.model.config.id2label.items() if lbl == "B-KEY"
] # index for the B-KEY entity

def __call__(self, *args, num_keywords: int = 1, **kwargs):
self._num_keywords = num_keywords
return super().__call__(*args, **kwargs)

"""Taken from
https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py#L341
Modified to fix the total number of keyphrases returned
"""

def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
aggregation_strategy = AggregationStrategy.SIMPLE
sorted_entities = sorted(
[e for e in pre_entities], key=lambda e: e["scores"][self.b_key_label_idx], reverse=True
)
extracted_b_key_words = set()
extracted_b_key_idxs = set()
# extract top-n tokens with highest B-KEY score, skipping duplicates
for e in sorted_entities:
if e["word"] not in extracted_b_key_words:
extracted_b_key_words.add(e["word"])
extracted_b_key_idxs.add(e["index"])
if len(extracted_b_key_idxs) >= self._num_keywords:
break

if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
entities = []
for pre_entity in pre_entities:
# if entity is one of our extracted B-KEYs, fix prediction to indicate new keyphrase
if pre_entity["index"] in extracted_b_key_idxs:
entity_idx = self.b_key_label_idx
else:
pre_entity["scores"][self.b_key_label_idx] = 0
entity_idx = pre_entity["scores"].argmax()

score = pre_entity["scores"][entity_idx]
entity = {
"entity": self.model.config.id2label[entity_idx],
"score": score,
"index": pre_entity["index"],
"word": pre_entity["word"],
"start": pre_entity["start"],
"end": pre_entity["end"],
}
entities.append(entity)
else:
entities = self.aggregate_words(pre_entities, aggregation_strategy)

if aggregation_strategy == AggregationStrategy.NONE:
return entities
return self.group_entities(entities)

"""Taken from
https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py#L420
Modified to prevent lone I-KEYs from being extracted
"""

def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Args:
entities (`dict`): The entities predicted by the pipeline.
"""
# Get the first entity in the entity group

# modify to set as non-entity if no B-KEY exists in group
entity = self.non_entity_label
score = np.nanmean([entity["score"] for entity in entities])
for e in entities:
bi, tag = self.get_tag(e["entity"])
if bi == "B" and tag == self.key_label:
entity = self.key_label
score = e["score"]
break

tokens = [entity["word"] for entity in entities]

entity_group = {
"entity_group": entity,
"score": score,
"word": self.tokenizer.convert_tokens_to_string(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group


class ModelStore:
def __init__(self):
self.keyword_extractor = None
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

def load(self):
self.keyword_extractor = FixedERPipeline(
model="yanekyuk/bert-uncased-keyword-extractor",
device=self.device,
)


MODEL_STORE = ModelStore()
80 changes: 80 additions & 0 deletions data/datasets/poetry_instructions/data/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import dataclass
from enum import Enum

import numpy as np

from .augmentation import extract_keywords
from .template import (
CONTINUE_POEM_TEMPLATE,
NEW_POEM_ABOUT_IN_STYLE_OF_TEMPLATE,
NEW_POEM_ABOUT_TEMPLATE,
NEW_POEM_IN_STYLE_OF_TEMPLATE,
)


@dataclass
class PoetryRecord:
poem: str
title: str
author: str
theme: str
time_period: str


class PoetryDialogueTask(Enum):
CONTINUE = "make_continue_poem_dialogue"
IN_STYLE = "make_poem_in_style_dialogue"
ABOUT_KEYWORDS = "make_poem_about_keywords_dialogue"
ABOUT_KEYWORDS_IN_STYLE = "make_poem_about_keywords_in_style_dialogue"

@staticmethod
def make_continue_poem_dialogue(record: "PoetryRecord") -> str:
line_splits = record.poem.split("\n")
if len(line_splits) <= 1:
return PoetryDialogueTask.random_task_excluding(PoetryDialogueTask.CONTINUE).prepare_dialogue(record)
line_split_idx = np.random.randint(1, len(line_splits))
return CONTINUE_POEM_TEMPLATE.format(
poem_start="\n".join(line_splits[:line_split_idx]),
poem_end="\n".join(line_splits[line_split_idx:]),
)

@staticmethod
def make_poem_in_style_dialogue(record: "PoetryRecord") -> str:
return NEW_POEM_IN_STYLE_OF_TEMPLATE.format(
author=record.author,
poem=record.poem,
)

@staticmethod
def make_poem_about_keywords_dialogue(record: "PoetryRecord") -> str:
keywords = extract_keywords(record.poem, num_keywords=np.random.randint(1, 4))
keyword_string = keywords[-1]
if len(keywords) > 1:
keyword_string = ", ".join(keywords[:-1]) + " and " + keyword_string
return NEW_POEM_ABOUT_TEMPLATE.format(
about=keyword_string,
poem=record.poem,
)

@staticmethod
def make_poem_about_keywords_in_style_dialogue(record: "PoetryRecord") -> str:
keywords = extract_keywords(record.poem, num_keywords=np.random.randint(1, 4))
keyword_string = keywords[-1]
if len(keywords) > 1:
keyword_string = ", ".join(keywords[:-1]) + " and " + keyword_string
return NEW_POEM_ABOUT_IN_STYLE_OF_TEMPLATE.format(
about=keyword_string,
author=record.author,
poem=record.poem,
)

@staticmethod
def random_task():
return np.random.choice(PoetryDialogueTask)

@staticmethod
def random_task_excluding(*exclude_tasks: "PoetryDialogueTask"):
return np.random.choice([t for t in PoetryDialogueTask if t not in exclude_tasks])

def prepare_dialogue(self, record: PoetryRecord) -> str:
return getattr(PoetryDialogueTask, self.value)(record)
16 changes: 16 additions & 0 deletions data/datasets/poetry_instructions/data/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
CONTINUE_POEM_TEMPLATE = """User: Can you continue this poem for me?
{poem_start}
Assistant: Sure, a continuation for this poem could be:
{poem_end}"""

NEW_POEM_IN_STYLE_OF_TEMPLATE = """User: Can you write a poem for me in the style of {author}?
Assistant: Sure, here's a poem in the style of {author}:
{poem}"""

NEW_POEM_ABOUT_TEMPLATE = """User: Can you write me a poem about {about}?
Assistant: Sure, here's a poem about {about}:
{poem}"""

NEW_POEM_ABOUT_IN_STYLE_OF_TEMPLATE = """User: Can you write me a poem about {about} in the style of {author}?
Assistant: Sure, here's a poem about {about} in the style of {author}:
{poem}"""
21 changes: 21 additions & 0 deletions data/datasets/poetry_instructions/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

import datasets


@dataclass
class OpenAssistantConfig(datasets.BuilderConfig):
"""BuilderConfig for OpenAssistant datasets."""

name: str = None
version: datasets.Version = None
description: str = None
schema: str = None
subset_id: str = None


features = datasets.Features(
{
"conversation": datasets.Value("string"),
}
)
Loading

0 comments on commit 1e8a3ac

Please sign in to comment.