Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: local env representation #58

Merged
merged 6 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/xtal2txt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from collections import Counter
from pathlib import Path
from typing import List, Union, Tuple
from typing import List, Union, Tuple, Optional

from invcryrep.invcryrep import InvCryRep
from pymatgen.core import Structure
Expand All @@ -12,6 +12,7 @@
from robocrys import StructureCondenser, StructureDescriber

from xtal2txt.transforms import TransformationCallback
from xtal2txt.local_env import LocalEnvAnalyzer


class TextRep:
Expand Down Expand Up @@ -225,6 +226,26 @@ def get_composition(self, format="hill") -> str:
composition = composition_string.replace(" ", "")
return composition

def get_local_env_rep(self, local_env_kwargs: Optional[dict] = None) -> str:
"""
Get the local environment representation of the crystal structure.

The local environment representation is a string that contains
the space group symbol and the local environment of each atom in the unit cell.
The local environment of each atom is represented as SMILES string and the
Wyckoff symbol of the local environment.

Args:
local_env_kwargs (dict): Keyword arguments to pass to the LocalEnvAnalyzer.

Returns:
str: The local environment representation of the crystal structure.
"""
if not local_env_kwargs:
local_env_kwargs = {}
analyzer = LocalEnvAnalyzer(**local_env_kwargs)
return analyzer.structure_to_local_env_string(self.structure)

def get_crystal_llm_rep(
self,
permute_atoms: bool = False,
Expand Down Expand Up @@ -447,6 +468,7 @@ def get_all_text_reps(self, decimal_places: int = 2):
decimal_places=decimal_places,
),
"zmatrix": self._safe_call(self.get_zmatrix_rep),
"local_env": self._safe_call(self.get_local_env_rep, local_env_kwargs=None),
}

def get_requested_text_reps(
Expand Down Expand Up @@ -487,6 +509,8 @@ def get_requested_text_reps(
decimal_places=decimal_places,
),
"zmatrix": lambda: self._safe_call(self.get_zmatrix_rep, decimal_places=1),
"local_env": lambda: self._safe_call(self.get_local_env_rep,
local_env_kwargs=None),
}

