From 2c74a55617cec94a76a5b56b7d9b6527c6bfbbed Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Wed, 13 Sep 2023 08:40:17 +0000 Subject: [PATCH 1/7] first commit --- example.ipynb | 217 +++++++++++++++++++++++--------- jsonformer/logits_processors.py | 24 ++++ jsonformer/main.py | 38 ++++++ 3 files changed, 222 insertions(+), 57 deletions(-) diff --git a/example.ipynb b/example.ipynb index 49520e8..df1a314 100644 --- a/example.ipynb +++ b/example.ipynb @@ -6,35 +6,51 @@ "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ubuntu/jsonformer/.venv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading model and tokenizer...\n", - "Loaded model and tokenizer\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9feb8c50978f4b46b3b9ae40b5051d18", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 str: return response return response.split('"')[0].strip() + + def generate_enum(self) -> str: + prompt = self.get_prompt() + '"' + self.debug("[generate_string]", prompt, is_prompt=True) + input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + + response = self.model.generate( + input_tokens, + max_new_tokens=self.max_string_token_length, + num_return_sequences=1, + temperature=self.temperature, + stopping_criteria=[ + StringStoppingCriteria(self.tokenizer, len(input_tokens[0])) + ], + pad_token_id=self.tokenizer.eos_token_id, + ) + + # Some models output the prompt as part of the response + # This removes the prompt from the response if it is present + if ( + len(response[0]) >= len(input_tokens[0]) + and (response[0][: len(input_tokens[0])] == input_tokens).all() + ): + response = response[0][len(input_tokens[0]) :] + if response.shape[0] == 1: + response = response[0] + + response = self.tokenizer.decode(response, skip_special_tokens=True) + + self.debug("[generate_string]", "|" + response + "|") + + if response.count('"') < 1: + return response + + return response.split('"')[0].strip() def generate_object( self, properties: Dict[str, Any], obj: Dict[str, Any] @@ -146,6 +183,7 @@ def generate_object( self.debug("[generate_object] generating value for", key) obj[key] = self.generate_value(schema, obj, key) return obj + def generate_value( self, From 34526044c22a8a15cb59e1119121a35fb1fd7ed9 Mon Sep 17 00:00:00 2001 From: Jakub Dulas Date: Wed, 13 Sep 2023 10:43:11 +0200 Subject: [PATCH 2/7] added tree --- coding.ipynb | 287 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 coding.ipynb diff --git a/coding.ipynb b/coding.ipynb new file mode 100644 index 0000000..d31274f --- /dev/null +++ b/coding.ipynb @@ -0,0 +1,287 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "from jsonformer import Jsonformer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0327f14545b74c38bc99223e52a7bf52", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
Date: Wed, 13 Sep 2023 13:10:12 +0200 Subject: [PATCH 3/7] added type enum --- README.md | 25 +++++++++++ jsonformer/logits_processors.py | 73 +++++++++++++++++++++++++++------ jsonformer/main.py | 30 +++++++++----- 3 files changed, 105 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index a734002..2fff27e 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,31 @@ print(generated_data) } ``` +### Example Enum +```python +color = { + "type": "object", + "properties": { + "color": { + "type": "enum", + "values": [ + "black", + "red", + "white", + "green", + "blue" + ] + } + } +} +``` + +```python +{ + color: "blue" +} +``` + ## Features - Bulletproof JSON generation: Jsonformer ensures that the generated JSON is always syntactically correct and conforms to the specified schema. diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index 49d9920..560048d 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -60,6 +60,30 @@ def __call__( return True return False + + +class EnumStoppingCriteria(StoppingCriteria): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + prompt_length: int, + enums + ): + self.tokenizer = tokenizer + self.prompt_length = prompt_length + self.enums = enums + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + ) -> bool: + decoded = self.tokenizer.decode( + input_ids[0][self.prompt_length :], skip_special_tokens=True + ) + + return decoded in self.enums + class OutputNumbersTokens(LogitsWarper): def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): @@ -84,25 +108,48 @@ def __call__(self, _, scores): return scores -class OutputLiteralTokens(LogitsWarper): - def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str, enums): +class OutputEnumTokens(LogitsWarper): + def __init__(self, tokenizer: PreTrainedTokenizer, enums): self.tokenizer = tokenizer - self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") vocab_size = len(tokenizer) self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) - self.enums = [f'"{enum}"'if isinstance(enum,str) else enum for enum in enums] - + self.tree = self.build_tree(enums) + self.is_first_call = True + self.vocab_size = len(tokenizer) - def __call__(self, _, scores): - allowed_tokens = [] + def create_mask(self, allowed_tokens): + allowed_mask = torch.zeros(self.vocab_size, dtype=torch.bool) for _, token_id in self.tokenizer.get_vocab().items(): - token_str = self.tokenizer.decode(token_id).strip() - for enum in self.enums: - if enum.startswith(token_str): - allowed_tokens.append(token_id) + if token_id in allowed_tokens: + allowed_mask[token_id] = True + return allowed_mask + + def build_tree(self, enums): + tree = {} + for enum in enums: + encoded_enum = self.tokenizer.encode(enum)[1:] # we want to skip sos token + curr_obj = tree + for code in encoded_enum: + if code in curr_obj.keys(): + curr_obj = curr_obj[code] + else: + curr_obj[code] = {} + curr_obj = curr_obj[code] + return tree + + def __call__(self, input_ids, scores): + if not self.is_first_call: + self.tree = self.tree[int(input_ids[0][-1])] + else: + self.is_first_call = False + + allowed_tokens = self.tree.keys() + + if not len(allowed_tokens): + raise Exception("Shouldn't happen") - - mask = self.allowed_mask.expand_as(scores) + allowed_mask = self.create_mask(allowed_tokens) + mask = allowed_mask.expand_as(scores) scores[~mask] = -float("inf") return scores diff --git a/jsonformer/main.py b/jsonformer/main.py index 72f49f7..3d95944 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -4,6 +4,8 @@ NumberStoppingCriteria, OutputNumbersTokens, StringStoppingCriteria, + EnumStoppingCriteria, + OutputEnumTokens ) from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer @@ -139,20 +141,22 @@ def generate_string(self) -> str: return response.split('"')[0].strip() - def generate_enum(self) -> str: - prompt = self.get_prompt() + '"' - self.debug("[generate_string]", prompt, is_prompt=True) + def generate_enum(self, values) -> str: + prompt = self.get_prompt() + self.debug("[generate_enum]", prompt, is_prompt=True) input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to( self.model.device ) + values = [f'"{value}"'if isinstance(value,str) else value for value in values] response = self.model.generate( input_tokens, - max_new_tokens=self.max_string_token_length, + max_new_tokens=max([len(self.tokenizer.encode(value)[1:]) for value in values]), num_return_sequences=1, temperature=self.temperature, + logits_processor=[OutputEnumTokens(self.tokenizer, values)], stopping_criteria=[ - StringStoppingCriteria(self.tokenizer, len(input_tokens[0])) + EnumStoppingCriteria(self.tokenizer, len(input_tokens[0]), values) ], pad_token_id=self.tokenizer.eos_token_id, ) @@ -169,12 +173,12 @@ def generate_enum(self) -> str: response = self.tokenizer.decode(response, skip_special_tokens=True) - self.debug("[generate_string]", "|" + response + "|") - - if response.count('"') < 1: - return response + self.debug("[generate_enum]", "|" + response + "|") - return response.split('"')[0].strip() + if response[0] == response[-1] == '"': + return response[1:-1] + + return float(response) def generate_object( self, properties: Dict[str, Any], obj: Dict[str, Any] @@ -221,6 +225,12 @@ def generate_value( else: obj.append(new_obj) return self.generate_object(schema["properties"], new_obj) + elif schema_type == "enum": + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + return self.generate_enum(schema["values"]) else: raise ValueError(f"Unsupported schema type: {schema_type}") From ec6d4c11c9d06fb0b58e94d59e5556da7714a280 Mon Sep 17 00:00:00 2001 From: Jakub Dulas Date: Wed, 13 Sep 2023 13:44:43 +0200 Subject: [PATCH 4/7] bug fixing --- jsonformer/logits_processors.py | 2 +- jsonformer/main.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index 560048d..518e785 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -127,7 +127,7 @@ def create_mask(self, allowed_tokens): def build_tree(self, enums): tree = {} for enum in enums: - encoded_enum = self.tokenizer.encode(enum)[1:] # we want to skip sos token + encoded_enum = self.tokenizer.encode(enum, add_special_tokens=False) curr_obj = tree for code in encoded_enum: if code in curr_obj.keys(): diff --git a/jsonformer/main.py b/jsonformer/main.py index 3d95944..2805ee7 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -147,7 +147,7 @@ def generate_enum(self, values) -> str: input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to( self.model.device ) - values = [f'"{value}"'if isinstance(value,str) else value for value in values] + values = [f'"{value}"'if isinstance(value,str) else str(value) for value in values] response = self.model.generate( input_tokens, @@ -178,7 +178,9 @@ def generate_enum(self, values) -> str: if response[0] == response[-1] == '"': return response[1:-1] - return float(response) + if '.' in response: + return float(response) + return int(response) def generate_object( self, properties: Dict[str, Any], obj: Dict[str, Any] From 42971268ec3bb9fb2ac08ce99526f02581612228 Mon Sep 17 00:00:00 2001 From: Jakub Dulas <70340102+jakubdulas@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:03:27 +0200 Subject: [PATCH 5/7] added files --- example.ipynb | 217 +++++++++++++------------------------------------- 1 file changed, 57 insertions(+), 160 deletions(-) diff --git a/example.ipynb b/example.ipynb index df1a314..49520e8 100644 --- a/example.ipynb +++ b/example.ipynb @@ -6,51 +6,35 @@ "metadata": {}, "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9feb8c50978f4b46b3b9ae40b5051d18", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00 Date: Wed, 13 Sep 2023 14:03:56 +0200 Subject: [PATCH 6/7] Delete file --- coding.ipynb | 287 --------------------------------------------------- 1 file changed, 287 deletions(-) delete mode 100644 coding.ipynb diff --git a/coding.ipynb b/coding.ipynb deleted file mode 100644 index d31274f..0000000 --- a/coding.ipynb +++ /dev/null @@ -1,287 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer\n", - "from jsonformer import Jsonformer" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0327f14545b74c38bc99223e52a7bf52", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "VBox(children=(HTML(value='
Date: Thu, 14 Sep 2023 12:49:55 +0200 Subject: [PATCH 7/7] fixed bug --- jsonformer/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jsonformer/main.py b/jsonformer/main.py index 2805ee7..4785c15 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -151,7 +151,7 @@ def generate_enum(self, values) -> str: response = self.model.generate( input_tokens, - max_new_tokens=max([len(self.tokenizer.encode(value)[1:]) for value in values]), + max_new_tokens=max([len(self.tokenizer.encode(value, add_special_tokens=False)) for value in values]), num_return_sequences=1, temperature=self.temperature, logits_processor=[OutputEnumTokens(self.tokenizer, values)],