Skip to content

Commit

Permalink
update to Pydantic 2+ and fix other outdated things
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-zero committed Jul 27, 2023
1 parent 3043c7c commit 3706fae
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 65 deletions.
6 changes: 3 additions & 3 deletions cp2k_input_tools/basissets/cp2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
from decimal import Decimal
from typing import Iterator, List, Optional, Sequence, Tuple

from pydantic import BaseModel, Extra
from pydantic import BaseModel

from ..utils import SYM2NUM, DatafileIterMixin, FromDictMixin, dformat

N_VAL_EL_MATCH = re.compile(r"q(?P<nvalel>\d+)$")


class BasisSetCoefficients(BaseModel, extra=Extra.forbid):
class BasisSetCoefficients(BaseModel, extra="forbid"):
"""A 'shell' in one single basis set"""

n: int
l: List[Tuple[int, int]] # noqa: E741
coefficients: List[List[Decimal]]


class BasisSetData(BaseModel, DatafileIterMixin, FromDictMixin, extra=Extra.forbid):
class BasisSetData(BaseModel, DatafileIterMixin, FromDictMixin, extra="forbid"):
"""Basis set data for a single element"""

element: str
Expand Down
6 changes: 3 additions & 3 deletions cp2k_input_tools/basissets/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from decimal import Decimal
from typing import Iterator, List, Optional, Sequence, Tuple

from pydantic import BaseModel, Extra
from pydantic import BaseModel

from ..pseudopotentials.ecp import ECP
from ..utils import NUM2SYM, DatafileIterMixin, FromDictMixin

BLOCK_MATCH = re.compile(r"^\s*\d+\s+\d+\s*$")


class BasisSetCoefficients(BaseModel, extra=Extra.forbid):
class BasisSetCoefficients(BaseModel, extra="forbid"):
"""A 'shell' in one single basis set"""

shell: int # 0: s, 1: s and p, 2: p, 3: d, 4:f
Expand All @@ -23,7 +23,7 @@ class BasisSetCoefficients(BaseModel, extra=Extra.forbid):
coefficients: List[Tuple[Decimal, Decimal]]


class BasisSetData(BaseModel, DatafileIterMixin, FromDictMixin, extra=Extra.forbid):
class BasisSetData(BaseModel, DatafileIterMixin, FromDictMixin, extra="forbid"):
"""Basis set data for a single element"""

Z: int
Expand Down
30 changes: 15 additions & 15 deletions cp2k_input_tools/pseudopotentials/cp2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from decimal import Decimal, InvalidOperation
from typing import Iterator, List, Sequence

from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import BaseModel, Field, model_validator

from ..utils import DatafileIterMixin, FromDictMixin, dformat

Expand All @@ -14,37 +14,37 @@ class PseudopotentialDataLocal(BaseModel):
r: Decimal
coefficients: List[Decimal] = Field(..., alias="coeffs")

class Config:
extra = "forbid"
allow_population_by_field_name = True
model_config = {
"extra": "forbid",
"populate_by_name": True,
}


class PseudopotentialDataNonLocal(BaseModel):
r: Decimal
nproj: int
coefficients: List[Decimal] = Field(..., alias="coeffs")

@root_validator
def check_coefficients(cls, values):
assert (
len(values["coefficients"]) == values["nproj"] * (values["nproj"] + 1) // 2
), "invalid number of coefficients for non-local projection"
return values
@model_validator(mode="after")
def check_coefficients(cls, obj):
assert len(obj.coefficients) == obj.nproj * (obj.nproj + 1) // 2, "invalid number of coefficients for non-local projection"
return obj

class Config:
extra = "forbid"
allow_population_by_field_name = True
model_config = {
"extra": "forbid",
"populate_by_name": True,
}


class PseudopotentialDataNLCC(BaseModel, extra=Extra.forbid):
class PseudopotentialDataNLCC(BaseModel, extra="forbid"):
"""Nonlinear Core Correction data"""

r: Decimal
n: int
c: Decimal


class PseudopotentialData(BaseModel, DatafileIterMixin, FromDictMixin, extra=Extra.forbid):
class PseudopotentialData(BaseModel, DatafileIterMixin, FromDictMixin, extra="forbid"):
element: str
identifiers: List[str]
n_el: List[int]
Expand Down
4 changes: 2 additions & 2 deletions cp2k_input_tools/pseudopotentials/ecp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from decimal import Decimal
from typing import Iterator, List, Tuple

from pydantic import BaseModel, Extra
from pydantic import BaseModel

from ..utils import NUM2SYM, dformat


class ECP(BaseModel, extra=Extra.forbid):
class ECP(BaseModel, extra="forbid"):
"""ECP for a single element"""

Z: int
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Pint = ">=0.15,<0.23"
"ruamel.yaml" = {version = ">=0.16.5,<0.18.0", optional = true}
pygls = {version = "^1.0.0", optional = true}
Jinja2 = ">=2.11.3,<4.0.0"
pydantic = "^1.8"
pydantic = ">=2,<3.0"
click = ">=7.1.2,<9"

[tool.poetry.extras]
Expand All @@ -32,7 +32,7 @@ lsp = ["pygls"]
[tool.poetry.dev-dependencies]
pytest = "^7.1"
pytest-cov = "^4.0"
pytest-console-scripts = "^1.2"
pytest-console-scripts = "^1.4"
sphinx = "^6.1.3"
sphinx-rtd-theme = "^1.0.0"
taskipy = "^1.4"
Expand Down
8 changes: 4 additions & 4 deletions tests/test_basisset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_check_formatting():
def test_datafile_lint(script_runner):
"""check that reformatting a formatted file leaves it as is"""
bsetfile = INPUTS_DIR / "BASIS_SET.formatted"
ret = script_runner.run("cp2k-datafile-lint", "basisset", str(bsetfile))
ret = script_runner.run(["cp2k-datafile-lint", "basisset", str(bsetfile)])

assert ret.stderr == ""
assert ret.success
Expand All @@ -58,7 +58,7 @@ def test_datafile_lint_crystal(script_runner):
"""check that reformatting a formatted file leaves it as is"""
bsetfile = INPUTS_DIR / "29_Cu.pob-DZVP-rev2"
ret = script_runner.run(
"cp2k-datafile-lint", "basisset", "--input-basis-format", "crystal", "--output-basis-format", "crystal", str(bsetfile)
["cp2k-datafile-lint", "basisset", "--input-basis-format", "crystal", "--output-basis-format", "crystal", str(bsetfile)]
)

assert ret.stderr == ""
Expand Down Expand Up @@ -240,8 +240,8 @@ def test_bset_from_dicts():
}

# NOTE: they are not identical since the first one goes via the internal bit-representation of the float
BasisSetData.parse_obj(floated_dict)
BasisSetData.parse_obj(stringified_dict)
BasisSetData.model_validate(floated_dict)
BasisSetData.model_validate(stringified_dict)


def test_new_style_ae_basisset_import():
Expand Down
Loading

0 comments on commit 3706fae

Please sign in to comment.