From 5852cb13e7c510c1a22a6c49df8cfe415316e766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Pi=C3=B1a?= Date: Mon, 16 Dec 2024 14:09:03 -0500 Subject: [PATCH 1/3] fix: prometheus metrics exporter, add new mandatory field server.id to metronome (#1200) Co-authored-by: Mijail Rondon --- canary/build.gradle | 2 +- canary/canary.properties | 1 + .../main/java/io/littlehorse/canary/Main.java | 1 + .../canary/aggregator/Aggregator.java | 3 +- .../internal/MetricStoreExporter.java | 28 ++-- .../aggregator/topology/MetricsTopology.java | 2 + .../canary/config/CanaryConfig.java | 11 +- .../metronome/internal/BeatProducer.java | 4 + .../canary/prometheus/PrometheusExporter.java | 4 +- .../prometheus/PrometheusServerExporter.java | 4 +- canary/src/main/proto/beats.proto | 1 + canary/src/main/proto/metrics.proto | 1 + .../internal/MetricStoreExporterTest.java | 131 ++++++++++++++++++ .../topology/MetricsTopologyTest.java | 23 ++- .../canary/config/CanaryConfigTest.java | 11 ++ docs/CANARY_CONFIGURATIONS.md | 10 ++ 16 files changed, 216 insertions(+), 21 deletions(-) create mode 100644 canary/src/test/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporterTest.java diff --git a/canary/build.gradle b/canary/build.gradle index 830220269..d15cd4077 100644 --- a/canary/build.gradle +++ b/canary/build.gradle @@ -35,7 +35,7 @@ dependencies { implementation 'io.javalin:javalin-micrometer:6.3.0' // Prometheus - implementation 'io.micrometer:micrometer-registry-prometheus:1.12.2' + implementation 'io.micrometer:micrometer-registry-prometheus:1.14.2' // RocksDB implementation 'org.rocksdb:rocksdbjni:9.0.0' diff --git a/canary/canary.properties b/canary/canary.properties index 64d6086db..f08680df3 100644 --- a/canary/canary.properties +++ b/canary/canary.properties @@ -14,3 +14,4 @@ lh.canary.metronome.frequency.ms=1000 lh.canary.metronome.run.threads=1 lh.canary.metronome.run.requests=300 lh.canary.metronome.run.sample.rate=1 +lh.canary.metronome.server.id=lh diff --git a/canary/src/main/java/io/littlehorse/canary/Main.java b/canary/src/main/java/io/littlehorse/canary/Main.java index 55085d54c..c79188837 100644 --- a/canary/src/main/java/io/littlehorse/canary/Main.java +++ b/canary/src/main/java/io/littlehorse/canary/Main.java @@ -67,6 +67,7 @@ private static void initialize(final String[] args) throws IOException { lhConfig.getApiBootstrapHost(), lhConfig.getApiBootstrapPort(), lhClient.getServerVersion(), + canaryConfig.getMetronomeServerId(), canaryConfig.getTopicName(), canaryConfig.toKafkaConfig().toMap(), canaryConfig.getMetronomeBeatExtraTags()); diff --git a/canary/src/main/java/io/littlehorse/canary/aggregator/Aggregator.java b/canary/src/main/java/io/littlehorse/canary/aggregator/Aggregator.java index 11d1924ed..4fc00e2b4 100644 --- a/canary/src/main/java/io/littlehorse/canary/aggregator/Aggregator.java +++ b/canary/src/main/java/io/littlehorse/canary/aggregator/Aggregator.java @@ -31,7 +31,8 @@ public Aggregator( @Override public void bindTo(final MeterRegistry registry) { - final MetricStoreExporter prometheusMetricStoreExporter = new MetricStoreExporter(kafkaStreams, METRICS_STORE); + final MetricStoreExporter prometheusMetricStoreExporter = + new MetricStoreExporter(kafkaStreams, METRICS_STORE, Duration.ofSeconds(30)); prometheusMetricStoreExporter.bindTo(registry); } } diff --git a/canary/src/main/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporter.java b/canary/src/main/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporter.java index 842a19e21..a9cef273e 100644 --- a/canary/src/main/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporter.java +++ b/canary/src/main/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporter.java @@ -9,6 +9,7 @@ import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.binder.MeterBinder; +import java.time.Duration; import java.util.*; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -23,22 +24,25 @@ import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; @Slf4j -public class MetricStoreExporter implements MeterBinder { +public class MetricStoreExporter implements MeterBinder, AutoCloseable { private final KafkaStreams kafkaStreams; private final String storeName; - private final Map currentMeters; + private final Map currentMeters = new HashMap<>(); + private final Duration refreshPeriod; + private ScheduledExecutorService mainExecutor; - public MetricStoreExporter(final KafkaStreams kafkaStreams, final String storeName) { + public MetricStoreExporter(final KafkaStreams kafkaStreams, final String storeName, final Duration refreshPeriod) { this.kafkaStreams = kafkaStreams; this.storeName = storeName; - currentMeters = new HashMap<>(); + this.refreshPeriod = refreshPeriod; } private static List toTags(final MetricKey key) { final List tags = new ArrayList<>(); tags.add(Tag.of("server", "%s:%s".formatted(key.getServerHost(), key.getServerPort()))); tags.add(Tag.of("server_version", key.getServerVersion())); + tags.add(Tag.of("server_id", key.getServerId())); tags.addAll(key.getTagsList().stream() .map(tag -> Tag.of(tag.getKey(), tag.getValue())) .toList()); @@ -47,12 +51,15 @@ private static List toTags(final MetricKey key) { @Override public void bindTo(final MeterRegistry registry) { - final ScheduledExecutorService mainExecutor = Executors.newSingleThreadScheduledExecutor(); - ShutdownHook.add("Latency Metrics Exporter", () -> { - mainExecutor.shutdownNow(); - mainExecutor.awaitTermination(1, TimeUnit.SECONDS); - }); - mainExecutor.scheduleAtFixedRate(() -> updateMetrics(registry), 30, 30, TimeUnit.SECONDS); + mainExecutor = Executors.newSingleThreadScheduledExecutor(); + ShutdownHook.add("Latency Metrics Exporter", this); + mainExecutor.scheduleAtFixedRate( + () -> updateMetrics(registry), 0, refreshPeriod.toMillis(), TimeUnit.MILLISECONDS); + } + + public void close() throws InterruptedException { + mainExecutor.shutdownNow(); + mainExecutor.awaitTermination(1, TimeUnit.SECONDS); } private void updateMetrics(final MeterRegistry registry) { @@ -72,7 +79,6 @@ private void updateMetrics(final MeterRegistry registry) { while (records.hasNext()) { final KeyValue record = records.next(); foundMetrics.add(record.key); - final PrometheusMetric current = currentMeters.get(record.key); if (current == null) { final AtomicDouble newMeter = new AtomicDouble(record.value.getValue()); diff --git a/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java b/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java index d89fa61e1..2f2b930ea 100644 --- a/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java +++ b/canary/src/main/java/io/littlehorse/canary/aggregator/topology/MetricsTopology.java @@ -172,6 +172,7 @@ private static MetricKey buildMetricKey(final BeatKey key, final String id) { .setServerVersion(key.getServerVersion()) .setServerPort(key.getServerPort()) .setServerHost(key.getServerHost()) + .setServerId(key.getServerId()) .setId("canary_%s".formatted(id)); if (key.hasStatus() && !Strings.isNullOrEmpty(key.getStatus())) { @@ -204,6 +205,7 @@ private static BeatKey removeWfId(final BeatKey key, final BeatValue value) { .setServerHost(key.getServerHost()) .setServerPort(key.getServerPort()) .setStatus(key.getStatus()) + .setServerId(key.getServerId()) .addAllTags(key.getTagsList()) .build(); } diff --git a/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java b/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java index f70e85d12..4607d26a1 100644 --- a/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java +++ b/canary/src/main/java/io/littlehorse/canary/config/CanaryConfig.java @@ -32,6 +32,7 @@ public class CanaryConfig implements Config { public static final String METRONOME_GET_RETRIES = "metronome.get.retries"; public static final String METRONOME_WORKER_ENABLE = "metronome.worker.enable"; public static final String METRONOME_DATA_PATH = "metronome.data.path"; + public static final String METRONOME_SERVER_ID = "metronome.server.id"; public static final String METRONOME_BEAT_EXTRA_TAGS = "metronome.beat.extra.tags"; public static final String METRONOME_BEAT_EXTRA_TAGS_PREFIX = "%s.".formatted(METRONOME_BEAT_EXTRA_TAGS); @@ -71,7 +72,11 @@ public KafkaConfig toKafkaConfig() { } private String getConfig(final String configName) { - return configs.get(configName).toString(); + final Object value = configs.get(configName); + if (value == null) { + throw new IllegalArgumentException("Configuration 'lh.canary." + configName + "' not found"); + } + return value.toString(); } public String getTopicName() { @@ -181,4 +186,8 @@ public int getWorkflowVersion() { public String getMetronomeDataPath() { return getConfig(METRONOME_DATA_PATH); } + + public String getMetronomeServerId() { + return getConfig(METRONOME_SERVER_ID); + } } diff --git a/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java b/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java index 96671335c..4520a102f 100644 --- a/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java +++ b/canary/src/main/java/io/littlehorse/canary/metronome/internal/BeatProducer.java @@ -25,17 +25,20 @@ public class BeatProducer { private final int lhServerPort; private final String lhServerVersion; private final String topicName; + private final String lhServerId; public BeatProducer( final String lhServerHost, final int lhServerPort, final String lhServerVersion, + final String lhServerId, final String topicName, final Map producerConfig, final Map extraTags) { this.lhServerHost = lhServerHost; this.lhServerPort = lhServerPort; this.lhServerVersion = lhServerVersion; + this.lhServerId = lhServerId; this.topicName = topicName; this.extraTags = extraTags; @@ -89,6 +92,7 @@ private BeatKey buildKey(final String id, final BeatType type, final String stat .setServerHost(lhServerHost) .setServerPort(lhServerPort) .setServerVersion(lhServerVersion) + .setServerId(lhServerId) .setId(id) .setType(type); diff --git a/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusExporter.java b/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusExporter.java index d2f055580..4e8e4707e 100644 --- a/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusExporter.java +++ b/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusExporter.java @@ -3,8 +3,8 @@ import io.littlehorse.canary.util.ShutdownHook; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.binder.MeterBinder; -import io.micrometer.prometheus.PrometheusConfig; -import io.micrometer.prometheus.PrometheusMeterRegistry; +import io.micrometer.prometheusmetrics.PrometheusConfig; +import io.micrometer.prometheusmetrics.PrometheusMeterRegistry; import java.util.Map; public class PrometheusExporter { diff --git a/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusServerExporter.java b/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusServerExporter.java index 1bb6f18c8..740e8e613 100644 --- a/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusServerExporter.java +++ b/canary/src/main/java/io/littlehorse/canary/prometheus/PrometheusServerExporter.java @@ -1,9 +1,9 @@ package io.littlehorse.canary.prometheus; import io.javalin.Javalin; +import io.javalin.http.ContentType; import io.javalin.http.Context; import io.littlehorse.canary.util.ShutdownHook; -import io.prometheus.client.exporter.common.TextFormat; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -22,6 +22,6 @@ public PrometheusServerExporter( private void printMetrics(final Context context) { log.trace("Processing metrics request"); - context.contentType(TextFormat.CONTENT_TYPE_004).result(prometheusExporter.scrape()); + context.contentType(ContentType.PLAIN).result(prometheusExporter.scrape()); } } diff --git a/canary/src/main/proto/beats.proto b/canary/src/main/proto/beats.proto index b0a2e7261..043ed4042 100644 --- a/canary/src/main/proto/beats.proto +++ b/canary/src/main/proto/beats.proto @@ -27,6 +27,7 @@ message BeatKey { optional string status = 5; optional string id = 6; repeated Tag tags = 7; + string server_id = 8; } message BeatValue { diff --git a/canary/src/main/proto/metrics.proto b/canary/src/main/proto/metrics.proto index 90abbe8fc..276da2d81 100644 --- a/canary/src/main/proto/metrics.proto +++ b/canary/src/main/proto/metrics.proto @@ -12,6 +12,7 @@ message MetricKey { string server_version = 3; string id = 4; repeated Tag tags = 5; + string server_id = 6; } message MetricValue { diff --git a/canary/src/test/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporterTest.java b/canary/src/test/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporterTest.java new file mode 100644 index 000000000..8f20872d0 --- /dev/null +++ b/canary/src/test/java/io/littlehorse/canary/aggregator/internal/MetricStoreExporterTest.java @@ -0,0 +1,131 @@ +package io.littlehorse.canary.aggregator.internal; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.when; + +import io.littlehorse.canary.proto.MetricKey; +import io.littlehorse.canary.proto.MetricValue; +import io.littlehorse.canary.proto.Tag; +import io.micrometer.prometheusmetrics.PrometheusConfig; +import io.micrometer.prometheusmetrics.PrometheusMeterRegistry; +import java.time.Duration; +import java.util.List; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class MetricStoreExporterTest { + + public static final String TEST_STORAGE = "testStorage"; + public static final String HOST = "localhost"; + + @Mock + KafkaStreams kafkaStreams; + + @Mock + ReadOnlyKeyValueStore store; + + @Mock + KeyValueIterator records; + + PrometheusMeterRegistry prometheusRegistry; + MetricStoreExporter metricExporter; + + @BeforeEach + void setUp() { + metricExporter = new MetricStoreExporter(kafkaStreams, TEST_STORAGE, Duration.ofSeconds(10)); + prometheusRegistry = new PrometheusMeterRegistry(PrometheusConfig.DEFAULT); + } + + @AfterEach + void tearDown() throws InterruptedException { + metricExporter.close(); + prometheusRegistry.close(); + } + + @Test + public void shouldScrapeSimpleMetric() throws InterruptedException { + // metrics + MetricKey key = createMetricsKey(List.of( + Tag.newBuilder().setKey("custom_tag").setValue("custom_value").build())); + MetricKey key2 = createMetricsKey(List.of()); + MetricValue value = MetricValue.newBuilder().setValue(1.0).build(); + + // records + when(records.hasNext()).thenReturn(true, false); + when(records.next()).thenReturn(KeyValue.pair(key, value)); + doNothing().when(records).close(); + + // store + when(store.all()).thenReturn(records); + + // kafka streams + when(kafkaStreams.state()).thenReturn(KafkaStreams.State.RUNNING); + when(kafkaStreams.store(any())).thenReturn(store); + + metricExporter.bindTo(prometheusRegistry); + + Thread.sleep(500); + + assertThat(prometheusRegistry.scrape()) + .contains( + "my_metric{custom_tag=\"custom_value\",server=\"localhost:2023\",server_id=\"my_server\",server_version=\"test\"} 1.0"); + } + + private static MetricKey createMetricsKey(List tags) { + return createMetricsKey(HOST, tags); + } + + private static MetricKey createMetricsKey(String host, List tags) { + return MetricKey.newBuilder() + .setServerHost(host) + .setServerPort(2023) + .setServerVersion("test") + .setId("my_metric") + .setServerId("my_server") + .addAllTags(tags) + .build(); + } + + @Test + void printMetricsWithTwoDifferentServers() throws InterruptedException { + // metrics + List tags = List.of( + Tag.newBuilder().setKey("custom_tag").setValue("custom_value").build()); + MetricKey key1 = createMetricsKey(tags); + MetricKey key2 = createMetricsKey("localhost2", tags); + MetricValue value = MetricValue.newBuilder().setValue(1.0).build(); + + // records + when(records.hasNext()).thenReturn(true, true, false); + when(records.next()).thenReturn(KeyValue.pair(key1, value), KeyValue.pair(key2, value)); + doNothing().when(records).close(); + + // store + when(store.all()).thenReturn(records); + + // kafka streams + when(kafkaStreams.state()).thenReturn(KafkaStreams.State.RUNNING); + when(kafkaStreams.store(any())).thenReturn(store); + + metricExporter.bindTo(prometheusRegistry); + + Thread.sleep(500); + System.out.printf(prometheusRegistry.scrape()); + assertThat(prometheusRegistry.scrape()) + .isEqualTo( + "# HELP my_metric \n" + "# TYPE my_metric gauge\n" + + "my_metric{custom_tag=\"custom_value\",server=\"localhost2:2023\",server_id=\"my_server\",server_version=\"test\"} 1.0\n" + + "my_metric{custom_tag=\"custom_value\",server=\"localhost:2023\",server_id=\"my_server\",server_version=\"test\"} 1.0\n"); + } +} diff --git a/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java b/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java index afb90b3ed..752198d7e 100644 --- a/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java +++ b/canary/src/test/java/io/littlehorse/canary/aggregator/topology/MetricsTopologyTest.java @@ -30,6 +30,7 @@ class MetricsTopologyTest { public static final String HOST_2 = "localhost2"; public static final int PORT_2 = 2024; + public static final String SERVER_ID = "LH"; private TopologyTestDriver testDriver; private TestInputTopic inputTopic; @@ -60,8 +61,11 @@ private static MetricKey newMetricKey(String id, String status, Map tags) { - MetricKey.Builder builder = - MetricKey.newBuilder().setServerHost(host).setServerPort(port).setId(id); + MetricKey.Builder builder = MetricKey.newBuilder() + .setServerHost(host) + .setServerPort(port) + .setId(id) + .setServerId(SERVER_ID); if (status != null) { builder.addTags(Tag.newBuilder().setKey("status").setValue(status).build()); @@ -105,6 +109,7 @@ private static TestRecord newBeat( .setServerHost(host) .setServerPort(port) .setType(type) + .setServerId(SERVER_ID) .setId(id); BeatValue.Builder valueBuilder = BeatValue.newBuilder().setTime(Timestamps.now()); @@ -180,10 +185,22 @@ void includeBeatTagsIntoMetrics() { Map expectedTags = Map.of("my_tag", "value"); inputTopic.pipeInput(newBeat(expectedType, getRandomId(), 20L, "ok", expectedTags)); + inputTopic.pipeInput(newBeat(expectedType, getRandomId(), 20L, "ok")); - assertThat(getCount()).isEqualTo(3); + assertThat(getCount()).isEqualTo(6); assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_avg", "ok", expectedTags))) .isEqualTo(newMetricValue(20.)); + assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_max", "ok", expectedTags))) + .isEqualTo(newMetricValue(20.)); + assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_count", "ok", expectedTags))) + .isEqualTo(newMetricValue(1.)); + + assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_avg", "ok"))) + .isEqualTo(newMetricValue(20.)); + assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_max", "ok"))) + .isEqualTo(newMetricValue(20.)); + assertThat(store.get(newMetricKey("canary_" + expectedTypeName + "_count", "ok"))) + .isEqualTo(newMetricValue(1.)); } @Test diff --git a/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java b/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java index 07a679da0..f34fa3fa5 100644 --- a/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java +++ b/canary/src/test/java/io/littlehorse/canary/config/CanaryConfigTest.java @@ -2,6 +2,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.entry; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.Map; import org.junit.jupiter.api.Test; @@ -65,4 +66,14 @@ void getEmptyMetronomeExtraTags() { assertThat(output).isEmpty(); assertThat(output).isNotNull(); } + + @Test + void throwsExceptionIfConfigurationIsNotFound() { + CanaryConfig canaryConfig = new CanaryConfig(Map.of()); + + IllegalArgumentException result = + assertThrows(IllegalArgumentException.class, canaryConfig::getMetronomeServerId); + + assertThat(result.getMessage()).isEqualTo("Configuration 'lh.canary.metronome.server.id' not found"); + } } diff --git a/docs/CANARY_CONFIGURATIONS.md b/docs/CANARY_CONFIGURATIONS.md index 8d7f167b8..8b906103e 100644 --- a/docs/CANARY_CONFIGURATIONS.md +++ b/docs/CANARY_CONFIGURATIONS.md @@ -15,6 +15,7 @@ * [`lh.canary.metronome.get.retries`](#lhcanarymetronomegetretries) * [`lh.canary.metronome.data.path`](#lhcanarymetronomedatapath) * [`lh.canary.metronome.beat.extra.tags.`](#lhcanarymetronomebeatextratagsadditional-tag) + * [`lh.canary.metronome.server.id`](#lhcanarymetronomeserverid) * [Kafka Configurations](#kafka-configurations) * [LH Client Configurations](#lh-client-configurations) * [Task Worker](#task-worker) @@ -166,6 +167,15 @@ For example: `lh.canary.metronome.beat.extra.tags.my_tag=my-value`. - **Default:** null - **Importance:** low +#### `lh.canary.metronome.server.id` + +Add the tag server id the prometheus metrics (**mandatory**). +For example: `lh.canary.metronome.server.id=lh`. + +- **Type:** string +- **Default:** null +- **Importance:** high + ### Kafka Configurations LH Canary supports all kafka configurations. Use the prefix `lh.canary.kafka` and append the kafka config. From 608a593025e69304d635b6008b3c7df2ab6e2954 Mon Sep 17 00:00:00 2001 From: KarlaCarvajal Date: Mon, 16 Dec 2024 15:11:01 -0600 Subject: [PATCH 2/3] fix(sdk-dotnet): fix task worker connection manager (#1191) * Fix task worker connection manager to send hosts and ports available * Fix rebalance when the boostrap sever is down and then up --- sdk-dotnet/Examples/BasicExample/MyWorker.cs | 2 +- sdk-dotnet/Examples/BasicExample/Program.cs | 4 +- .../Worker/VariableMappingTest.cs | 18 +-- sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs | 2 +- sdk-dotnet/LittleHorse.Sdk/LHConfig.cs | 10 +- .../Worker/Internal/LHServerConnection.cs | 28 ++--- .../Internal/LHServerConnectionManager.cs | 88 +++++++------- .../LittleHorse.Sdk/Worker/Internal/LHTask.cs | 105 +++++++++++++++++ .../LittleHorse.Sdk/Worker/LHTaskWorker.cs | 111 +++--------------- .../LittleHorse.Sdk/Worker/LHWorkerContext.cs | 7 +- 10 files changed, 199 insertions(+), 176 deletions(-) create mode 100644 sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHTask.cs diff --git a/sdk-dotnet/Examples/BasicExample/MyWorker.cs b/sdk-dotnet/Examples/BasicExample/MyWorker.cs index ef3d797aa..49c106989 100644 --- a/sdk-dotnet/Examples/BasicExample/MyWorker.cs +++ b/sdk-dotnet/Examples/BasicExample/MyWorker.cs @@ -4,7 +4,7 @@ namespace Examples.BasicExample { public class MyWorker { - [LHTaskMethod("greet-dotnet")] + [LHTaskMethod("greet")] public string Greeting(string name) { var message = $"Hello team, This is a Dotnet Worker"; diff --git a/sdk-dotnet/Examples/BasicExample/Program.cs b/sdk-dotnet/Examples/BasicExample/Program.cs index 2d91f7951..617a38cab 100644 --- a/sdk-dotnet/Examples/BasicExample/Program.cs +++ b/sdk-dotnet/Examples/BasicExample/Program.cs @@ -51,9 +51,9 @@ static void Main(string[] args) { var loggerFactory = _serviceProvider.GetRequiredService(); var config = GetLHConfig(args, loggerFactory); - + MyWorker executable = new MyWorker(); - var taskWorker = new LHTaskWorker(executable, "greet-dotnet", config); + var taskWorker = new LHTaskWorker(executable, "greet", config); taskWorker.RegisterTaskDef(); diff --git a/sdk-dotnet/LittleHorse.Sdk.Tests/Worker/VariableMappingTest.cs b/sdk-dotnet/LittleHorse.Sdk.Tests/Worker/VariableMappingTest.cs index 0e664e5f4..1f06e4050 100644 --- a/sdk-dotnet/LittleHorse.Sdk.Tests/Worker/VariableMappingTest.cs +++ b/sdk-dotnet/LittleHorse.Sdk.Tests/Worker/VariableMappingTest.cs @@ -39,7 +39,7 @@ public void VariableMapping_WithValidLHTypes_ShouldBeBuiltSuccessfully() foreach (var type in testAllowedTypes) { var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var result = new VariableMapping(taskDef, position, type, paramName); @@ -53,7 +53,7 @@ public void VariableMapping_WithMismatchTypesInt_ShouldThrowException() Type type1 = typeof(Int64); Type type2 = typeof(string); var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var exception = Assert.Throws( () => new VariableMapping(taskDef, 0, type2, "any param name")); @@ -67,7 +67,7 @@ public void VariableMapping_WithMismatchTypeDouble_ShouldThrowException() Type type1 = typeof(double); Type type2 = typeof(Int64); var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var exception = Assert.Throws( () => new VariableMapping(taskDef, 0, type2, "any param name")); @@ -81,7 +81,7 @@ public void VariableMapping_WithMismatchTypeString_ShouldThrowException() Type type1 = typeof(string); Type type2 = typeof(double); var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var exception = Assert.Throws( () => new VariableMapping(taskDef, 0, type2, "any param name")); @@ -95,7 +95,7 @@ public void VariableMapping_WithMismatchTypeBool_ShouldThrowException() Type type1 = typeof(bool); Type type2 = typeof(string); var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var exception = Assert.Throws( () => new VariableMapping(taskDef, 0, type2, "any param name")); @@ -109,7 +109,7 @@ public void VariableMapping_WithMismatchTypeBytes_ShouldThrowException() Type type1 = typeof(byte[]); Type type2 = typeof(string); var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var exception = Assert.Throws( () => new VariableMapping(taskDef, 0, type2, "any param name")); @@ -302,11 +302,11 @@ public void VariableMapping_WithAssignJsonStringValue_ShouldReturnCustomObject() Assert.Equal(expectedObject.Cars!.Count, actualObject.Cars!.Count); } - private TaskDef getTaskDefForTest(VariableType type) + private TaskDef? getTaskDefForTest(VariableType type) { var inputVar = new VariableDef(); inputVar.Type = type; - TaskDef taskDef = new TaskDef(); + TaskDef? taskDef = new TaskDef(); TaskDefId taskDefId = new TaskDefId(); taskDef.Id = taskDefId; taskDef.InputVars.Add(inputVar); @@ -317,7 +317,7 @@ private TaskDef getTaskDefForTest(VariableType type) private VariableMapping getVariableMappingForTest(Type type, string paramName, int position) { var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type); - TaskDef taskDef = getTaskDefForTest(variableType); + TaskDef? taskDef = getTaskDefForTest(variableType); var variableMapping = new VariableMapping(taskDef, position, type, paramName); diff --git a/sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs b/sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs index 76208d884..1b55390e1 100644 --- a/sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs +++ b/sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs @@ -5,7 +5,7 @@ namespace LittleHorse.Sdk.Helper { public static class LHHelper { - public static WfRunId GetWFRunId(TaskRunSource taskRunSource) + public static WfRunId? GetWfRunId(TaskRunSource taskRunSource) { switch (taskRunSource.TaskRunSourceCase) { diff --git a/sdk-dotnet/LittleHorse.Sdk/LHConfig.cs b/sdk-dotnet/LittleHorse.Sdk/LHConfig.cs index 1723853b2..f1c6ca057 100644 --- a/sdk-dotnet/LittleHorse.Sdk/LHConfig.cs +++ b/sdk-dotnet/LittleHorse.Sdk/LHConfig.cs @@ -116,14 +116,14 @@ private bool IsOAuth } } - public LittleHorseClient GetGrcpClientInstance() + public LittleHorseClient GetGrpcClientInstance() { - return GetGrcpClientInstance(BootstrapHost, BootstrapPort); + return GetGrpcClientInstance(BootstrapHost, BootstrapPort); } - public LittleHorseClient GetGrcpClientInstance(string host, int port) + public LittleHorseClient GetGrpcClientInstance(string host, int port) { - string channelKey = BootstrapServer; + string channelKey = $"{BootstrapProtocol}://{host}:{port}"; if (_createdChannels.ContainsKey(channelKey)) { @@ -208,7 +208,7 @@ public TaskDef GetTaskDef(string taskDefName) { try { - var client = GetGrcpClientInstance(); + var client = GetGrpcClientInstance(); var taskDefId = new TaskDefId() { Name = taskDefName diff --git a/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnection.cs b/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnection.cs index 774119c37..0691e41f9 100644 --- a/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnection.cs +++ b/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnection.cs @@ -8,25 +8,25 @@ namespace LittleHorse.Sdk.Worker.Internal { public class LHServerConnection : IDisposable { - private LHServerConnectionManager _connectionManager; - private LHHostInfo _hostInfo; + private readonly LHServerConnectionManager _connectionManager; + private readonly LHHostInfo _hostInfo; private bool _running; - private LittleHorseClient _client; - private AsyncDuplexStreamingCall _call; - private ILogger? _logger; + private readonly LittleHorseClient _client; + private readonly AsyncDuplexStreamingCall _call; + private readonly ILogger? _logger; - public LHHostInfo HostInfo { get { return _hostInfo; } } + public LHHostInfo HostInfo => _hostInfo; public LHServerConnection(LHServerConnectionManager connectionManager, LHHostInfo hostInfo) { _connectionManager = connectionManager; _hostInfo = hostInfo; _logger = LHLoggerFactoryProvider.GetLogger>(); - _client = _connectionManager.Config.GetGrcpClientInstance(); + _client = _connectionManager.Config.GetGrpcClientInstance(hostInfo.Host, hostInfo.Port); _call = _client.PollTask(); } - public void Connect() + public void Open() { _running = true; Task.Run(RequestMoreWorkAsync); @@ -48,12 +48,12 @@ private async Task RequestMoreWorkAsync() if (taskToDo.Result != null) { var scheduledTask = taskToDo.Result; - var wFRunId = LHHelper.GetWFRunId(scheduledTask.Source); - _logger?.LogDebug($"Received task schedule request for wfRun {wFRunId.Id}"); + var wFRunId = LHHelper.GetWfRunId(scheduledTask.Source); + _logger?.LogDebug($"Received task schedule request for wfRun {wFRunId?.Id}"); - _connectionManager.SubmitTaskForExecution(scheduledTask, _client); + _connectionManager.SubmitTaskForExecution(scheduledTask); - _logger?.LogDebug($"Scheduled task on threadpool for wfRun {wFRunId.Id}"); + _logger?.LogDebug($"Scheduled task on threadpool for wfRun {wFRunId?.Id}"); } else { @@ -82,9 +82,9 @@ public void Dispose() _running = false; } - public bool IsSame(LHHostInfo hostInfoToCompare) + public bool IsSame(string host, int port) { - return _hostInfo.Host.Equals(hostInfoToCompare.Host) && _hostInfo.Port == hostInfoToCompare.Port; + return _hostInfo.Host.Equals(host) && _hostInfo.Port == port; } } } diff --git a/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnectionManager.cs b/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnectionManager.cs index 381a3d70e..f7677be94 100644 --- a/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnectionManager.cs +++ b/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnectionManager.cs @@ -17,35 +17,25 @@ public class LHServerConnectionManager : IDisposable private const int BALANCER_SLEEP_TIME = 5000; private const int MAX_REPORT_RETRIES = 5; - private LHConfig _config; - private MethodInfo _taskMethod; - private TaskDef _taskDef; - private List _mappings; - private T _executable; - private ILogger? _logger; - private LittleHorseClient _bootstrapClient; + private readonly LHConfig _config; + private readonly ILogger? _logger; + private readonly LittleHorseClient _bootstrapClient; private bool _running; private List> _runningConnections; - private Thread _rebalanceThread; - private SemaphoreSlim _semaphore; + private readonly Thread _rebalanceThread; + private readonly SemaphoreSlim _semaphore; + private readonly LHTask _task; - public LHConfig Config { get { return _config; } } - public TaskDef TaskDef { get { return _taskDef; } } + public LHConfig Config => _config; + public TaskDef TaskDef => _task.TaskDef!; public LHServerConnectionManager(LHConfig config, - MethodInfo taskMethod, - TaskDef taskDef, - List mappings, - T executable) + LHTask task) { _config = config; - _taskMethod = taskMethod; - _taskDef = taskDef; - _mappings = mappings; - _executable = executable; _logger = LHLoggerFactoryProvider.GetLogger>(); - - _bootstrapClient = config.GetGrcpClientInstance(); + _task = task; + _bootstrapClient = config.GetGrpcClientInstance(); _running = false; _runningConnections = new List>(); @@ -85,22 +75,22 @@ private void DoHeartBeat() { var request = new RegisterTaskWorkerRequest { - TaskDefId = _taskDef.Id, + TaskDefId = _task.TaskDef!.Id, TaskWorkerId = _config.WorkerId, }; var response = _bootstrapClient.RegisterTaskWorker(request); - HandleRegisterTaskWorkResponse(response); - + HandleRegisterTaskWorkerResponse(response); } catch (Exception ex) { _logger?.LogError(ex, $"Failed contacting bootstrap host {_config.BootstrapHost}:{_config.BootstrapPort}"); + _runningConnections = new List>(); } } - private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response) + private void HandleRegisterTaskWorkerResponse(RegisterTaskWorkerResponse response) { response.YourHosts.ToList().ForEach(host => { @@ -109,9 +99,9 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response) try { var newConnection = new LHServerConnection(this, host); - newConnection.Connect(); + newConnection.Open(); _runningConnections.Add(newConnection); - _logger?.LogInformation($"Adding connection to: {host.Host}:{host.Port} for task '{_taskDef.Id}'"); + _logger?.LogInformation($"Adding connection to: {host.Host}:{host.Port} for task '{_task.TaskDef!.Id}'"); } catch (IOException ex) { @@ -125,7 +115,7 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response) for (int i = lastIndexOfRunningConnection; i >= 0; i--) { var runningThread = _runningConnections[i]; - + if (!ShouldBeRunning(runningThread, response.YourHosts)) { _logger?.LogInformation($"Stopping worker thread for host {runningThread.HostInfo.Host} : {runningThread.HostInfo.Port}"); @@ -138,51 +128,56 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response) private bool ShouldBeRunning(LHServerConnection runningThread, RepeatedField hosts) { - return hosts.ToList().Any(host => runningThread.IsSame(host)); + return hosts.ToList().Any(host => runningThread.IsSame(host.Host, host.Port)); } private bool IsAlreadyRunning(LHHostInfo host) { - return _runningConnections.Any(conn => conn.IsSame(host)); + return _runningConnections.Any(conn => conn.IsSame(host.Host, host.Port)); } - public async void SubmitTaskForExecution(ScheduledTask scheduledTask, LittleHorseClient client) + public async void SubmitTaskForExecution(ScheduledTask scheduledTask) { await _semaphore.WaitAsync(); - DoTask(scheduledTask, client); + DoTask(scheduledTask); } - private void DoTask(ScheduledTask scheduledTask, LittleHorseClient client) + private void DoTask(ScheduledTask scheduledTask) { ReportTaskRun result = ExecuteTask(scheduledTask, LHMappingHelper.MapDateTimeFromProtoTimeStamp(scheduledTask.CreatedAt)); - _semaphore.Release(); - var wfRunId = LHHelper.GetWFRunId(scheduledTask.Source); + var wfRunId = LHHelper.GetWfRunId(scheduledTask.Source); try { var retriesLeft = MAX_REPORT_RETRIES; - _logger?.LogDebug($"Going to report task for wfRun {wfRunId.Id}"); + _logger?.LogDebug($"Going to report task for wfRun {wfRunId?.Id}"); Policy.Handle().WaitAndRetry(MAX_REPORT_RETRIES, retryAttempt => TimeSpan.FromSeconds(5), onRetry: (exception, timeSpan, retryCount, context) => - { - --retriesLeft; - _logger?.LogDebug($"Failed to report task for wfRun {wfRunId}: {exception.Message}. Retries left: {retriesLeft}"); - _logger?.LogDebug($"Retrying reportTask rpc on taskRun {LHHelper.TaskRunIdToString(result.TaskRunId)}"); - }).Execute(() => RunReportTask(result)); + { + --retriesLeft; + _logger?.LogDebug( + $"Failed to report task for wfRun {wfRunId}: {exception.Message}. Retries left: {retriesLeft}"); + _logger?.LogDebug( + $"Retrying reportTask rpc on taskRun {LHHelper.TaskRunIdToString(result.TaskRunId)}"); + }).Execute(() => RunReportTask(result)); } catch (Exception ex) { _logger?.LogDebug($"Failed to report task for wfRun {wfRunId}: {ex.Message}. No retries left."); } + finally + { + _semaphore.Release(); + } } private void RunReportTask(ReportTaskRun reportedTask) { - var response = _bootstrapClient.ReportTask(reportedTask); + _bootstrapClient.ReportTask(reportedTask); } private ReportTaskRun ExecuteTask(ScheduledTask scheduledTask, DateTime? scheduleTime) @@ -278,14 +273,15 @@ private ReportTaskRun ExecuteTask(ScheduledTask scheduledTask, DateTime? schedul private object? Invoke(ScheduledTask scheduledTask, LHWorkerContext workerContext) { - var inputs = _mappings.Select(mapping => mapping.Assign(scheduledTask, workerContext)).ToArray(); + var inputs = _task.TaskMethodMappings.Select(mapping => mapping.Assign(scheduledTask, workerContext)).ToArray(); - return _taskMethod.Invoke(_executable, inputs); + return _task.TaskMethod!.Invoke(_task.Executable, inputs); } - public void CloseConnection(LHServerConnection connection) + public void CloseConnection(string host, int port) { - var currConn = _runningConnections.Where(c => c.IsSame(connection.HostInfo)).FirstOrDefault(); + var currConn = _runningConnections.FirstOrDefault(c => + c.IsSame(host, port)); if (currConn != null) { diff --git a/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHTask.cs b/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHTask.cs new file mode 100644 index 000000000..4d809f723 --- /dev/null +++ b/sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHTask.cs @@ -0,0 +1,105 @@ +using System.Reflection; +using LittleHorse.Common.Proto; +using LittleHorse.Sdk.Exceptions; +using static LittleHorse.Common.Proto.LittleHorse; + +namespace LittleHorse.Sdk.Worker.Internal; + +public class LHTask +{ + private TaskDef? _taskDef; + private MethodInfo? _taskMethod; + private List _taskMethodMappings; + private readonly string _taskDefName; + private LHTaskSignature? _taskSignature; + private readonly T _executable; + private readonly LittleHorseClient _lhClient; + + public MethodInfo? TaskMethod => _taskMethod; + public List TaskMethodMappings => _taskMethodMappings; + public T Executable => _executable; + public string TaskDefName => _taskDefName; + public TaskDef? TaskDef => _taskDef; + + public LHTask(T executable, string taskDefName, LittleHorseClient lhClient) + { + _taskDefName = taskDefName; + _executable = executable; + _lhClient = lhClient; + _taskMethodMappings = new List(); + } + + internal void PrepareLHTaskMethod() + { + _taskSignature = new LHTaskSignature(_taskDefName, _executable); + _taskMethod = _taskSignature.TaskMethod; + + ValidateTaskMethodParameters(_taskMethod, _taskSignature); + _taskMethodMappings = CreateVariableMappings(_taskMethod, _taskSignature); + } + + private void ValidateTaskMethodParameters(MethodInfo taskMethod, LHTaskSignature taskSignature) + { + _taskDef = GetTaskDef(); + if (taskSignature.HasWorkerContextAtEnd) + { + if (taskSignature.TaskMethod.GetParameters().Length - 1 != _taskDef.InputVars.Count) + { + throw new LHTaskSchemaMismatchException("Number of task method params doesn't match number of taskdef params!"); + } + } + else + { + if (taskMethod.GetParameters().Length != _taskDef.InputVars.Count) + { + throw new LHTaskSchemaMismatchException("Number of task method params doesn't match number of taskdef params!"); + } + } + } + + private List CreateVariableMappings(MethodInfo taskMethod, LHTaskSignature taskSignature) + { + var mappings = new List(); + + var taskParams = taskMethod.GetParameters(); + _taskDef = GetTaskDef(); + + for (int index = 0; index < _taskDef?.InputVars.Count; index++) + { + var taskParam = taskParams[index]; + + if (taskParam.ParameterType.IsAssignableFrom(typeof(LHWorkerContext))) + { + throw new LHTaskSchemaMismatchException("Can only have WorkerContext after all required taskDef params."); + } + + mappings.Add(CreateVariableMapping(_taskDef, index, taskParam.ParameterType, taskParam.Name)); + } + + if (taskSignature.HasWorkerContextAtEnd) + { + mappings.Add(CreateVariableMapping(_taskDef, taskParams.Count() - 1, typeof(LHWorkerContext), null)); + } + + return mappings; + } + + private VariableMapping CreateVariableMapping(TaskDef? taskDef, int index, Type type, string? paramName) + { + return new VariableMapping(taskDef!, index, type, paramName); + } + + internal TaskDef GetTaskDef() + { + if (_taskDef is null) + { + var taskDefId = new TaskDefId + { + Name = _taskDefName + }; + _taskDef = _lhClient.GetTaskDef(taskDefId); + } + + return _taskDef; + } +} \ No newline at end of file diff --git a/sdk-dotnet/LittleHorse.Sdk/Worker/LHTaskWorker.cs b/sdk-dotnet/LittleHorse.Sdk/Worker/LHTaskWorker.cs index 41db02a5d..bcf69d4c1 100644 --- a/sdk-dotnet/LittleHorse.Sdk/Worker/LHTaskWorker.cs +++ b/sdk-dotnet/LittleHorse.Sdk/Worker/LHTaskWorker.cs @@ -1,5 +1,4 @@ -using System.Reflection; -using Grpc.Core; +using Grpc.Core; using LittleHorse.Common.Proto; using LittleHorse.Sdk.Exceptions; using LittleHorse.Sdk.Helper; @@ -18,27 +17,20 @@ namespace LittleHorse.Sdk.Worker /// public class LHTaskWorker { - private LHConfig _config; - private ILogger>? _logger; - private T _executable; - private TaskDef? _taskDef; - private MethodInfo? _taskMethod; - private List _mappings; - private LHTaskSignature? _taskSignature; + private readonly LHConfig _config; + private readonly ILogger>? _logger; private LHServerConnectionManager? _manager; - private string _taskDefName; - private LittleHorseClient _grpcClient; + private readonly LittleHorseClient _lhClient; + private readonly LHTask _task; - public string TaskDefName { get => _taskDefName; } + public string TaskDefName => _task.TaskDefName; public LHTaskWorker(T executable, string taskDefName, LHConfig config) { _config = config; _logger = LHLoggerFactoryProvider.GetLogger>(); - _executable = executable; - _mappings = new List(); - _taskDefName = taskDefName; - _grpcClient = _config.GetGrcpClientInstance(); + _lhClient = _config.GetGrpcClientInstance(); + _task = new LHTask(executable, taskDefName, _lhClient); } /// @@ -53,17 +45,11 @@ public void Start() { if (!TaskDefExists()) { - throw new LHMisconfigurationException($"Couldn't find TaskDef: {_taskDefName}"); + throw new LHMisconfigurationException($"Couldn't find TaskDef: {_task.TaskDefName}"); } - _taskSignature = new LHTaskSignature(_taskDefName, _executable); - _taskMethod = _taskSignature.TaskMethod; - - ValidateTaskMethodParameters(_taskMethod, _taskSignature); - _mappings = CreateVariableMappings(_taskMethod, _taskSignature); - - _manager = new LHServerConnectionManager(_config, _taskMethod, GetTaskDef(), _mappings, _executable); - + _task.PrepareLHTaskMethod(); + _manager = new LHServerConnectionManager(_config, _task); _manager.Start(); } @@ -79,11 +65,7 @@ public bool TaskDefExists() { try { - var taskDefId = new TaskDefId - { - Name = _taskDefName, - }; - _grpcClient.GetTaskDef(taskDefId); + _task.GetTaskDef(); return true; } @@ -124,15 +106,15 @@ public void RegisterTaskDef() /// private void RegisterTaskDef(bool swallowAlreadyExists) { - _logger?.LogInformation($"Creating TaskDef: {_taskDefName}"); + _logger?.LogInformation($"Creating TaskDef: {_task.TaskDefName}"); try { - var signature = new LHTaskSignature(_taskDefName, _executable); + var signature = new LHTaskSignature(_task.TaskDefName, _task.Executable); var request = new PutTaskDefRequest { - Name = _taskDefName + Name = _task.TaskDefName }; foreach (var lhMethodParam in signature.LhMethodParams) @@ -151,7 +133,7 @@ private void RegisterTaskDef(bool swallowAlreadyExists) request.OutputSchema = signature.TaskDefOutputSchema; } - var response = _grpcClient.PutTaskDef(request); + var response = _lhClient.PutTaskDef(request); _logger?.LogInformation($"Created TaskDef:\n{LHMappingHelper.MapProtoToJson(response)}"); } @@ -159,7 +141,7 @@ private void RegisterTaskDef(bool swallowAlreadyExists) { if (swallowAlreadyExists && ex.StatusCode == StatusCode.AlreadyExists) { - _logger?.LogInformation($"TaskDef {_taskDefName} already exists!"); + _logger?.LogInformation($"TaskDef {_task.TaskDefName} already exists!"); } else { @@ -167,64 +149,5 @@ private void RegisterTaskDef(bool swallowAlreadyExists) } } } - - private TaskDef GetTaskDef() - { - if (_taskDef is null) - { - _taskDef = _config.GetTaskDef(_taskDefName); - } - - return _taskDef; - } - - private void ValidateTaskMethodParameters(MethodInfo taskMethod, LHTaskSignature taskSignature) - { - if (taskSignature.HasWorkerContextAtEnd) - { - if (taskSignature.TaskMethod.GetParameters().Length - 1 != GetTaskDef().InputVars.Count) - { - throw new LHTaskSchemaMismatchException("Number of task method params doesn't match number of taskdef params!"); - } - } - else - { - if (taskMethod.GetParameters().Length != GetTaskDef().InputVars.Count) - { - throw new LHTaskSchemaMismatchException("Number of task method params doesn't match number of taskdef params!"); - } - } - } - - private List CreateVariableMappings(MethodInfo taskMethod, LHTaskSignature taskSignature) - { - var mappings = new List(); - - var taskParams = taskMethod.GetParameters(); - - for (int index = 0; index < GetTaskDef().InputVars.Count; index++) - { - var taskParam = taskParams[index]; - - if (taskParam.ParameterType.IsAssignableFrom(typeof(LHWorkerContext))) - { - throw new LHTaskSchemaMismatchException("Can only have WorkerContext after all required taskDef params."); - } - - mappings.Add(CreateVariableMapping(GetTaskDef(), index, taskParam.ParameterType, taskParam.Name)); - } - - if (taskSignature.HasWorkerContextAtEnd) - { - mappings.Add(CreateVariableMapping(GetTaskDef(), taskParams.Count() - 1, typeof(LHWorkerContext), null)); - } - - return mappings; - } - - private VariableMapping CreateVariableMapping(TaskDef taskDef, int index, Type type, string? paramName) - { - return new VariableMapping(taskDef, index, type, paramName); - } } } diff --git a/sdk-dotnet/LittleHorse.Sdk/Worker/LHWorkerContext.cs b/sdk-dotnet/LittleHorse.Sdk/Worker/LHWorkerContext.cs index 4850f59b7..75991a537 100644 --- a/sdk-dotnet/LittleHorse.Sdk/Worker/LHWorkerContext.cs +++ b/sdk-dotnet/LittleHorse.Sdk/Worker/LHWorkerContext.cs @@ -1,5 +1,4 @@ -using System.Text; -using LittleHorse.Common.Proto; +using LittleHorse.Common.Proto; using LittleHorse.Sdk.Helper; namespace LittleHorse.Sdk.Worker @@ -35,7 +34,7 @@ public LHWorkerContext(ScheduledTask scheduleTask, DateTime? scheduleDateTime) /// public WfRunId GetWfRunId() { - return LHHelper.GetWFRunId(_scheduleTask.Source); + return LHHelper.GetWfRunId(_scheduleTask.Source)!; } /// @@ -44,7 +43,7 @@ public WfRunId GetWfRunId() /// /// @return a `NodeRunIdPb` protobuf class with the ID from the executed NodeRun. /// - public NodeRunId GetNodeRunId() + public NodeRunId? GetNodeRunId() { TaskRunSource source = _scheduleTask.Source; switch (source.TaskRunSourceCase) { From 0486f18aeaefc4ea9d9b926c659a9c20a27abda0 Mon Sep 17 00:00:00 2001 From: Jacob Snarr Date: Mon, 16 Dec 2024 18:06:25 -0500 Subject: [PATCH 3/3] fix(lhctl): Add revisionNumber argument to `wfSpec delete` example (#1193) --- lhctl/internal/wf_spec.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhctl/internal/wf_spec.go b/lhctl/internal/wf_spec.go index ea7e6a6d7..149f5d7e0 100644 --- a/lhctl/internal/wf_spec.go +++ b/lhctl/internal/wf_spec.go @@ -140,7 +140,7 @@ Returns a list of ObjectId's that can be passed into 'lhctl get wfSpec'. } var deleteWfSpecCmd = &cobra.Command{ - Use: "wfSpec ", + Use: "wfSpec ", Short: "Delete a WfSpec.", Long: `Delete a WfSpec. You must provide the name and exact version of the WfSpec to delete.