diff --git a/prem_utils/connectors/anyscale.py b/prem_utils/connectors/anyscale.py index 8e247f7..1903d30 100644 --- a/prem_utils/connectors/anyscale.py +++ b/prem_utils/connectors/anyscale.py @@ -83,6 +83,9 @@ def create_job( validation_dataset: list[Datapoint] | None = None, num_epochs: int = 3, ) -> str: + if "anyscale" in model: + model = model.replace("anyscale/", "", 1) + training_file_id = self._upload_data(training_dataset, size=20) validation_file_id = None