diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index cbe3ca76a..9c8be0356 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -106,7 +106,8 @@ def _command(self, args: argparse.Namespace): parser = argparse.ArgumentParser() configurator = run_configurators_mapping[ConfigurationType(conf.type)] configurator.register(parser) - configurator.apply(parser.parse_args(args.unknown), conf) + args, unknown = parser.parse_known_args(args.unknown) + configurator.apply(args, unknown, conf) with console.status("Getting run plan..."): run_plan = self.api.runs.get_plan( diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index f1e5f9d97..164ad7e41 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -11,7 +11,10 @@ ConfigurationType, DevEnvironmentConfiguration, PortMapping, + ServiceConfiguration, + TaskConfiguration, ) +from dstack._internal.utils.interpolator import VariablesInterpolator class BaseRunConfigurator: @@ -30,11 +33,20 @@ def register(cls, parser: argparse.ArgumentParser): ) @classmethod - def apply(cls, args: argparse.Namespace, conf: BaseConfiguration): + def apply(cls, args: argparse.Namespace, unknown: List[str], conf: BaseConfiguration): if args.envs: for k, v in args.envs: conf.env[k] = v + cls.interpolate_run_args(conf.setup, unknown) + + @classmethod + def interpolate_run_args(cls, value: List[str], unknown): + run_args = " ".join(unknown) + interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"]) + for i in range(len(value)): + value[i] = interpolator.interpolate(value[i]) + class RunWithPortsConfigurator(BaseRunConfigurator): @classmethod @@ -51,8 +63,8 @@ def register(cls, parser: argparse.ArgumentParser): ) @classmethod - def apply(cls, args: argparse.Namespace, conf: BaseConfigurationWithPorts): - super().apply(args, conf) + def apply(cls, args: argparse.Namespace, unknown: List[str], conf: BaseConfigurationWithPorts): + super().apply(args, unknown, conf) if args.ports: conf.ports = list(merge_ports(conf.ports, args.ports).values()) @@ -60,13 +72,21 @@ def apply(cls, args: argparse.Namespace, conf: BaseConfigurationWithPorts): class TaskRunConfigurator(RunWithPortsConfigurator): TYPE = ConfigurationType.TASK + @classmethod + def apply(cls, args: argparse.Namespace, unknown: List[str], conf: TaskConfiguration): + super().apply(args, unknown, conf) + + cls.interpolate_run_args(conf.commands, unknown) + class DevEnvironmentRunConfigurator(RunWithPortsConfigurator): TYPE = ConfigurationType.DEV_ENVIRONMENT @classmethod - def apply(cls, args: argparse.Namespace, conf: DevEnvironmentConfiguration): - super().apply(args, conf) + def apply( + cls, args: argparse.Namespace, unknown: List[str], conf: DevEnvironmentConfiguration + ): + super().apply(args, unknown, conf) if conf.ide == "vscode" and conf.version is None: conf.version = _detect_vscode_version() if conf.version is None: @@ -80,6 +100,12 @@ def apply(cls, args: argparse.Namespace, conf: DevEnvironmentConfiguration): class ServiceRunConfigurator(BaseRunConfigurator): TYPE = ConfigurationType.SERVICE + @classmethod + def apply(cls, args: argparse.Namespace, unknown: List[str], conf: ServiceConfiguration): + super().apply(args, unknown, conf) + + cls.interpolate_run_args(conf.commands, unknown) + def env_var(v: str) -> Tuple[str, str]: r = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)=(.*)$", v) diff --git a/src/tests/_internal/cli/services/configurators/test_run.py b/src/tests/_internal/cli/services/configurators/test_run.py index e36c28c97..48acd874e 100644 --- a/src/tests/_internal/cli/services/configurators/test_run.py +++ b/src/tests/_internal/cli/services/configurators/test_run.py @@ -67,6 +67,6 @@ def apply_args( configurator = run_configurators_mapping[conf.type] configurator.register(parser) conf = conf.copy(deep=True) # to avoid modifying the original configuration - args = parser.parse_args(args) - configurator.apply(args, conf) + args, unknown = parser.parse_known_args(args) + configurator.apply(args, unknown, conf) return conf, args