Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for null and union types #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions jsonformer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
OutputNumbersTokens,
StringStoppingCriteria,
)
from jsonformer.type_prefixes import get_prefix_tokens_for_types

from termcolor import cprint
from transformers import PreTrainedModel, PreTrainedTokenizer
import json
Expand Down Expand Up @@ -33,6 +35,8 @@ def __init__(
self.json_schema = json_schema
self.prompt = prompt

self.type_prefix_tokens = get_prefix_tokens_for_types(tokenizer)

self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt)

self.generation_marker = "|GENERATION|"
Expand Down Expand Up @@ -147,13 +151,49 @@ def generate_object(
obj[key] = self.generate_value(schema, obj, key)
return obj

def choose_type_to_generate(self, possible_types: List[str]) -> str:
possible_types = list(set(possible_types)) # remove duplicates
self.debug("[choose_type_to_generate]", possible_types)
if len(possible_types) < 1:
raise ValueError(f"Union type must not be empty")
elif len(possible_types) == 1:
return possible_types[0]

prompt = self.get_prompt()
input_tensor = self.tokenizer.encode(prompt, return_tensors="pt")
output = self.model.forward(input_tensor.to(self.model.device))
logits = output.logits[0, -1]

max_type = None
max_logit = -float("inf")
for possible_type in possible_types:
try:
prefix_tokens = self.type_prefix_tokens[possible_type]
except KeyError:
raise ValueError(f"Unsupported schema type: {possible_type}")
max_type_logit = logits[prefix_tokens].max()
if max_type_logit > max_logit:
max_type = possible_type
max_logit = max_type_logit

if max_type is None:
raise Exception("Unable to find best type to generate for union type")
self.debug("[choose_type_to_generate]", max_type)
return max_type

def generate_value(
self,
schema: Dict[str, Any],
obj: Union[Dict[str, Any], List[Any]],
key: Union[str, None] = None,
) -> Any:
schema_type = schema["type"]
if isinstance(schema_type, list):
if key:
obj[key] = self.generation_marker
else:
obj.append(self.generation_marker)
schema_type = self.choose_type_to_generate(schema_type)
if schema_type == "number":
if key:
obj[key] = self.generation_marker
Expand Down Expand Up @@ -183,6 +223,8 @@ def generate_value(
else:
obj.append(new_obj)
return self.generate_object(schema["properties"], new_obj)
elif schema_type == "null":
return None
else:
raise ValueError(f"Unsupported schema type: {schema_type}")

Expand Down
32 changes: 32 additions & 0 deletions jsonformer/type_prefixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from transformers import PreTrainedTokenizer
from typing import Dict, List
import re

def is_number_prefix(s: str) -> bool:
return re.match(r"^[\-\d]+\.?[\d]*$", s)

def is_boolean_prefix(s: str) -> bool:
return 'true'.startswith(s) or 'false'.startswith(s)

def is_null_prefix(s: str) -> bool:
return 'null'.startswith(s)

def is_string_prefix(s: str) -> bool:
return re.match(r'^"[^"]*"?$', s)

def is_array_prefix(s: str) -> bool:
return re.match(r'^\[["\-\d\[{]*$', s)

def is_object_prefix(s: str) -> bool:
return re.match(r'^\{"?$', s)

def get_prefix_tokens_for_types(tokenizer: PreTrainedTokenizer) -> Dict[str, List[str]]:
vocab = tokenizer.vocab.items()
return {
"number": [v for k, v in vocab if is_number_prefix(k)],
"boolean": [v for k, v in vocab if is_boolean_prefix(k)],
"null": [v for k, v in vocab if is_null_prefix(k)],
"string": [v for k, v in vocab if is_string_prefix(k)],
"array": [v for k, v in vocab if is_array_prefix(k)],
"object": [v for k, v in vocab if is_object_prefix(k)],
}