Skip to content

Commit

Permalink
Add cuda tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nnn911 committed Aug 16, 2023
1 parent 2c1d697 commit 2406aa9
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/test_modifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import torch
from ovito.io import import_file
from ovito.modifiers import CommonNeighborAnalysisModifier, ExpressionSelectionModifier
from ovito.pipeline import Pipeline
Expand Down Expand Up @@ -109,3 +110,49 @@ def test_fcc_distance_settings(import_pipeline: Pipeline):
assert len(data.tables["Convergence"].xy()) == 8
for k, v in expected.items():
assert data.attributes[k] == v


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda not available")
def test_fcc_distance_settings_cuda(import_pipeline: Pipeline):
pipe = import_pipeline
pipe.modifiers.append(
ScoreBasedDenoising(scale=2.5, structure="FCC", device="cuda")
)
pipe.modifiers.append(CommonNeighborAnalysisModifier())
data = pipe.compute()
expected = {
"CommonNeighborAnalysis.counts.BCC": 11,
"CommonNeighborAnalysis.counts.FCC": 7898,
"CommonNeighborAnalysis.counts.HCP": 214,
"CommonNeighborAnalysis.counts.ICO": 0,
"CommonNeighborAnalysis.counts.OTHER": 325,
}

assert len(data.tables["Convergence"].xy()) == 8
for k, v in expected.items():
assert data.attributes[k] == v


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda not available")
def test_selection_fcc_settings_cuda(import_pipeline: Pipeline):
pipe = import_pipeline
pipe.modifiers.append(
ExpressionSelectionModifier(
expression="ReducedPosition.Z > 0.4 && ReducedPosition.Z < 0.6"
)
)
pipe.modifiers.append(
ScoreBasedDenoising(structure="FCC", only_selected=True, device="cuda")
)
pipe.modifiers.append(CommonNeighborAnalysisModifier())
data = pipe.compute()
expected = {
"CommonNeighborAnalysis.counts.BCC": 40,
"CommonNeighborAnalysis.counts.FCC": 4015,
"CommonNeighborAnalysis.counts.HCP": 175,
"CommonNeighborAnalysis.counts.ICO": 0,
"CommonNeighborAnalysis.counts.OTHER": 4218,
}

for k, v in expected.items():
assert data.attributes[k] == v

0 comments on commit 2406aa9

Please sign in to comment.