From 5a22e54474777ef35bbe64fa7ec2379e67b7db88 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Sat, 27 Apr 2024 01:17:34 +0200 Subject: [PATCH 1/4] feat: add slice tokenizer --- src/xtal2txt/tokenizer.py | 26 +++ src/xtal2txt/vocabs/1.json | 2 +- src/xtal2txt/vocabs/smiles_vocab.json | 143 ++++++++++++ src/xtal2txt/vocabs/smiles_vocab.txt | 138 ++++++++++++ src/xtal2txt/vocabs/smiles_vocab_rt.json | 263 +++++++++++++++++++++++ src/xtal2txt/vocabs/smiles_vocab_rt.txt | 258 ++++++++++++++++++++++ tests/tokenizer/test_smiles_tokenizer.py | 18 ++ 7 files changed, 847 insertions(+), 1 deletion(-) create mode 100644 src/xtal2txt/vocabs/smiles_vocab.json create mode 100644 src/xtal2txt/vocabs/smiles_vocab.txt create mode 100644 src/xtal2txt/vocabs/smiles_vocab_rt.json create mode 100644 src/xtal2txt/vocabs/smiles_vocab_rt.txt create mode 100644 tests/tokenizer/test_smiles_tokenizer.py diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index c562324..bd37a44 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -30,6 +30,8 @@ CRYSTAL_LLM_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab.json") CRYSTAL_LLM_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab_rt.json") +SMILES_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab.json") +SMILES_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab_rt.json") ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json") @@ -443,6 +445,30 @@ def token_analysis(self, list_of_tokens): ] + +class SmilesTokenizer(Xtal2txtTokenizer): + def __init__( + self, + special_num_token: bool = False, + vocab_file=CRYSTAL_LLM_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + if special_num_token: + vocab_file = SMILES_RT_VOCAB + else: + vocab_file = SMILES_VOCAB + super(SmilesTokenizer, self).__init__( + special_num_token=special_num_token, + vocab_file=vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) + + + class RobocrysTokenizer: """Tokenizer for Robocrystallographer. Would be BPE tokenizer. trained on the Robocrystallographer dataset. diff --git a/src/xtal2txt/vocabs/1.json b/src/xtal2txt/vocabs/1.json index 8869489..6437e4a 100644 --- a/src/xtal2txt/vocabs/1.json +++ b/src/xtal2txt/vocabs/1.json @@ -1 +1 @@ -{"o o o": 0, "o o +": 1, "o o -": 2, "o + o": 3, "o + +": 4, "o + -": 5, "o - o": 6, "o - +": 7, "o - -": 8, "+ o o": 9, "+ o +": 10, "+ o -": 11, "+ + o": 12, "+ + +": 13, "+ + -": 14, "+ - o": 15, "+ - +": 16, "+ - -": 17, "- o o": 18, "- o +": 19, "- o -": 20, "- + o": 21, "- + +": 22, "- + -": 23, "- - o": 24, "- - +": 25, "- - -": 26, "H": 27, "He": 28, "Li": 29, "Be": 30, "B": 31, "C": 32, "N": 33, "O": 34, "F": 35, "Ne": 36, "Na": 37, "Mg": 38, "Al": 39, "Si": 40, "P": 41, "S": 42, "Cl": 43, "K": 44, "Ar": 45, "Ca": 46, "Sc": 47, "Ti": 48, "V": 49, "Cr": 50, "Mn": 51, "Fe": 52, "Ni": 53, "Co": 54, "Cu": 55, "Zn": 56, "Ga": 57, "Ge": 58, "As": 59, "Se": 60, "Br": 61, "Kr": 62, "Rb": 63, "Sr": 64, "Y": 65, "Zr": 66, "Nb": 67, "Mo": 68, "Tc": 69, "Ru": 70, "Rh": 71, "Pd": 72, "Ag": 73, "Cd": 74, "In": 75, "Sn": 76, "Sb": 77, "Te": 78, "I": 79, "Xe": 80, "Cs": 81, "Ba": 82, "La": 83, "Ce": 84, "Pr": 85, "Nd": 86, "Pm": 87, "Sm": 88, "Eu": 89, "Gd": 90, "Tb": 91, "Dy": 92, "Ho": 93, "Er": 94, "Tm": 95, "Yb": 96, "Lu": 97, "Hf": 98, "Ta": 99, "W": 100, "Re": 101, "Os": 102, "Ir": 103, "Pt": 104, "Au": 105, "Hg": 106, "Tl": 107, "Pb": 108, "Bi": 109, "Th": 110, "Pa": 111, "U": 112, "Np": 113, "Pu": 114, "Am": 115, "Cm": 116, "Bk": 117, "Cf": 118, "Es": 119, "Fm": 120, "Md": 121, "No": 122, "Lr": 123, "Rf": 124, "Db": 125, "Sg": 126, "Bh": 127, "Hs": 128, "Mt": 129, "Ds": 130, "Rg": 131, "Cn": 132, "Nh": 133, "Fl": 134, "Mc": 135, "Lv": 136, "Ts": 137, "Og": 138, "0": 139, "1": 140, "2": 141, "3": 142, "4": 143, "5": 144, "6": 145, "7": 146, "8": 147, "9": 148, "[CLS]": 149, "[SEP]": 150} \ No newline at end of file +{"_._": 0, "_0_-0_": 1, "_0_-1_": 2, "_0_-2_": 3, "_0_-3_": 4, "_0_-4_": 5, "_0_-5_": 6, "_0_0_": 7, "_0_1_": 8, "_0_2_": 9, "_0_3_": 10, "_0_4_": 11, "_0_5_": 12, "_1_-0_": 13, "_1_-1_": 14, "_1_-2_": 15, "_1_-3_": 16, "_1_-4_": 17, "_1_-5_": 18, "_1_-6_": 19, "_1_0_": 20, "_1_1_": 21, "_1_2_": 22, "_1_3_": 23, "_1_4_": 24, "_1_5_": 25, "_2_-0_": 26, "_2_-1_": 27, "_2_-2_": 28, "_2_-3_": 29, "_2_-4_": 30, "_2_-5_": 31, "_2_-6_": 32, "_2_0_": 33, "_2_1_": 34, "_2_2_": 35, "_2_3_": 36, "_2_4_": 37, "_2_5_": 38, "_3_-0_": 39, "_3_-1_": 40, "_3_-2_": 41, "_3_-3_": 42, "_3_-4_": 43, "_3_-5_": 44, "_3_-6_": 45, "_3_0_": 46, "_3_1_": 47, "_3_2_": 48, "_3_3_": 49, "_3_4_": 50, "_3_5_": 51, "_4_-0_": 52, "_4_-1_": 53, "_4_-2_": 54, "_4_-3_": 55, "_4_-4_": 56, "_4_-5_": 57, "_4_-6_": 58, "_4_0_": 59, "_4_1_": 60, "_4_2_": 61, "_4_3_": 62, "_4_4_": 63, "_4_5_": 64, "_5_-0_": 65, "_5_-1_": 66, "_5_-2_": 67, "_5_-3_": 68, "_5_-4_": 69, "_5_-5_": 70, "_5_-6_": 71, "_5_0_": 72, "_5_1_": 73, "_5_2_": 74, "_5_3_": 75, "_5_4_": 76, "_5_5_": 77, "_6_-0_": 78, "_6_-1_": 79, "_6_-2_": 80, "_6_-3_": 81, "_6_-4_": 82, "_6_-5_": 83, "_6_-6_": 84, "_6_0_": 85, "_6_1_": 86, "_6_2_": 87, "_6_3_": 88, "_6_4_": 89, "_6_5_": 90, "_7_-0_": 91, "_7_-1_": 92, "_7_-2_": 93, "_7_-3_": 94, "_7_-4_": 95, "_7_-5_": 96, "_7_-6_": 97, "_7_0_": 98, "_7_1_": 99, "_7_2_": 100, "_7_3_": 101, "_7_4_": 102, "_7_5_": 103, "_8_-0_": 104, "_8_-1_": 105, "_8_-2_": 106, "_8_-3_": 107, "_8_-4_": 108, "_8_-5_": 109, "_8_-6_": 110, "_8_0_": 111, "_8_1_": 112, "_8_2_": 113, "_8_3_": 114, "_8_4_": 115, "_8_5_": 116, "_9_-0_": 117, "_9_-1_": 118, "_9_-2_": 119, "_9_-3_": 120, "_9_-4_": 121, "_9_-5_": 122, "_9_-6_": 123, "_9_0_": 124, "_9_1_": 125, "_9_2_": 126, "_9_3_": 127, "_9_4_": 128, "_9_5_": 129, "H": 130, "He": 131, "Li": 132, "Be": 133, "B": 134, "C": 135, "N": 136, "O": 137, "F": 138, "Ne": 139, "Na": 140, "Mg": 141, "Al": 142, "Si": 143, "P": 144, "S": 145, "Cl": 146, "K": 147, "Ar": 148, "Ca": 149, "Sc": 150, "Ti": 151, "V": 152, "Cr": 153, "Mn": 154, "Fe": 155, "Ni": 156, "Co": 157, "Cu": 158, "Zn": 159, "Ga": 160, "Ge": 161, "As": 162, "Se": 163, "Br": 164, "Kr": 165, "Rb": 166, "Sr": 167, "Y": 168, "Zr": 169, "Nb": 170, "Mo": 171, "Tc": 172, "Ru": 173, "Rh": 174, "Pd": 175, "Ag": 176, "Cd": 177, "In": 178, "Sn": 179, "Sb": 180, "Te": 181, "I": 182, "Xe": 183, "Cs": 184, "Ba": 185, "La": 186, "Ce": 187, "Pr": 188, "Nd": 189, "Pm": 190, "Sm": 191, "Eu": 192, "Gd": 193, "Tb": 194, "Dy": 195, "Ho": 196, "Er": 197, "Tm": 198, "Yb": 199, "Lu": 200, "Hf": 201, "Ta": 202, "W": 203, "Re": 204, "Os": 205, "Ir": 206, "Pt": 207, "Au": 208, "Hg": 209, "Tl": 210, "Pb": 211, "Bi": 212, "Th": 213, "Pa": 214, "U": 215, "Np": 216, "Pu": 217, "Am": 218, "Cm": 219, "Bk": 220, "Cf": 221, "Es": 222, "Fm": 223, "Md": 224, "No": 225, "Lr": 226, "Rf": 227, "Db": 228, "Sg": 229, "Bh": 230, "Hs": 231, "Mt": 232, "Ds": 233, "Rg": 234, "Cn": 235, "Nh": 236, "Fl": 237, "Mc": 238, "Lv": 239, "Ts": 240, "Og": 241, "+": 242, "-": 243, "/": 244, "\n": 245, "a": 246, "n": 247, "c": 248, "b": 249, "m": 250, "d": 251, "R": 252, "A": 253, "(": 254, ")": 255, "[": 256, "]": 257, "*": 258, ".": 259, " ": 260, "[CLS]": 261, "[SEP]": 262} \ No newline at end of file diff --git a/src/xtal2txt/vocabs/smiles_vocab.json b/src/xtal2txt/vocabs/smiles_vocab.json new file mode 100644 index 0000000..e5f5ddf --- /dev/null +++ b/src/xtal2txt/vocabs/smiles_vocab.json @@ -0,0 +1,143 @@ +{ + "H": 0, + "He": 1, + "Li": 2, + "Be": 3, + "B": 4, + "C": 5, + "N": 6, + "O": 7, + "F": 8, + "Ne": 9, + "Na": 10, + "Mg": 11, + "Al": 12, + "Si": 13, + "P": 14, + "S": 15, + "Cl": 16, + "K": 17, + "Ar": 18, + "Ca": 19, + "Sc": 20, + "Ti": 21, + "V": 22, + "Cr": 23, + "Mn": 24, + "Fe": 25, + "Ni": 26, + "Co": 27, + "Cu": 28, + "Zn": 29, + "Ga": 30, + "Ge": 31, + "As": 32, + "Se": 33, + "Br": 34, + "Kr": 35, + "Rb": 36, + "Sr": 37, + "Y": 38, + "Zr": 39, + "Nb": 40, + "Mo": 41, + "Tc": 42, + "Ru": 43, + "Rh": 44, + "Pd": 45, + "Ag": 46, + "Cd": 47, + "In": 48, + "Sn": 49, + "Sb": 50, + "Te": 51, + "I": 52, + "Xe": 53, + "Cs": 54, + "Ba": 55, + "La": 56, + "Ce": 57, + "Pr": 58, + "Nd": 59, + "Pm": 60, + "Sm": 61, + "Eu": 62, + "Gd": 63, + "Tb": 64, + "Dy": 65, + "Ho": 66, + "Er": 67, + "Tm": 68, + "Yb": 69, + "Lu": 70, + "Hf": 71, + "Ta": 72, + "W": 73, + "Re": 74, + "Os": 75, + "Ir": 76, + "Pt": 77, + "Au": 78, + "Hg": 79, + "Tl": 80, + "Pb": 81, + "Bi": 82, + "Th": 83, + "Pa": 84, + "U": 85, + "Np": 86, + "Pu": 87, + "Am": 88, + "Cm": 89, + "Bk": 90, + "Cf": 91, + "Es": 92, + "Fm": 93, + "Md": 94, + "No": 95, + "Lr": 96, + "Rf": 97, + "Db": 98, + "Sg": 99, + "Bh": 100, + "Hs": 101, + "Mt": 102, + "Ds": 103, + "Rg": 104, + "Cn": 105, + "Nh": 106, + "Fl": 107, + "Mc": 108, + "Lv": 109, + "Ts": 110, + "Og": 111, + "0": 112, + "1": 113, + "2": 114, + "3": 115, + "4": 116, + "5": 117, + "6": 118, + "7": 119, + "8": 120, + "9": 121, + " ": 122, + "+": 123, + "-": 124, + "/": 125, + "\n": 126, + "a": 127, + "n": 128, + "c": 129, + "b": 130, + "m": 131, + "d": 132, + "R": 133, + "A": 134, + "(": 135, + ")": 136, + "[": 137, + "]": 138, + "*": 139, + ".": 140 +} diff --git a/src/xtal2txt/vocabs/smiles_vocab.txt b/src/xtal2txt/vocabs/smiles_vocab.txt new file mode 100644 index 0000000..203c2a1 --- /dev/null +++ b/src/xtal2txt/vocabs/smiles_vocab.txt @@ -0,0 +1,138 @@ +H +He +Li +Be +B +C +N +O +F +Ne +Na +Mg +Al +Si +P +S +Cl +K +Ar +Ca +Sc +Ti +V +Cr +Mn +Fe +Ni +Co +Cu +Zn +Ga +Ge +As +Se +Br +Kr +Rb +Sr +Y +Zr +Nb +Mo +Tc +Ru +Rh +Pd +Ag +Cd +In +Sn +Sb +Te +I +Xe +Cs +Ba +La +Ce +Pr +Nd +Pm +Sm +Eu +Gd +Tb +Dy +Ho +Er +Tm +Yb +Lu +Hf +Ta +W +Re +Os +Ir +Pt +Au +Hg +Tl +Pb +Bi +Th +Pa +U +Np +Pu +Am +Cm +Bk +Cf +Es +Fm +Md +No +Lr +Rf +Db +Sg +Bh +Hs +Mt +Ds +Rg +Cn +Nh +Fl +Mc +Lv +Ts +Og +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +n +c +b +m +d +R +A +( +) +[ +] +. ++ +- +\n \ No newline at end of file diff --git a/src/xtal2txt/vocabs/smiles_vocab_rt.json b/src/xtal2txt/vocabs/smiles_vocab_rt.json new file mode 100644 index 0000000..31f7063 --- /dev/null +++ b/src/xtal2txt/vocabs/smiles_vocab_rt.json @@ -0,0 +1,263 @@ +{ + "_._": 0, + "_0_-0_": 1, + "_0_-1_": 2, + "_0_-2_": 3, + "_0_-3_": 4, + "_0_-4_": 5, + "_0_-5_": 6, + "_0_0_": 7, + "_0_1_": 8, + "_0_2_": 9, + "_0_3_": 10, + "_0_4_": 11, + "_0_5_": 12, + "_1_-0_": 13, + "_1_-1_": 14, + "_1_-2_": 15, + "_1_-3_": 16, + "_1_-4_": 17, + "_1_-5_": 18, + "_1_-6_": 19, + "_1_0_": 20, + "_1_1_": 21, + "_1_2_": 22, + "_1_3_": 23, + "_1_4_": 24, + "_1_5_": 25, + "_2_-0_": 26, + "_2_-1_": 27, + "_2_-2_": 28, + "_2_-3_": 29, + "_2_-4_": 30, + "_2_-5_": 31, + "_2_-6_": 32, + "_2_0_": 33, + "_2_1_": 34, + "_2_2_": 35, + "_2_3_": 36, + "_2_4_": 37, + "_2_5_": 38, + "_3_-0_": 39, + "_3_-1_": 40, + "_3_-2_": 41, + "_3_-3_": 42, + "_3_-4_": 43, + "_3_-5_": 44, + "_3_-6_": 45, + "_3_0_": 46, + "_3_1_": 47, + "_3_2_": 48, + "_3_3_": 49, + "_3_4_": 50, + "_3_5_": 51, + "_4_-0_": 52, + "_4_-1_": 53, + "_4_-2_": 54, + "_4_-3_": 55, + "_4_-4_": 56, + "_4_-5_": 57, + "_4_-6_": 58, + "_4_0_": 59, + "_4_1_": 60, + "_4_2_": 61, + "_4_3_": 62, + "_4_4_": 63, + "_4_5_": 64, + "_5_-0_": 65, + "_5_-1_": 66, + "_5_-2_": 67, + "_5_-3_": 68, + "_5_-4_": 69, + "_5_-5_": 70, + "_5_-6_": 71, + "_5_0_": 72, + "_5_1_": 73, + "_5_2_": 74, + "_5_3_": 75, + "_5_4_": 76, + "_5_5_": 77, + "_6_-0_": 78, + "_6_-1_": 79, + "_6_-2_": 80, + "_6_-3_": 81, + "_6_-4_": 82, + "_6_-5_": 83, + "_6_-6_": 84, + "_6_0_": 85, + "_6_1_": 86, + "_6_2_": 87, + "_6_3_": 88, + "_6_4_": 89, + "_6_5_": 90, + "_7_-0_": 91, + "_7_-1_": 92, + "_7_-2_": 93, + "_7_-3_": 94, + "_7_-4_": 95, + "_7_-5_": 96, + "_7_-6_": 97, + "_7_0_": 98, + "_7_1_": 99, + "_7_2_": 100, + "_7_3_": 101, + "_7_4_": 102, + "_7_5_": 103, + "_8_-0_": 104, + "_8_-1_": 105, + "_8_-2_": 106, + "_8_-3_": 107, + "_8_-4_": 108, + "_8_-5_": 109, + "_8_-6_": 110, + "_8_0_": 111, + "_8_1_": 112, + "_8_2_": 113, + "_8_3_": 114, + "_8_4_": 115, + "_8_5_": 116, + "_9_-0_": 117, + "_9_-1_": 118, + "_9_-2_": 119, + "_9_-3_": 120, + "_9_-4_": 121, + "_9_-5_": 122, + "_9_-6_": 123, + "_9_0_": 124, + "_9_1_": 125, + "_9_2_": 126, + "_9_3_": 127, + "_9_4_": 128, + "_9_5_": 129, + "H": 130, + "He": 131, + "Li": 132, + "Be": 133, + "B": 134, + "C": 135, + "N": 136, + "O": 137, + "F": 138, + "Ne": 139, + "Na": 140, + "Mg": 141, + "Al": 142, + "Si": 143, + "P": 144, + "S": 145, + "Cl": 146, + "K": 147, + "Ar": 148, + "Ca": 149, + "Sc": 150, + "Ti": 151, + "V": 152, + "Cr": 153, + "Mn": 154, + "Fe": 155, + "Ni": 156, + "Co": 157, + "Cu": 158, + "Zn": 159, + "Ga": 160, + "Ge": 161, + "As": 162, + "Se": 163, + "Br": 164, + "Kr": 165, + "Rb": 166, + "Sr": 167, + "Y": 168, + "Zr": 169, + "Nb": 170, + "Mo": 171, + "Tc": 172, + "Ru": 173, + "Rh": 174, + "Pd": 175, + "Ag": 176, + "Cd": 177, + "In": 178, + "Sn": 179, + "Sb": 180, + "Te": 181, + "I": 182, + "Xe": 183, + "Cs": 184, + "Ba": 185, + "La": 186, + "Ce": 187, + "Pr": 188, + "Nd": 189, + "Pm": 190, + "Sm": 191, + "Eu": 192, + "Gd": 193, + "Tb": 194, + "Dy": 195, + "Ho": 196, + "Er": 197, + "Tm": 198, + "Yb": 199, + "Lu": 200, + "Hf": 201, + "Ta": 202, + "W": 203, + "Re": 204, + "Os": 205, + "Ir": 206, + "Pt": 207, + "Au": 208, + "Hg": 209, + "Tl": 210, + "Pb": 211, + "Bi": 212, + "Th": 213, + "Pa": 214, + "U": 215, + "Np": 216, + "Pu": 217, + "Am": 218, + "Cm": 219, + "Bk": 220, + "Cf": 221, + "Es": 222, + "Fm": 223, + "Md": 224, + "No": 225, + "Lr": 226, + "Rf": 227, + "Db": 228, + "Sg": 229, + "Bh": 230, + "Hs": 231, + "Mt": 232, + "Ds": 233, + "Rg": 234, + "Cn": 235, + "Nh": 236, + "Fl": 237, + "Mc": 238, + "Lv": 239, + "Ts": 240, + "Og": 241, + "+": 242, + "-": 243, + "/": 244, + "\n": 245, + "a": 246, + "n": 247, + "c": 248, + "b": 249, + "m": 250, + "d": 251, + "R": 252, + "A": 253, + "(": 254, + ")": 255, + "[": 256, + "]": 257, + "*": 258, + ".": 259, + " ":260 +} diff --git a/src/xtal2txt/vocabs/smiles_vocab_rt.txt b/src/xtal2txt/vocabs/smiles_vocab_rt.txt new file mode 100644 index 0000000..370eeeb --- /dev/null +++ b/src/xtal2txt/vocabs/smiles_vocab_rt.txt @@ -0,0 +1,258 @@ +H +He +Li +Be +B +C +N +O +F +Ne +Na +Mg +Al +Si +P +S +Cl +K +Ar +Ca +Sc +Ti +V +Cr +Mn +Fe +Ni +Co +Cu +Zn +Ga +Ge +As +Se +Br +Kr +Rb +Sr +Y +Zr +Nb +Mo +Tc +Ru +Rh +Pd +Ag +Cd +In +Sn +Sb +Te +I +Xe +Cs +Ba +La +Ce +Pr +Nd +Pm +Sm +Eu +Gd +Tb +Dy +Ho +Er +Tm +Yb +Lu +Hf +Ta +W +Re +Os +Ir +Pt +Au +Hg +Tl +Pb +Bi +Th +Pa +U +Np +Pu +Am +Cm +Bk +Cf +Es +Fm +Md +No +Lr +Rf +Db +Sg +Bh +Hs +Mt +Ds +Rg +Cn +Nh +Fl +Mc +Lv +Ts +Og +_._ +_0_-0_ +_0_-1_ +_0_-2_ +_0_-3_ +_0_-4_ +_0_-5_ +_0_0_ +_0_1_ +_0_2_ +_0_3_ +_0_4_ +_0_5_ +_1_-0_ +_1_-1_ +_1_-2_ +_1_-3_ +_1_-4_ +_1_-5_ +_1_-6_ +_1_0_ +_1_1_ +_1_2_ +_1_3_ +_1_4_ +_1_5_ +_2_-0_ +_2_-1_ +_2_-2_ +_2_-3_ +_2_-4_ +_2_-5_ +_2_-6_ +_2_0_ +_2_1_ +_2_2_ +_2_3_ +_2_4_ +_2_5_ +_3_-0_ +_3_-1_ +_3_-2_ +_3_-3_ +_3_-4_ +_3_-5_ +_3_-6_ +_3_0_ +_3_1_ +_3_2_ +_3_3_ +_3_4_ +_3_5_ +_4_-0_ +_4_-1_ +_4_-2_ +_4_-3_ +_4_-4_ +_4_-5_ +_4_-6_ +_4_0_ +_4_1_ +_4_2_ +_4_3_ +_4_4_ +_4_5_ +_5_-0_ +_5_-1_ +_5_-2_ +_5_-3_ +_5_-4_ +_5_-5_ +_5_-6_ +_5_0_ +_5_1_ +_5_2_ +_5_3_ +_5_4_ +_5_5_ +_6_-0_ +_6_-1_ +_6_-2_ +_6_-3_ +_6_-4_ +_6_-5_ +_6_-6_ +_6_0_ +_6_1_ +_6_2_ +_6_3_ +_6_4_ +_6_5_ +_7_-0_ +_7_-1_ +_7_-2_ +_7_-3_ +_7_-4_ +_7_-5_ +_7_-6_ +_7_0_ +_7_1_ +_7_2_ +_7_3_ +_7_4_ +_7_5_ +_8_-0_ +_8_-1_ +_8_-2_ +_8_-3_ +_8_-4_ +_8_-5_ +_8_-6_ +_8_0_ +_8_1_ +_8_2_ +_8_3_ +_8_4_ +_8_5_ +_9_-0_ +_9_-1_ +_9_-2_ +_9_-3_ +_9_-4_ +_9_-5_ +_9_-6_ +_9_0_ +_9_1_ +_9_2_ +_9_3_ +_9_4_ +_9_5_ +a +n +c +b +m +d +R +A +( +) +[ +] +. ++ +- +\n \ No newline at end of file diff --git a/tests/tokenizer/test_smiles_tokenizer.py b/tests/tokenizer/test_smiles_tokenizer.py new file mode 100644 index 0000000..57144f5 --- /dev/null +++ b/tests/tokenizer/test_smiles_tokenizer.py @@ -0,0 +1,18 @@ +import pytest +from xtal2txt.tokenizer import SmilesTokenizer +import os + + +@pytest.fixture +def smiles_rt_tokenizer(scope="module"): + return SmilesTokenizer( + special_num_token=True, model_max_length=512, truncation=False, padding=False + ) + + + +def test_composition_rt_tokens(smiles_rt_tokenizer) -> None: + excepted_output = ['[CLS]', 'I', '-', '_4_1_', '_2_0_', 'd', '\n', 'S', '_2_0_', '-', ' ', '(', '_8_0_', 'd', ')', ' ', '[', 'Cu', ']', 'S', '(', '[', 'In', ']', ')', '(', '[', 'In', ']', ')', '[', 'Cu', ']', '\n', 'Cu', '+', ' ', '(', '_4_0_', 'a', ')', ' ','[', 'S', ']', '[', 'Cu', ']', '(', '[', 'S', ']', ')', '(', '[', 'S', ']', ')', '[', 'S', ']', '\n', 'In', '_3_0_', '+', ' ','(', '_4_0_', 'b', ')', ' ','[', 'S', ']', '[', 'In', ']', '(', '[', 'S', ']', ')', '[', 'S', ']', '.', '[', 'S', ']', '[SEP]'] + input_string = "I-42d\nS2- (8d) [Cu]S([In])([In])[Cu]\nCu+ (4a) [S][Cu]([S])([S])[S]\nIn3+ (4b) [S][In]([S])[S].[S]" + tokens = smiles_rt_tokenizer.tokenize(input_string) + assert tokens == excepted_output From d4516440056d4ab2a4a65b419f0ee2d596258dfa Mon Sep 17 00:00:00 2001 From: Nawaf <86834161+n0w0f@users.noreply.github.com> Date: Sat, 27 Apr 2024 12:58:48 +0200 Subject: [PATCH 2/4] Update tests/tokenizer/test_smiles_tokenizer.py Co-authored-by: Kevin M Jablonka <32935233+kjappelbaum@users.noreply.github.com> --- tests/tokenizer/test_smiles_tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tokenizer/test_smiles_tokenizer.py b/tests/tokenizer/test_smiles_tokenizer.py index 57144f5..8f4b592 100644 --- a/tests/tokenizer/test_smiles_tokenizer.py +++ b/tests/tokenizer/test_smiles_tokenizer.py @@ -11,7 +11,7 @@ def smiles_rt_tokenizer(scope="module"): -def test_composition_rt_tokens(smiles_rt_tokenizer) -> None: +def test_smiles_rt_tokens(smiles_rt_tokenizer) -> None: excepted_output = ['[CLS]', 'I', '-', '_4_1_', '_2_0_', 'd', '\n', 'S', '_2_0_', '-', ' ', '(', '_8_0_', 'd', ')', ' ', '[', 'Cu', ']', 'S', '(', '[', 'In', ']', ')', '(', '[', 'In', ']', ')', '[', 'Cu', ']', '\n', 'Cu', '+', ' ', '(', '_4_0_', 'a', ')', ' ','[', 'S', ']', '[', 'Cu', ']', '(', '[', 'S', ']', ')', '(', '[', 'S', ']', ')', '[', 'S', ']', '\n', 'In', '_3_0_', '+', ' ','(', '_4_0_', 'b', ')', ' ','[', 'S', ']', '[', 'In', ']', '(', '[', 'S', ']', ')', '[', 'S', ']', '.', '[', 'S', ']', '[SEP]'] input_string = "I-42d\nS2- (8d) [Cu]S([In])([In])[Cu]\nCu+ (4a) [S][Cu]([S])([S])[S]\nIn3+ (4b) [S][In]([S])[S].[S]" tokens = smiles_rt_tokenizer.tokenize(input_string) From ec55b772e616dab538c90ca163988fd11b745381 Mon Sep 17 00:00:00 2001 From: Kevin Maik Jablonka Date: Sun, 28 Apr 2024 11:54:06 +0200 Subject: [PATCH 3/4] feat: local env representation --- src/xtal2txt/core.py | 23 +++++- src/xtal2txt/decoder.py | 2 +- src/xtal2txt/local_env.py | 8 +- src/xtal2txt/tokenizer.py | 19 +++-- src/xtal2txt/transforms.py | 1 + tests/test_textrep.py | 8 ++ tests/tokenizer/test_cif_tokenizer.py | 2 +- tests/tokenizer/test_composition_tokenizer.py | 25 ++++--- tests/tokenizer/test_crystal_llm_tokenizer.py | 2 +- tests/tokenizer/test_rt_tokenier.py | 34 ++++++--- tests/tokenizer/test_slice_tokenizer.py | 74 ++++++++++++++++--- 11 files changed, 153 insertions(+), 45 deletions(-) diff --git a/src/xtal2txt/core.py b/src/xtal2txt/core.py index 4bcf0ae..3128297 100644 --- a/src/xtal2txt/core.py +++ b/src/xtal2txt/core.py @@ -2,7 +2,7 @@ import re from collections import Counter from pathlib import Path -from typing import List, Union, Tuple +from typing import List, Union, Tuple, Optional from invcryrep.invcryrep import InvCryRep from pymatgen.core import Structure @@ -12,6 +12,7 @@ from robocrys import StructureCondenser, StructureDescriber from xtal2txt.transforms import TransformationCallback +from xtal2txt.local_env import LocalEnvAnalyzer class TextRep: @@ -225,6 +226,26 @@ def get_composition(self, format="hill") -> str: composition = composition_string.replace(" ", "") return composition + def get_local_env_rep(self, local_env_kwargs: Optional[dict] = None) -> str: + """ + Get the local environment representation of the crystal structure. + + The local environment representation is a string that contains + the space group symbol and the local environment of each atom in the unit cell. + The local environment of each atom is represented as SMILES string and the + Wyckoff symbol of the local environment. + + Args: + local_env_kwargs (dict): Keyword arguments to pass to the LocalEnvAnalyzer. + + Returns: + str: The local environment representation of the crystal structure. + """ + if not local_env_kwargs: + local_env_kwargs = {} + analyzer = LocalEnvAnalyzer(**local_env_kwargs) + return analyzer.structure_to_local_env_string(self.structure) + def get_crystal_llm_rep( self, permute_atoms: bool = False, diff --git a/src/xtal2txt/decoder.py b/src/xtal2txt/decoder.py index 6214cf5..15b8056 100644 --- a/src/xtal2txt/decoder.py +++ b/src/xtal2txt/decoder.py @@ -218,7 +218,7 @@ def wyckoff_matcher( output_struct = DecodeTextRep(self.text).wyckoff_decoder( self.text, lattice_params=True ) - + return StructureMatcher( ltol, stol, diff --git a/src/xtal2txt/local_env.py b/src/xtal2txt/local_env.py index 614e0f4..2a9e108 100644 --- a/src/xtal2txt/local_env.py +++ b/src/xtal2txt/local_env.py @@ -43,7 +43,9 @@ def __init__(self, distance_cutoff: float = 1.4, angle_cutoff: float = 0.3): self.distance_cutoff = distance_cutoff self.angle_cutoff = angle_cutoff - def get_local_environments(self, structure: Structure) -> Tuple[List[dict], List[dict], str]: + def get_local_environments( + self, structure: Structure + ) -> Tuple[List[dict], List[dict], str]: """Get the local environments of the atoms in a structure. Args: @@ -58,7 +60,9 @@ def get_local_environments(self, structure: Structure) -> Tuple[List[dict], List sga = SpacegroupAnalyzer(structure) symm_struct = sga.get_symmetrized_structure() - inequivalent_indices = [indices[0] for indices in symm_struct.equivalent_indices] + inequivalent_indices = [ + indices[0] for indices in symm_struct.equivalent_indices + ] wyckoffs = symm_struct.wyckoff_symbols # a Voronoi tessellation is used to determine the local environment of each atom diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index c562324..c5e5111 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -51,7 +51,9 @@ def num_matcher(self, text: str) -> str: r"\d+(?:\.\d+)?" # Match any number, whether it is part of a string or not ) matches = list(re.finditer(pattern, text)) - for match in reversed(matches): #since we are replacing substring with a bigger subtring the string we are working on + for match in reversed( + matches + ): # since we are replacing substring with a bigger subtring the string we are working on start, end = match.start(), match.end() tokens = self.tokenize(match.group()) replacement = "".join(tokens) @@ -123,7 +125,10 @@ def __init__( self, special_num_token: bool = False, vocab_file=None, - special_tokens={"cls_token": "[CLS]","sep_token": "[SEP]",}, + special_tokens={ + "cls_token": "[CLS]", + "sep_token": "[SEP]", + }, model_max_length=None, padding_length=None, **kwargs, @@ -133,14 +138,13 @@ def __init__( ) self.truncation = False self.padding = False - self.padding_length = padding_length - + self.padding_length = padding_length self.special_num_tokens = special_num_token self.vocab = self.load_vocab(vocab_file) self.vocab_file = vocab_file - # Initialize special tokens + # Initialize special tokens self.special_tokens = special_tokens if special_tokens is not None else {} self.add_special_tokens(self.special_tokens) @@ -156,7 +160,6 @@ def load_vocab(self, vocab_file): else: raise ValueError(f"Unsupported file type: {file_extension}") - def get_vocab(self): return self.vocab @@ -181,7 +184,9 @@ def tokenize(self, text): matches = [self.cls_token] + matches if self.truncation and len(matches) > self.model_max_length: - matches = matches[: self.model_max_length-1] # -1 since we add sep token later + matches = matches[ + : self.model_max_length - 1 + ] # -1 since we add sep token later if self.sep_token is not None: matches += [self.sep_token] diff --git a/src/xtal2txt/transforms.py b/src/xtal2txt/transforms.py index e216be5..d32e8a7 100644 --- a/src/xtal2txt/transforms.py +++ b/src/xtal2txt/transforms.py @@ -3,6 +3,7 @@ from pymatgen.core.structure import Structure from typing import Union, List + def set_seed(seed: int): """ Set the random seed for both random and numpy.random. diff --git a/tests/test_textrep.py b/tests/test_textrep.py index ade4a99..4e9bb05 100644 --- a/tests/test_textrep.py +++ b/tests/test_textrep.py @@ -74,3 +74,11 @@ def test_robocrys_for_cif_format() -> None: def test_get_robocrys_rep() -> None: excepted_output = "SrTiO3 is (Cubic) Perovskite structured and crystallizes in the cubic Pm-3m space group. Sr(1)2+ is bonded to twelve equivalent O(1)2- atoms to form SrO12 cuboctahedra that share corners with twelve equivalent Sr(1)O12 cuboctahedra, faces with six equivalent Sr(1)O12 cuboctahedra, and faces with eight equivalent Ti(1)O6 octahedra. All Sr(1)-O(1) bond lengths are 2.77 Å. Ti(1)4+ is bonded to six equivalent O(1)2- atoms to form TiO6 octahedra that share corners with six equivalent Ti(1)O6 octahedra and faces with eight equivalent Sr(1)O12 cuboctahedra. The corner-sharing octahedra are not tilted. All Ti(1)-O(1) bond lengths are 1.96 Å. O(1)2- is bonded in a distorted linear geometry to four equivalent Sr(1)2+ and two equivalent Ti(1)4+ atoms." assert srtio3_p1.get_robocrys_rep() == excepted_output + + +def test_get_local_env_rep() -> None: + expected_output = """Pm-3m +Sr2+ (1a) [O][Sr][O].[O].[O].[O].[O].[O].[O].[O].[O].[O].[O] +Ti4+ (1b) [O][Ti]([O])([O])([O])([O])[O] +O2- (3c) [Ti]O[Ti]""" + assert srtio3_p1.get_local_env_rep() == expected_output diff --git a/tests/tokenizer/test_cif_tokenizer.py b/tests/tokenizer/test_cif_tokenizer.py index 939e8c8..e1509c1 100644 --- a/tests/tokenizer/test_cif_tokenizer.py +++ b/tests/tokenizer/test_cif_tokenizer.py @@ -25,5 +25,5 @@ def test_encode_decode(tokenizer): for name, struct in structures.items(): input_string = struct.get_cif_string() token_ids = tokenizer.encode(input_string) - decoded_tokens = tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = tokenizer.decode(token_ids, skip_special_tokens=True) assert input_string == decoded_tokens diff --git a/tests/tokenizer/test_composition_tokenizer.py b/tests/tokenizer/test_composition_tokenizer.py index c1c114a..b2400a2 100644 --- a/tests/tokenizer/test_composition_tokenizer.py +++ b/tests/tokenizer/test_composition_tokenizer.py @@ -32,21 +32,22 @@ def test_convert_id_to_token(tokenizer): def test_encode_decode(tokenizer): input_string = "SrTiO3" token_ids = tokenizer.encode(input_string) - decoded_tokens = tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = tokenizer.decode(token_ids, skip_special_tokens=True) assert input_string == decoded_tokens input_string_2 = "Cr4P16Pb4" token_ids = tokenizer.encode(input_string_2) - decoded_tokens = tokenizer.decode(token_ids,skip_special_tokens=True) - - - -@pytest.mark.parametrize("input_string,expected", [ - ("Ba2ClSr", ['[CLS]','Ba', '2', 'Cl','Sr','[SEP]']), - ("BrMn2V", ['[CLS]','Br', 'Mn', '2', 'V','[SEP]']), - ("La2Ta4", ['[CLS]','La', '2', 'Ta', '4','[SEP]']), - ("Cr4P16Pb4", ['[CLS]','Cr', '4', 'P', '1', '6', 'Pb', '4','[SEP]']), - -]) + decoded_tokens = tokenizer.decode(token_ids, skip_special_tokens=True) + + +@pytest.mark.parametrize( + "input_string,expected", + [ + ("Ba2ClSr", ["[CLS]", "Ba", "2", "Cl", "Sr", "[SEP]"]), + ("BrMn2V", ["[CLS]", "Br", "Mn", "2", "V", "[SEP]"]), + ("La2Ta4", ["[CLS]", "La", "2", "Ta", "4", "[SEP]"]), + ("Cr4P16Pb4", ["[CLS]", "Cr", "4", "P", "1", "6", "Pb", "4", "[SEP]"]), + ], +) def test_tokenizer(tokenizer, input_string, expected): tokens = tokenizer.tokenize(input_string) assert tokens == expected diff --git a/tests/tokenizer/test_crystal_llm_tokenizer.py b/tests/tokenizer/test_crystal_llm_tokenizer.py index 2fab13b..76dfa2d 100644 --- a/tests/tokenizer/test_crystal_llm_tokenizer.py +++ b/tests/tokenizer/test_crystal_llm_tokenizer.py @@ -132,5 +132,5 @@ def test_convert_id_to_token(tokenizer): def test_encode_decode(tokenizer): input_string = "5.6 5.6 5.6\n90 90 90\nN0+\n0.48 0.98 0.52\nN0+\n0.98 0.52 0.48\nN0+\n0.02 0.02 0.02\nN0+\n0.52 0.48 0.98" token_ids = tokenizer.encode(input_string) - decoded_tokens = tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = tokenizer.decode(token_ids, skip_special_tokens=True) assert input_string == decoded_tokens diff --git a/tests/tokenizer/test_rt_tokenier.py b/tests/tokenizer/test_rt_tokenier.py index a19c1c0..8342799 100644 --- a/tests/tokenizer/test_rt_tokenier.py +++ b/tests/tokenizer/test_rt_tokenier.py @@ -3,7 +3,12 @@ import difflib from xtal2txt.core import TextRep -from xtal2txt.tokenizer import CifTokenizer, CrysllmTokenizer, SliceTokenizer, CompositionTokenizer +from xtal2txt.tokenizer import ( + CifTokenizer, + CrysllmTokenizer, + SliceTokenizer, + CompositionTokenizer, +) THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -53,27 +58,35 @@ def print_diff(input_string, decoded_tokens): print("\n".join(diff)) -def test_encode_decode(cif_rt_tokenizer, crystal_llm_rt_tokenizer, slice_rt_tokenizer,composition_rt_tokenizer): +def test_encode_decode( + cif_rt_tokenizer, + crystal_llm_rt_tokenizer, + slice_rt_tokenizer, + composition_rt_tokenizer, +): for name, struct in structures.items(): - input_string = struct.get_composition() token_ids = composition_rt_tokenizer.encode(input_string) - decoded_tokens = composition_rt_tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = composition_rt_tokenizer.decode( + token_ids, skip_special_tokens=True + ) assert input_string == decoded_tokens input_string = struct.get_cif_string(format="p1", decimal_places=2) token_ids = cif_rt_tokenizer.encode(input_string) - decoded_tokens = cif_rt_tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = cif_rt_tokenizer.decode(token_ids, skip_special_tokens=True) assert input_string == decoded_tokens input_string = struct.get_crystal_llm_rep() token_ids = crystal_llm_rt_tokenizer.encode(input_string) - decoded_tokens = crystal_llm_rt_tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = crystal_llm_rt_tokenizer.decode( + token_ids, skip_special_tokens=True + ) assert input_string == decoded_tokens input_string = struct.get_slice() token_ids = slice_rt_tokenizer.encode(input_string) - decoded_tokens = slice_rt_tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = slice_rt_tokenizer.decode(token_ids, skip_special_tokens=True) try: assert input_string.strip() == decoded_tokens except AssertionError: @@ -82,13 +95,12 @@ def test_encode_decode(cif_rt_tokenizer, crystal_llm_rt_tokenizer, slice_rt_toke def test_composition_rt_tokens(composition_rt_tokenizer) -> None: - excepted_output = ["[CLS]","Se", "_2_0_", "Se", "_3_0_","[SEP]"] + excepted_output = ["[CLS]", "Se", "_2_0_", "Se", "_3_0_", "[SEP]"] input_string = "Se2Se3" tokens = composition_rt_tokenizer.tokenize(input_string) assert tokens == excepted_output - def test_cif_rt_tokenize(cif_rt_tokenizer): input_string = "data_N2\n_symmetry_space_group_name_H-M 'P 1'\n_cell_length_a 5.605\n_cell_length_b 5.605\n_cell_length_c 5.605\n_cell_angle_alpha 90.0\n_cell_angle_beta 90.0\n_cell_angle_gamma 90.0\n_symmetry_Int_Tables_number 1\n_chemical_formula_structural N2\n_chemical_formula_sum N4\n_cell_volume 176.125\n_cell_formula_units_Z 2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n 1 'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n N0+ 0.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n N0+ N0 1 0.477 0.977 0.523 1.0\n N0+ N1 1 0.977 0.523 0.477 1.0\n N0+ N2 1 0.023 0.023 0.023 1.0\n N0+ N3 1 0.523 0.477 0.977 1.0\n" tokens = cif_rt_tokenizer.tokenize(input_string) @@ -364,7 +376,7 @@ def test_cif_rt_tokens(cif_rt_tokenizer) -> None: "_._", "_0_-1_", "\n", - "[SEP]" + "[SEP]", ] input_string = "data_N2\n_symmetry_space_group_name_H-M 'P 1'\n_cell_length_a 5.605\n_cell_length_b 5.605\n_cell_length_c 5.605\n_cell_angle_alpha 90.0\n_cell_angle_beta 90.0\n_cell_angle_gamma 90.0\n_symmetry_Int_Tables_number 1\n_chemical_formula_structural N2\n_chemical_formula_sum N4\n_cell_volume 176.125\n_cell_formula_units_Z 2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n 1 'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n N0+ 0.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n N0+ N0 1 0.477 0.977 0.523 1.0\n N0+ N1 1 0.977 0.523 0.477 1.0\n N0+ N2 1 0.023 0.023 0.023 1.0\n N0+ N3 1 0.523 0.477 0.977 1.0\n" tokens = cif_rt_tokenizer.tokenize(input_string) @@ -489,7 +501,7 @@ def test_crystal_llm_tokens(crystal_llm_rt_tokenizer) -> None: "_._", "_5_-1_", "_0_-2_", - "[SEP]" + "[SEP]", ] input_string = "3.9 3.9 3.9\n90 90 90\nSr2+\n0.00 0.00 0.00\nTi4+\n0.50 0.50 0.50\nO2-\n0.50 0.00 0.50\nO2-\n0.50 0.50 0.00\nO2-\n0.00 0.50 0.50" tokens = crystal_llm_rt_tokenizer.tokenize(input_string) diff --git a/tests/tokenizer/test_slice_tokenizer.py b/tests/tokenizer/test_slice_tokenizer.py index 3dcd378..178e8bd 100644 --- a/tests/tokenizer/test_slice_tokenizer.py +++ b/tests/tokenizer/test_slice_tokenizer.py @@ -46,18 +46,74 @@ def test_convert_id_to_token(tokenizer): def test_encode_decode(tokenizer): input_string = "Ga Ga Ga Ga P P P P 0 3 - - o 0 2 - o - 0 1 o - - 0 7 o o o 0 6 o o o 0 5 o o o 1 2 - + o 1 3 - o + 2 3 o - + 4 5 o o o 4 6 o o o 4 7 o o o 5 7 o o o 5 6 o o o 6 7 o o o" token_ids = tokenizer.encode(input_string) - decoded_tokens = tokenizer.decode(token_ids,skip_special_tokens=True) + decoded_tokens = tokenizer.decode(token_ids, skip_special_tokens=True) assert input_string == decoded_tokens - -@pytest.mark.parametrize("input_string,expected", [ - ("Se Se Mo 0 2 o o + 0 2 + o o 0 2 o + o 1 2 o o + 1 2 o + o 1 2 + o o", ['[CLS]','Se', 'Se', 'Mo', '0', '2', 'o o +', '0', '2', '+ o o', '0', '2', 'o + o', '1', '2', 'o o +', '1', '2', 'o + o', '1', '2', '+ o o','[SEP]']), - ("H H O", ['[CLS]','H', 'H', 'O','[SEP]']), - ("Sc Sc 0 1 - - - ", ['[CLS]','Sc', 'Sc', '0', '1', '- - -','[SEP]']), - ("Cu Cu Cu Cu 0 3 - - o 0 2 - o - 0 1 o - - 1 2 - + o 1 3 - o + 2 3 o - + ", ['[CLS]','Cu', 'Cu', 'Cu', 'Cu', '0', '3', '- - o', '0', '2', '- o -', '0', '1', 'o - -', '1', '2', '- + o', '1', '3', '- o +', '2', '3', 'o - +','[SEP]']), - -]) +@pytest.mark.parametrize( + "input_string,expected", + [ + ( + "Se Se Mo 0 2 o o + 0 2 + o o 0 2 o + o 1 2 o o + 1 2 o + o 1 2 + o o", + [ + "[CLS]", + "Se", + "Se", + "Mo", + "0", + "2", + "o o +", + "0", + "2", + "+ o o", + "0", + "2", + "o + o", + "1", + "2", + "o o +", + "1", + "2", + "o + o", + "1", + "2", + "+ o o", + "[SEP]", + ], + ), + ("H H O", ["[CLS]", "H", "H", "O", "[SEP]"]), + ("Sc Sc 0 1 - - - ", ["[CLS]", "Sc", "Sc", "0", "1", "- - -", "[SEP]"]), + ( + "Cu Cu Cu Cu 0 3 - - o 0 2 - o - 0 1 o - - 1 2 - + o 1 3 - o + 2 3 o - + ", + [ + "[CLS]", + "Cu", + "Cu", + "Cu", + "Cu", + "0", + "3", + "- - o", + "0", + "2", + "- o -", + "0", + "1", + "o - -", + "1", + "2", + "- + o", + "1", + "3", + "- o +", + "2", + "3", + "o - +", + "[SEP]", + ], + ), + ], +) def test_tokenizer(tokenizer, input_string, expected): tokens = tokenizer.tokenize(input_string) assert tokens == expected From 7fe6c055a89544d0a41ba18d445304b7e62a5c23 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Sun, 28 Apr 2024 20:02:29 +0200 Subject: [PATCH 4/4] chore: update get rep methods --- src/xtal2txt/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/xtal2txt/core.py b/src/xtal2txt/core.py index 3128297..3703f0d 100644 --- a/src/xtal2txt/core.py +++ b/src/xtal2txt/core.py @@ -468,6 +468,7 @@ def get_all_text_reps(self, decimal_places: int = 2): decimal_places=decimal_places, ), "zmatrix": self._safe_call(self.get_zmatrix_rep), + "local_env": self._safe_call(self.get_local_env_rep, local_env_kwargs=None), } def get_requested_text_reps( @@ -508,6 +509,8 @@ def get_requested_text_reps( decimal_places=decimal_places, ), "zmatrix": lambda: self._safe_call(self.get_zmatrix_rep, decimal_places=1), + "local_env": lambda: self._safe_call(self.get_local_env_rep, + local_env_kwargs=None), } return {rep: all_reps[rep]() for rep in requested_reps if rep in all_reps}