From 0abe27ef214d74d5ebf441ff04ec6288b6cf7615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Pi=C3=B1a?= Date: Fri, 13 Dec 2024 16:46:03 -0500 Subject: [PATCH] feat: add extra tags to metronome beats Co-authored-by: Mijail Rondon --- .../main/java/io/littlehorse/canary/Main.java | 3 +- .../aggregator/topology/MetricsTopology.java | 13 +- .../canary/config/CanaryConfig.java | 13 +- .../metronome/internal/BeatProducer.java | 18 +- canary/src/main/proto/beats.proto | 2 + .../topology/MetricsTopologyTest.java | 67 +++++++- .../canary/config/CanaryConfigTest.java | 21 +++ sdk-python/examples/basic/example_basic.py | 2 +- sdk-python/examples/user_tasks/user_tasks.py | 36 ++-- sdk-python/littlehorse/workflow.py | 97 ++++++----- sdk-python/tests/test_workflow.py | 160 +++++++++++++----- 11 files changed, 312 insertions(+), 120 deletions(-) diff --git a/canary/src/main/java/io/littlehorse/canary/Main.java b/canary/src/main/java/io/littlehorse/canary/Main.java index 3428abda0..55085d54c 100644 --- a/canary/src/main/java/io/littlehorse/canary/Main.java +++ b/canary/src/main/java/io/littlehorse/canary/Main.java @@ -68,7 +68,8 @@ private static void initialize(final String[] args) throws IOException { lhConfig.getApiBootstrapPort(), lhClient.getServerVersion(), canaryConfig.getTopicName(), - canaryConfig.toKafkaConfig().toMap()); + canaryConfig.toKafkaConfig().toMap(), + canaryConfig.getMetronomeBeatExtraTags()); // start worker if (canaryConfig.isMetronomeWorkerEnabled()) { diff --git a/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java b/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java index 4e55747d2..d89fa61e1 100644 --- a/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java +++ b/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java @@ -45,8 +45,7 @@ public Topology toTopology() { .filterNot(MetricsTopology::isExhaustedRetries) // remove the id .groupBy( - MetricsTopology::cleanBeatKey, - Grouped.with(ProtobufSerdes.BeatKey(), ProtobufSerdes.BeatValue())) + MetricsTopology::removeWfId, Grouped.with(ProtobufSerdes.BeatKey(), ProtobufSerdes.BeatValue())) // reset aggregator every minute .windowedBy(TimeWindows.ofSizeAndGrace(Duration.ofMinutes(1), Duration.ofSeconds(5))) // calculate average @@ -62,8 +61,7 @@ public Topology toTopology() { final KStream countMetricStream = beatsStream // remove the id .groupBy( - MetricsTopology::cleanBeatKey, - Grouped.with(ProtobufSerdes.BeatKey(), ProtobufSerdes.BeatValue())) + MetricsTopology::removeWfId, Grouped.with(ProtobufSerdes.BeatKey(), ProtobufSerdes.BeatValue())) // count all .count(initializeCountStore(COUNT_STORE)) .toStream() @@ -181,6 +179,10 @@ private static MetricKey buildMetricKey(final BeatKey key, final String id) { Tag.newBuilder().setKey("status").setValue(key.getStatus().toLowerCase())); } + if (key.getTagsCount() > 0) { + builder.addAllTags(key.getTagsList()); + } + return builder.build(); } @@ -195,13 +197,14 @@ private static AverageAggregator initializeAverageAggregator() { return AverageAggregator.newBuilder().build(); } - private static BeatKey cleanBeatKey(final BeatKey key, final BeatValue value) { + private static BeatKey removeWfId(final BeatKey key, final BeatValue value) { return BeatKey.newBuilder() .setType(key.getType()) .setServerVersion(key.getServerVersion()) .setServerHost(key.getServerHost()) .setServerPort(key.getServerPort()) .setStatus(key.getStatus()) + .addAllTags(key.getTagsList()) .build(); } diff --git a/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java b/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java index 18ef1add5..f70e85d12 100644 --- a/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java +++ b/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java @@ -32,11 +32,14 @@ public class CanaryConfig implements Config { public static final String METRONOME_GET_RETRIES = "metronome.get.retries"; public static final String METRONOME_WORKER_ENABLE = "metronome.worker.enable"; public static final String METRONOME_DATA_PATH = "metronome.data.path"; + public static final String METRONOME_BEAT_EXTRA_TAGS = "metronome.beat.extra.tags"; + public static final String METRONOME_BEAT_EXTRA_TAGS_PREFIX = "%s.".formatted(METRONOME_BEAT_EXTRA_TAGS); public static final String AGGREGATOR_ENABLE = "aggregator.enable"; + public static final String AGGREGATOR_STORE_RETENTION_MS = "aggregator.store.retention.ms"; + public static final String METRICS_PORT = "metrics.port"; public static final String METRICS_PATH = "metrics.path"; - public static final String AGGREGATOR_STORE_RETENTION_MS = "aggregator.store.retention.ms"; public static final String METRICS_COMMON_TAGS = "metrics.common.tags"; public static final String METRICS_COMMON_TAGS_PREFIX = "%s.".formatted(METRICS_COMMON_TAGS); @@ -147,6 +150,14 @@ public Map getCommonTags() { entry -> entry.getValue().toString())); } + public Map getMetronomeBeatExtraTags() { + return configs.entrySet().stream() + .filter(entry -> entry.getKey().startsWith(METRONOME_BEAT_EXTRA_TAGS_PREFIX)) + .collect(Collectors.toMap( + entry -> entry.getKey().substring(METRONOME_BEAT_EXTRA_TAGS_PREFIX.length()), + entry -> entry.getValue().toString())); + } + public boolean isTopicCreationEnabled() { return Boolean.parseBoolean(getConfig(TOPIC_CREATION_ENABLE)); } diff --git a/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java b/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java index 1e6f89041..96671335c 100644 --- a/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java +++ b/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java @@ -5,8 +5,10 @@ import io.littlehorse.canary.proto.BeatKey; import io.littlehorse.canary.proto.BeatType; import io.littlehorse.canary.proto.BeatValue; +import io.littlehorse.canary.proto.Tag; import io.littlehorse.canary.util.ShutdownHook; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -18,6 +20,7 @@ public class BeatProducer { private final Producer producer; + private final Map extraTags; private final String lhServerHost; private final int lhServerPort; private final String lhServerVersion; @@ -28,13 +31,15 @@ public BeatProducer( final int lhServerPort, final String lhServerVersion, final String topicName, - final Map kafkaProducerConfigMap) { + final Map producerConfig, + final Map extraTags) { this.lhServerHost = lhServerHost; this.lhServerPort = lhServerPort; this.lhServerVersion = lhServerVersion; this.topicName = topicName; + this.extraTags = extraTags; - producer = new KafkaProducer<>(kafkaProducerConfigMap); + producer = new KafkaProducer<>(producerConfig); ShutdownHook.add("Beat Producer", producer); } @@ -91,6 +96,15 @@ private BeatKey buildKey(final String id, final BeatType type, final String stat builder.setStatus(status); } + final List tags = extraTags.entrySet().stream() + .map(entry -> Tag.newBuilder() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()) + .toList(); + + builder.addAllTags(tags); + return builder.build(); } } diff --git a/canary/src/main/proto/beats.proto b/canary/src/main/proto/beats.proto index e1ae9a2bd..b0a2e7261 100644 --- a/canary/src/main/proto/beats.proto +++ b/canary/src/main/proto/beats.proto @@ -5,6 +5,7 @@ option java_multiple_files = true; option java_package = "io.littlehorse.canary.proto"; import "google/protobuf/timestamp.proto"; +import "metrics.proto"; enum BeatType { WF_RUN_REQUEST = 0; @@ -25,6 +26,7 @@ message BeatKey { BeatType type = 4; optional string status = 5; optional string id = 6; + repeated Tag tags = 7; } message BeatValue { diff --git a/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java b/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java index c3d74fed1..8b7a23ce8 100644 --- a/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java +++ b/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java @@ -10,6 +10,8 @@ import java.io.IOException; import java.nio.file.Files; import java.time.Duration; +import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.UUID; import org.apache.kafka.streams.StreamsConfig; @@ -46,14 +48,18 @@ private static MetricKey newMetricKey(String id) { } private static MetricKey newMetricKey(String id, String status) { - return newMetricKey(HOST_1, PORT_1, id, status); + return newMetricKey(HOST_1, PORT_1, id, status, null); } private static MetricKey newMetricKey(String host, int port, String id) { - return newMetricKey(host, port, id, null); + return newMetricKey(host, port, id, null, null); } - private static MetricKey newMetricKey(String host, int port, String id, String status) { + private static MetricKey newMetricKey(String id, String status, Map tags) { + return newMetricKey(HOST_1, PORT_1, id, status, tags); + } + + private static MetricKey newMetricKey(String host, int port, String id, String status, Map tags) { MetricKey.Builder builder = MetricKey.newBuilder().setServerHost(host).setServerPort(port).setId(id); @@ -61,19 +67,40 @@ private static MetricKey newMetricKey(String host, int port, String id, String s builder.addTags(Tag.newBuilder().setKey("status").setValue(status).build()); } + if (tags != null) { + List tagList = tags.entrySet().stream() + .map(entry -> Tag.newBuilder() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()) + .toList(); + builder.addAllTags(tagList); + } + return builder.build(); } private static TestRecord newBeat(BeatType type, String id, Long latency) { - return newBeat(HOST_1, PORT_1, type, id, latency, null); + return newBeat(HOST_1, PORT_1, type, id, latency, null, null); } private static TestRecord newBeat(BeatType type, String id, Long latency, String beatStatus) { - return newBeat(HOST_1, PORT_1, type, id, latency, beatStatus); + return newBeat(HOST_1, PORT_1, type, id, latency, beatStatus, null); } private static TestRecord newBeat( - String host, int port, BeatType type, String id, Long latency, String beatStatus) { + BeatType type, String id, Long latency, String beatStatus, Map tags) { + return newBeat(HOST_1, PORT_1, type, id, latency, beatStatus, tags); + } + + private static TestRecord newBeat( + String host, + int port, + BeatType type, + String id, + Long latency, + String beatStatus, + Map tags) { BeatKey.Builder keyBuilder = BeatKey.newBuilder() .setServerHost(host) .setServerPort(port) @@ -85,6 +112,16 @@ private static TestRecord newBeat( keyBuilder.setStatus(beatStatus); } + if (tags != null) { + List tagList = tags.entrySet().stream() + .map(entry -> Tag.newBuilder() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()) + .toList(); + keyBuilder.addAllTags(tagList); + } + if (latency != null) { valueBuilder.setLatency(latency); } @@ -136,6 +173,18 @@ void calculateCountAndLatencyForWfRunRequest() { .isEqualTo(newMetricValue(3.)); } + @Test + void includeBeatTagsIntoMetrics() { + BeatType expectedType = BeatType.WF_RUN_REQUEST; + String expectedTypeName = expectedType.name().toLowerCase(); + + inputTopic.pipeInput(newBeat(expectedType, getRandomId(), 20L, "ok", Map.of("my_tag", "value"))); + + assertThat(getCount()).isEqualTo(3); + assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_avg", "ok", Map.of("my_tag", "value")))) + .isEqualTo(newMetricValue(20.)); + } + @Test void calculateCountForExhaustedRetries() { BeatType expectedType = BeatType.GET_WF_RUN_EXHAUSTED_RETRIES; @@ -270,9 +319,9 @@ void calculateCountAndLatencyForTaskRunWithDuplicatedAndTwoServers() { inputTopic.pipeInput(newBeat(expectedType, expectedUniqueId, 10L)); inputTopic.pipeInput(newBeat(expectedType, expectedUniqueId, 30L)); - inputTopic.pipeInput(newBeat(HOST_2, PORT_2, expectedType, expectedUniqueId, 20L, null)); - inputTopic.pipeInput(newBeat(HOST_2, PORT_2, expectedType, expectedUniqueId, 10L, null)); - inputTopic.pipeInput(newBeat(HOST_2, PORT_2, expectedType, expectedUniqueId, 30L, null)); + inputTopic.pipeInput(newBeat(HOST_2, PORT_2, expectedType, expectedUniqueId, 20L, null, null)); + inputTopic.pipeInput(newBeat(HOST_2, PORT_2, expectedType, expectedUniqueId, 10L, null, null)); + inputTopic.pipeInput(newBeat(HOST_2, PORT_2, expectedType, expectedUniqueId, 30L, null, null)); assertThat(getCount()).isEqualTo(8); diff --git a/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java b/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java index 005417c13..07a679da0 100644 --- a/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java +++ b/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java @@ -44,4 +44,25 @@ void getCommonTags() { assertThat(output).contains(entry("application_id", "my_id"), entry("extra", "extra_tag")); } + + @Test + void getMetronomeExtraTags() { + Map input = Map.of("lh.canary.metronome.beat.extra.tags.my_tag", "extra_tag"); + + CanaryConfig canaryConfig = new CanaryConfig(input); + + Map output = canaryConfig.getMetronomeBeatExtraTags(); + + assertThat(output).contains(entry("my_tag", "extra_tag")); + } + + @Test + void getEmptyMetronomeExtraTags() { + CanaryConfig canaryConfig = new CanaryConfig(Map.of()); + + Map output = canaryConfig.getMetronomeBeatExtraTags(); + + assertThat(output).isEmpty(); + assertThat(output).isNotNull(); + } } diff --git a/sdk-python/examples/basic/example_basic.py b/sdk-python/examples/basic/example_basic.py index fe349646c..bed79022f 100644 --- a/sdk-python/examples/basic/example_basic.py +++ b/sdk-python/examples/basic/example_basic.py @@ -46,4 +46,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/sdk-python/examples/user_tasks/user_tasks.py b/sdk-python/examples/user_tasks/user_tasks.py index dcd6f9b14..d61d3286f 100644 --- a/sdk-python/examples/user_tasks/user_tasks.py +++ b/sdk-python/examples/user_tasks/user_tasks.py @@ -20,19 +20,21 @@ def get_config() -> LHConfig: config.load(config_path) return config + async def get_user_task_def() -> PutUserTaskDefRequest: return PutUserTaskDefRequest( - name="person-details", - fields=[ - UserTaskField( - name="PersonDetails", - description="Person complementary information", - display_name="Other Details", - required=True, - type=VariableType.STR - ) - ] -) + name="person-details", + fields=[ + UserTaskField( + name="PersonDetails", + description="Person complementary information", + display_name="Other Details", + required=True, + type=VariableType.STR, + ) + ], + ) + def get_workflow() -> Workflow: def my_entrypoint(wf: WorkflowThread) -> None: @@ -42,16 +44,22 @@ def my_entrypoint(wf: WorkflowThread) -> None: arg1 = "Sam" arg2 = {"identification": "1258796641-4", "Address": "NA-Street", "Age": 28} - wf.schedule_reminder_task(user_task_output, delay_in_seconds, task_def_name, arg1, arg2) + wf.schedule_reminder_task( + user_task_output, delay_in_seconds, task_def_name, arg1, arg2 + ) return Workflow("example-user-tasks", my_entrypoint) -async def greeting(name: str, person_details: dict[str, Any], ctx: WorkerContext) -> str: + +async def greeting( + name: str, person_details: dict[str, Any], ctx: WorkerContext +) -> str: msg = f"Hello {name}!. WfRun {ctx.wf_run_id.id} Person: {person_details}" print(msg) await asyncio.sleep(random.uniform(0.5, 1.5)) return msg + async def main() -> None: config = get_config() wf = get_workflow() @@ -66,4 +74,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/sdk-python/littlehorse/workflow.py b/sdk-python/littlehorse/workflow.py index 496552769..7ac64e43a 100644 --- a/sdk-python/littlehorse/workflow.py +++ b/sdk-python/littlehorse/workflow.py @@ -130,14 +130,15 @@ def to_variable_assignment(value: Any) -> VariableAssignment: json_path=json_path, variable_name=variable_name, ) - + if isinstance(value, LHExpression): expression: LHExpression = value return VariableAssignment( expression=VariableAssignment.Expression( lhs=to_variable_assignment(expression.lhs()), operation=expression.operation(), - rhs=to_variable_assignment(expression.rhs())) + rhs=to_variable_assignment(expression.rhs()), + ) ) return VariableAssignment( @@ -153,39 +154,40 @@ def __init__(self, lhs: Any, operation: VariableMutationType, rhs: Any) -> None: def add(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.ADD, other) - + def subtract(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.SUBTRACT, other) - + def multiply(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.MULTIPLY, other) - + def divide(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.DIVIDE, other) - + def extend(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.EXTEND, other) - + def remove_if_present(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.REMOVE_IF_PRESENT, other) - + def remove_index(self, index: Optional[Union[int, Any]] = None) -> LHExpression: if index is None: raise ValueError("Expected 'index' to be set, but it was None.") return LHExpression(self, VariableMutationType.REMOVE_INDEX, index) - + def remove_key(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.REMOVE_KEY, other) - + def lhs(self) -> Any: return self._lhs - + def rhs(self) -> Any: return self._rhs - + def operation(self) -> Any: return self._operation + class WorkflowCondition: def __init__(self, left_hand: Any, comparator: Comparator, right_hand: Any) -> None: """Returns a WorkflowCondition that can be used in @@ -334,30 +336,30 @@ def with_json_path(self, json_path: str) -> "NodeOutput": out = NodeOutput(self.node_name) out.json_path = json_path return out - + def add(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.ADD, other) - + def subtract(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.SUBTRACT, other) - + def multiply(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.MULTIPLY, other) - + def divide(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.DIVIDE, other) - + def extend(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.EXTEND, other) - + def remove_if_present(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.REMOVE_IF_PRESENT, other) - + def remove_index(self, index: Optional[Union[int, Any]] = None) -> LHExpression: if index is None: raise ValueError("Expected 'index' to be set, but it was None.") return LHExpression(self, VariableMutationType.REMOVE_INDEX, index) - + def remove_key(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.REMOVE_KEY, other) @@ -389,7 +391,7 @@ def __init__( TypeError: If variable_type and type(default_value) are not compatible. """ self.name = variable_name - self.type = variable_type + self.type = variable_type self.parent = parent self.default_value: Optional[VariableValue] = None self._json_path: Optional[str] = None @@ -517,7 +519,7 @@ def searchable_on( def required(self) -> "WfRunVariable": self._required = True return self - + def with_default(self, default_value: Any) -> WfRunVariable: self._set_default(default_value) @@ -558,43 +560,43 @@ def compile(self) -> ThreadVarDef: def is_equal_to(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.EQUALS, rhs) - + def is_not_equal_to(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.NOT_EQUALS, rhs) - + def is_greater_than(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.GREATER_THAN, rhs) - + def is_greater_than_eq(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.GREATER_THAN_EQ, rhs) - + def is_less_than_eq(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.LESS_THAN_EQ, rhs) def is_less_than(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.LESS_THAN, rhs) - + def does_contain(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.IN, rhs) def does_not_contain(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.NOT_IN, rhs) - + def is_in(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.IN, rhs) def is_not_in(self, rhs: Any) -> WorkflowCondition: return self.parent.condition(self, Comparator.NOT_IN, rhs) - + def assign(self, rhs: Any) -> None: self.parent.mutate(self, VariableMutationType.ASSIGN, rhs) def add(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.ADD, other) - + def subtract(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.SUBTRACT, other) - + def multiply(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.MULTIPLY, other) @@ -603,15 +605,15 @@ def divide(self, other: Any) -> LHExpression: def extend(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.EXTEND, other) - + def remove_if_present(self, other: Any) -> LHExpression: return LHExpression(self, VariableMutationType.REMOVE_IF_PRESENT, other) - + def remove_index(self, index: Optional[Union[int, Any]] = None) -> LHExpression: if index is None: raise ValueError("Expected 'index' to be set, but it was None.") return LHExpression(self, VariableMutationType.REMOVE_INDEX, index) - + def remove_key(self, key: Any) -> LHExpression: return LHExpression(self, VariableMutationType.REMOVE_KEY, key) @@ -1280,45 +1282,45 @@ def multiply(self, lhs: Any, rhs: Any) -> LHExpression: def add(self, lhs: Any, rhs: Any) -> LHExpression: return LHExpression(lhs, VariableMutationType.ADD, rhs) - + def divide(self, lhs: Any, rhs: Any) -> LHExpression: return LHExpression(lhs, VariableMutationType.DIVIDE, rhs) - + def subtract(self, lhs: Any, rhs: Any) -> LHExpression: return LHExpression(lhs, VariableMutationType.SUBTRACT, rhs) - + def extend(self, lhs: Any, rhs: Any) -> LHExpression: return LHExpression(lhs, VariableMutationType.EXTEND, rhs) - + def remove_if_present(self, lhs: Any, rhs: Any) -> LHExpression: return LHExpression(lhs, VariableMutationType.REMOVE_IF_PRESENT, rhs) - + def remove_index(self, index: Optional[Union[int, Any]] = None) -> LHExpression: if index is None: raise ValueError("Expected 'index' to be set, but it was None.") return LHExpression(self, VariableMutationType.REMOVE_INDEX, index) - + def remove_key(self, lhs: Any, rhs: Any) -> LHExpression: return LHExpression(lhs, VariableMutationType.REMOVE_KEY, rhs) - + def declare_bool(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.BOOL) - + def declare_int(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.INT) - + def declare_str(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.STR) - + def declare_double(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.DOUBLE) - + def declare_bytes(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.BYTES) - + def declare_json_arr(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.JSON_ARR) - + def declare_json_obj(self, name: str) -> WfRunVariable: return self.add_variable(name, VariableType.JSON_OBJ) @@ -2114,6 +2116,7 @@ def with_task_timeout_seconds( self._default_timeout_seconds = timeout_seconds return self + def create_workflow_spec( workflow: Workflow, config: LHConfig, timeout: Optional[int] = None ) -> None: diff --git a/sdk-python/tests/test_workflow.py b/sdk-python/tests/test_workflow.py index 1604efd22..58a632bd8 100644 --- a/sdk-python/tests/test_workflow.py +++ b/sdk-python/tests/test_workflow.py @@ -74,7 +74,9 @@ def test_validate_json_path_format(self): class TestWfRunVariable(unittest.TestCase): def test_value_is_not_none(self): - variable = WfRunVariable("my-var", VariableType.STR, None, default_value="my-str") + variable = WfRunVariable( + "my-var", VariableType.STR, None, default_value="my-str" + ) self.assertEqual(variable.default_value.WhichOneof("value"), "str") self.assertEqual(variable.default_value.str, "my-str") @@ -155,7 +157,10 @@ def test_compile_variable(self): variable = WfRunVariable("my-var", VariableType.STR, None) self.assertEqual( variable.compile(), - ThreadVarDef(var_def=VariableDef(name="my-var", type=VariableType.STR), access_level=WfRunVariableAccessLevel.PRIVATE_VAR), + ThreadVarDef( + var_def=VariableDef(name="my-var", type=VariableType.STR), + access_level=WfRunVariableAccessLevel.PRIVATE_VAR, + ), ) variable = WfRunVariable("my-var", VariableType.JSON_OBJ, None) @@ -170,7 +175,9 @@ def test_compile_variable(self): self.assertEqual(variable.compile(), expected_output) def test_compile_private_variable(self): - variable = WfRunVariable("my-var", VariableType.STR, None, access_level="PRIVATE_VAR") + variable = WfRunVariable( + "my-var", VariableType.STR, None, access_level="PRIVATE_VAR" + ) expected_output = ThreadVarDef( var_def=VariableDef(name="my-var", type=VariableType.STR), access_level="PRIVATE_VAR", @@ -286,7 +293,9 @@ class MyClass: def if_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-1", variable_type=VariableType.INT, parent=thread + variable_name="variable-1", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 1, @@ -294,7 +303,9 @@ def if_condition(self, thread: WorkflowThread) -> None: thread.execute("task-a") thread.mutate( WfRunVariable( - variable_name="variable-3", variable_type=VariableType.INT, parent=thread + variable_name="variable-3", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 3, @@ -304,7 +315,9 @@ def if_condition(self, thread: WorkflowThread) -> None: def else_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-2", variable_type=VariableType.INT, parent=thread + variable_name="variable-2", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 2, @@ -313,7 +326,9 @@ def else_condition(self, thread: WorkflowThread) -> None: thread.execute("task-d") thread.mutate( WfRunVariable( - variable_name="variable-4", variable_type=VariableType.INT, parent=thread + variable_name="variable-4", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 4, @@ -359,7 +374,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-1", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=1)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=1) + ), ) ], ), @@ -378,7 +395,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-2", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=2)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=2) + ), ) ], ), @@ -393,7 +412,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-3", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=3)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=3) + ), ) ], ) @@ -416,7 +437,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-4", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=4)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=4) + ), ) ], ) @@ -436,7 +459,9 @@ class MyClass: def if_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-1", variable_type=VariableType.INT, parent=thread + variable_name="variable-1", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 1, @@ -445,7 +470,9 @@ def if_condition(self, thread: WorkflowThread) -> None: def else_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-2", variable_type=VariableType.INT, parent=thread + variable_name="variable-2", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 2, @@ -491,7 +518,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-2", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=2)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=2) + ), ) ], ), @@ -510,7 +539,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-1", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=1)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=1) + ), ) ], ), @@ -530,7 +561,9 @@ class MyClass: def if_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-2", variable_type=VariableType.INT, parent=thread + variable_name="variable-2", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 2, @@ -539,7 +572,9 @@ def if_condition(self, thread: WorkflowThread) -> None: def my_entrypoint(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-1", variable_type=VariableType.INT, parent=thread + variable_name="variable-1", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 1, @@ -549,7 +584,9 @@ def my_entrypoint(self, thread: WorkflowThread) -> None: ) thread.mutate( WfRunVariable( - variable_name="variable-3", variable_type=VariableType.INT, parent=thread + variable_name="variable-3", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 3, @@ -575,7 +612,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-1", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=1)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=1) + ), ) ], ) @@ -599,7 +638,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-2", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=2)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=2) + ), ) ], ), @@ -626,7 +667,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-3", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=3)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=3) + ), ) ], ) @@ -642,7 +685,9 @@ class MyClass: def my_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-1", variable_type=VariableType.INT, parent=thread + variable_name="variable-1", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 1, @@ -650,7 +695,9 @@ def my_condition(self, thread: WorkflowThread) -> None: thread.execute("my-task") thread.mutate( WfRunVariable( - variable_name="variable-2", variable_type=VariableType.INT, parent=thread + variable_name="variable-2", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 2, @@ -694,7 +741,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-1", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=1)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=1) + ), ) ], ), @@ -721,7 +770,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-2", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=2)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=2) + ), ) ], ) @@ -741,7 +792,9 @@ class MyClass: def my_condition(self, thread: WorkflowThread) -> None: thread.mutate( WfRunVariable( - variable_name="variable-1", variable_type=VariableType.INT, parent=thread + variable_name="variable-1", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 1, @@ -749,7 +802,9 @@ def my_condition(self, thread: WorkflowThread) -> None: thread.execute("my-task") thread.mutate( WfRunVariable( - variable_name="variable-2", variable_type=VariableType.INT, parent=thread + variable_name="variable-2", + variable_type=VariableType.INT, + parent=thread, ), VariableMutationType.ASSIGN, 2, @@ -793,7 +848,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-1", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=1)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=1) + ), ) ], ), @@ -820,7 +877,9 @@ def to_thread(self): VariableMutation( lhs_name="variable-2", operation=VariableMutationType.ASSIGN, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=2)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=2) + ), ) ], ) @@ -1061,7 +1120,9 @@ def my_entrypoint(thread: WorkflowThread) -> None: VariableMutation( lhs_name="value", operation=VariableMutationType.MULTIPLY, - rhs_assignment=VariableAssignment(literal_value=VariableValue(int=2)), + rhs_assignment=VariableAssignment( + literal_value=VariableValue(int=2) + ), ) ], ) @@ -1090,7 +1151,9 @@ def my_entrypoint(thread: WorkflowThread) -> None: edge = node.outgoing_edges[0] - self.assertEqual(edge.variable_mutations[0].rhs_assignment.literal_value.str, "some-value") + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.literal_value.str, "some-value" + ) def test_node_output_mutations_should_also_use_variable_assignments(self): def my_entrypoint(thread: WorkflowThread) -> None: @@ -1103,12 +1166,17 @@ def my_entrypoint(thread: WorkflowThread) -> None: edge = node.outgoing_edges[0] - self.assertEqual(edge.variable_mutations[0].rhs_assignment.node_output.node_name, "1-use-the-force-TASK") + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.node_output.node_name, + "1-use-the-force-TASK", + ) def test_node_output_mutations_should_carry_json_path(self): def my_entrypoint(thread: WorkflowThread) -> None: my_var = thread.add_variable("my-var", VariableType.STR) - my_var.assign(thread.execute("use-the-force").with_json_path("$.hello.there")) + my_var.assign( + thread.execute("use-the-force").with_json_path("$.hello.there") + ) wfSpec = Workflow("obiwan", my_entrypoint).compile() entrypoint = wfSpec.thread_specs[wfSpec.entrypoint_thread_name] @@ -1116,11 +1184,18 @@ def my_entrypoint(thread: WorkflowThread) -> None: edge = node.outgoing_edges[0] - self.assertEqual(edge.variable_mutations[0].rhs_assignment.node_output.node_name, "1-use-the-force-TASK") + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.node_output.node_name, + "1-use-the-force-TASK", + ) - self.assertEqual(edge.variable_mutations[0].rhs_assignment.json_path, "$.hello.there") + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.json_path, "$.hello.there" + ) - def test_assigning_variables_to_other_variables_should_use_variable_assignment(self): + def test_assigning_variables_to_other_variables_should_use_variable_assignment( + self, + ): def my_entrypoint(thread: WorkflowThread) -> None: my_var = thread.add_variable("my-var", VariableType.STR) other_var = thread.add_variable("other-var", VariableType.STR) @@ -1131,7 +1206,9 @@ def my_entrypoint(thread: WorkflowThread) -> None: node = entrypoint.nodes["0-entrypoint-ENTRYPOINT"] edge = node.outgoing_edges[0] - self.assertEqual(edge.variable_mutations[0].rhs_assignment.variable_name, "other-var") + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.variable_name, "other-var" + ) def test_assigning_variables_to_other_variables_should_carry_json_path(self): def my_entrypoint(thread: WorkflowThread) -> None: @@ -1144,10 +1221,13 @@ def my_entrypoint(thread: WorkflowThread) -> None: node = entrypoint.nodes["0-entrypoint-ENTRYPOINT"] edge = node.outgoing_edges[0] - self.assertEqual(edge.variable_mutations[0].rhs_assignment.variable_name, "other-var") + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.variable_name, "other-var" + ) - self.assertEqual(edge.variable_mutations[0].rhs_assignment.json_path, "$.hello.there") - + self.assertEqual( + edge.variable_mutations[0].rhs_assignment.json_path, "$.hello.there" + ) class TestWorkflow(unittest.TestCase):