From 7397bbe480dc8216a29977bd131542ed02178cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saulius=20Ma=C5=A1nauskas?= Date: Wed, 12 May 2021 15:44:48 +0300 Subject: [PATCH] feat: send cluster deltas instead of full snapshots (#22) --- go.mod | 2 - go.sum | 9 +- internal/castai/castai.go | 107 ++----- internal/castai/castai_test.go | 58 ---- internal/castai/mock/client.go | 26 +- internal/castai/types.go | 32 +- internal/services/collector/collector.go | 236 --------------- internal/services/collector/mock/collector.go | 66 ---- internal/services/collector/types.go | 24 -- internal/services/controller/controller.go | 284 ++++++++++++++++++ .../services/controller/controller_test.go | 90 ++++++ internal/services/providers/castai/castai.go | 19 +- internal/services/providers/eks/eks.go | 61 ++-- internal/services/providers/eks/eks_test.go | 112 +++---- .../services/providers/types/mock/provider.go | 15 + internal/services/providers/types/types.go | 4 +- internal/services/version/mock/version.go | 62 ++++ internal/services/version/version.go | 51 ++++ internal/services/version/version_test.go | 44 +++ internal/services/worker/worker.go | 180 ----------- internal/services/worker/worker_test.go | 69 ----- main.go | 30 +- 22 files changed, 694 insertions(+), 887 deletions(-) delete mode 100644 internal/services/collector/collector.go delete mode 100644 internal/services/collector/mock/collector.go delete mode 100644 internal/services/collector/types.go create mode 100644 internal/services/controller/controller.go create mode 100644 internal/services/controller/controller_test.go create mode 100644 internal/services/version/mock/version.go create mode 100644 internal/services/version/version.go create mode 100644 internal/services/version/version_test.go delete mode 100644 internal/services/worker/worker.go delete mode 100644 internal/services/worker/worker_test.go diff --git a/go.mod b/go.mod index f340fdd1..b08f79e2 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,10 @@ go 1.16 require ( github.com/aws/aws-sdk-go v1.37.23 - github.com/cenkalti/backoff/v4 v4.1.0 github.com/go-resty/resty/v2 v2.5.0 github.com/golang/mock v1.4.1 github.com/google/uuid v1.1.2 github.com/jarcoal/httpmock v1.0.8 - github.com/patrickmn/go-cache v2.1.0+incompatible github.com/sirupsen/logrus v1.7.0 github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.6.1 diff --git a/go.sum b/go.sum index 9aa00c16..d4b56ad0 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,6 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= -github.com/cenkalti/backoff/v4 v4.1.0 h1:c8LkOFQTzuO0WBM/ae5HdGQuZPfPxp7lqBRwQRm4fSc= -github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -91,6 +89,7 @@ github.com/emicklei/go-restful v2.9.5+incompatible/go.mod h1:otzb+WCGbkyDHkqmQmT github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v4.5.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch v4.9.0+incompatible h1:kLcOMZeuLAJvL2BPWLMIj5oaZQobrkAqrL+WFZwQses= github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= @@ -130,6 +129,7 @@ github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4er github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -204,6 +204,7 @@ github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= @@ -295,13 +296,12 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.10.2 h1:aY/nuoWlKJud2J6U0E3NWsjlg+0GtwXxgEqthRdzlcs= github.com/onsi/gomega v1.10.2/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -687,6 +687,7 @@ k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE= k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/klog/v2 v2.4.0 h1:7+X0fUguPyrKEC4WjH8iGDg3laWgMo5tMnRTIGTTxGQ= k8s.io/klog/v2 v2.4.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= +k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd h1:sOHNzJIkytDF6qadMNKhhDRpc6ODik8lVC6nOur7B2c= k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd/go.mod h1:WOJ3KddDSol4tAGcJo0Tvi+dK12EcqSLqcWsryKMpfM= k8s.io/utils v0.0.0-20201110183641-67b214c5f920/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= k8s.io/utils v0.0.0-20210111153108-fddb29f9d009 h1:0T5IaWHO3sJTEmCP6mUlBvMukxPKUQWqiI/YuiBNMiQ= diff --git a/internal/castai/castai.go b/internal/castai/castai.go index 58997da0..b58236cd 100644 --- a/internal/castai/castai.go +++ b/internal/castai/castai.go @@ -7,13 +7,10 @@ import ( "encoding/json" "fmt" "io" - "mime/multipart" "net/http" - "net/textproto" "net/url" "time" - "github.com/cenkalti/backoff/v4" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" @@ -27,8 +24,8 @@ const ( ) var ( - hdrContentType = http.CanonicalHeaderKey("Content-Type") - hdrContentDisposition = http.CanonicalHeaderKey("Content-Disposition") + hdrContentType = http.CanonicalHeaderKey("Content-Type") + hdrAPIKey = http.CanonicalHeaderKey(headerAPIKey) ) // Client responsible for communication between the agent and CAST AI API. @@ -36,12 +33,10 @@ type Client interface { // RegisterCluster sends a request to CAST AI containing discovered cluster properties used to authenticate the // cluster and register it. RegisterCluster(ctx context.Context, req *RegisterClusterRequest) (*RegisterClusterResponse, error) - // SendClusterSnapshot sends a cluster snapshot to CAST AI to enable savings estimations / autoscaling / etc. - SendClusterSnapshot(ctx context.Context, snap *Snapshot) error - // SendClusterSnapshotWithRetry sends cluster snapshot with retries to CAST AI to enable savings estimations / autoscaling / etc. - SendClusterSnapshotWithRetry(ctx context.Context, snap *Snapshot) error // GetAgentCfg is used to poll CAST AI for agent configuration which can be updated via UI or other means. GetAgentCfg(ctx context.Context, clusterID string) (*AgentCfgResponse, error) + // SendDelta sends the kubernetes state change to CAST AI. Function is noop when items are empty. + SendDelta(ctx context.Context, delta *Delta) error } // NewClient creates and configures the CAST AI client. @@ -60,7 +55,7 @@ func NewDefaultClient() *resty.Client { client.SetHostURL(fmt.Sprintf("https://%s", cfg.URL)) client.SetRetryCount(defaultRetryCount) client.SetTimeout(defaultTimeout) - client.Header.Set(headerAPIKey, cfg.Key) + client.Header.Set(hdrAPIKey, cfg.Key) return client } @@ -70,35 +65,17 @@ type client struct { rest *resty.Client } -func (c *client) RegisterCluster(ctx context.Context, req *RegisterClusterRequest) (*RegisterClusterResponse, error) { - body := &RegisterClusterResponse{} - resp, err := c.rest.R(). - SetBody(req). - SetResult(body). - SetContext(ctx). - Post("/v1/kubernetes/external-clusters") - if err != nil { - return nil, err - } - if resp.IsError() { - return nil, fmt.Errorf("request error status_code=%d body=%s", resp.StatusCode(), resp.Body()) - } - - c.log.Infof("cluster registered: %+v", body) - - return body, nil -} +func (c *client) SendDelta(ctx context.Context, delta *Delta) error { + c.log.Infof("sending delta with items[%d]", len(delta.Items)) -func (c *client) SendClusterSnapshot(ctx context.Context, snap *Snapshot) error { cfg := config.Get().API - uri, err := url.Parse(fmt.Sprintf("https://%s/v1/agent/snapshot", cfg.URL)) + uri, err := url.Parse(fmt.Sprintf("https://%s/v1/agent/cluster-delta", cfg.URL)) if err != nil { - return err + return fmt.Errorf("invalid url: %w", err) } r, w := io.Pipe() - mw := multipart.NewWriter(w) go func() { defer func() { @@ -106,27 +83,23 @@ func (c *client) SendClusterSnapshot(ctx context.Context, snap *Snapshot) error c.log.Errorf("closing pipe: %v", err) } }() - defer func() { - if err := mw.Close(); err != nil { - c.log.Errorf("closing multipart writer: %w", err) - } - }() - if err := writeSnapshotPart(mw, snap); err != nil { - c.log.Errorf("writing snapshot content: %v", err) + + if err := json.NewEncoder(w).Encode(delta); err != nil { + c.log.Errorf("marshaling delta: %v", err) } }() req, err := http.NewRequestWithContext(ctx, http.MethodPost, uri.String(), r) if err != nil { - return err + return fmt.Errorf("creating delta request: %w", err) } - req.Header.Set(hdrContentType, mw.FormDataContentType()) - req.Header.Set(headerAPIKey, cfg.Key) + req.Header.Set(hdrContentType, "application/json") + req.Header.Set(hdrAPIKey, cfg.Key) resp, err := c.rest.GetClient().Do(req) if err != nil { - return err + return fmt.Errorf("sending delta request: %w", err) } defer func() { if err := resp.Body.Close(); err != nil { @@ -139,30 +112,31 @@ func (c *client) SendClusterSnapshot(ctx context.Context, snap *Snapshot) error if _, err := buf.ReadFrom(resp.Body); err != nil { c.log.Errorf("failed reading error response body: %v", err) } - return fmt.Errorf("request failed with status_code=%d", resp.StatusCode) + return fmt.Errorf("delta request error status_code=%d body=%s", resp.StatusCode, buf.String()) } - c.log.Infof( - "snapshot with nodes[%d], pods[%d] sent, response_code=%d", - len(snap.NodeList.Items), - len(snap.PodList.Items), - resp.StatusCode, - ) + c.log.Infof("delta with items[%d] sent, response_code=%d", len(delta.Items), resp.StatusCode) return nil } -func (c *client) SendClusterSnapshotWithRetry(ctx context.Context, snap *Snapshot) error { - b := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) - op := func() error { - return c.SendClusterSnapshot(ctx, snap) +func (c *client) RegisterCluster(ctx context.Context, req *RegisterClusterRequest) (*RegisterClusterResponse, error) { + body := &RegisterClusterResponse{} + resp, err := c.rest.R(). + SetBody(req). + SetResult(body). + SetContext(ctx). + Post("/v1/kubernetes/external-clusters") + if err != nil { + return nil, err } - - if err := backoff.Retry(op, b); err != nil { - return fmt.Errorf("sending snapshot data: %v", err) + if resp.IsError() { + return nil, fmt.Errorf("request error status_code=%d body=%s", resp.StatusCode(), resp.Body()) } - return nil + c.log.Infof("cluster registered: %+v", body) + + return body, nil } func (c *client) GetAgentCfg(ctx context.Context, clusterID string) (*AgentCfgResponse, error) { @@ -180,20 +154,3 @@ func (c *client) GetAgentCfg(ctx context.Context, clusterID string) (*AgentCfgRe return body, nil } - -func writeSnapshotPart(mw *multipart.Writer, snap *Snapshot) error { - header := textproto.MIMEHeader{} - header.Set(hdrContentDisposition, `form-data; name="payload"; filename="payload.json"`) - header.Set(hdrContentType, "application/json") - - bw, err := mw.CreatePart(header) - if err != nil { - return fmt.Errorf("creating payload part: %w", err) - } - - if err := json.NewEncoder(bw).Encode(snap); err != nil { - return fmt.Errorf("marshaling snapshot payload: %w", err) - } - - return nil -} diff --git a/internal/castai/castai_test.go b/internal/castai/castai_test.go index a0a82232..71c66193 100644 --- a/internal/castai/castai_test.go +++ b/internal/castai/castai_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "net/http" - "os" "testing" "github.com/go-resty/resty/v2" @@ -12,10 +11,6 @@ import ( "github.com/jarcoal/httpmock" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "castai-agent/internal/services/collector" ) func TestClient_RegisterCluster(t *testing.T) { @@ -40,56 +35,3 @@ func TestClient_RegisterCluster(t *testing.T) { require.NoError(t, err) require.Equal(t, registerClusterResp, got) } - -func TestClient_SendClusterSnapshot(t *testing.T) { - require.NoError(t, os.Setenv("API_KEY", "api-key")) - require.NoError(t, os.Setenv("API_URL", "localhost")) - - rest := resty.New() - httpmock.ActivateNonDefault(rest.GetClient()) - defer httpmock.Reset() - - c := NewClient(logrus.New(), rest) - - snapshot := &Snapshot{ - ClusterID: uuid.New().String(), - ClusterData: &collector.ClusterData{ - NodeList: &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - }, - }, - }, - PodList: &corev1.PodList{ - Items: []corev1.Pod{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - }, - }, - }, - }, - } - - httpmock.RegisterResponder(http.MethodPost, "https://localhost/v1/agent/snapshot", func(req *http.Request) (*http.Response, error) { - f, _, err := req.FormFile("payload") - require.NoError(t, err) - - actualRequest := &Snapshot{} - require.NoError(t, json.NewDecoder(f).Decode(actualRequest)) - - require.Equal(t, snapshot, actualRequest) - - require.Equal(t, "api-key", req.Header.Get(headerAPIKey)) - - return httpmock.NewStringResponse(http.StatusNoContent, "ok"), nil - }) - - err := c.SendClusterSnapshot(context.Background(), snapshot) - - require.NoError(t, err) -} diff --git a/internal/castai/mock/client.go b/internal/castai/mock/client.go index 30440457..c5f024bf 100644 --- a/internal/castai/mock/client.go +++ b/internal/castai/mock/client.go @@ -65,30 +65,16 @@ func (mr *MockClientMockRecorder) RegisterCluster(arg0, arg1 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterCluster", reflect.TypeOf((*MockClient)(nil).RegisterCluster), arg0, arg1) } -// SendClusterSnapshot mocks base method. -func (m *MockClient) SendClusterSnapshot(arg0 context.Context, arg1 *castai.Snapshot) error { +// SendDelta mocks base method. +func (m *MockClient) SendDelta(arg0 context.Context, arg1 *castai.Delta) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendClusterSnapshot", arg0, arg1) + ret := m.ctrl.Call(m, "SendDelta", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// SendClusterSnapshot indicates an expected call of SendClusterSnapshot. -func (mr *MockClientMockRecorder) SendClusterSnapshot(arg0, arg1 interface{}) *gomock.Call { +// SendDelta indicates an expected call of SendDelta. +func (mr *MockClientMockRecorder) SendDelta(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendClusterSnapshot", reflect.TypeOf((*MockClient)(nil).SendClusterSnapshot), arg0, arg1) -} - -// SendClusterSnapshotWithRetry mocks base method. -func (m *MockClient) SendClusterSnapshotWithRetry(arg0 context.Context, arg1 *castai.Snapshot) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendClusterSnapshotWithRetry", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendClusterSnapshotWithRetry indicates an expected call of SendClusterSnapshotWithRetry. -func (mr *MockClientMockRecorder) SendClusterSnapshotWithRetry(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendClusterSnapshotWithRetry", reflect.TypeOf((*MockClient)(nil).SendClusterSnapshotWithRetry), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDelta", reflect.TypeOf((*MockClient)(nil).SendDelta), arg0, arg1) } diff --git a/internal/castai/types.go b/internal/castai/types.go index 887812b3..12acf3bf 100644 --- a/internal/castai/types.go +++ b/internal/castai/types.go @@ -1,6 +1,6 @@ package castai -import "castai-agent/internal/services/collector" +import "time" type EKSParams struct { ClusterName string `json:"clusterName"` @@ -26,19 +26,27 @@ type RegisterClusterResponse struct { type AgentCfgResponse struct { IntervalSeconds string `json:"intervalSeconds"` + Resync bool `json:"resync"` } -type SnapshotRequest struct { - Payload []byte `json:"payload"` +type Delta struct { + ClusterID string `json:"clusterId"` + ClusterVersion string `json:"clusterVersion"` + FullSnapshot bool `json:"fullSnapshot"` + Items []*DeltaItem `json:"items"` } -type Snapshot struct { - ClusterID string `json:"clusterId"` - AccountID string `json:"accountId"` - OrganizationID string `json:"organizationId"` - ClusterProvider string `json:"clusterProvider"` - ClusterName string `json:"clusterName"` - ClusterVersion string `json:"clusterVersion"` - ClusterRegion string `json:"clusterRegion"` - *collector.ClusterData +type DeltaItem struct { + Event EventType `json:"event"` + Kind string `json:"kind"` + Data string `json:"data"` + CreatedAt time.Time `json:"createdAt"` } + +type EventType string + +const ( + EventAdd EventType = "add" + EventUpdate EventType = "update" + EventDelete EventType = "delete" +) diff --git a/internal/services/collector/collector.go b/internal/services/collector/collector.go deleted file mode 100644 index d472da60..00000000 --- a/internal/services/collector/collector.go +++ /dev/null @@ -1,236 +0,0 @@ -//go:generate mockgen -destination ./mock/collector.go . Collector -package collector - -import ( - "context" - "fmt" - "regexp" - "strconv" - - "github.com/sirupsen/logrus" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/version" - "k8s.io/client-go/kubernetes" -) - -// Collector is responsible for gathering K8s data from the cluster. -type Collector interface { - // Collect cluster snapshot data. - Collect(ctx context.Context) (*ClusterData, error) - // GetVersion returns the K8s cluster version. Version is nil when the collector is unable to retrieve it. - GetVersion() *version.Info -} - -// NewCollector creates and configures a Collector instance. -func NewCollector(log logrus.FieldLogger, clientset kubernetes.Interface) (Collector, error) { - var v *version.Info - var minor int - if cs, ok := clientset.(*kubernetes.Clientset); ok { - sv, err := cs.ServerVersion() - if err != nil { - log.Errorf("getting cluster version: %v", err) - } else { - v = sv - m, err := strconv.Atoi(regexp.MustCompile(`^(\d+)`).FindString(sv.Minor)) - if err != nil { - return nil, fmt.Errorf("parsing k8s version: %w", err) - } - minor = m - } - } - - return &collector{ - log: log, - clientset: clientset, - cd: &ClusterData{}, - v: v, - minor: minor, - }, nil -} - -type collector struct { - log logrus.FieldLogger - clientset kubernetes.Interface - cd *ClusterData - minor int - v *version.Info -} - -func (c *collector) Collect(ctx context.Context) (*ClusterData, error) { - if err := c.collectNodes(ctx); err != nil { - return nil, fmt.Errorf("collecting nodes: %w", err) - } - - if err := c.collectPods(ctx); err != nil { - return nil, fmt.Errorf("collecting pods: %w", err) - } - - if err := c.collectPersistentVolumes(ctx); err != nil { - return nil, fmt.Errorf("collecting persistent volumes: %w", err) - } - - if err := c.collectPersistentVolumeClaims(ctx); err != nil { - return nil, fmt.Errorf("collecting persistent volume claims: %w", err) - } - - if err := c.collectDeploymentList(ctx); err != nil { - return nil, fmt.Errorf("collecting deployments: %w", err) - } - - if err := c.collectReplicaSetList(ctx); err != nil { - return nil, fmt.Errorf("collecting replica sets: %w", err) - } - - if err := c.collectDaemonSetList(ctx); err != nil { - return nil, fmt.Errorf("collecting daemon sets: %w", err) - } - - if err := c.collectStatefulSetList(ctx); err != nil { - return nil, fmt.Errorf("collecting stateful sets: %w", err) - } - - if err := c.collectReplicationControllerList(ctx); err != nil { - return nil, fmt.Errorf("collecting replication controllers: %w", err) - } - - if err := c.collectServiceList(ctx); err != nil { - return nil, fmt.Errorf("collecting services: %w", err) - } - - if c.minor >= 17 { - if err := c.collectCSINodeList(ctx); err != nil { - return nil, fmt.Errorf("collecting csi nodes: %w", err) - } - } - - if err := c.collectStorageClassList(ctx); err != nil { - return nil, fmt.Errorf("collecting storage classes: %w", err) - } - - if err := c.collectJobList(ctx); err != nil { - return nil, fmt.Errorf("collecting jobs: %w", err) - } - - return c.cd, nil -} - -func (c *collector) GetVersion() *version.Info { - return c.v -} - -func (c *collector) collectNodes(ctx context.Context) error { - nodes, err := c.clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.NodeList = nodes - return nil -} - -func (c *collector) collectPods(ctx context.Context) error { - pods, err := c.clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.PodList = pods - return nil -} - -func (c *collector) collectPersistentVolumes(ctx context.Context) error { - pvs, err := c.clientset.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.PersistentVolumeList = pvs - return nil -} - -func (c *collector) collectPersistentVolumeClaims(ctx context.Context) error { - pvc, err := c.clientset.CoreV1().PersistentVolumeClaims("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.PersistentVolumeClaimList = pvc - return nil -} - -func (c *collector) collectDeploymentList(ctx context.Context) error { - dpls, err := c.clientset.AppsV1().Deployments("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.DeploymentList = dpls - return nil -} - -func (c *collector) collectReplicaSetList(ctx context.Context) error { - rpsl, err := c.clientset.AppsV1().ReplicaSets("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.ReplicaSetList = rpsl - return nil -} - -func (c *collector) collectDaemonSetList(ctx context.Context) error { - dsl, err := c.clientset.AppsV1().DaemonSets("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.DaemonSetList = dsl - return nil -} - -func (c *collector) collectStatefulSetList(ctx context.Context) error { - stsl, err := c.clientset.AppsV1().StatefulSets("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.StatefulSetList = stsl - return nil -} - -func (c *collector) collectReplicationControllerList(ctx context.Context) error { - rc, err := c.clientset.CoreV1().ReplicationControllers("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.ReplicationControllerList = rc - return nil -} - -func (c *collector) collectServiceList(ctx context.Context) error { - svc, err := c.clientset.CoreV1().Services("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.ServiceList = svc - return nil -} - -func (c *collector) collectCSINodeList(ctx context.Context) error { - csin, err := c.clientset.StorageV1().CSINodes().List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.CSINodeList = csin - return nil -} - -func (c *collector) collectStorageClassList(ctx context.Context) error { - scl, err := c.clientset.StorageV1().StorageClasses().List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.StorageClassList = scl - return nil -} - -func (c *collector) collectJobList(ctx context.Context) error { - jobs, err := c.clientset.BatchV1().Jobs("").List(ctx, metav1.ListOptions{}) - if err != nil { - return err - } - c.cd.JobList = jobs - return nil -} diff --git a/internal/services/collector/mock/collector.go b/internal/services/collector/mock/collector.go deleted file mode 100644 index e9bb6865..00000000 --- a/internal/services/collector/mock/collector.go +++ /dev/null @@ -1,66 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: castai-agent/internal/services/collector (interfaces: Collector) - -// Package mock_collector is a generated GoMock package. -package mock_collector - -import ( - collector "castai-agent/internal/services/collector" - context "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - version "k8s.io/apimachinery/pkg/version" -) - -// MockCollector is a mock of Collector interface. -type MockCollector struct { - ctrl *gomock.Controller - recorder *MockCollectorMockRecorder -} - -// MockCollectorMockRecorder is the mock recorder for MockCollector. -type MockCollectorMockRecorder struct { - mock *MockCollector -} - -// NewMockCollector creates a new mock instance. -func NewMockCollector(ctrl *gomock.Controller) *MockCollector { - mock := &MockCollector{ctrl: ctrl} - mock.recorder = &MockCollectorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockCollector) EXPECT() *MockCollectorMockRecorder { - return m.recorder -} - -// Collect mocks base method. -func (m *MockCollector) Collect(arg0 context.Context) (*collector.ClusterData, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Collect", arg0) - ret0, _ := ret[0].(*collector.ClusterData) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Collect indicates an expected call of Collect. -func (mr *MockCollectorMockRecorder) Collect(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Collect", reflect.TypeOf((*MockCollector)(nil).Collect), arg0) -} - -// GetVersion mocks base method. -func (m *MockCollector) GetVersion() *version.Info { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVersion") - ret0, _ := ret[0].(*version.Info) - return ret0 -} - -// GetVersion indicates an expected call of GetVersion. -func (mr *MockCollectorMockRecorder) GetVersion() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockCollector)(nil).GetVersion)) -} diff --git a/internal/services/collector/types.go b/internal/services/collector/types.go deleted file mode 100644 index 13b49fa1..00000000 --- a/internal/services/collector/types.go +++ /dev/null @@ -1,24 +0,0 @@ -package collector - -import ( - appv1 "k8s.io/api/apps/v1" - batchv1 "k8s.io/api/batch/v1" - corev1 "k8s.io/api/core/v1" - storagev1 "k8s.io/api/storage/v1" -) - -type ClusterData struct { - NodeList *corev1.NodeList `json:"nodeList"` - PodList *corev1.PodList `json:"podList"` - PersistentVolumeList *corev1.PersistentVolumeList `json:"persistentVolumeList"` - PersistentVolumeClaimList *corev1.PersistentVolumeClaimList `json:"persistentVolumeClaimList"` - DeploymentList *appv1.DeploymentList `json:"deploymentList"` - ReplicaSetList *appv1.ReplicaSetList `json:"replicaSetList"` - DaemonSetList *appv1.DaemonSetList `json:"daemonSetList"` - StatefulSetList *appv1.StatefulSetList `json:"statefulSetList"` - ReplicationControllerList *corev1.ReplicationControllerList `json:"replicationControllerList"` - ServiceList *corev1.ServiceList `json:"serviceList"` - CSINodeList *storagev1.CSINodeList `json:"csiNodeList"` - StorageClassList *storagev1.StorageClassList `json:"storageClassList"` - JobList *batchv1.JobList `json:"jobList"` -} diff --git a/internal/services/controller/controller.go b/internal/services/controller/controller.go new file mode 100644 index 00000000..d01b608a --- /dev/null +++ b/internal/services/controller/controller.go @@ -0,0 +1,284 @@ +package controller + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" + appsv1 "k8s.io/api/apps/v1" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + storagev1 "k8s.io/api/storage/v1" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/informers" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" + + "castai-agent/internal/castai" + "castai-agent/internal/services/providers/types" + "castai-agent/internal/services/version" + "castai-agent/pkg/labels" +) + +type Controller struct { + log logrus.FieldLogger + castaiclient castai.Client + provider types.Provider + queue workqueue.RateLimitingInterface + interval time.Duration + prepDuration time.Duration + informers map[reflect.Type]cache.SharedInformer + + delta *castai.Delta + mu sync.Mutex + spotCache map[string]bool +} + +func New( + log logrus.FieldLogger, + f informers.SharedInformerFactory, + castaiclient castai.Client, + provider types.Provider, + clusterID string, + interval time.Duration, + prepDuration time.Duration, + v version.Interface, +) *Controller { + typeInformerMap := map[reflect.Type]cache.SharedInformer{ + reflect.TypeOf(&corev1.Node{}): f.Core().V1().Nodes().Informer(), + reflect.TypeOf(&corev1.Pod{}): f.Core().V1().Pods().Informer(), + reflect.TypeOf(&corev1.PersistentVolume{}): f.Core().V1().PersistentVolumes().Informer(), + reflect.TypeOf(&corev1.PersistentVolumeClaim{}): f.Core().V1().PersistentVolumeClaims().Informer(), + reflect.TypeOf(&corev1.ReplicationController{}): f.Core().V1().ReplicationControllers().Informer(), + reflect.TypeOf(&corev1.Service{}): f.Core().V1().Services().Informer(), + reflect.TypeOf(&appsv1.Deployment{}): f.Apps().V1().Deployments().Informer(), + reflect.TypeOf(&appsv1.ReplicaSet{}): f.Apps().V1().ReplicaSets().Informer(), + reflect.TypeOf(&appsv1.DaemonSet{}): f.Apps().V1().DaemonSets().Informer(), + reflect.TypeOf(&appsv1.StatefulSet{}): f.Apps().V1().StatefulSets().Informer(), + reflect.TypeOf(&storagev1.StorageClass{}): f.Storage().V1().StorageClasses().Informer(), + reflect.TypeOf(&batchv1.Job{}): f.Batch().V1().Jobs().Informer(), + } + + if v.MinorInt() >= 17 { + typeInformerMap[reflect.TypeOf(&storagev1.CSINode{})] = f.Storage().V1().CSINodes().Informer() + } + + c := &Controller{ + log: log, + castaiclient: castaiclient, + provider: provider, + interval: interval, + prepDuration: prepDuration, + delta: &castai.Delta{ClusterID: clusterID, ClusterVersion: v.Full(), FullSnapshot: true}, + spotCache: map[string]bool{}, + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "castai-agent"), + informers: typeInformerMap, + } + + for typ, informer := range c.informers { + typ := typ + informer := informer + log := log.WithField("informer", typ.String()) + + var h cache.ResourceEventHandler + + if typ == reflect.TypeOf(&corev1.Node{}) { + h = cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + c.nodeAddHandler(log, castai.EventAdd, obj) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + c.nodeAddHandler(log, castai.EventUpdate, newObj) + }, + DeleteFunc: func(obj interface{}) { + c.nodeDeleteHandler(log, castai.EventDelete, obj) + }, + } + } else { + h = cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + genericHandler(log, c.queue, typ, castai.EventAdd, obj) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + genericHandler(log, c.queue, typ, castai.EventUpdate, newObj) + }, + DeleteFunc: func(obj interface{}) { + genericHandler(log, c.queue, typ, castai.EventDelete, obj) + }, + } + } + + informer.AddEventHandler(h) + } + + return c +} + +func (c *Controller) nodeAddHandler( + log logrus.FieldLogger, + event castai.EventType, + obj interface{}, +) { + node, ok := obj.(*corev1.Node) + if !ok { + log.Errorf("expected to get *corev1.Node but got %T", obj) + return + } + + spot, ok := c.spotCache[node.Name] + if !ok { + var err error + spot, err = c.provider.IsSpot(context.Background(), node) + if err != nil { + log.Warnf("failed to determine whether node %q is spot: %v", node.Name, err) + } else { + c.spotCache[node.Name] = spot + } + } + + if spot { + node.Labels[labels.Spot] = "true" + } + + genericHandler(log, c.queue, reflect.TypeOf(&corev1.Node{}), event, node) +} + +func (c *Controller) nodeDeleteHandler( + log logrus.FieldLogger, + event castai.EventType, + obj interface{}, +) { + node, ok := obj.(*corev1.Node) + if !ok { + log.Errorf("expected to get *corev1.Node but got %T", obj) + return + } + + delete(c.spotCache, node.Name) + + genericHandler(log, c.queue, reflect.TypeOf(&corev1.Node{}), event, node) +} + +func genericHandler( + log logrus.FieldLogger, + queue workqueue.RateLimitingInterface, + expected reflect.Type, + event castai.EventType, + obj interface{}, +) { + if reflect.TypeOf(obj) != expected { + log.Errorf("expected to get %v but got %T", expected, obj) + return + } + + typeName := expected.String() + kind := typeName[strings.LastIndex(typeName, ".")+1:] + + data, err := encode(obj) + if err != nil { + log.Errorf("failed to encode %T: %v", obj, err) + return + } + + queue.Add(&castai.DeltaItem{ + Event: event, + Kind: kind, + Data: data, + CreatedAt: time.Now().UTC(), + }) +} + +func encode(obj interface{}) (string, error) { + b, err := json.Marshal(obj) + if err != nil { + return "", fmt.Errorf("marshaling %T to json: %v", obj, err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func (c *Controller) Run(ctx context.Context) { + defer c.queue.ShutDown() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + syncs := make([]cache.InformerSynced, 0, len(c.informers)) + for _, informer := range c.informers { + syncs = append(syncs, informer.HasSynced) + } + + if !cache.WaitForCacheSync(ctx.Done(), syncs...) { + c.log.Errorf("failed to sync") + return + } + + go func() { + const dur = 15 * time.Second + c.log.Infof("polling agent configuration every %s", dur) + wait.Until(func() { + cfg, err := c.castaiclient.GetAgentCfg(ctx, c.delta.ClusterID) + if err != nil { + c.log.Errorf("failed getting agent configuration: %v", err) + return + } + if cfg.Resync { + c.log.Info("restarting controller to resync data") + cancel() + } + }, dur, ctx.Done()) + }() + + go func() { + c.log.Info("collecting initial cluster snapshot") + time.Sleep(c.prepDuration) + c.log.Infof("sending cluster deltas every %s", c.interval) + wait.Until(func() { + c.send(ctx) + }, c.interval, ctx.Done()) + }() + + go func() { + <-ctx.Done() + c.queue.ShutDown() + }() + + c.pollQueueUntilDone() +} + +func (c *Controller) pollQueueUntilDone() { + for { + item, done := c.queue.Get() + if done { + return + } + + di, ok := item.(*castai.DeltaItem) + if !ok { + c.log.Errorf("expected queue item to be of type %T but got %T", &castai.DeltaItem{}, item) + continue + } + + c.mu.Lock() + c.delta.Items = append(c.delta.Items, di) + c.mu.Unlock() + } +} + +func (c *Controller) send(ctx context.Context) { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.castaiclient.SendDelta(ctx, c.delta); err != nil { + c.log.Errorf("failed sending delta: %v", err) + return + } + + c.delta.Items = nil + c.delta.FullSnapshot = false +} diff --git a/internal/services/controller/controller_test.go b/internal/services/controller/controller_test.go new file mode 100644 index 00000000..a3983bab --- /dev/null +++ b/internal/services/controller/controller_test.go @@ -0,0 +1,90 @@ +package controller + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes/fake" + + "castai-agent/internal/castai" + mock_castai "castai-agent/internal/castai/mock" + mock_types "castai-agent/internal/services/providers/types/mock" + mock_version "castai-agent/internal/services/version/mock" + "castai-agent/pkg/labels" +) + +func Test(t *testing.T) { + mockctrl := gomock.NewController(t) + castaiclient := mock_castai.NewMockClient(mockctrl) + version := mock_version.NewMockInterface(mockctrl) + provider := mock_types.NewMockProvider(mockctrl) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + node := &v1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node1", Labels: map[string]string{}}} + expectedNode := node.DeepCopy() + expectedNode.Labels[labels.Spot] = "true" + nodeData, err := encode(expectedNode) + require.NoError(t, err) + + pod := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Namespace: v1.NamespaceDefault, Name: "pod1"}} + podData, err := encode(pod) + require.NoError(t, err) + + clientset := fake.NewSimpleClientset(node, pod) + + version.EXPECT().MinorInt().Return(19) + version.EXPECT().Full().Return("1.19+") + + clusterID := uuid.New() + + var invocations int64 + + castaiclient.EXPECT(). + SendDelta(gomock.Any(), gomock.Any()).AnyTimes(). + DoAndReturn(func(_ context.Context, d *castai.Delta) error { + defer atomic.AddInt64(&invocations, 1) + + require.Equal(t, clusterID.String(), d.ClusterID) + require.Equal(t, "1.19+", d.ClusterVersion) + require.True(t, d.FullSnapshot) + require.Len(t, d.Items, 2) + + var actualValues []string + for _, item := range d.Items { + actualValues = append(actualValues, fmt.Sprintf("%s-%s-%s", item.Event, item.Kind, item.Data)) + } + + require.Contains(t, actualValues, fmt.Sprintf("%s-%s-%s", castai.EventAdd, "Node", nodeData)) + require.Contains(t, actualValues, fmt.Sprintf("%s-%s-%s", castai.EventAdd, "Pod", podData)) + + return nil + }) + + castaiclient.EXPECT().GetAgentCfg(gomock.Any(), gomock.Any()).AnyTimes().Return(&castai.AgentCfgResponse{}, nil) + provider.EXPECT().IsSpot(gomock.Any(), node).Return(true, nil) + + f := informers.NewSharedInformerFactory(clientset, 0) + ctrl := New(logrus.New(), f, castaiclient, provider, clusterID.String(), 15*time.Second, 100*time.Millisecond, version) + f.Start(ctx.Done()) + + go ctrl.Run(ctx) + + wait.Until(func() { + if atomic.LoadInt64(&invocations) >= 1 { + cancel() + } + }, 10*time.Millisecond, ctx.Done()) +} diff --git a/internal/services/providers/castai/castai.go b/internal/services/providers/castai/castai.go index 16fb4c42..63a16962 100644 --- a/internal/services/providers/castai/castai.go +++ b/internal/services/providers/castai/castai.go @@ -24,6 +24,13 @@ type Provider struct { log logrus.FieldLogger } +func (p *Provider) IsSpot(_ context.Context, node *v1.Node) (bool, error) { + if val, ok := node.Labels[labels.Spot]; ok && val == "true" { + return true, nil + } + return false, nil +} + func (p *Provider) RegisterCluster(_ context.Context, _ castai.Client) (*types.ClusterRegistration, error) { cfg := config.Get().CASTAI return &types.ClusterRegistration{ @@ -36,18 +43,6 @@ func (p *Provider) Name() string { return Name } -func (p *Provider) FilterSpot(_ context.Context, nodes []*v1.Node) ([]*v1.Node, error) { - var spots []*v1.Node - - for _, n := range nodes { - if val, ok := n.ObjectMeta.Labels[labels.Spot]; ok && val == "true" { - spots = append(spots, n) - } - } - - return spots, nil -} - func (p *Provider) AccountID(_ context.Context) (string, error) { return "", nil } diff --git a/internal/services/providers/eks/eks.go b/internal/services/providers/eks/eks.go index 370c8b79..dbaad643 100644 --- a/internal/services/providers/eks/eks.go +++ b/internal/services/providers/eks/eks.go @@ -3,9 +3,7 @@ package eks import ( "context" "fmt" - "time" - "github.com/patrickmn/go-cache" "github.com/sirupsen/logrus" "k8s.io/api/core/v1" @@ -13,10 +11,14 @@ import ( "castai-agent/internal/config" "castai-agent/internal/services/providers/eks/client" "castai-agent/internal/services/providers/types" + "castai-agent/pkg/labels" ) const ( Name = "eks" + + LabelCapacity = "eks.amazonaws.com/capacityType" + ValueCapacitySpot = "SPOT" ) // New configures and returns an EKS provider. @@ -48,7 +50,6 @@ func New(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) { type Provider struct { log logrus.FieldLogger awsClient client.Client - spotCache *cache.Cache } func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (*types.ClusterRegistration, error) { @@ -85,52 +86,32 @@ func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (* }, nil } -func (p *Provider) FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node, error) { - if p.spotCache == nil { - p.spotCache = cache.New(60*time.Minute, 10*time.Minute) +func (p *Provider) IsSpot(ctx context.Context, node *v1.Node) (bool, error) { + if val, ok := node.Labels[LabelCapacity]; ok && ValueCapacitySpot == val { + return true, nil } - var spotNodes []*v1.Node - var checkPrivateDNS []string - - for _, node := range nodes { - val, exists := p.spotCache.Get(node.Name) - - if !exists { - checkPrivateDNS = append(checkPrivateDNS, node.ObjectMeta.Labels[v1.LabelHostname]) - continue - } - - spot := val.(bool) - - if spot { - spotNodes = append(spotNodes, node) - } + if val, ok := node.Labels[labels.Spot]; ok && val == "true" { + return true, nil } - if len(checkPrivateDNS) > 0 { - instances, err := p.awsClient.GetInstancesByPrivateDNS(ctx, checkPrivateDNS) - if err != nil { - return nil, err - } + hostname, ok := node.Labels[v1.LabelHostname] + if !ok { + return false, fmt.Errorf("label %s not found on node %s", v1.LabelHostname, node.Name) + } - spotInstances := make(map[string]bool, len(instances)) - for _, instance := range instances { - spotInstances[*instance.PrivateDnsName] = instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot" - } + instances, err := p.awsClient.GetInstancesByPrivateDNS(ctx, []string{hostname}) + if err != nil { + return false, fmt.Errorf("getting instances by hostname: %w", err) + } - for _, node := range nodes { - spot, ok := spotInstances[node.ObjectMeta.Labels[v1.LabelHostname]] - if ok { - p.spotCache.SetDefault(node.Name, spot) - } - if spot { - spotNodes = append(spotNodes, node) - } + for _, instance := range instances { + if instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot" { + return true, nil } } - return spotNodes, nil + return false, nil } func (p *Provider) Name() string { diff --git a/internal/services/providers/eks/eks_test.go b/internal/services/providers/eks/eks_test.go index b70f7be2..52d0cf7c 100644 --- a/internal/services/providers/eks/eks_test.go +++ b/internal/services/providers/eks/eks_test.go @@ -17,6 +17,7 @@ import ( mock_castai "castai-agent/internal/castai/mock" mock_client "castai-agent/internal/services/providers/eks/client/mock" "castai-agent/internal/services/providers/types" + "castai-agent/pkg/labels" ) func TestProvider_RegisterCluster(t *testing.T) { @@ -59,9 +60,8 @@ func TestProvider_RegisterCluster(t *testing.T) { require.Equal(t, expected, got) } -func TestProvider_FilterSpot(t *testing.T) { - t.Run("no spot instances", func(t *testing.T) { - ctx := context.Background() +func TestProvider_IsSpot(t *testing.T) { + t.Run("spot instance capacity label", func(t *testing.T) { awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ @@ -69,34 +69,31 @@ func TestProvider_FilterSpot(t *testing.T) { awsClient: awsClient, } - nodes := []*v1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Labels: map[string]string{ - v1.LabelHostname: "hostname", - }, - }, - }, - } + got, err := p.IsSpot(context.Background(), &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ + LabelCapacity: ValueCapacitySpot, + }}}) - instances := []*ec2.Instance{ - { - PrivateDnsName: pointer.StringPtr("hostname"), - InstanceLifecycle: pointer.StringPtr("on-demand"), - }, - } + require.NoError(t, err) + require.True(t, got) + }) - awsClient.EXPECT().GetInstancesByPrivateDNS(ctx, []string{"hostname"}).Return(instances, nil) + t.Run("spot instance CAST AI label", func(t *testing.T) { + awsClient := mock_client.NewMockClient(gomock.NewController(t)) + + p := &Provider{ + log: logrus.New(), + awsClient: awsClient, + } - got, err := p.FilterSpot(ctx, nodes) + got, err := p.IsSpot(context.Background(), &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ + labels.Spot: "true", + }}}) require.NoError(t, err) - require.Empty(t, got) + require.True(t, got) }) - t.Run("one spot instance", func(t *testing.T) { - ctx := context.Background() + t.Run("spot instance lifecycle response", func(t *testing.T) { awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ @@ -104,45 +101,21 @@ func TestProvider_FilterSpot(t *testing.T) { awsClient: awsClient, } - spotNode := &v1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "spot", - Labels: map[string]string{ - v1.LabelHostname: "spot", - }, - }, - } - - nodes := []*v1.Node{spotNode, { - ObjectMeta: metav1.ObjectMeta{ - Name: "on-demand", - Labels: map[string]string{ - v1.LabelHostname: "on-demand", - }, - }, - }} - - instances := []*ec2.Instance{ + awsClient.EXPECT().GetInstancesByPrivateDNS(gomock.Any(), []string{"hostname"}).Return([]*ec2.Instance{ { - PrivateDnsName: pointer.StringPtr("spot"), InstanceLifecycle: pointer.StringPtr("spot"), }, - { - PrivateDnsName: pointer.StringPtr("on-demand"), - InstanceLifecycle: pointer.StringPtr("on-demand"), - }, - } + }, nil) - awsClient.EXPECT().GetInstancesByPrivateDNS(ctx, []string{"spot", "on-demand"}).Return(instances, nil) - - got, err := p.FilterSpot(ctx, nodes) + got, err := p.IsSpot(context.Background(), &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ + v1.LabelHostname: "hostname", + }}}) require.NoError(t, err) - require.Equal(t, []*v1.Node{spotNode}, got) + require.True(t, got) }) - t.Run("should use cache", func(t *testing.T) { - ctx := context.Background() + t.Run("on-demand instance", func(t *testing.T) { awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ @@ -150,34 +123,17 @@ func TestProvider_FilterSpot(t *testing.T) { awsClient: awsClient, } - nodes := []*v1.Node{ + awsClient.EXPECT().GetInstancesByPrivateDNS(gomock.Any(), []string{"hostname"}).Return([]*ec2.Instance{ { - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Labels: map[string]string{ - v1.LabelHostname: "hostname", - }, - }, - }, - } - - instances := []*ec2.Instance{ - { - PrivateDnsName: pointer.StringPtr("hostname"), InstanceLifecycle: pointer.StringPtr("on-demand"), }, - } - - awsClient.EXPECT().GetInstancesByPrivateDNS(ctx, []string{"hostname"}).Times(1).Return(instances, nil) - - got, err := p.FilterSpot(ctx, nodes) - - require.NoError(t, err) - require.Empty(t, got) + }, nil) - got, err = p.FilterSpot(ctx, nodes) + got, err := p.IsSpot(context.Background(), &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ + v1.LabelHostname: "hostname", + }}}) require.NoError(t, err) - require.Empty(t, got) + require.False(t, got) }) } diff --git a/internal/services/providers/types/mock/provider.go b/internal/services/providers/types/mock/provider.go index f15fa2c2..e85efbdd 100644 --- a/internal/services/providers/types/mock/provider.go +++ b/internal/services/providers/types/mock/provider.go @@ -97,6 +97,21 @@ func (mr *MockProviderMockRecorder) FilterSpot(arg0, arg1 interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterSpot", reflect.TypeOf((*MockProvider)(nil).FilterSpot), arg0, arg1) } +// IsSpot mocks base method. +func (m *MockProvider) IsSpot(arg0 context.Context, arg1 *v1.Node) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsSpot", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsSpot indicates an expected call of IsSpot. +func (mr *MockProviderMockRecorder) IsSpot(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsSpot", reflect.TypeOf((*MockProvider)(nil).IsSpot), arg0, arg1) +} + // Name mocks base method. func (m *MockProvider) Name() string { m.ctrl.T.Helper() diff --git a/internal/services/providers/types/types.go b/internal/services/providers/types/types.go index edfa2215..11ff8176 100644 --- a/internal/services/providers/types/types.go +++ b/internal/services/providers/types/types.go @@ -13,8 +13,8 @@ import ( type Provider interface { // RegisterCluster retrieves cluster registration data needed to correctly identify the cluster. RegisterCluster(ctx context.Context, client castclient.Client) (*ClusterRegistration, error) - // FilterSpot returns a list of nodes which are configured as spot/preemtible instances. - FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node, error) + // IsSpot checks provider specific properties whether the node lifecycle is spot/preemtible. + IsSpot(ctx context.Context, node *v1.Node) (bool, error) // Name of the provider. Name() string // AccountID of the EC2 instance. diff --git a/internal/services/version/mock/version.go b/internal/services/version/mock/version.go new file mode 100644 index 00000000..77f2dd73 --- /dev/null +++ b/internal/services/version/mock/version.go @@ -0,0 +1,62 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: castai-agent/internal/services/version (interfaces: Interface) + +// Package mock_version is a generated GoMock package. +package mock_version + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockInterface is a mock of Interface interface. +type MockInterface struct { + ctrl *gomock.Controller + recorder *MockInterfaceMockRecorder +} + +// MockInterfaceMockRecorder is the mock recorder for MockInterface. +type MockInterfaceMockRecorder struct { + mock *MockInterface +} + +// NewMockInterface creates a new mock instance. +func NewMockInterface(ctrl *gomock.Controller) *MockInterface { + mock := &MockInterface{ctrl: ctrl} + mock.recorder = &MockInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { + return m.recorder +} + +// Full mocks base method. +func (m *MockInterface) Full() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Full") + ret0, _ := ret[0].(string) + return ret0 +} + +// Full indicates an expected call of Full. +func (mr *MockInterfaceMockRecorder) Full() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Full", reflect.TypeOf((*MockInterface)(nil).Full)) +} + +// MinorInt mocks base method. +func (m *MockInterface) MinorInt() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MinorInt") + ret0, _ := ret[0].(int) + return ret0 +} + +// MinorInt indicates an expected call of MinorInt. +func (mr *MockInterfaceMockRecorder) MinorInt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MinorInt", reflect.TypeOf((*MockInterface)(nil).MinorInt)) +} diff --git a/internal/services/version/version.go b/internal/services/version/version.go new file mode 100644 index 00000000..317cdb99 --- /dev/null +++ b/internal/services/version/version.go @@ -0,0 +1,51 @@ +//go:generate mockgen -destination ./mock/version.go . Interface +package version + +import ( + "fmt" + "regexp" + "strconv" + + "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/version" + "k8s.io/client-go/kubernetes" +) + +type Interface interface { + Full() string + MinorInt() int +} + +func Get(log logrus.FieldLogger, clientset kubernetes.Interface) (Interface, error) { + cs, ok := clientset.(*kubernetes.Clientset) + if !ok { + return nil, fmt.Errorf("expected clientset to be of type *kubernetes.Clientset but was %T", clientset) + } + + sv, err := cs.ServerVersion() + if err != nil { + return nil, fmt.Errorf("getting server version: %w", err) + } + + log.Infof("kubernetes version %s.%s", sv.Major, sv.Minor) + + m, err := strconv.Atoi(regexp.MustCompile(`^(\d+)`).FindString(sv.Minor)) + if err != nil { + return nil, fmt.Errorf("parsing minor version: %w", err) + } + + return &Version{v: sv, m: m}, nil +} + +type Version struct { + v *version.Info + m int +} + +func (v *Version) Full() string { + return v.v.Major + "." + v.v.Minor +} + +func (v *Version) MinorInt() int { + return v.m +} diff --git a/internal/services/version/version_test.go b/internal/services/version/version_test.go new file mode 100644 index 00000000..0c5720fc --- /dev/null +++ b/internal/services/version/version_test.go @@ -0,0 +1,44 @@ +package version + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/version" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +func Test(t *testing.T) { + v := version.Info{ + Major: "1", + Minor: "21+", + GitCommit: "2812f9fb0003709fc44fc34166701b377020f1c9", + } + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := json.Marshal(v) + if err != nil { + t.Errorf("unexpected encoding error: %v", err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err = w.Write(b) + require.NoError(t, err) + })) + defer s.Close() + client := kubernetes.NewForConfigOrDie(&rest.Config{Host: s.URL}) + + got, err := Get(logrus.New(), client) + if err != nil { + return + } + + require.NoError(t, err) + require.Equal(t, "1.21+", got.Full()) + require.Equal(t, 21, got.MinorInt()) +} diff --git a/internal/services/worker/worker.go b/internal/services/worker/worker.go deleted file mode 100644 index be1822b1..00000000 --- a/internal/services/worker/worker.go +++ /dev/null @@ -1,180 +0,0 @@ -package worker - -import ( - "context" - "fmt" - "strconv" - "strings" - "time" - - "github.com/sirupsen/logrus" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/util/wait" - - "castai-agent/internal/castai" - "castai-agent/internal/services/collector" - "castai-agent/internal/services/providers/types" - "castai-agent/pkg/labels" -) - -func Run( - ctx context.Context, - log logrus.FieldLogger, - reg *types.ClusterRegistration, - col collector.Collector, - castclient castai.Client, - provider types.Provider, -) error { - w := &worker{ - log: log, - col: col, - castclient: castclient, - provider: provider, - reg: reg, - intervalCh: make(chan struct{}), - } - - return w.run(ctx) -} - -type worker struct { - log logrus.FieldLogger - col collector.Collector - castclient castai.Client - provider types.Provider - reg *types.ClusterRegistration - interval time.Duration - intervalCh chan struct{} -} - -func (w *worker) run(ctx context.Context) error { - interval, err := w.getInterval(ctx) - if err != nil { - return fmt.Errorf("getting snapshot collection interval: %w", err) - } - - w.interval = *interval - - if err := w.collect(ctx); err != nil { - return fmt.Errorf("collecting first snapshot data: %w", err) - } - - go w.pollInterval(ctx) - - w.log.Infof("collecting snapshot every %s", w.interval) - - for { - select { - case <-time.After(w.interval): - case <-w.intervalCh: - case <-ctx.Done(): - return nil - } - - if err := w.collect(ctx); err != nil { - w.log.Errorf("collecting snapshot data: %v", err) - } - } -} - -func (w *worker) pollInterval(ctx context.Context) { - const dur = 15 * time.Second - w.log.Infof("polling agent configuration every %s", dur) - wait.Until(func() { - interval, err := w.getInterval(ctx) - if err != nil { - w.log.Errorf("polling interval: %v", err) - return - } - if *interval != w.interval { - w.log.Infof("snapshot collection interval changed from %s to %s", w.interval, *interval) - w.interval = *interval - w.intervalCh <- struct{}{} - } - }, dur, ctx.Done()) -} - -func (w *worker) getInterval(ctx context.Context) (*time.Duration, error) { - cfg, err := w.castclient.GetAgentCfg(ctx, w.reg.ClusterID) - if err != nil { - return nil, fmt.Errorf("getting agent configuration: %w", err) - } - - intervalSeconds, err := strconv.Atoi(cfg.IntervalSeconds) - if err != nil { - return nil, fmt.Errorf("parsing interval %q: %w", cfg.IntervalSeconds, err) - } - - remoteInterval := time.Duration(intervalSeconds) * time.Second - - return &remoteInterval, nil -} - -func (w *worker) collect(ctx context.Context) error { - cd, err := w.col.Collect(ctx) - if err != nil { - return err - } - - accountID, err := w.provider.AccountID(ctx) - if err != nil { - return fmt.Errorf("getting account id: %w", err) - } - - clusterName, err := w.provider.ClusterName(ctx) - if err != nil { - return fmt.Errorf("getting cluster name: %w", err) - } - - region, err := w.provider.ClusterRegion(ctx) - if err != nil { - return fmt.Errorf("getting cluster region: %w", err) - } - - snap := &castai.Snapshot{ - ClusterID: w.reg.ClusterID, - OrganizationID: w.reg.OrganizationID, - ClusterProvider: strings.ToUpper(w.provider.Name()), - AccountID: accountID, - ClusterName: clusterName, - ClusterRegion: region, - ClusterData: cd, - } - - if v := w.col.GetVersion(); v != nil { - snap.ClusterVersion = v.Major + "." + v.Minor - } - - if err := w.addSpotLabel(ctx, snap.NodeList); err != nil { - w.log.Errorf("adding spot labels: %v", err) - } - - ctx, cancel := context.WithTimeout(ctx, w.interval) - defer cancel() - - if err := w.castclient.SendClusterSnapshotWithRetry(ctx, snap); err != nil { - return fmt.Errorf("sending cluster snapshot: %w", err) - } - - return nil -} - -func (w *worker) addSpotLabel(ctx context.Context, nodes *v1.NodeList) error { - nodeMap := make(map[string]*v1.Node, len(nodes.Items)) - items := make([]*v1.Node, len(nodes.Items)) - for i, node := range nodes.Items { - items[i] = &nodes.Items[i] - nodeMap[node.Name] = &nodes.Items[i] - } - - spotNodes, err := w.provider.FilterSpot(ctx, items) - if err != nil { - return fmt.Errorf("filtering spot instances: %w", err) - } - - for _, node := range spotNodes { - nodeMap[node.Name].Labels[labels.Spot] = "true" - } - - return nil -} diff --git a/internal/services/worker/worker_test.go b/internal/services/worker/worker_test.go deleted file mode 100644 index cebe0cc6..00000000 --- a/internal/services/worker/worker_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package worker - -import ( - "context" - "testing" - - "github.com/golang/mock/gomock" - "github.com/google/uuid" - "github.com/stretchr/testify/require" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/version" - - "castai-agent/internal/castai" - mock_castai "castai-agent/internal/castai/mock" - "castai-agent/internal/services/collector" - mock_collector "castai-agent/internal/services/collector/mock" - "castai-agent/internal/services/providers/types" - mock_types "castai-agent/internal/services/providers/types/mock" -) - -func TestCollect(t *testing.T) { - ctx := context.Background() - mockctrl := gomock.NewController(t) - col := mock_collector.NewMockCollector(mockctrl) - provider := mock_types.NewMockProvider(mockctrl) - castclient := mock_castai.NewMockClient(mockctrl) - - reg := &types.ClusterRegistration{ - ClusterID: uuid.New().String(), - OrganizationID: uuid.New().String(), - } - - spot := v1.Node{ObjectMeta: metav1.ObjectMeta{Name: "spot", Labels: map[string]string{}}} - onDemand := v1.Node{ObjectMeta: metav1.ObjectMeta{Name: "on-demand"}} - - cd := &collector.ClusterData{NodeList: &v1.NodeList{Items: []v1.Node{spot, onDemand}}} - col.EXPECT().Collect(ctx).Return(cd, nil) - col.EXPECT().GetVersion().Return(&version.Info{Major: "1", Minor: "20"}) - - provider.EXPECT().AccountID(ctx).Return("accountID", nil) - provider.EXPECT().ClusterName(ctx).Return("clusterName", nil) - provider.EXPECT().ClusterRegion(ctx).Return("eu-central-1", nil) - provider.EXPECT().Name().Return("eks") - provider.EXPECT().FilterSpot(ctx, []*v1.Node{&spot, &onDemand}).Return([]*v1.Node{&spot}, nil) - - castclient.EXPECT().SendClusterSnapshotWithRetry(gomock.Any(), &castai.Snapshot{ - ClusterID: reg.ClusterID, - OrganizationID: reg.OrganizationID, - AccountID: "accountID", - ClusterProvider: "EKS", - ClusterName: "clusterName", - ClusterRegion: "eu-central-1", - ClusterData: cd, - ClusterVersion: "1.20", - }).Return(nil) - - w := &worker{ - reg: reg, - col: col, - provider: provider, - castclient: castclient, - } - - err := w.collect(ctx) - - require.NoError(t, err) - require.Equal(t, map[string]string{"scheduling.cast.ai/spot": "true"}, spot.ObjectMeta.Labels) -} diff --git a/main.go b/main.go index b2ec7bb6..f30cf68e 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,11 @@ import ( "context" "fmt" "io/ioutil" + "time" "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" @@ -13,9 +16,9 @@ import ( "castai-agent/internal/castai" "castai-agent/internal/config" - "castai-agent/internal/services/collector" + "castai-agent/internal/services/controller" "castai-agent/internal/services/providers" - "castai-agent/internal/services/worker" + "castai-agent/internal/services/version" ) func main() { @@ -35,13 +38,15 @@ func run(ctx context.Context, log logrus.FieldLogger) error { log = log.WithField("provider", provider.Name()) - castclient := castai.NewClient(log, castai.NewDefaultClient()) + castaiclient := castai.NewClient(log, castai.NewDefaultClient()) - reg, err := provider.RegisterCluster(ctx, castclient) + reg, err := provider.RegisterCluster(ctx, castaiclient) if err != nil { return fmt.Errorf("registering cluster: %w", err) } + log = log.WithField("cluster_id", reg.ClusterID) + restconfig, err := retrieveKubeConfig() if err != nil { return err @@ -52,12 +57,19 @@ func run(ctx context.Context, log logrus.FieldLogger) error { return err } - col, err := collector.NewCollector(log, clientset) - if err != nil { - return fmt.Errorf("initializing snapshot collector: %w", err) - } + wait.Until(func() { + v, err := version.Get(log, clientset) + if err != nil { + panic(fmt.Errorf("failed getting kubernetes version: %v", err)) + } + + f := informers.NewSharedInformerFactory(clientset, 0) + ctrl := controller.New(log, f, castaiclient, provider, reg.ClusterID, 15*time.Second, 30*time.Second, v) + f.Start(ctx.Done()) + ctrl.Run(ctx) + }, 0, ctx.Done()) - return worker.Run(ctx, log, reg, col, castclient, provider) + return nil } func kubeConfigFromEnv() (*rest.Config, error) {