Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data module #674

Merged
merged 52 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d641a4b
WIP keyspec implementation
WillBaldwin0 Sep 2, 2024
d61f917
WIP2
WillBaldwin0 Sep 2, 2024
f8a3c51
Merge branch 'multihead-merge' into refactor-data
WillBaldwin0 Sep 2, 2024
cefb4f7
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 2, 2024
a7cead1
WIP3
WillBaldwin0 Sep 2, 2024
79bd1df
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 2, 2024
818bcc8
fixed some tests
WillBaldwin0 Sep 2, 2024
bfae06a
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 2, 2024
b52d20b
new interface passing old tests
WillBaldwin0 Sep 2, 2024
a742297
linting and fixed preprocess data
WillBaldwin0 Sep 2, 2024
ddfb31d
more linting
WillBaldwin0 Sep 2, 2024
4253216
Update unittest.yaml
WillBaldwin0 Sep 3, 2024
8803c48
fix key overwriting and unittests
WillBaldwin0 Sep 3, 2024
6e31995
small bug in settings REF_forces
WillBaldwin0 Sep 3, 2024
dfefca0
remove head key and some minor fixes
WillBaldwin0 Sep 4, 2024
2e4d524
default to Default for heads
WillBaldwin0 Sep 5, 2024
7d5c70d
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 5, 2024
1390a0b
added new test and fix calculator
WillBaldwin0 Sep 6, 2024
339231e
fix tests seed
WillBaldwin0 Sep 6, 2024
70341b1
fix average e0s method
WillBaldwin0 Sep 9, 2024
92d39eb
added missing charges and dipoles weights
WillBaldwin0 Sep 9, 2024
e9e2779
linting
WillBaldwin0 Sep 9, 2024
c0b65e2
moved keyspec construction into run_train
WillBaldwin0 Sep 9, 2024
3b0c34f
pass copies to neighborhood
WillBaldwin0 Sep 18, 2024
b78d1ec
convience function for logging dataset stats
Oct 22, 2024
eed2a41
formatting
Oct 22, 2024
b6bf6c0
fix type hint
RokasEl Oct 24, 2024
5b44fce
minor fixes from review
Oct 29, 2024
503e410
Merge branch 'develop' into refactor-data
WillBaldwin0 Oct 29, 2024
c00410c
fixes for new tests and linting
Oct 29, 2024
908acd1
head key in preprocessor
Oct 29, 2024
808f794
formatting
Oct 29, 2024
8d9f6cb
Merge branch 'develop' into refactor-data
WillBaldwin0 Oct 29, 2024
7a19ed6
new calculator syntax in test_run_train
WillBaldwin0 Oct 29, 2024
b8ef3ab
Merge branch 'develop' into refactor-data
ilyes319 Nov 6, 2024
8579c26
Merge pull request #663 from WillBaldwin0/refactor-data
ilyes319 Nov 6, 2024
8f0c6f3
configargparse for the preprocessing parser
Jan 22, 2025
251bb43
Merge pull request #795 from WillBaldwin0/refactor-data
ilyes319 Jan 22, 2025
fcf8385
Merge branch 'develop' into refactor_data
RokasEl Mar 20, 2025
0058751
Test fixes and default changes for mp data assembly
RokasEl Mar 20, 2025
44e2c77
add future annotations
RokasEl Mar 21, 2025
a2ba8ab
old python doesnt support strict zipping
RokasEl Mar 21, 2025
86927b4
pre-commit fixes
RokasEl Mar 21, 2025
c764a0d
Update .gitignore
ilyes319 Mar 27, 2025
b55cbee
Merge pull request #879 from ACEsuit/refactor_data_update
ilyes319 Mar 27, 2025
1778893
Merge branch 'refactor_data' of https://github.com/ACEsuit/mace into …
ilyes319 Mar 28, 2025
e897a7d
Merge remote-tracking branch 'origin/develop' into refactor_data
ilyes319 Mar 28, 2025
4f1719e
add custom encoder for wandb
ilyes319 Mar 28, 2025
8a7bfdd
fix all linter
ilyes319 Mar 28, 2025
99de42d
Merge branch 'develop' into refactor_data
ilyes319 Mar 28, 2025
f50f779
fix the head selection in test_multifiles
ilyes319 Mar 28, 2025
379f91e
fix fixture names
ilyes319 Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ dist/
*.xyz
/checkpoints
*.model
/results
*.db

.benchmarks
*.db
*.png
30 changes: 26 additions & 4 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,24 @@ def __init__(
[int(z) for z in self.models[0].atomic_numbers]
)
self.charges_key = charges_key

try:
self.heads = self.models[0].heads
self.available_heads: List[str] = self.models[0].heads # type: ignore
except AttributeError:
self.heads = ["Default"]
self.available_heads = ["Default"]
kwarg_head = kwargs.get("head", None)
if kwarg_head is not None:
self.head = kwarg_head
else:
self.head = self.available_heads[0]
if kwarg_head is None and self.head.lower() != "default":
raise ValueError(
"Head keyword was not provided, and the head in the model is not 'Default'"
f"Please provide a head keyword to specify the head you want to use. Available heads are: {self.available_heads}"
)

print("Using head", self.head, "out of", self.available_heads)

model_dtype = get_model_dtype(self.models[0])
if default_dtype == "":
print(
Expand Down Expand Up @@ -250,11 +264,19 @@ def _create_result_tensors(
return dict_of_tensors

def _atoms_to_batch(self, atoms):
config = data.config_from_atoms(atoms, charges_key=self.charges_key)
keyspec = data.KeySpecification(
info_keys={}, arrays_keys={"charges": self.charges_key}
)
config = data.config_from_atoms(
atoms, key_specification=keyspec, head_name=self.head
)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
)
],
batch_size=1,
Expand Down
Loading