diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py
index c0bc92b..e87be41 100644
--- a/metalearners/drlearner.py
+++ b/metalearners/drlearner.py
@@ -13,6 +13,7 @@
     OosMethod,
     Params,
     Scoring,
+    SplitIndices,
     Vector,
     _ScikitModel,
 )
@@ -128,7 +129,7 @@ def __init__(
         )
         self.adaptive_clipping = adaptive_clipping
 
-    def fit(
+    def fit_all_nuisance(
         self,
         X: Matrix,
         y: Vector,
@@ -148,10 +149,12 @@ def fit(
         for treatment_variant in range(self.n_variants):
             self._treatment_variants_indices.append(w == treatment_variant)
 
+        self._cv_split_indices: SplitIndices | None
+
         if synchronize_cross_fitting:
-            cv_split_indices = self._split(X)
+            self._cv_split_indices = self._split(X)
         else:
-            cv_split_indices = None
+            self._cv_split_indices = None
 
         nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
         for treatment_variant in range(self.n_variants):
@@ -176,7 +179,7 @@ def fit(
                 model_ord=0,
                 n_jobs_cross_fitting=n_jobs_cross_fitting,
                 fit_params=qualified_fit_params[NUISANCE][PROPENSITY_MODEL],
-                cv=cv_split_indices,
+                cv=self._cv_split_indices,
             )
         )
 
@@ -189,6 +192,25 @@ def fit(
 
         self._assign_joblib_nuisance_results(results)
 
+        return self
+
+    def fit_all_treatment(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+    ) -> Self:
+        if not hasattr(self, "_cv_split_indices"):
+            raise ValueError(
+                "The nuisance models need to be fitted before fitting the treatment models."
+                "In particular, the MetaLearner's attribute _cv_split_indices, "
+                "typically set during nuisance fitting, does not exist."
+            )
+        qualified_fit_params = self._qualified_fit_params(fit_params)
         treatment_jobs: list[_ParallelJoblibSpecification] = []
         for treatment_variant in range(1, self.n_variants):
             pseudo_outcomes = self._pseudo_outcome(
@@ -207,9 +229,10 @@ def fit(
                     model_ord=treatment_variant - 1,
                     n_jobs_cross_fitting=n_jobs_cross_fitting,
                     fit_params=qualified_fit_params[TREATMENT][TREATMENT_MODEL],
-                    cv=cv_split_indices,
+                    cv=self._cv_split_indices,
                 )
             )
+        parallel = Parallel(n_jobs=n_jobs_base_learners)
         results = parallel(
             delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
         )
diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py
index 6214af9..2fb285b 100644
--- a/metalearners/metalearner.py
+++ b/metalearners/metalearner.py
@@ -744,6 +744,41 @@ def _assign_joblib_treatment_results(
             ] = result.cross_fit_estimator
 
     @abstractmethod
+    def fit_all_nuisance(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+    ) -> Self:
+        """Fit all nuisance models of the MetaLearner.
+
+        If pre-fitted models were passed at instantiation, these are never refitted.
+
+        For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`.
+        """
+        ...
+
+    @abstractmethod
+    def fit_all_treatment(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+    ) -> Self:
+        """Fit all treatment models of the MetaLearner.
+
+        For the parameters check :meth:`metalearners.metalearner.MetaLearner.fit`.
+        """
+        ...
+
     def fit(
         self,
         X: Matrix,
@@ -791,7 +826,27 @@ def fit(
         the same data splits where possible. Note that if there are several models to be synchronized which are
         classifiers, these cannot be split via stratification.
         """
-        ...
+        self.fit_all_nuisance(
+            X=X,
+            y=y,
+            w=w,
+            n_jobs_cross_fitting=n_jobs_cross_fitting,
+            fit_params=fit_params,
+            synchronize_cross_fitting=synchronize_cross_fitting,
+            n_jobs_base_learners=n_jobs_base_learners,
+        )
+
+        self.fit_all_treatment(
+            X=X,
+            y=y,
+            w=w,
+            n_jobs_cross_fitting=n_jobs_cross_fitting,
+            fit_params=fit_params,
+            synchronize_cross_fitting=synchronize_cross_fitting,
+            n_jobs_base_learners=n_jobs_base_learners,
+        )
+
+        return self
 
     def predict_nuisance(
         self,
diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py
index b7e9297..8e85502 100644
--- a/metalearners/rlearner.py
+++ b/metalearners/rlearner.py
@@ -156,7 +156,7 @@ def _supports_multi_treatment(cls) -> bool:
     def _supports_multi_class(cls) -> bool:
         return False
 
-    def fit(
+    def fit_all_nuisance(
         self,
         X: Matrix,
         y: Vector,
@@ -165,14 +165,10 @@ def fit(
         fit_params: dict | None = None,
         synchronize_cross_fitting: bool = True,
         n_jobs_base_learners: int | None = None,
-        epsilon: float = _EPSILON,
     ) -> Self:
-
         self._validate_treatment(w)
         self._validate_outcome(y, w)
 
-        self._variants_indices = []
-
         qualified_fit_params = self._qualified_fit_params(fit_params)
         self._validate_fit_params(qualified_fit_params)
 
@@ -214,7 +210,22 @@ def fit(
         )
         self._assign_joblib_nuisance_results(results)
 
+        return self
+
+    def fit_all_treatment(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+        epsilon: float = _EPSILON,
+    ) -> Self:
+        qualified_fit_params = self._qualified_fit_params(fit_params)
         treatment_jobs: list[_ParallelJoblibSpecification] = []
+        self._variants_indices = []
         for treatment_variant in range(1, self.n_variants):
 
             is_treatment = w == treatment_variant
@@ -246,6 +257,7 @@ def fit(
                     n_jobs_cross_fitting=n_jobs_cross_fitting,
                 )
             )
+        parallel = Parallel(n_jobs=n_jobs_base_learners)
         results = parallel(
             delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
         )
diff --git a/metalearners/slearner.py b/metalearners/slearner.py
index 7158cca..c6c99a6 100644
--- a/metalearners/slearner.py
+++ b/metalearners/slearner.py
@@ -141,7 +141,7 @@ def __init__(
             random_state=random_state,
         )
 
-    def fit(
+    def fit_all_nuisance(
         self,
         X: Matrix,
         y: Vector,
@@ -175,6 +175,18 @@ def fit(
         )
         return self
 
+    def fit_all_treatment(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+    ) -> Self:
+        return self
+
     def predict(
         self,
         X: Matrix,
diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py
index cc61fe5..17442cf 100644
--- a/metalearners/tlearner.py
+++ b/metalearners/tlearner.py
@@ -48,7 +48,7 @@ def _supports_multi_treatment(cls) -> bool:
     def _supports_multi_class(cls) -> bool:
         return True
 
-    def fit(
+    def fit_all_nuisance(
         self,
         X: Matrix,
         y: Vector,
@@ -92,6 +92,18 @@ def fit(
         self._assign_joblib_nuisance_results(results)
         return self
 
+    def fit_all_treatment(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+    ) -> Self:
+        return self
+
     def predict(
         self,
         X,
diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py
index 109619d..72d2107 100644
--- a/metalearners/xlearner.py
+++ b/metalearners/xlearner.py
@@ -74,7 +74,7 @@ def _supports_multi_treatment(cls) -> bool:
     def _supports_multi_class(cls) -> bool:
         return False
 
-    def fit(
+    def fit_all_nuisance(
         self,
         X: Matrix,
         y: Vector,
@@ -91,7 +91,7 @@ def fit(
 
         qualified_fit_params = self._qualified_fit_params(fit_params)
 
-        cvs: list = []
+        self._cvs: list = []
 
         for treatment_variant in range(self.n_variants):
             self._treatment_variants_indices.append(w == treatment_variant)
@@ -101,7 +101,7 @@ def fit(
                 )
             else:
                 cv_split_indices = None
-            cvs.append(cv_split_indices)
+            self._cvs.append(cv_split_indices)
 
         nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
         for treatment_variant in range(self.n_variants):
@@ -115,7 +115,7 @@ def fit(
                     model_ord=treatment_variant,
                     n_jobs_cross_fitting=n_jobs_cross_fitting,
                     fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
-                    cv=cvs[treatment_variant],
+                    cv=self._cvs[treatment_variant],
                 )
             )
 
@@ -138,6 +138,32 @@ def fit(
         )
         self._assign_joblib_nuisance_results(results)
 
+        return self
+
+    def fit_all_treatment(
+        self,
+        X: Matrix,
+        y: Vector,
+        w: Vector,
+        n_jobs_cross_fitting: int | None = None,
+        fit_params: dict | None = None,
+        synchronize_cross_fitting: bool = True,
+        n_jobs_base_learners: int | None = None,
+    ) -> Self:
+        if self._treatment_variants_indices is None:
+            raise ValueError(
+                "The nuisance models need to be fitted before fitting the treatment models."
+                "In particular, the MetaLearner's attribute _treatment_variant_indices, "
+                "typically set during nuisance fitting, is None."
+            )
+        if not hasattr(self, "_cvs"):
+            raise ValueError(
+                "The nuisance models need to be fitted before fitting the treatment models."
+                "In particular, the MetaLearner's attribute _cvs, "
+                "typically set during nuisance fitting, does not exist."
+            )
+        qualified_fit_params = self._qualified_fit_params(fit_params)
+
         treatment_jobs: list[_ParallelJoblibSpecification] = []
         for treatment_variant in range(1, self.n_variants):
             imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
@@ -153,7 +179,7 @@ def fit(
                     model_ord=treatment_variant - 1,
                     n_jobs_cross_fitting=n_jobs_cross_fitting,
                     fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
-                    cv=cvs[treatment_variant],
+                    cv=self._cvs[treatment_variant],
                 )
             )
 
@@ -165,10 +191,11 @@ def fit(
                     model_ord=treatment_variant - 1,
                     n_jobs_cross_fitting=n_jobs_cross_fitting,
                     fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
-                    cv=cvs[0],
+                    cv=self._cvs[0],
                 )
             )
 
+        parallel = Parallel(n_jobs=n_jobs_base_learners)
         results = parallel(
             delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
         )