Skip to content

Commit

Permalink
fix(sdk-dotnet): fix task worker connection manager (#1191)
Browse files Browse the repository at this point in the history
* Fix task worker connection manager to send hosts and ports available
* Fix rebalance when the boostrap sever is down and then up
  • Loading branch information
KarlaCarvajal authored Dec 16, 2024
1 parent 5852cb1 commit 608a593
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 176 deletions.
2 changes: 1 addition & 1 deletion sdk-dotnet/Examples/BasicExample/MyWorker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace Examples.BasicExample
{
public class MyWorker
{
[LHTaskMethod("greet-dotnet")]
[LHTaskMethod("greet")]
public string Greeting(string name)
{
var message = $"Hello team, This is a Dotnet Worker";
Expand Down
4 changes: 2 additions & 2 deletions sdk-dotnet/Examples/BasicExample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ static void Main(string[] args)
{
var loggerFactory = _serviceProvider.GetRequiredService<ILoggerFactory>();
var config = GetLHConfig(args, loggerFactory);

MyWorker executable = new MyWorker();
var taskWorker = new LHTaskWorker<MyWorker>(executable, "greet-dotnet", config);
var taskWorker = new LHTaskWorker<MyWorker>(executable, "greet", config);

taskWorker.RegisterTaskDef();

Expand Down
18 changes: 9 additions & 9 deletions sdk-dotnet/LittleHorse.Sdk.Tests/Worker/VariableMappingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void VariableMapping_WithValidLHTypes_ShouldBeBuiltSuccessfully()
foreach (var type in testAllowedTypes)
{
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var result = new VariableMapping(taskDef, position, type, paramName);

Expand All @@ -53,7 +53,7 @@ public void VariableMapping_WithMismatchTypesInt_ShouldThrowException()
Type type1 = typeof(Int64);
Type type2 = typeof(string);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -67,7 +67,7 @@ public void VariableMapping_WithMismatchTypeDouble_ShouldThrowException()
Type type1 = typeof(double);
Type type2 = typeof(Int64);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -81,7 +81,7 @@ public void VariableMapping_WithMismatchTypeString_ShouldThrowException()
Type type1 = typeof(string);
Type type2 = typeof(double);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -95,7 +95,7 @@ public void VariableMapping_WithMismatchTypeBool_ShouldThrowException()
Type type1 = typeof(bool);
Type type2 = typeof(string);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand All @@ -109,7 +109,7 @@ public void VariableMapping_WithMismatchTypeBytes_ShouldThrowException()
Type type1 = typeof(byte[]);
Type type2 = typeof(string);
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type1);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var exception = Assert.Throws<LHTaskSchemaMismatchException>(
() => new VariableMapping(taskDef, 0, type2, "any param name"));
Expand Down Expand Up @@ -302,11 +302,11 @@ public void VariableMapping_WithAssignJsonStringValue_ShouldReturnCustomObject()
Assert.Equal(expectedObject.Cars!.Count, actualObject.Cars!.Count);
}

private TaskDef getTaskDefForTest(VariableType type)
private TaskDef? getTaskDefForTest(VariableType type)
{
var inputVar = new VariableDef();
inputVar.Type = type;
TaskDef taskDef = new TaskDef();
TaskDef? taskDef = new TaskDef();
TaskDefId taskDefId = new TaskDefId();
taskDef.Id = taskDefId;
taskDef.InputVars.Add(inputVar);
Expand All @@ -317,7 +317,7 @@ private TaskDef getTaskDefForTest(VariableType type)
private VariableMapping getVariableMappingForTest(Type type, string paramName, int position)
{
var variableType = LHMappingHelper.MapDotNetTypeToLHVariableType(type);
TaskDef taskDef = getTaskDefForTest(variableType);
TaskDef? taskDef = getTaskDefForTest(variableType);

var variableMapping = new VariableMapping(taskDef, position, type, paramName);

Expand Down
2 changes: 1 addition & 1 deletion sdk-dotnet/LittleHorse.Sdk/Helper/LHHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace LittleHorse.Sdk.Helper
{
public static class LHHelper
{
public static WfRunId GetWFRunId(TaskRunSource taskRunSource)
public static WfRunId? GetWfRunId(TaskRunSource taskRunSource)
{
switch (taskRunSource.TaskRunSourceCase)
{
Expand Down
10 changes: 5 additions & 5 deletions sdk-dotnet/LittleHorse.Sdk/LHConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ private bool IsOAuth
}
}

public LittleHorseClient GetGrcpClientInstance()
public LittleHorseClient GetGrpcClientInstance()
{
return GetGrcpClientInstance(BootstrapHost, BootstrapPort);
return GetGrpcClientInstance(BootstrapHost, BootstrapPort);
}

public LittleHorseClient GetGrcpClientInstance(string host, int port)
public LittleHorseClient GetGrpcClientInstance(string host, int port)
{
string channelKey = BootstrapServer;
string channelKey = $"{BootstrapProtocol}://{host}:{port}";

if (_createdChannels.ContainsKey(channelKey))
{
Expand Down Expand Up @@ -208,7 +208,7 @@ public TaskDef GetTaskDef(string taskDefName)
{
try
{
var client = GetGrcpClientInstance();
var client = GetGrpcClientInstance();
var taskDefId = new TaskDefId()
{
Name = taskDefName
Expand Down
28 changes: 14 additions & 14 deletions sdk-dotnet/LittleHorse.Sdk/Worker/Internal/LHServerConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@ namespace LittleHorse.Sdk.Worker.Internal
{
public class LHServerConnection<T> : IDisposable
{
private LHServerConnectionManager<T> _connectionManager;
private LHHostInfo _hostInfo;
private readonly LHServerConnectionManager<T> _connectionManager;
private readonly LHHostInfo _hostInfo;
private bool _running;
private LittleHorseClient _client;
private AsyncDuplexStreamingCall<PollTaskRequest, PollTaskResponse> _call;
private ILogger? _logger;
private readonly LittleHorseClient _client;
private readonly AsyncDuplexStreamingCall<PollTaskRequest, PollTaskResponse> _call;
private readonly ILogger? _logger;

public LHHostInfo HostInfo { get { return _hostInfo; } }
public LHHostInfo HostInfo => _hostInfo;

public LHServerConnection(LHServerConnectionManager<T> connectionManager, LHHostInfo hostInfo)
{
_connectionManager = connectionManager;
_hostInfo = hostInfo;
_logger = LHLoggerFactoryProvider.GetLogger<LHServerConnection<T>>();
_client = _connectionManager.Config.GetGrcpClientInstance();
_client = _connectionManager.Config.GetGrpcClientInstance(hostInfo.Host, hostInfo.Port);
_call = _client.PollTask();
}

public void Connect()
public void Open()
{
_running = true;
Task.Run(RequestMoreWorkAsync);
Expand All @@ -48,12 +48,12 @@ private async Task RequestMoreWorkAsync()
if (taskToDo.Result != null)
{
var scheduledTask = taskToDo.Result;
var wFRunId = LHHelper.GetWFRunId(scheduledTask.Source);
_logger?.LogDebug($"Received task schedule request for wfRun {wFRunId.Id}");
var wFRunId = LHHelper.GetWfRunId(scheduledTask.Source);
_logger?.LogDebug($"Received task schedule request for wfRun {wFRunId?.Id}");

_connectionManager.SubmitTaskForExecution(scheduledTask, _client);
_connectionManager.SubmitTaskForExecution(scheduledTask);

_logger?.LogDebug($"Scheduled task on threadpool for wfRun {wFRunId.Id}");
_logger?.LogDebug($"Scheduled task on threadpool for wfRun {wFRunId?.Id}");
}
else
{
Expand Down Expand Up @@ -82,9 +82,9 @@ public void Dispose()
_running = false;
}

public bool IsSame(LHHostInfo hostInfoToCompare)
public bool IsSame(string host, int port)
{
return _hostInfo.Host.Equals(hostInfoToCompare.Host) && _hostInfo.Port == hostInfoToCompare.Port;
return _hostInfo.Host.Equals(host) && _hostInfo.Port == port;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,25 @@ public class LHServerConnectionManager<T> : IDisposable
private const int BALANCER_SLEEP_TIME = 5000;
private const int MAX_REPORT_RETRIES = 5;

private LHConfig _config;
private MethodInfo _taskMethod;
private TaskDef _taskDef;
private List<VariableMapping> _mappings;
private T _executable;
private ILogger? _logger;
private LittleHorseClient _bootstrapClient;
private readonly LHConfig _config;
private readonly ILogger? _logger;
private readonly LittleHorseClient _bootstrapClient;
private bool _running;
private List<LHServerConnection<T>> _runningConnections;
private Thread _rebalanceThread;
private SemaphoreSlim _semaphore;
private readonly Thread _rebalanceThread;
private readonly SemaphoreSlim _semaphore;
private readonly LHTask<T> _task;

public LHConfig Config { get { return _config; } }
public TaskDef TaskDef { get { return _taskDef; } }
public LHConfig Config => _config;
public TaskDef TaskDef => _task.TaskDef!;

public LHServerConnectionManager(LHConfig config,
MethodInfo taskMethod,
TaskDef taskDef,
List<VariableMapping> mappings,
T executable)
LHTask<T> task)
{
_config = config;
_taskMethod = taskMethod;
_taskDef = taskDef;
_mappings = mappings;
_executable = executable;
_logger = LHLoggerFactoryProvider.GetLogger<LHServerConnectionManager<T>>();

_bootstrapClient = config.GetGrcpClientInstance();
_task = task;
_bootstrapClient = config.GetGrpcClientInstance();

_running = false;
_runningConnections = new List<LHServerConnection<T>>();
Expand Down Expand Up @@ -85,22 +75,22 @@ private void DoHeartBeat()
{
var request = new RegisterTaskWorkerRequest
{
TaskDefId = _taskDef.Id,
TaskDefId = _task.TaskDef!.Id,
TaskWorkerId = _config.WorkerId,
};

var response = _bootstrapClient.RegisterTaskWorker(request);

HandleRegisterTaskWorkResponse(response);

HandleRegisterTaskWorkerResponse(response);
}
catch (Exception ex)
{
_logger?.LogError(ex, $"Failed contacting bootstrap host {_config.BootstrapHost}:{_config.BootstrapPort}");
_runningConnections = new List<LHServerConnection<T>>();
}
}

private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response)
private void HandleRegisterTaskWorkerResponse(RegisterTaskWorkerResponse response)
{
response.YourHosts.ToList().ForEach(host =>
{
Expand All @@ -109,9 +99,9 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response)
try
{
var newConnection = new LHServerConnection<T>(this, host);
newConnection.Connect();
newConnection.Open();
_runningConnections.Add(newConnection);
_logger?.LogInformation($"Adding connection to: {host.Host}:{host.Port} for task '{_taskDef.Id}'");
_logger?.LogInformation($"Adding connection to: {host.Host}:{host.Port} for task '{_task.TaskDef!.Id}'");
}
catch (IOException ex)
{
Expand All @@ -125,7 +115,7 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response)
for (int i = lastIndexOfRunningConnection; i >= 0; i--)
{
var runningThread = _runningConnections[i];

if (!ShouldBeRunning(runningThread, response.YourHosts))
{
_logger?.LogInformation($"Stopping worker thread for host {runningThread.HostInfo.Host} : {runningThread.HostInfo.Port}");
Expand All @@ -138,51 +128,56 @@ private void HandleRegisterTaskWorkResponse(RegisterTaskWorkerResponse response)

private bool ShouldBeRunning(LHServerConnection<T> runningThread, RepeatedField<LHHostInfo> hosts)
{
return hosts.ToList().Any(host => runningThread.IsSame(host));
return hosts.ToList().Any(host => runningThread.IsSame(host.Host, host.Port));
}

private bool IsAlreadyRunning(LHHostInfo host)
{
return _runningConnections.Any(conn => conn.IsSame(host));
return _runningConnections.Any(conn => conn.IsSame(host.Host, host.Port));
}

public async void SubmitTaskForExecution(ScheduledTask scheduledTask, LittleHorseClient client)
public async void SubmitTaskForExecution(ScheduledTask scheduledTask)
{
await _semaphore.WaitAsync();

DoTask(scheduledTask, client);
DoTask(scheduledTask);
}

private void DoTask(ScheduledTask scheduledTask, LittleHorseClient client)
private void DoTask(ScheduledTask scheduledTask)
{
ReportTaskRun result = ExecuteTask(scheduledTask, LHMappingHelper.MapDateTimeFromProtoTimeStamp(scheduledTask.CreatedAt));
_semaphore.Release();

var wfRunId = LHHelper.GetWFRunId(scheduledTask.Source);
var wfRunId = LHHelper.GetWfRunId(scheduledTask.Source);

try
{
var retriesLeft = MAX_REPORT_RETRIES;

_logger?.LogDebug($"Going to report task for wfRun {wfRunId.Id}");
_logger?.LogDebug($"Going to report task for wfRun {wfRunId?.Id}");
Policy.Handle<Exception>().WaitAndRetry(MAX_REPORT_RETRIES,
retryAttempt => TimeSpan.FromSeconds(5),
onRetry: (exception, timeSpan, retryCount, context) =>
{
--retriesLeft;
_logger?.LogDebug($"Failed to report task for wfRun {wfRunId}: {exception.Message}. Retries left: {retriesLeft}");
_logger?.LogDebug($"Retrying reportTask rpc on taskRun {LHHelper.TaskRunIdToString(result.TaskRunId)}");
}).Execute(() => RunReportTask(result));
{
--retriesLeft;
_logger?.LogDebug(
$"Failed to report task for wfRun {wfRunId}: {exception.Message}. Retries left: {retriesLeft}");
_logger?.LogDebug(
$"Retrying reportTask rpc on taskRun {LHHelper.TaskRunIdToString(result.TaskRunId)}");
}).Execute(() => RunReportTask(result));
}
catch (Exception ex)
{
_logger?.LogDebug($"Failed to report task for wfRun {wfRunId}: {ex.Message}. No retries left.");
}
finally
{
_semaphore.Release();
}
}

private void RunReportTask(ReportTaskRun reportedTask)
{
var response = _bootstrapClient.ReportTask(reportedTask);
_bootstrapClient.ReportTask(reportedTask);
}

private ReportTaskRun ExecuteTask(ScheduledTask scheduledTask, DateTime? scheduleTime)
Expand Down Expand Up @@ -278,14 +273,15 @@ private ReportTaskRun ExecuteTask(ScheduledTask scheduledTask, DateTime? schedul

private object? Invoke(ScheduledTask scheduledTask, LHWorkerContext workerContext)
{
var inputs = _mappings.Select(mapping => mapping.Assign(scheduledTask, workerContext)).ToArray();
var inputs = _task.TaskMethodMappings.Select(mapping => mapping.Assign(scheduledTask, workerContext)).ToArray();

return _taskMethod.Invoke(_executable, inputs);
return _task.TaskMethod!.Invoke(_task.Executable, inputs);
}

public void CloseConnection(LHServerConnection<T> connection)
public void CloseConnection(string host, int port)
{
var currConn = _runningConnections.Where(c => c.IsSame(connection.HostInfo)).FirstOrDefault();
var currConn = _runningConnections.FirstOrDefault(c =>
c.IsSame(host, port));

if (currConn != null)
{
Expand Down
Loading

0 comments on commit 608a593

Please sign in to comment.