Skip to content

Commit

Permalink
[Bug]: The dstack run command doesn't support ${{ run.args }} any…
Browse files Browse the repository at this point in the history
…more (#832)

[Bug]: The `dstack run` command doesn't support `${{ run.args }}` anymore #820
  • Loading branch information
peterschmidt85 authored Jan 12, 2024
1 parent ec575cb commit a16e3c2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/dstack/_internal/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def _command(self, args: argparse.Namespace):
parser = argparse.ArgumentParser()
configurator = run_configurators_mapping[ConfigurationType(conf.type)]
configurator.register(parser)
configurator.apply(parser.parse_args(args.unknown), conf)
args, unknown = parser.parse_known_args(args.unknown)
configurator.apply(args, unknown, conf)

with console.status("Getting run plan..."):
run_plan = self.api.runs.get_plan(
Expand Down
36 changes: 31 additions & 5 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
ConfigurationType,
DevEnvironmentConfiguration,
PortMapping,
ServiceConfiguration,
TaskConfiguration,
)
from dstack._internal.utils.interpolator import VariablesInterpolator


class BaseRunConfigurator:
Expand All @@ -30,11 +33,20 @@ def register(cls, parser: argparse.ArgumentParser):
)

@classmethod
def apply(cls, args: argparse.Namespace, conf: BaseConfiguration):
def apply(cls, args: argparse.Namespace, unknown: List[str], conf: BaseConfiguration):
if args.envs:
for k, v in args.envs:
conf.env[k] = v

cls.interpolate_run_args(conf.setup, unknown)

@classmethod
def interpolate_run_args(cls, value: List[str], unknown):
run_args = " ".join(unknown)
interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"])
for i in range(len(value)):
value[i] = interpolator.interpolate(value[i])


class RunWithPortsConfigurator(BaseRunConfigurator):
@classmethod
Expand All @@ -51,22 +63,30 @@ def register(cls, parser: argparse.ArgumentParser):
)

@classmethod
def apply(cls, args: argparse.Namespace, conf: BaseConfigurationWithPorts):
super().apply(args, conf)
def apply(cls, args: argparse.Namespace, unknown: List[str], conf: BaseConfigurationWithPorts):
super().apply(args, unknown, conf)
if args.ports:
conf.ports = list(merge_ports(conf.ports, args.ports).values())


class TaskRunConfigurator(RunWithPortsConfigurator):
TYPE = ConfigurationType.TASK

@classmethod
def apply(cls, args: argparse.Namespace, unknown: List[str], conf: TaskConfiguration):
super().apply(args, unknown, conf)

cls.interpolate_run_args(conf.commands, unknown)


class DevEnvironmentRunConfigurator(RunWithPortsConfigurator):
TYPE = ConfigurationType.DEV_ENVIRONMENT

@classmethod
def apply(cls, args: argparse.Namespace, conf: DevEnvironmentConfiguration):
super().apply(args, conf)
def apply(
cls, args: argparse.Namespace, unknown: List[str], conf: DevEnvironmentConfiguration
):
super().apply(args, unknown, conf)
if conf.ide == "vscode" and conf.version is None:
conf.version = _detect_vscode_version()
if conf.version is None:
Expand All @@ -80,6 +100,12 @@ def apply(cls, args: argparse.Namespace, conf: DevEnvironmentConfiguration):
class ServiceRunConfigurator(BaseRunConfigurator):
TYPE = ConfigurationType.SERVICE

@classmethod
def apply(cls, args: argparse.Namespace, unknown: List[str], conf: ServiceConfiguration):
super().apply(args, unknown, conf)

cls.interpolate_run_args(conf.commands, unknown)


def env_var(v: str) -> Tuple[str, str]:
r = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)=(.*)$", v)
Expand Down
4 changes: 2 additions & 2 deletions src/tests/_internal/cli/services/configurators/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def apply_args(
configurator = run_configurators_mapping[conf.type]
configurator.register(parser)
conf = conf.copy(deep=True) # to avoid modifying the original configuration
args = parser.parse_args(args)
configurator.apply(args, conf)
args, unknown = parser.parse_known_args(args)
configurator.apply(args, unknown, conf)
return conf, args

0 comments on commit a16e3c2

Please sign in to comment.