@@ -463,8 +463,8 @@ def lennard_jones_pair(
463
463
# the system. Thanks to the fact we rely on ``torch`` autodifferentiation mechanism, the
464
464
# forces acting on the virtual sites will be automatically split between O and H atoms,
465
465
# in a way that is consistent with the definition.
466
-
467
- # forces acting on the M sites will be automatically split between O and H atoms, in a
466
+ #
467
+ # Forces acting on the M sites will be automatically split between O and H atoms, in a
468
468
# way that is consistent with the definition.
469
469
470
470
@@ -583,7 +583,8 @@ def get_molecular_geometry(
583
583
)
584
584
585
585
tensor = TensorMap (
586
- keys = Labels ("_" , torch .zeros (1 , 1 , dtype = torch .int32 )), blocks = [data ]
586
+ keys = Labels ("_" , torch .zeros (1 , 1 , device = charges .device , dtype = torch .int32 )),
587
+ blocks = [data ],
587
588
)
588
589
589
590
m_system .add_data (name = "charges" , tensor = tensor )
@@ -637,8 +638,6 @@ def __init__(
637
638
hoh_angle_eq : float ,
638
639
hoh_angle_k : float ,
639
640
p3m_options : Optional [dict ] = None ,
640
- dtype : Optional [torch .dtype ] = None ,
641
- device : Optional [torch .device ] = None ,
642
641
):
643
642
super ().__init__ ()
644
643
@@ -653,34 +652,24 @@ def __init__(
653
652
** p3m_parameters ,
654
653
prefactor = torchpme .prefactors .kcalmol_A , # consistent units
655
654
)
656
- self .p3m_calculator .to (device = device , dtype = dtype )
657
655
658
656
self .coulomb = torchpme .CoulombPotential ()
659
- self .coulomb .to (device = device , dtype = dtype )
660
657
661
658
# We use a half neighborlist and allow to have pairs farther than cutoff
662
659
# (`strict=False`) since this is not problematic for PME and may speed up the
663
660
# computation of the neigbors.
664
661
self .nlo = NeighborListOptions (cutoff = cutoff , full_list = False , strict = False )
665
662
666
- # registers model parameters as buffers
667
- self .dtype = dtype if dtype is not None else torch .get_default_dtype ()
668
- self .device = device if device is not None else torch .get_default_device ()
669
-
670
- self .register_buffer ("cutoff" , torch .tensor (cutoff , dtype = self .dtype ))
671
- self .register_buffer ("lj_sigma" , torch .tensor (lj_sigma , dtype = self .dtype ))
672
- self .register_buffer ("lj_epsilon" , torch .tensor (lj_epsilon , dtype = self .dtype ))
673
- self .register_buffer ("m_gamma" , torch .tensor (m_gamma , dtype = self .dtype ))
674
- self .register_buffer ("m_charge" , torch .tensor (m_charge , dtype = self .dtype ))
675
- self .register_buffer ("oh_bond_eq" , torch .tensor (oh_bond_eq , dtype = self .dtype ))
676
- self .register_buffer ("oh_bond_k" , torch .tensor (oh_bond_k , dtype = self .dtype ))
677
- self .register_buffer (
678
- "oh_bond_alpha" , torch .tensor (oh_bond_alpha , dtype = self .dtype )
679
- )
680
- self .register_buffer (
681
- "hoh_angle_eq" , torch .tensor (hoh_angle_eq , dtype = self .dtype )
682
- )
683
- self .register_buffer ("hoh_angle_k" , torch .tensor (hoh_angle_k , dtype = self .dtype ))
663
+ self .register_buffer ("cutoff" , torch .tensor (cutoff ))
664
+ self .register_buffer ("lj_sigma" , torch .tensor (lj_sigma ))
665
+ self .register_buffer ("lj_epsilon" , torch .tensor (lj_epsilon ))
666
+ self .register_buffer ("m_gamma" , torch .tensor (m_gamma ))
667
+ self .register_buffer ("m_charge" , torch .tensor (m_charge ))
668
+ self .register_buffer ("oh_bond_eq" , torch .tensor (oh_bond_eq ))
669
+ self .register_buffer ("oh_bond_k" , torch .tensor (oh_bond_k ))
670
+ self .register_buffer ("oh_bond_alpha" , torch .tensor (oh_bond_alpha ))
671
+ self .register_buffer ("hoh_angle_eq" , torch .tensor (hoh_angle_eq ))
672
+ self .register_buffer ("hoh_angle_k" , torch .tensor (hoh_angle_k ))
684
673
685
674
def requested_neighbor_lists (self ):
686
675
"""Returns the list of neighbor list options that are needed."""
@@ -778,25 +767,28 @@ def forward(
778
767
779
768
# Rename property label to follow metatensor's convention for an atomistic model
780
769
samples = Labels (
781
- ["system" ], torch .arange (len (systems ), device = self .device ).reshape (- 1 , 1 )
770
+ ["system" ],
771
+ torch .arange (len (systems ), device = energy_tot .device ).reshape (- 1 , 1 ),
782
772
)
773
+ properties = Labels (["energy" ], torch .tensor ([[0 ]], device = energy_tot .device ))
774
+
783
775
block = TensorBlock (
784
776
values = torch .sum (energy_tot ).reshape (- 1 , 1 ),
785
777
samples = samples ,
786
778
components = torch .jit .annotate (List [Labels ], []),
787
- properties = Labels ([ "energy" ], torch . tensor ([[ 0 ]], device = self . device )) ,
779
+ properties = properties ,
788
780
)
789
781
return {
790
782
"energy" : TensorMap (
791
- Labels ("_" , torch .tensor ([[0 ]], device = self .device )), [block ]
783
+ Labels ("_" , torch .tensor ([[0 ]], device = energy_tot .device )), [block ]
792
784
),
793
785
}
794
786
795
787
796
788
# %%
797
789
#
798
790
# All this class does is take a ``System`` and return its energy (as a
799
- # :clas :`metatensor.TensorMap``).
791
+ # :class :`metatensor.TensorMap``).
800
792
801
793
qtip4pf_parameters = dict (
802
794
cutoff = 7.0 ,
@@ -861,7 +853,7 @@ def forward(
861
853
# Model options include a definition of its units, and a description of the quantities
862
854
# it can compute.
863
855
#
864
- # .. note::
856
+ # .. note::
865
857
#
866
858
# We neeed to specify that the model has infinite interaction range because of the
867
859
# presence of a long-range term that means one cannot assume that forces decay to zero
@@ -876,7 +868,7 @@ def forward(
876
868
atomic_types = [1 , 8 ],
877
869
interaction_range = torch .inf ,
878
870
length_unit = length_unit ,
879
- supported_devices = ["cpu " , "cuda " ],
871
+ supported_devices = ["cuda " , "cpu " ],
880
872
dtype = "float32" ,
881
873
)
882
874
0 commit comments