From 2406aa9928d41089093d8512be7a19656028f6a2 Mon Sep 17 00:00:00 2001 From: Daniel Utt Date: Wed, 16 Aug 2023 13:28:21 +0200 Subject: [PATCH] Add cuda tests --- tests/test_modifier.py | 47 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_modifier.py b/tests/test_modifier.py index 9d5bbd4..8ba5436 100644 --- a/tests/test_modifier.py +++ b/tests/test_modifier.py @@ -1,4 +1,5 @@ import pytest +import torch from ovito.io import import_file from ovito.modifiers import CommonNeighborAnalysisModifier, ExpressionSelectionModifier from ovito.pipeline import Pipeline @@ -109,3 +110,49 @@ def test_fcc_distance_settings(import_pipeline: Pipeline): assert len(data.tables["Convergence"].xy()) == 8 for k, v in expected.items(): assert data.attributes[k] == v + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda not available") +def test_fcc_distance_settings_cuda(import_pipeline: Pipeline): + pipe = import_pipeline + pipe.modifiers.append( + ScoreBasedDenoising(scale=2.5, structure="FCC", device="cuda") + ) + pipe.modifiers.append(CommonNeighborAnalysisModifier()) + data = pipe.compute() + expected = { + "CommonNeighborAnalysis.counts.BCC": 11, + "CommonNeighborAnalysis.counts.FCC": 7898, + "CommonNeighborAnalysis.counts.HCP": 214, + "CommonNeighborAnalysis.counts.ICO": 0, + "CommonNeighborAnalysis.counts.OTHER": 325, + } + + assert len(data.tables["Convergence"].xy()) == 8 + for k, v in expected.items(): + assert data.attributes[k] == v + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda not available") +def test_selection_fcc_settings_cuda(import_pipeline: Pipeline): + pipe = import_pipeline + pipe.modifiers.append( + ExpressionSelectionModifier( + expression="ReducedPosition.Z > 0.4 && ReducedPosition.Z < 0.6" + ) + ) + pipe.modifiers.append( + ScoreBasedDenoising(structure="FCC", only_selected=True, device="cuda") + ) + pipe.modifiers.append(CommonNeighborAnalysisModifier()) + data = pipe.compute() + expected = { + "CommonNeighborAnalysis.counts.BCC": 40, + "CommonNeighborAnalysis.counts.FCC": 4015, + "CommonNeighborAnalysis.counts.HCP": 175, + "CommonNeighborAnalysis.counts.ICO": 0, + "CommonNeighborAnalysis.counts.OTHER": 4218, + } + + for k, v in expected.items(): + assert data.attributes[k] == v