diff --git a/internal/cast/cast.go b/internal/cast/cast.go index c41ffd51..d40ba76f 100644 --- a/internal/cast/cast.go +++ b/internal/cast/cast.go @@ -2,6 +2,7 @@ package cast import ( + "bytes" "castai-agent/internal/config" "context" "encoding/json" @@ -42,7 +43,6 @@ func NewDefaultClient() *resty.Client { client.SetRetryCount(defaultRetryCount) client.SetTimeout(defaultTimeout) client.Header.Set("X-API-Key", cfg.Key) - client.Header.Set("Content-Type", "application/json") return client } @@ -74,14 +74,15 @@ func (c *client) RegisterCluster(ctx context.Context, req *RegisterClusterReques func (c *client) SendClusterSnapshot(ctx context.Context, snap *Snapshot) error { payload, err := json.Marshal(snap) if err != nil { - return err + return fmt.Errorf("marshaling snapshot payload: %w", err) } + buf := bytes.NewBuffer(payload) resp, err := c.rest.R(). - SetBody(&SnapshotRequest{Payload: payload}). + SetFileReader("payload", "payload.json", buf). SetResult(&RegisterClusterResponse{}). SetContext(ctx). - Post("/v1/agent/eks-snapshot") + Post("/v1/agent/snapshot") if err != nil { return err } diff --git a/internal/cast/cast_test.go b/internal/cast/cast_test.go index 56817c54..c9a89863 100644 --- a/internal/cast/cast_test.go +++ b/internal/cast/cast_test.go @@ -69,14 +69,14 @@ func TestClient_SendClusterSnapshot(t *testing.T) { }, } - httpmock.RegisterResponder(http.MethodPost, "/v1/agent/eks-snapshot", func(req *http.Request) (*http.Response, error) { - actualRequest := &SnapshotRequest{} - require.NoError(t, json.NewDecoder(req.Body).Decode(actualRequest)) + httpmock.RegisterResponder(http.MethodPost, "/v1/agent/snapshot", func(req *http.Request) (*http.Response, error) { + f, _, err := req.FormFile("payload") + require.NoError(t, err) - actualSnapshot := &Snapshot{} - require.NoError(t, json.Unmarshal(actualRequest.Payload, actualSnapshot)) + actualRequest := &Snapshot{} + require.NoError(t, json.NewDecoder(f).Decode(actualRequest)) - require.Equal(t, snapshot, actualSnapshot) + require.Equal(t, snapshot, actualRequest) return httpmock.NewStringResponse(http.StatusNoContent, "ok"), nil }) diff --git a/internal/services/collector/collector.go b/internal/services/collector/collector.go index 5c5e4669..4a60bfa7 100644 --- a/internal/services/collector/collector.go +++ b/internal/services/collector/collector.go @@ -3,90 +3,120 @@ package collector import ( "context" - + "fmt" "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/version" "k8s.io/client-go/kubernetes" + "regexp" + "strconv" ) +// Collector is responsible for gathering K8s data from the cluster. type Collector interface { + // Collect cluster snapshot data. Collect(ctx context.Context) (*ClusterData, error) + // GetVersion returns the K8s cluster version. Version is nil when the collector is unable to retrieve it. + GetVersion() *version.Info +} + +// NewCollector creates and configures a Collector instance. +func NewCollector(log logrus.FieldLogger, clientset kubernetes.Interface) (Collector, error) { + var v *version.Info + var minor int + if cs, ok := clientset.(*kubernetes.Clientset); ok { + sv, err := cs.ServerVersion() + if err != nil { + log.Errorf("getting cluster version: %v", err) + } else { + v = sv + m, err := strconv.Atoi(regexp.MustCompile(`^(\d+)`).FindString(sv.Minor)) + if err != nil { + return nil, fmt.Errorf("parsing k8s version: %w", err) + } + minor = m + } + } + + return &collector{ + log: log, + clientset: clientset, + cd: &ClusterData{}, + v: v, + minor: minor, + }, nil } type collector struct { log logrus.FieldLogger clientset kubernetes.Interface cd *ClusterData -} - -func NewCollector(log logrus.FieldLogger, clientset kubernetes.Interface) Collector { - var cd ClusterData - return &collector{ - log: log, - clientset: clientset, - cd: &cd, - } + minor int + v *version.Info } func (c *collector) Collect(ctx context.Context) (*ClusterData, error) { if err := c.collectNodes(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting nodes: %w", err) } if err := c.collectPods(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting pods: %w", err) } if err := c.collectPersistentVolumes(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting persistent volumes: %w", err) } if err := c.collectPersistentVolumeClaims(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting persistent volume claims: %w", err) } if err := c.collectDeploymentList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting deployments: %w", err) } if err := c.collectReplicaSetList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting replica sets: %w", err) } if err := c.collectDaemonSetList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting daemon sets: %w", err) } if err := c.collectStatefulSetList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting stateful sets: %w", err) } if err := c.collectReplicationControllerList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting replication controllers: %w", err) } if err := c.collectServiceList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting services: %w", err) } - - if err := c.collectCSINodeList(ctx); err != nil { - // https://kubernetes-csi.github.io/docs/csi-node-object.html - // GA since 1.17 - c.log.Debugf("could not get CSINodes: %v", err) + if c.minor >= 17 { + if err := c.collectCSINodeList(ctx); err != nil { + return nil, fmt.Errorf("collecting csi nodes: %w", err) + } } if err := c.collectStorageClassList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting storage classes: %w", err) } if err := c.collectJobList(ctx); err != nil { - return nil, err + return nil, fmt.Errorf("collecting jobs: %w", err) } return c.cd, nil } +func (c *collector) GetVersion() *version.Info { + return c.v +} + func (c *collector) collectNodes(ctx context.Context) error { nodes, err := c.clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) if err != nil { diff --git a/internal/services/collector/mock/collector.go b/internal/services/collector/mock/collector.go index 1901bb90..e9bb6865 100644 --- a/internal/services/collector/mock/collector.go +++ b/internal/services/collector/mock/collector.go @@ -10,6 +10,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + version "k8s.io/apimachinery/pkg/version" ) // MockCollector is a mock of Collector interface. @@ -49,3 +50,17 @@ func (mr *MockCollectorMockRecorder) Collect(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Collect", reflect.TypeOf((*MockCollector)(nil).Collect), arg0) } + +// GetVersion mocks base method. +func (m *MockCollector) GetVersion() *version.Info { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVersion") + ret0, _ := ret[0].(*version.Info) + return ret0 +} + +// GetVersion indicates an expected call of GetVersion. +func (mr *MockCollectorMockRecorder) GetVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockCollector)(nil).GetVersion)) +} diff --git a/main.go b/main.go index c2c2c866..d2362949 100644 --- a/main.go +++ b/main.go @@ -60,14 +60,17 @@ func run(ctx context.Context, log logrus.FieldLogger) error { return err } - col := collector.NewCollector(log, clientset) + col, err := collector.NewCollector(log, clientset) + if err != nil { + return fmt.Errorf("initializing snapshot collector: %w", err) + } const interval = 15 * time.Second ticker := time.NewTicker(interval) defer ticker.Stop() for { - if err := collect(ctx, log, c, col, provider, clientset, castclient); err != nil { + if err := collect(ctx, log, c, col, provider, castclient); err != nil { log.Errorf("collecting snapshot data: %v", err) } @@ -86,7 +89,6 @@ func collect( c *cast.RegisterClusterResponse, col collector.Collector, provider providers.Provider, - clientset kubernetes.Interface, castclient cast.Client, ) error { cd, err := col.Collect(ctx) @@ -119,14 +121,8 @@ func collect( ClusterData: cd, } - if cs, ok := clientset.(*kubernetes.Clientset); ok { - version, err := cs.ServerVersion() - if err != nil { - log.Errorf("getting cluster version: %v", version) - } - if version != nil { - snap.ClusterVersion = version.GitVersion - } + if v := col.GetVersion(); v != nil { + snap.ClusterVersion = v.Major + "." + v.Minor } if err := addSpotLabel(ctx, provider, snap.NodeList); err != nil { diff --git a/main_test.go b/main_test.go index d04697c8..131af4e5 100644 --- a/main_test.go +++ b/main_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/fake" + "k8s.io/apimachinery/pkg/version" "testing" ) @@ -22,7 +22,6 @@ func TestCollect(t *testing.T) { mockctrl := gomock.NewController(t) col := mock_collector.NewMockCollector(mockctrl) provider := mock_providers.NewMockProvider(mockctrl) - clientset := fake.NewSimpleClientset() castclient := mock_cast.NewMockClient(mockctrl) c := &cast.RegisterClusterResponse{Cluster: cast.Cluster{ID: uuid.New().String(), OrganizationID: uuid.New().String()}} @@ -32,6 +31,7 @@ func TestCollect(t *testing.T) { cd := &collector.ClusterData{NodeList: &v1.NodeList{Items: []v1.Node{spot, onDemand}}} col.EXPECT().Collect(ctx).Return(cd, nil) + col.EXPECT().GetVersion().Return(&version.Info{Major: "1", Minor: "20"}) provider.EXPECT().AccountID(ctx).Return("accountID", nil) provider.EXPECT().ClusterName(ctx).Return("clusterName", nil) @@ -47,9 +47,10 @@ func TestCollect(t *testing.T) { ClusterName: "clusterName", ClusterRegion: "eu-central-1", ClusterData: cd, + ClusterVersion: "1.20", }).Return(nil) - err := collect(ctx, logrus.New(), c, col, provider, clientset, castclient) + err := collect(ctx, logrus.New(), c, col, provider, castclient) require.NoError(t, err)