Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different approach to enum generation #41

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
bug fixing
  • Loading branch information
Jakub Dulas committed Sep 13, 2023
commit ec6d4c11c9d06fb0b58e94d59e5556da7714a280
2 changes: 1 addition & 1 deletion jsonformer/logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def create_mask(self, allowed_tokens):
def build_tree(self, enums):
tree = {}
for enum in enums:
encoded_enum = self.tokenizer.encode(enum)[1:] # we want to skip sos token
encoded_enum = self.tokenizer.encode(enum, add_special_tokens=False)
curr_obj = tree
for code in encoded_enum:
if code in curr_obj.keys():
Expand Down
6 changes: 4 additions & 2 deletions jsonformer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def generate_enum(self, values) -> str:
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
values = [f'"{value}"'if isinstance(value,str) else value for value in values]
values = [f'"{value}"'if isinstance(value,str) else str(value) for value in values]

response = self.model.generate(
input_tokens,
Expand Down Expand Up @@ -178,7 +178,9 @@ def generate_enum(self, values) -> str:
if response[0] == response[-1] == '"':
return response[1:-1]

return float(response)
if '.' in response:
return float(response)
return int(response)

def generate_object(
self, properties: Dict[str, Any], obj: Dict[str, Any]
Expand Down