Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed double-counting of periopdic image pairs #260

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions python/rascaline/rascaline/systems/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,23 @@ def compute_neighbors(self, cutoff):

nl_result = neighborlist.neighbor_list("ijdDS", self._atoms, cutoff)
for i, j, d, D, S in zip(*nl_result):
# we want a half neighbor list, so drop all duplicated neighbors
if j < i:
# we want a half neighbor list, so drop all duplicated
# neighbors
continue
elif i == j:
if S[0] == 0 and S[1] == 0 and S[2] == 0:
# only create pairs with the same atom twice if the pair spans more
# than one unit cell
continue
elif S[0] + S[1] + S[2] < 0 or (
(S[0] + S[1] + S[2] == 0) and (S[2] < 0 or (S[2] == 0 and S[1] < 0))
):
# When creating pairs between an atom and one of its periodic
# images, the code generate multiple redundant pairs (e.g. with
# shifts 0 1 1 and 0 -1 -1); and we want to only keep one of these.
# We keep the pair in the positive half plane of shifts.
continue

self._pairs.append((i, j, d, D, S))

self._pairs_by_center = []
Expand Down
54 changes: 54 additions & 0 deletions python/rascaline/tests/systems/ase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from rascaline import SphericalExpansion
from rascaline.systems import AseSystem


Expand Down Expand Up @@ -145,3 +146,56 @@ def test_no_pbc_cell():
)
with pytest.warns(Warning, match=message):
AseSystem(atoms)


def test_same_spherical_expansion():
system = ase.Atoms(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the feeling that this cell is this to small to check for more then one unit cell?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 3.6 in one direction, you should get pairs with a shift of 2 with a cutoff of 9

"CaC6",
positions=[
(0.0, 0.0, 0.0),
(1.88597, 1.92706, 0.0113749),
(2.66157, 3.55479, 7.7372),
(2.35488, 3.36661, 3.88),
(2.19266, 2.11524, 3.86858),
(2.52936, 4.62777, 3.87771),
(2.01817, 0.85408, 3.87087),
],
cell=[[3.57941, 0, 0], [0.682558, 5.02733, 0], [0.285565, 0.454525, 7.74858]],
pbc=True,
)

calculator = SphericalExpansion(
# make sure to choose a cutoff larger then the cell to test for pairs crossing
# multiple periodic boundaries
cutoff=9,
Luthaf marked this conversation as resolved.
Show resolved Hide resolved
max_radial=5,
max_angular=5,
atomic_gaussian_width=0.3,
radial_basis={"Gto": {}},
center_atom_weight=1.0,
cutoff_function={"Step": {}},
)

rascaline_nl = calculator.compute(
system, gradients=["positions", "cell"], use_native_system=True
)

ase_nl = calculator.compute(
system, gradients=["positions", "cell"], use_native_system=False
)

for key, block in rascaline_nl.items():
ase_block = ase_nl.block(key)

assert ase_block.samples == block.samples
# Since the pairs are in a different order, the values are slightly different
assert np.allclose(ase_block.values, block.values, atol=1e-16, rtol=1e-9)

for parameter in ["cell", "positions"]:
gradient = block.gradient(parameter)
ase_gradient = ase_block.gradient(parameter)

assert gradient.samples == ase_gradient.samples
assert np.allclose(
ase_gradient.values, gradient.values, atol=1e-16, rtol=1e-6
)
104 changes: 85 additions & 19 deletions rascaline/src/calculators/neighbor_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct NeighborList {
pub full_neighbor_list: bool,
/// Should individual atoms be considered their own neighbor? Setting this
/// to `true` will add "self pairs", i.e. pairs between an atom and itself,
/// with the distance 0. The `pair_id` of such pairs is set to -1.
/// with the distance 0.
pub self_pairs: bool,
}