return {rep: all_reps[rep]() for rep in requested_reps if rep in all_reps}
2 changes: 1 addition & 1 deletion src/xtal2txt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def wyckoff_matcher(
output_struct = DecodeTextRep(self.text).wyckoff_decoder(
self.text, lattice_params=True
)

return StructureMatcher(
ltol,
stol,
Expand Down
8 changes: 6 additions & 2 deletions src/xtal2txt/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(self, distance_cutoff: float = 1.4, angle_cutoff: float = 0.3):
self.distance_cutoff = distance_cutoff
self.angle_cutoff = angle_cutoff

def get_local_environments(self, structure: Structure) -> Tuple[List[dict], List[dict], str]:
def get_local_environments(
self, structure: Structure
) -> Tuple[List[dict], List[dict], str]:
"""Get the local environments of the atoms in a structure.

Args:
Expand All @@ -58,7 +60,9 @@ def get_local_environments(self, structure: Structure) -> Tuple[List[dict], List
sga = SpacegroupAnalyzer(structure)
symm_struct = sga.get_symmetrized_structure()

inequivalent_indices = [indices[0] for indices in symm_struct.equivalent_indices]
inequivalent_indices = [
indices[0] for indices in symm_struct.equivalent_indices
]
wyckoffs = symm_struct.wyckoff_symbols

# a Voronoi tessellation is used to determine the local environment of each atom
Expand Down
45 changes: 38 additions & 7 deletions src/xtal2txt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
CRYSTAL_LLM_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab.json")
CRYSTAL_LLM_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab_rt.json")

SMILES_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab.json")
SMILES_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab_rt.json")

ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json")

Expand All @@ -51,7 +53,9 @@ def num_matcher(self, text: str) -> str:
r"\d+(?:\.\d+)?" # Match any number, whether it is part of a string or not
)
matches = list(re.finditer(pattern, text))
for match in reversed(matches): #since we are replacing substring with a bigger subtring the string we are working on
for match in reversed(
matches
): # since we are replacing substring with a bigger subtring the string we are working on
start, end = match.start(), match.end()
tokens = self.tokenize(match.group())
replacement = "".join(tokens)
Expand Down Expand Up @@ -123,7 +127,10 @@ def __init__(
self,
special_num_token: bool = False,
vocab_file=None,
special_tokens={"cls_token": "[CLS]","sep_token": "[SEP]",},
special_tokens={
"cls_token": "[CLS]",
"sep_token": "[SEP]",
},
model_max_length=None,
padding_length=None,
**kwargs,
Expand All @@ -133,14 +140,13 @@ def __init__(
)
self.truncation = False
self.padding = False
self.padding_length = padding_length

self.padding_length = padding_length

self.special_num_tokens = special_num_token
self.vocab = self.load_vocab(vocab_file)
self.vocab_file = vocab_file

# Initialize special tokens
# Initialize special tokens
self.special_tokens = special_tokens if special_tokens is not None else {}
self.add_special_tokens(self.special_tokens)

Expand All @@ -156,7 +162,6 @@ def load_vocab(self, vocab_file):
else:
raise ValueError(f"Unsupported file type: {file_extension}")


def get_vocab(self):
return self.vocab

Expand All @@ -181,7 +186,9 @@ def tokenize(self, text):
matches = [self.cls_token] + matches

if self.truncation and len(matches) > self.model_max_length:
matches = matches[: self.model_max_length-1] # -1 since we add sep token later
matches = matches[
: self.model_max_length - 1
] # -1 since we add sep token later

if self.sep_token is not None:
matches += [self.sep_token]
Expand Down Expand Up @@ -443,6 +450,30 @@ def token_analysis(self, list_of_tokens):
]



class SmilesTokenizer(Xtal2txtTokenizer):
def __init__(
self,
special_num_token: bool = False,
vocab_file=CRYSTAL_LLM_VOCAB,
model_max_length=None,
padding_length=None,
**kwargs,
):
if special_num_token:
vocab_file = SMILES_RT_VOCAB
else:
vocab_file = SMILES_VOCAB
super(SmilesTokenizer, self).__init__(
special_num_token=special_num_token,
vocab_file=vocab_file,
model_max_length=model_max_length,
padding_length=padding_length,
**kwargs,
)



class RobocrysTokenizer:
"""Tokenizer for Robocrystallographer. Would be BPE tokenizer.
trained on the Robocrystallographer dataset.
Expand Down
1 change: 1 addition & 0 deletions src/xtal2txt/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pymatgen.core.structure import Structure
from typing import Union, List


def set_seed(seed: int):
"""
Set the random seed for both random and numpy.random.
Expand Down
2 changes: 1 addition & 1 deletion src/xtal2txt/vocabs/1.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"o o o": 0, "o o +": 1, "o o -": 2, "o + o": 3, "o + +": 4, "o + -": 5, "o - o": 6, "o - +": 7, "o - -": 8, "+ o o": 9, "+ o +": 10, "+ o -": 11, "+ + o": 12, "+ + +": 13, "+ + -": 14, "+ - o": 15, "+ - +": 16, "+ - -": 17, "- o o": 18, "- o +": 19, "- o -": 20, "- + o": 21, "- + +": 22, "- + -": 23, "- - o": 24, "- - +": 25, "- - -": 26, "H": 27, "He": 28, "Li": 29, "Be": 30, "B": 31, "C": 32, "N": 33, "O": 34, "F": 35, "Ne": 36, "Na": 37, "Mg": 38, "Al": 39, "Si": 40, "P": 41, "S": 42, "Cl": 43, "K": 44, "Ar": 45, "Ca": 46, "Sc": 47, "Ti": 48, "V": 49, "Cr": 50, "Mn": 51, "Fe": 52, "Ni": 53, "Co": 54, "Cu": 55, "Zn": 56, "Ga": 57, "Ge": 58, "As": 59, "Se": 60, "Br": 61, "Kr": 62, "Rb": 63, "Sr": 64, "Y": 65, "Zr": 66, "Nb": 67, "Mo": 68, "Tc": 69, "Ru": 70, "Rh": 71, "Pd": 72, "Ag": 73, "Cd": 74, "In": 75, "Sn": 76, "Sb": 77, "Te": 78, "I": 79, "Xe": 80, "Cs": 81, "Ba": 82, "La": 83, "Ce": 84, "Pr": 85, "Nd": 86, "Pm": 87, "Sm": 88, "Eu": 89, "Gd": 90, "Tb": 91, "Dy": 92, "Ho": 93, "Er": 94, "Tm": 95, "Yb": 96, "Lu": 97, "Hf": 98, "Ta": 99, "W": 100, "Re": 101, "Os": 102, "Ir": 103, "Pt": 104, "Au": 105, "Hg": 106, "Tl": 107, "Pb": 108, "Bi": 109, "Th": 110, "Pa": 111, "U": 112, "Np": 113, "Pu": 114, "Am": 115, "Cm": 116, "Bk": 117, "Cf": 118, "Es": 119, "Fm": 120, "Md": 121, "No": 122, "Lr": 123, "Rf": 124, "Db": 125, "Sg": 126, "Bh": 127, "Hs": 128, "Mt": 129, "Ds": 130, "Rg": 131, "Cn": 132, "Nh": 133, "Fl": 134, "Mc": 135, "Lv": 136, "Ts": 137, "Og": 138, "0": 139, "1": 140, "2": 141, "3": 142, "4": 143, "5": 144, "6": 145, "7": 146, "8": 147, "9": 148, "[CLS]": 149, "[SEP]": 150}
{"_._": 0, "_0_-0_": 1, "_0_-1_": 2, "_0_-2_": 3, "_0_-3_": 4, "_0_-4_": 5, "_0_-5_": 6, "_0_0_": 7, "_0_1_": 8, "_0_2_": 9, "_0_3_": 10, "_0_4_": 11, "_0_5_": 12, "_1_-0_": 13, "_1_-1_": 14, "_1_-2_": 15, "_1_-3_": 16, "_1_-4_": 17, "_1_-5_": 18, "_1_-6_": 19, "_1_0_": 20, "_1_1_": 21, "_1_2_": 22, "_1_3_": 23, "_1_4_": 24, "_1_5_": 25, "_2_-0_": 26, "_2_-1_": 27, "_2_-2_": 28, "_2_-3_": 29, "_2_-4_": 30, "_2_-5_": 31, "_2_-6_": 32, "_2_0_": 33, "_2_1_": 34, "_2_2_": 35, "_2_3_": 36, "_2_4_": 37, "_2_5_": 38, "_3_-0_": 39, "_3_-1_": 40, "_3_-2_": 41, "_3_-3_": 42, "_3_-4_": 43, "_3_-5_": 44, "_3_-6_": 45, "_3_0_": 46, "_3_1_": 47, "_3_2_": 48, "_3_3_": 49, "_3_4_": 50, "_3_5_": 51, "_4_-0_": 52, "_4_-1_": 53, "_4_-2_": 54, "_4_-3_": 55, "_4_-4_": 56, "_4_-5_": 57, "_4_-6_": 58, "_4_0_": 59, "_4_1_": 60, "_4_2_": 61, "_4_3_": 62, "_4_4_": 63, "_4_5_": 64, "_5_-0_": 65, "_5_-1_": 66, "_5_-2_": 67, "_5_-3_": 68, "_5_-4_": 69, "_5_-5_": 70, "_5_-6_": 71, "_5_0_": 72, "_5_1_": 73, "_5_2_": 74, "_5_3_": 75, "_5_4_": 76, "_5_5_": 77, "_6_-0_": 78, "_6_-1_": 79, "_6_-2_": 80, "_6_-3_": 81, "_6_-4_": 82, "_6_-5_": 83, "_6_-6_": 84, "_6_0_": 85, "_6_1_": 86, "_6_2_": 87, "_6_3_": 88, "_6_4_": 89, "_6_5_": 90, "_7_-0_": 91, "_7_-1_": 92, "_7_-2_": 93, "_7_-3_": 94, "_7_-4_": 95, "_7_-5_": 96, "_7_-6_": 97, "_7_0_": 98, "_7_1_": 99, "_7_2_": 100, "_7_3_": 101, "_7_4_": 102, "_7_5_": 103, "_8_-0_": 104, "_8_-1_": 105, "_8_-2_": 106, "_8_-3_": 107, "_8_-4_": 108, "_8_-5_": 109, "_8_-6_": 110, "_8_0_": 111, "_8_1_": 112, "_8_2_": 113, "_8_3_": 114, "_8_4_": 115, "_8_5_": 116, "_9_-0_": 117, "_9_-1_": 118, "_9_-2_": 119, "_9_-3_": 120, "_9_-4_": 121, "_9_-5_": 122, "_9_-6_": 123, "_9_0_": 124, "_9_1_": 125, "_9_2_": 126, "_9_3_": 127, "_9_4_": 128, "_9_5_": 129, "H": 130, "He": 131, "Li": 132, "Be": 133, "B": 134, "C": 135, "N": 136, "O": 137, "F": 138, "Ne": 139, "Na": 140, "Mg": 141, "Al": 142, "Si": 143, "P": 144, "S": 145, "Cl": 146, "K": 147, "Ar": 148, "Ca": 149, "Sc": 150, "Ti": 151, "V": 152, "Cr": 153, "Mn": 154, "Fe": 155, "Ni": 156, "Co": 157, "Cu": 158, "Zn": 159, "Ga": 160, "Ge": 161, "As": 162, "Se": 163, "Br": 164, "Kr": 165, "Rb": 166, "Sr": 167, "Y": 168, "Zr": 169, "Nb": 170, "Mo": 171, "Tc": 172, "Ru": 173, "Rh": 174, "Pd": 175, "Ag": 176, "Cd": 177, "In": 178, "Sn": 179, "Sb": 180, "Te": 181, "I": 182, "Xe": 183, "Cs": 184, "Ba": 185, "La": 186, "Ce": 187, "Pr": 188, "Nd": 189, "Pm": 190, "Sm": 191, "Eu": 192, "Gd": 193, "Tb": 194, "Dy": 195, "Ho": 196, "Er": 197, "Tm": 198, "Yb": 199, "Lu": 200, "Hf": 201, "Ta": 202, "W": 203, "Re": 204, "Os": 205, "Ir": 206, "Pt": 207, "Au": 208, "Hg": 209, "Tl": 210, "Pb": 211, "Bi": 212, "Th": 213, "Pa": 214, "U": 215, "Np": 216, "Pu": 217, "Am": 218, "Cm": 219, "Bk": 220, "Cf": 221, "Es": 222, "Fm": 223, "Md": 224, "No": 225, "Lr": 226, "Rf": 227, "Db": 228, "Sg": 229, "Bh": 230, "Hs": 231, "Mt": 232, "Ds": 233, "Rg": 234, "Cn": 235, "Nh": 236, "Fl": 237, "Mc": 238, "Lv": 239, "Ts": 240, "Og": 241, "+": 242, "-": 243, "/": 244, "\n": 245, "a": 246, "n": 247, "c": 248, "b": 249, "m": 250, "d": 251, "R": 252, "A": 253, "(": 254, ")": 255, "[": 256, "]": 257, "*": 258, ".": 259, " ": 260, "[CLS]": 261, "[SEP]": 262}
143 changes: 143 additions & 0 deletions src/xtal2txt/vocabs/smiles_vocab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
{
"H": 0,
"He": 1,
"Li": 2,
"Be": 3,
"B": 4,
"C": 5,
"N": 6,
"O": 7,
"F": 8,
"Ne": 9,
"Na": 10,
"Mg": 11,
"Al": 12,
"Si": 13,
"P": 14,
"S": 15,
"Cl": 16,
"K": 17,
"Ar": 18,
"Ca": 19,
"Sc": 20,
"Ti": 21,
"V": 22,
"Cr": 23,
"Mn": 24,
"Fe": 25,
"Ni": 26,
"Co": 27,
"Cu": 28,
"Zn": 29,
"Ga": 30,
"Ge": 31,
"As": 32,
"Se": 33,
"Br": 34,
"Kr": 35,
"Rb": 36,
"Sr": 37,
"Y": 38,
"Zr": 39,
"Nb": 40,
"Mo": 41,
"Tc": 42,
"Ru": 43,
"Rh": 44,
"Pd": 45,
"Ag": 46,
"Cd": 47,
"In": 48,
"Sn": 49,
"Sb": 50,
"Te": 51,
"I": 52,
"Xe": 53,
"Cs": 54,
"Ba": 55,
"La": 56,
"Ce": 57,
"Pr": 58,
"Nd": 59,
"Pm": 60,
"Sm": 61,
"Eu": 62,
"Gd": 63,
"Tb": 64,
"Dy": 65,
"Ho": 66,
"Er": 67,
"Tm": 68,
"Yb": 69,
"Lu": 70,
"Hf": 71,
"Ta": 72,
"W": 73,
"Re": 74,
"Os": 75,
"Ir": 76,
"Pt": 77,
"Au": 78,
"Hg": 79,
"Tl": 80,
"Pb": 81,
"Bi": 82,
"Th": 83,
"Pa": 84,
"U": 85,
"Np": 86,
"Pu": 87,
"Am": 88,
"Cm": 89,
"Bk": 90,
"Cf": 91,
"Es": 92,
"Fm": 93,
"Md": 94,
"No": 95,
"Lr": 96,
"Rf": 97,
"Db": 98,
"Sg": 99,
"Bh": 100,
"Hs": 101,
"Mt": 102,
"Ds": 103,
"Rg": 104,
"Cn": 105,
"Nh": 106,
"Fl": 107,
"Mc": 108,
"Lv": 109,
"Ts": 110,
"Og": 111,
"0": 112,
"1": 113,
"2": 114,
"3": 115,
"4": 116,
"5": 117,
"6": 118,
"7": 119,
"8": 120,
"9": 121,
" ": 122,
"+": 123,
"-": 124,
"/": 125,
"\n": 126,
"a": 127,
"n": 128,
"c": 129,
"b": 130,
"m": 131,
"d": 132,
"R": 133,
"A": 134,
"(": 135,
")": 136,
"[": 137,
"]": 138,
"*": 139,
".": 140
}
Loading
Loading