Skip to content

Commit

Permalink
Job usability fixups (#284)
Browse files Browse the repository at this point in the history
* Add %% to runtime context keys

* Allow ProcessCommand to apply templates

Required now that AlertConsumer no longer does this under the hood
  • Loading branch information
jvansanten authored Jan 23, 2025
1 parent c2dd228 commit 99efab8
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 287 deletions.
20 changes: 13 additions & 7 deletions ampel/cli/ProcessCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ampel.model.UnitModel import UnitModel
from ampel.struct.Resource import Resource
from ampel.util.freeze import recursive_freeze
from ampel.util.template import apply_templates


def _handle_traceback(signal, frame):
Expand Down Expand Up @@ -177,21 +178,26 @@ def run(
start_time = time()
logger = AmpelLogger.get_logger(base_flag=LogFlag.MANUAL_RUN)

ctx = self._get_context(
args,
unknown_args,
logger,
)

logger.info(f"Running task {args['name']}")

with open(args["schema"]) as f:
unit_model = UnitModel(**yaml.safe_load(f))
taskd = yaml.safe_load(f)
if "template" in taskd:
taskd = apply_templates(ctx, taskd["template"], taskd, logger)
taskd.pop("template")
unit_model = UnitModel(**taskd)

# always raise exceptions
unit_model.override = (unit_model.override or {}) | {
"raise_exc": not args["handle_exc"]
}

ctx = self._get_context(
args,
unknown_args,
logger,
)

if args["workflow"]:
process_name = f'{args["workflow"]}.{args["name"]}'
else:
Expand Down
11 changes: 4 additions & 7 deletions ampel/t4/T4RunTimeContextUpdater.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,11 @@ def do(self) -> Generator[T4Document, None, None]:
_chan=self.channel # type: ignore[arg-type]
)
if ret := t4_unit.do():
if not isinstance(ret, dict):
raise ValueError(f'Invalid {um.unit} return value, dict expected')
if not (isinstance(ret, dict) and all(isinstance(k, str) for k in ret)):
raise ValueError(f'Invalid {um.unit} return value, dict[str, Any] expected')
# Ensure alias keys start with %%
ret = {k if k.startswith('%%') else f'%%{k}': v for k, v in ret.items()}
for k in ret:
if not k[0] == '%' == k[1]:
raise ValueError(
f'Invalid run time alias returned by {um.unit}, '
f'run time aliases must begin with %%'
)
if not self.allow_alias_override and k in aliases:
raise ValueError(
f'Run time alias {k} was already registered, '
Expand Down
8 changes: 8 additions & 0 deletions ampel/test/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ampel.core.EventHandler import EventHandler
from ampel.log.AmpelLogger import AmpelLogger
from ampel.model.ingest.CompilerOptions import CompilerOptions
from ampel.model.ingest.IngestDirective import IngestDirective
from ampel.model.StateT2Dependency import StateT2Dependency
from ampel.struct.Resource import Resource
from ampel.struct.UnitResult import UnitResult
Expand Down Expand Up @@ -172,3 +173,10 @@ def morph(self, config: dict[str, Any], logger: AmpelLogger) -> dict[str, Any]:
class DummyUnitResultAdapter(AbsUnitResultAdapter):
def handle(self, ur: UnitResult) -> UnitResult:
return ur


class DummyIngestUnit(AbsEventUnit):
directives: list[IngestDirective]

def proceed(self, event_hdlr):
return super().proceed(event_hdlr)
6 changes: 6 additions & 0 deletions ampel/test/test-data/testing-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ unit:
version: 0.8.0a1
base:
- AbsEventUnit
DummyIngestUnit:
fqn: ampel.test.dummy
version: 0.8.0a1
base:
- AbsEventUnit
process:
t0: {}
t1: {}
Expand All @@ -634,3 +639,4 @@ resource:
group: samsa
template:
dummy_processor: ampel.test.dummy:DummyProcessorTemplate
hash_t2_config: ampel.config.alter.HashT2Config
99 changes: 98 additions & 1 deletion ampel/test/test_ProcessCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def test_resource_passing(
vault: Path,
tmpdir,
):

key = "dana"
value = "zuul"

Expand Down Expand Up @@ -135,3 +134,101 @@ def run_task(task_path: Path, first: bool = False):

run_task(writer, first=True)
run_task(reader)


def test_templates(
testing_config,
mock_db: MagicMock,
vault: Path,
tmpdir,
):
"""ProcessCommand resolves templates"""
task = dump(
{
"template": "hash_t2_config",
"unit": "DummyIngestUnit",
"config": dict(
directives=[
dict(
channel="TEST",
ingest=dict(
combine=[
dict(
unit="T1SimpleCombiner",
state_t2=[
dict(
unit="DummyTiedStateT2Unit",
config={
"t2_dependency": [
{
"unit": "DummyStateT2Unit",
"config": {"foo": 37},
}
]
},
)
],
)
]
),
)
]
),
},
tmpdir,
"task.yml",
)

def run_task(task_path: Path):
assert (
run(
[
"ampel",
"process",
"--config",
str(testing_config),
"--secrets",
str(vault),
"--db",
"whatevs",
"--log-profile",
"console_debug",
"--schema",
str(task_path),
"--name",
"task_1",
]
)
is None
)

run_task(task)

conf = mock_db("conf")
# get the last config inserted
doc = next(
d.args[0]
for d in reversed(conf.insert_one.call_args_list)
if "unit" in d.args[0]
)
assert doc["unit"] == "DummyIngestUnit"
config_id = doc["config"]["directives"][0]["ingest"]["combine"][0]["state_t2"][0][
"config"
]

def get_config(config_id: int):
return next(
(
d.args[0]
for d in reversed(conf.insert_one.call_args_list)
if d.args[0].get("_id") == config_id
),
None,
)

assert isinstance(config_id, int), "config was hashed"
config = get_config(config_id)
assert "t2_dependency" in config
subconfig = get_config(config["t2_dependency"][0]["config"])
subconfig.pop("_id")
assert subconfig == {"foo": 37}, "config was hashed recursively"
Loading

0 comments on commit 99efab8

Please sign in to comment.