Skip to content

Commit

Permalink
Merge pull request #73 from lamalab-org/vocab-store
Browse files Browse the repository at this point in the history
feat: move vocabs to external storage
  • Loading branch information
n0w0f authored Jun 6, 2024
2 parents 3f7081c + 1d9d0c6 commit b823574
Show file tree
Hide file tree
Showing 18 changed files with 90 additions and 62,788 deletions.
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
]
keywords = ["llm", "materials", "chemistry"]
dependencies = ["transformers", "slices", "robocrys", "matminer", "keras<3"]
dependencies = [
"transformers",
"slices",
"robocrys",
"matminer",
"keras<3",
"pystow",
]
[project.urls]
Homepage = "https://github.com/lamalab-org/xtal2txt"
Issues = "https://github.com/lamalab-org/xtal2txt/issues"
Expand Down
105 changes: 79 additions & 26 deletions src/xtal2txt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,69 @@
)

from typing import List
from xtal2txt.utils import xtal2txt_storage


THIS_DIR = os.path.dirname(os.path.abspath(__file__))

SLICE_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab.txt")
SLICE_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab_rt.txt")
SLICE_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/slice_vocab.txt?download=1"
)
)
SLICE_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/slice_vocab_rt.txt?download=1"
)
)

COMPOSITION_VOCAB = os.path.join(THIS_DIR, "vocabs", "composition_vocab.txt")
COMPOSITION_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "composition_vocab_rt.txt")
COMPOSITION_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/composition_vocab.txt?download=1"
)
)
COMPOSITION_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/composition_vocab_rt.txt?download=1"
)
)

CIF_VOCAB = os.path.join(THIS_DIR, "vocabs", "cif_vocab.json")
CIF_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "cif_vocab_rt.json")
CIF_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/cif_vocab.json?download=1"
)
)
CIF_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/cif_vocab_rt.json?download=1"
)
)

CRYSTAL_LLM_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab.json")
CRYSTAL_LLM_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab_rt.json")
CRYSTAL_LLM_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/crystal_llm_vocab.json?download=1"
)
)
CRYSTAL_LLM_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/crystal_llm_vocab_rt.json?download=1"
)
)

SMILES_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab.json")
SMILES_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab_rt.json")
SMILES_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/smiles_vocab.json?download=1"
)
)
SMILES_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/smiles_vocab_rt.json?download=1"
)
)

ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json")
ROBOCRYS_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/robocrys_vocab.json?download=1"
)
)


class NumTokenizer:
Expand Down Expand Up @@ -203,9 +246,11 @@ def convert_tokens_to_string(self, tokens):
if self.special_num_tokens:
return "".join(
[
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
(
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
)
for token in tokens
]
)
Expand Down Expand Up @@ -272,9 +317,11 @@ def save_vocabulary(self, save_directory, filename_prefix=None):

vocab_file = os.path.join(
save_directory,
f"{index + 1}-{filename_prefix}.json"
if filename_prefix
else f"{index + 1}.json",
(
f"{index + 1}-{filename_prefix}.json"
if filename_prefix
else f"{index + 1}.json"
),
)

with open(vocab_file, "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -335,17 +382,20 @@ def convert_tokens_to_string(self, tokens):
if self.special_num_tokens:
return " ".join(
[
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
(
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
)
for token in tokens
]
)
return " ".join(tokens).rstrip()

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = SLICE_ANALYSIS_DICT
return [
Expand Down Expand Up @@ -377,7 +427,8 @@ def __init__(

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = COMPOSITION_ANALYSIS_DICT
return [
Expand Down Expand Up @@ -409,7 +460,8 @@ def __init__(

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = CIF_ANALYSIS_DICT
return [
Expand Down Expand Up @@ -441,7 +493,8 @@ def __init__(

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = CRYSTAL_LLM_ANALYSIS_DICT
return [
Expand Down
3 changes: 3 additions & 0 deletions src/xtal2txt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pystow

xtal2txt_storage = pystow.module("xtal2txt")
1 change: 0 additions & 1 deletion src/xtal2txt/vocabs/1.json

This file was deleted.

185 changes: 0 additions & 185 deletions src/xtal2txt/vocabs/cif_vocab.json

This file was deleted.

Loading

0 comments on commit b823574

Please sign in to comment.