diff --git a/tests/conftest.py b/tests/conftest.py index 873d2c3f0a..369ccebebd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -229,12 +229,8 @@ def DummyStellarator(tmpdir_factory): @pytest.fixture(scope="session") -def DummyCoilSet(tmpdir_factory): +def DummyCoilSet(): """Create and save a dummy coil set for testing.""" - output_dir = tmpdir_factory.mktemp("result") - output_path_sym = output_dir.join("DummyCoilSet_sym.h5") - output_path_asym = output_dir.join("DummyCoilSet_asym.h5") - eq = get("precise_QH") minor_radius = eq.compute("a")["a"] @@ -256,25 +252,16 @@ def DummyCoilSet(tmpdir_factory): ) coils.append(coil) coilset_sym = CoilSet(coils, NFP=eq.NFP, sym=eq.sym) - coilset_sym.save(output_path_sym) # equivalent CoilSet without symmetry coilset_asym = CoilSet.from_symmetry(coilset_sym, NFP=eq.NFP, sym=eq.sym) - coilset_asym.save(output_path_asym) - DummyCoilSet_out = { - "output_path_sym": output_path_sym, - "output_path_asym": output_path_asym, - } - return DummyCoilSet_out + return coilset_sym, coilset_asym @pytest.fixture(scope="session") -def DummyMixedCoilSet(tmpdir_factory): +def DummyMixedCoilSet(): """Create and save a dummy mixed coil set for testing.""" - output_dir = tmpdir_factory.mktemp("result") - output_path = output_dir.join("DummyMixedCoilSet.h5") - tf_coil = FourierPlanarCoil(current=3, center=[2, 0, 0], normal=[0, 1, 0], r_n=[1]) tf_coil.rotate(angle=np.pi / 4) tf_coilset = CoilSet(tf_coil, NFP=2, sym=True) @@ -295,10 +282,16 @@ def DummyMixedCoilSet(tmpdir_factory): full_coilset = MixedCoilSet( (tf_coilset, vf_coilset, xyz_coil, spline_coil), check_intersection=False ) + return full_coilset + - full_coilset.save(output_path) - DummyMixedCoilSet_out = {"output_path": output_path} - return DummyMixedCoilSet_out +@pytest.fixture(scope="session") +def DummyNestedCoilSet(DummyCoilSet, DummyMixedCoilSet): + """Create and save a dummy nested coil set for testing.""" + sym_coils, __ = DummyCoilSet + mixed_coils = DummyMixedCoilSet + nested_coils = MixedCoilSet(sym_coils, mixed_coils, check_intersection=False) + return nested_coils @pytest.fixture(scope="session") diff --git a/tests/test_examples.py b/tests/test_examples.py index 6bae9dd16b..a31a9408d5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1314,16 +1314,25 @@ def test_second_stage_optimization_CoilSet(): @pytest.mark.slow @pytest.mark.unit -def test_optimize_with_all_coil_types(DummyCoilSet, DummyMixedCoilSet): +@pytest.mark.parametrize( + "coil,optimizer,index", + [ + (FourierPlanarCoil(), "fmintr", None), + (FourierRZCoil(), "fmintr", None), + (FourierXYZCoil(), "fmintr", None), + ("DummyMixedCoilSet", "fmintr", -1), # spline coil + ("DummyCoilSet", "lsq-exact", 0), # sym coils + ("DummyCoilSet", "lsq-exact", 1), # asym coils + ("DummyMixedCoilSet", "lsq-exact", None), + ("DummyNestedCoilSet", "lsq-exact", None), + ], +) +def test_optimize_with_all_coil_types(coil, optimizer, index, request): """Test optimizing for every type of coil and dummy coil sets.""" - sym_coils = load(load_from=str(DummyCoilSet["output_path_sym"]), file_format="hdf5") - asym_coils = load( - load_from=str(DummyCoilSet["output_path_asym"]), file_format="hdf5" - ) - mixed_coils = load( - load_from=str(DummyMixedCoilSet["output_path"]), file_format="hdf5" - ) - nested_coils = MixedCoilSet(sym_coils, mixed_coils, check_intersection=False) + if isinstance(coil, str): + coil = request.getfixturevalue(coil) + coil = coil[index] if index is not None else coil + eq = Equilibrium() # not attempting to accurately calc B for this test, # so make the grids very coarse @@ -1349,9 +1358,7 @@ def test(c, method): (c,), _ = optimizer.optimize(c, obj, maxiter=2, ftol=0, xtol=1e-15) # now check with optimizing geometry and actually check result - objs = [ - CoilLength(c, target=target), - ] + objs = [CoilLength(c, target=target)] extra_msg = "" if isinstance(c, MixedCoilSet): # just to check they work without error @@ -1366,7 +1373,7 @@ def test(c, method): obj = ObjectiveFunction(objs) - (c,), _ = optimizer.optimize(c, obj, maxiter=25, ftol=5e-3, xtol=1e-15) + (c,), _ = optimizer.optimize(c, obj, maxiter=50, ftol=5e-3, xtol=1e-15) flattened_coils = tree_leaves( c, is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet) ) @@ -1375,21 +1382,7 @@ def test(c, method): lengths, target, rtol=rtol, err_msg=f"lengths {c}" + extra_msg ) - spline_coil = mixed_coils.coils[-1].copy() - - # single coil - test(FourierPlanarCoil(), "fmintr") - test(FourierRZCoil(), "fmintr") - test(FourierXYZCoil(), "fmintr") - test(spline_coil, "fmintr") - - # CoilSet - test(sym_coils, "lsq-exact") - test(asym_coils, "lsq-exact") - - # MixedCoilSet - test(mixed_coils, "lsq-exact") - test(nested_coils, "lsq-exact") + test(coil, optimizer) @pytest.mark.unit