diff --git a/README.md b/README.md index f4ce3ed..1d2484e 100644 --- a/README.md +++ b/README.md @@ -96,8 +96,20 @@ propagator = ASSISTPropagator() ephemerides = propagator.generate_ephemeris(sbdb_orbits, observers) ``` +## Configuration +When initializing the `ASSISTPropagator`, you can configure several parameters that control the integration. +These parameters are passed directly to REBOUND's IAS15 integrator. The IAS15 integrator is a high accuracy integrator that uses adaptive timestepping to maintain precision while optimizing performance. +- `min_dt`: Minimum timestep for the integrator (default: 1e-15 days) +- `initial_dt`: Initial timestep for the integrator (default: 0.001 days) +- `adaptive_mode`: Controls the adaptive timestep behavior (default: 2) +These parameters are passed directly to REBOUND's IAS15 integrator. The IAS15 integrator is a high accuracy integrator that uses adaptive timestepping to maintain precision while optimizing performance. +Example: + +```python +propagator = ASSISTPropagator(min_dt=1e-12, initial_dt=0.0001, adaptive_mode=2) +``` diff --git a/pyproject.toml b/pyproject.toml index 458ed74..d8f8581 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dev = [ check = {composite = ["lint", "typecheck"]} format = { composite = ["black ./src/adam_assist", "isort ./src/adam_assist"]} lint = { composite = ["ruff check ./src/adam_assist", "black --check ./src/adam_assist", "isort --check-only ./src/adam_assist"] } -fix = "ruff ./src/adam_assist --fix" +fix = "ruff check ./src/adam_assist --fix" typecheck = "mypy --strict ./src/adam_assist" test = "pytest --benchmark-disable {args}" diff --git a/src/adam_assist/propagator.py b/src/adam_assist/propagator.py index 5e673e6..fb8c777 100644 --- a/src/adam_assist/propagator.py +++ b/src/adam_assist/propagator.py @@ -25,7 +25,7 @@ C = c.C try: - from adam_core.propagator.adam_assist_version import __version__ + from adam_assist.version import __version__ except ImportError: __version__ = "0.0.0" @@ -60,6 +60,25 @@ def hash_orbit_ids_to_uint32( class ASSISTPropagator(Propagator, ImpactMixin): # type: ignore + def __init__( + self, + *args: object, # Generic type for arbitrary positional arguments + min_dt: float = 1e-15, + initial_dt: float = 0.001, + adaptive_mode: int = 2, + **kwargs: object, # Generic type for arbitrary keyword arguments + ) -> None: + super().__init__(*args, **kwargs) + if min_dt <= 0: + raise ValueError("min_dt must be positive") + if initial_dt <= 0: + raise ValueError("initial_dt must be positive") + if min_dt > initial_dt: + raise ValueError("min_dt must be smaller than initial_dt") + self.min_dt = min_dt + self.initial_dt = initial_dt + self.adaptive_mode = adaptive_mode + def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitType: """ Propagate the orbits to the specified times. @@ -109,8 +128,9 @@ def _propagate_orbits_inner( ) sim = None sim = rebound.Simulation() - sim.ri_ias15.min_dt = 1e-15 - sim.ri_ias15.adaptive_mode = 2 + sim.dt = self.initial_dt + sim.ri_ias15.min_dt = self.min_dt + sim.ri_ias15.adaptive_mode = self.adaptive_mode # Set the simulation time, relative to the jd_ref start_tdb_time = orbits.coordinates.time.jd().to_numpy()[0] diff --git a/tests/test_propagator_settings.py b/tests/test_propagator_settings.py new file mode 100644 index 0000000..d7bfce2 --- /dev/null +++ b/tests/test_propagator_settings.py @@ -0,0 +1,95 @@ +import pyarrow as pa +import pytest +from adam_core.coordinates import CartesianCoordinates, Origin +from adam_core.orbits import Orbits +from adam_core.time import Timestamp + +from adam_assist import ASSISTPropagator + + +@pytest.fixture +def basic_orbit(): + """Create a basic test orbit""" + return Orbits.from_kwargs( + orbit_id=["test"], + coordinates=CartesianCoordinates.from_kwargs( + x=[1.0], + y=[0.0], + z=[0.0], + vx=[0.0], + vy=[1.0], + vz=[0.0], + time=Timestamp.from_mjd([60000], scale="tdb"), + origin=Origin.from_kwargs(code=["SUN"]), + frame="ecliptic", + ), + ) + +def test_default_settings(): + """Test that default settings are applied correctly""" + prop = ASSISTPropagator() + assert prop.min_dt == 1e-15 + assert prop.initial_dt == 0.001 + assert prop.adaptive_mode == 2 + +def test_custom_settings(): + """Test that custom settings are applied correctly""" + prop = ASSISTPropagator(min_dt=1e-12, initial_dt=0.01, adaptive_mode=1) + assert prop.min_dt == 1e-12 + assert prop.initial_dt == 0.01 + assert prop.adaptive_mode == 1 + +def test_invalid_min_dt(): + """Test that invalid min_dt raises ValueError""" + with pytest.raises(ValueError, match="min_dt must be positive"): + ASSISTPropagator(min_dt=-1e-15) + + with pytest.raises(ValueError, match="min_dt must be positive"): + ASSISTPropagator(min_dt=0) + +def test_invalid_initial_dt(): + """Test that invalid initial_dt raises ValueError""" + with pytest.raises(ValueError, match="initial_dt must be positive"): + ASSISTPropagator(initial_dt=-0.001) + + with pytest.raises(ValueError, match="initial_dt must be positive"): + ASSISTPropagator(initial_dt=0) + +def test_min_dt_greater_than_initial(): + """Test that min_dt > initial_dt raises ValueError""" + with pytest.raises(ValueError, match="min_dt must be smaller than initial_dt"): + ASSISTPropagator(min_dt=0.1, initial_dt=0.01) + +def test_propagation_with_different_settings(basic_orbit): + """Test that propagation works with different settings""" + # Test with default settings + prop_default = ASSISTPropagator() + + # Test with more restrictive settings + prop_restrictive = ASSISTPropagator(min_dt=1e-12, initial_dt=0.0001) + + # Test with less restrictive settings + prop_loose = ASSISTPropagator(min_dt=1e-9, initial_dt=0.01) + + # Propagate for 10 days with each propagator + target_time = Timestamp.from_mjd([60010], scale="tdb") + + result_default = prop_default.propagate_orbits(basic_orbit, target_time) + result_restrictive = prop_restrictive.propagate_orbits(basic_orbit, target_time) + result_loose = prop_loose.propagate_orbits(basic_orbit, target_time) + + # All should produce results + assert len(result_default) == 1 + assert len(result_restrictive) == 1 + assert len(result_loose) == 1 + + # Results should be similar but not identical due to different step sizes + # Using a relatively loose tolerance since we expect some differences + tolerance = 1e-6 + + default_pos = result_default.coordinates.values[0, :3] + restrictive_pos = result_restrictive.coordinates.values[0, :3] + loose_pos = result_loose.coordinates.values[0, :3] + + assert abs(default_pos - restrictive_pos).max() < tolerance + assert abs(default_pos - loose_pos).max() < tolerance \ No newline at end of file