From c789b7ea1a02dc443735d426573b5219826cb3ef Mon Sep 17 00:00:00 2001 From: luke-lombardi <33990301+luke-lombardi@users.noreply.github.com> Date: Thu, 26 Dec 2024 11:20:27 -0500 Subject: [PATCH] Hotfix to allow re-registering of ttld machines (#810) --- pkg/api/v1/machine.go | 7 ++++++- pkg/repository/base.go | 2 +- pkg/repository/provider_redis.go | 12 ++++++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pkg/api/v1/machine.go b/pkg/api/v1/machine.go index 1b8235277..d26a4e3b1 100644 --- a/pkg/api/v1/machine.go +++ b/pkg/api/v1/machine.go @@ -83,6 +83,11 @@ func (g *MachineGroup) RegisterMachine(ctx echo.Context) error { hostName = fmt.Sprintf("%s.%s.%s", request.HostName, g.config.Tailscale.User, g.config.Tailscale.HostName) } + poolConfig, ok := g.config.Worker.Pools[request.PoolName] + if !ok { + return HTTPInternalServerError("Invalid pool name") + } + err = g.providerRepo.RegisterMachine(request.ProviderName, request.PoolName, request.MachineID, &types.ProviderMachineState{ MachineId: request.MachineID, Token: request.Token, @@ -90,7 +95,7 @@ func (g *MachineGroup) RegisterMachine(ctx echo.Context) error { Cpu: cpu, Memory: memory, GpuCount: uint32(gpuCount), - }) + }, &poolConfig) if err != nil { return HTTPInternalServerError("Failed to register machine") } diff --git a/pkg/repository/base.go b/pkg/repository/base.go index e817a7e9d..d2839882f 100755 --- a/pkg/repository/base.go +++ b/pkg/repository/base.go @@ -158,7 +158,7 @@ type ProviderRepository interface { RemoveMachine(providerName, poolName, machineId string) error SetMachineKeepAlive(providerName, poolName, machineId, agentVersion string, metrics *types.ProviderMachineMetrics) error SetLastWorkerSeen(providerName, poolName, machineId string) error - RegisterMachine(providerName, poolName, machineId string, newMachineInfo *types.ProviderMachineState) error + RegisterMachine(providerName, poolName, machineId string, newMachineInfo *types.ProviderMachineState, poolConfig *types.WorkerPoolConfig) error WaitForMachineRegistration(providerName, poolName, machineId string) (*types.ProviderMachineState, error) ListAllMachines(providerName, poolName string, useLock bool) ([]*types.ProviderMachine, error) SetMachineLock(providerName, poolName, machineId string) error diff --git a/pkg/repository/provider_redis.go b/pkg/repository/provider_redis.go index 7ee2dc660..9dca4df20 100644 --- a/pkg/repository/provider_redis.go +++ b/pkg/repository/provider_redis.go @@ -299,12 +299,20 @@ func (r *ProviderRedisRepository) RemoveMachine(providerName, poolName, machineI return nil } -func (r *ProviderRedisRepository) RegisterMachine(providerName, poolName, machineId string, newMachineInfo *types.ProviderMachineState) error { +func (r *ProviderRedisRepository) RegisterMachine(providerName, poolName, machineId string, newMachineInfo *types.ProviderMachineState, poolConfig *types.WorkerPoolConfig) error { stateKey := common.RedisKeys.ProviderMachineState(providerName, poolName, machineId) machineInfo, err := r.getMachineStateFromKey(stateKey) if err != nil { - return fmt.Errorf("failed to get machine state <%v>: %w", stateKey, err) + // TODO: This is a temporary fix to allow the machine to be registered + // without having to update the machine state, in the future we should tie + // registration token to machine ID and store that somewhere else persistently + machineInfo = &types.ProviderMachineState{} + machineInfo.Gpu = poolConfig.GPUType + machineInfo.Created = fmt.Sprintf("%d", time.Now().UTC().Unix()) + machineInfo.LastKeepalive = fmt.Sprintf("%d", time.Now().UTC().Unix()) + machineInfo.PoolName = newMachineInfo.PoolName + machineInfo.MachineId = newMachineInfo.MachineId } machineInfo.HostName = newMachineInfo.HostName