Expand Down Expand Up @@ -423,7 +423,8 @@ impl FullNeighborList {
let cell_c = pair.cell_shift_indices[2];

if species_first == species_second {
// same species for both atoms in the pair
// same species for both atoms in the pair, add the pair
// twice in both directions.
if species[pair.first] == species_first.i32() && species[pair.second] == species_second.i32() {
builder.add(&[
LabelValue::from(system_i),
Expand All @@ -434,18 +435,14 @@ impl FullNeighborList {
LabelValue::from(cell_c),
]);

if pair.first != pair.second {
// if the pair is between two different atoms,
// also add the reversed (second -> first) pair.
builder.add(&[
LabelValue::from(system_i),
LabelValue::from(pair.second),
LabelValue::from(pair.first),
LabelValue::from(-cell_a),
LabelValue::from(-cell_b),
LabelValue::from(-cell_c),
]);
}
builder.add(&[
LabelValue::from(system_i),
LabelValue::from(pair.second),
LabelValue::from(pair.first),
LabelValue::from(-cell_a),
LabelValue::from(-cell_b),
LabelValue::from(-cell_c),
]);
}
} else {
// different species, find the right order for the pair
Expand Down Expand Up @@ -501,6 +498,11 @@ impl FullNeighborList {
let species = system.species()?;

for pair in system.pairs()? {
if pair.first == pair.second {
// self pairs should not be part of the neighbors list
assert_ne!(pair.cell_shift_indices, [0, 0, 0]);
}

let first_block_i = descriptor.keys().position(&[
species[pair.first].into(), species[pair.second].into()
]);
Expand Down Expand Up @@ -565,11 +567,6 @@ impl FullNeighborList {
}
}

if pair.first == pair.second {
// do not duplicate self pairs
continue;
}

// then the pair second -> first
if let Some(second_block_i) = second_block_i {
let mut block = descriptor.block_mut_by_id(second_block_i);
Expand Down Expand Up @@ -764,6 +761,75 @@ mod tests {
assert_relative_eq!(array, expected, max_relative=1e-6);
}

#[test]
fn periodic_neighbor_list() {
let mut calculator = Calculator::from(Box::new(NeighborList{
cutoff: 12.0,
full_neighbor_list: false,
self_pairs: false,
}) as Box<dyn CalculatorBase>);

let mut systems = test_systems(&["CH"]);

let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();
assert_eq!(*descriptor.keys(), Labels::new(
["species_first_atom", "species_second_atom"],
&[[1, 1], [1, 6], [6, 6]]
));

// H-H block
let block = descriptor.block_by_id(0);
assert_eq!(block.samples(), Labels::new(
["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
// the pairs only differ in cell shifts
&[[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 1, 1, 1, 0, 0]]
));

let array = block.values().to_array();
let expected = &ndarray::arr3(&[
[[0.0], [0.0], [10.0]],
[[0.0], [10.0], [0.0]],
[[10.0], [0.0], [0.0]],
]).into_dyn();
assert_relative_eq!(array, expected, max_relative=1e-6);

// now a full NL
let mut calculator = Calculator::from(Box::new(NeighborList{
cutoff: 12.0,
full_neighbor_list: true,
self_pairs: false,
}) as Box<dyn CalculatorBase>);

let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();
assert_eq!(*descriptor.keys(), Labels::new(
["species_first_atom", "species_second_atom"],
&[[1, 1], [1, 6], [6, 1], [6, 6]]
));

// H-H block
let block = descriptor.block_by_id(0);
assert_eq!(block.samples(), Labels::new(
["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
// twice as many pairs
&[
[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, -1],
[0, 1, 1, 0, 1, 0], [0, 1, 1, 0, -1, 0],
[0, 1, 1, 1, 0, 0], [0, 1, 1, -1, 0, 0],
]
));

let array = block.values().to_array();
let expected = &ndarray::arr3(&[
[[0.0], [0.0], [10.0]],
[[0.0], [0.0], [-10.0]],
[[0.0], [10.0], [0.0]],
[[0.0], [-10.0], [0.0]],
[[10.0], [0.0], [0.0]],
[[-10.0], [0.0], [0.0]],
]).into_dyn();
assert_relative_eq!(array, expected, max_relative=1e-6);
}

#[test]
fn finite_differences_positions() {
// half neighbor list
Expand Down
8 changes: 1 addition & 7 deletions rascaline/src/calculators/soap/spherical_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,6 @@ impl SphericalExpansion {
}
}

if pair.first == pair.second {
// do not compute for the reversed pair if the pair is
// between an atom and its image
continue;
}

if let Some(mapped_center) = result.centers_mapping[pair.second] {
// add the pair contribution to the atomic environnement
// corresponding to the **second** atom in the pair
Expand Down Expand Up @@ -778,7 +772,7 @@ mod tests {

fn parameters() -> SphericalExpansionParameters {
SphericalExpansionParameters {
cutoff: 3.5,
cutoff: 7.8,
max_radial: 6,
max_angular: 6,
atomic_gaussian_width: 0.3,
Expand Down
10 changes: 2 additions & 8 deletions rascaline/src/calculators/soap/spherical_expansion_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,13 +755,7 @@ impl CalculatorBase for SphericalExpansionByPair {
}
}

// also check for the block with a reversed pair, except if
// we are handling a pair between an atom and it's own
// periodic image
if pair.first == pair.second {
continue;
}

// also check for the block with a reversed pair
contribution.inverse_pair(&self.m_1_pow_l);

for spherical_harmonics_l in 0..=self.parameters.max_angular {
Expand Down Expand Up @@ -817,7 +811,7 @@ mod tests {

fn parameters() -> SphericalExpansionParameters {
SphericalExpansionParameters {
cutoff: 3.5,
cutoff: 7.3,
max_radial: 6,
max_angular: 6,
atomic_gaussian_width: 0.3,
Expand Down
Loading
Loading