Skip to content

Commit

Permalink
Merge pull request #14 from castai/use-multipart
Browse files Browse the repository at this point in the history
feat: use multi-part requests to upload snapshots
  • Loading branch information
saumas authored Apr 8, 2021
2 parents b4514a6 + 8c1ddd4 commit cdcf67d
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 51 deletions.
9 changes: 5 additions & 4 deletions internal/cast/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package cast

import (
"bytes"
"castai-agent/internal/config"
"context"
"encoding/json"
Expand Down Expand Up @@ -42,7 +43,6 @@ func NewDefaultClient() *resty.Client {
client.SetRetryCount(defaultRetryCount)
client.SetTimeout(defaultTimeout)
client.Header.Set("X-API-Key", cfg.Key)
client.Header.Set("Content-Type", "application/json")

return client
}
Expand Down Expand Up @@ -74,14 +74,15 @@ func (c *client) RegisterCluster(ctx context.Context, req *RegisterClusterReques
func (c *client) SendClusterSnapshot(ctx context.Context, snap *Snapshot) error {
payload, err := json.Marshal(snap)
if err != nil {
return err
return fmt.Errorf("marshaling snapshot payload: %w", err)
}
buf := bytes.NewBuffer(payload)

resp, err := c.rest.R().
SetBody(&SnapshotRequest{Payload: payload}).
SetFileReader("payload", "payload.json", buf).
SetResult(&RegisterClusterResponse{}).
SetContext(ctx).
Post("/v1/agent/eks-snapshot")
Post("/v1/agent/snapshot")
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions internal/cast/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ func TestClient_SendClusterSnapshot(t *testing.T) {
},
}

httpmock.RegisterResponder(http.MethodPost, "/v1/agent/eks-snapshot", func(req *http.Request) (*http.Response, error) {
actualRequest := &SnapshotRequest{}
require.NoError(t, json.NewDecoder(req.Body).Decode(actualRequest))
httpmock.RegisterResponder(http.MethodPost, "/v1/agent/snapshot", func(req *http.Request) (*http.Response, error) {
f, _, err := req.FormFile("payload")
require.NoError(t, err)

actualSnapshot := &Snapshot{}
require.NoError(t, json.Unmarshal(actualRequest.Payload, actualSnapshot))
actualRequest := &Snapshot{}
require.NoError(t, json.NewDecoder(f).Decode(actualRequest))

require.Equal(t, snapshot, actualSnapshot)
require.Equal(t, snapshot, actualRequest)

return httpmock.NewStringResponse(http.StatusNoContent, "ok"), nil
})
Expand Down
84 changes: 57 additions & 27 deletions internal/services/collector/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,90 +3,120 @@ package collector

import (
"context"

"fmt"
"github.com/sirupsen/logrus"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/version"
"k8s.io/client-go/kubernetes"
"regexp"
"strconv"
)

// Collector is responsible for gathering K8s data from the cluster.
type Collector interface {
// Collect cluster snapshot data.
Collect(ctx context.Context) (*ClusterData, error)
// GetVersion returns the K8s cluster version. Version is nil when the collector is unable to retrieve it.
GetVersion() *version.Info
}

// NewCollector creates and configures a Collector instance.
func NewCollector(log logrus.FieldLogger, clientset kubernetes.Interface) (Collector, error) {
var v *version.Info
var minor int
if cs, ok := clientset.(*kubernetes.Clientset); ok {
sv, err := cs.ServerVersion()
if err != nil {
log.Errorf("getting cluster version: %v", err)
} else {
v = sv
m, err := strconv.Atoi(regexp.MustCompile(`^(\d+)`).FindString(sv.Minor))
if err != nil {
return nil, fmt.Errorf("parsing k8s version: %w", err)
}
minor = m
}
}

return &collector{
log: log,
clientset: clientset,
cd: &ClusterData{},
v: v,
minor: minor,
}, nil
}

type collector struct {
log logrus.FieldLogger
clientset kubernetes.Interface
cd *ClusterData
}

func NewCollector(log logrus.FieldLogger, clientset kubernetes.Interface) Collector {
var cd ClusterData
return &collector{
log: log,
clientset: clientset,
cd: &cd,
}
minor int
v *version.Info
}

func (c *collector) Collect(ctx context.Context) (*ClusterData, error) {
if err := c.collectNodes(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting nodes: %w", err)
}

if err := c.collectPods(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting pods: %w", err)
}

if err := c.collectPersistentVolumes(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting persistent volumes: %w", err)
}

if err := c.collectPersistentVolumeClaims(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting persistent volume claims: %w", err)
}

if err := c.collectDeploymentList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting deployments: %w", err)
}

if err := c.collectReplicaSetList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting replica sets: %w", err)
}

if err := c.collectDaemonSetList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting daemon sets: %w", err)
}

if err := c.collectStatefulSetList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting stateful sets: %w", err)
}

if err := c.collectReplicationControllerList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting replication controllers: %w", err)
}

if err := c.collectServiceList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting services: %w", err)
}


if err := c.collectCSINodeList(ctx); err != nil {
// https://kubernetes-csi.github.io/docs/csi-node-object.html
// GA since 1.17
c.log.Debugf("could not get CSINodes: %v", err)
if c.minor >= 17 {
if err := c.collectCSINodeList(ctx); err != nil {
return nil, fmt.Errorf("collecting csi nodes: %w", err)
}
}

if err := c.collectStorageClassList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting storage classes: %w", err)
}

if err := c.collectJobList(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("collecting jobs: %w", err)
}

return c.cd, nil
}

func (c *collector) GetVersion() *version.Info {
return c.v
}

func (c *collector) collectNodes(ctx context.Context) error {
nodes, err := c.clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions internal/services/collector/mock/collector.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 7 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ func run(ctx context.Context, log logrus.FieldLogger) error {
return err
}

col := collector.NewCollector(log, clientset)
col, err := collector.NewCollector(log, clientset)
if err != nil {
return fmt.Errorf("initializing snapshot collector: %w", err)
}

const interval = 15 * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
if err := collect(ctx, log, c, col, provider, clientset, castclient); err != nil {
if err := collect(ctx, log, c, col, provider, castclient); err != nil {
log.Errorf("collecting snapshot data: %v", err)
}

Expand All @@ -86,7 +89,6 @@ func collect(
c *cast.RegisterClusterResponse,
col collector.Collector,
provider providers.Provider,
clientset kubernetes.Interface,
castclient cast.Client,
) error {
cd, err := col.Collect(ctx)
Expand Down Expand Up @@ -119,14 +121,8 @@ func collect(
ClusterData: cd,
}

if cs, ok := clientset.(*kubernetes.Clientset); ok {
version, err := cs.ServerVersion()
if err != nil {
log.Errorf("getting cluster version: %v", version)
}
if version != nil {
snap.ClusterVersion = version.GitVersion
}
if v := col.GetVersion(); v != nil {
snap.ClusterVersion = v.Major + "." + v.Minor
}

if err := addSpotLabel(ctx, provider, snap.NodeList); err != nil {
Expand Down
7 changes: 4 additions & 3 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
"k8s.io/apimachinery/pkg/version"
"testing"
)

Expand All @@ -22,7 +22,6 @@ func TestCollect(t *testing.T) {
mockctrl := gomock.NewController(t)
col := mock_collector.NewMockCollector(mockctrl)
provider := mock_providers.NewMockProvider(mockctrl)
clientset := fake.NewSimpleClientset()
castclient := mock_cast.NewMockClient(mockctrl)

c := &cast.RegisterClusterResponse{Cluster: cast.Cluster{ID: uuid.New().String(), OrganizationID: uuid.New().String()}}
Expand All @@ -32,6 +31,7 @@ func TestCollect(t *testing.T) {

cd := &collector.ClusterData{NodeList: &v1.NodeList{Items: []v1.Node{spot, onDemand}}}
col.EXPECT().Collect(ctx).Return(cd, nil)
col.EXPECT().GetVersion().Return(&version.Info{Major: "1", Minor: "20"})

provider.EXPECT().AccountID(ctx).Return("accountID", nil)
provider.EXPECT().ClusterName(ctx).Return("clusterName", nil)
Expand All @@ -47,9 +47,10 @@ func TestCollect(t *testing.T) {
ClusterName: "clusterName",
ClusterRegion: "eu-central-1",
ClusterData: cd,
ClusterVersion: "1.20",
}).Return(nil)

err := collect(ctx, logrus.New(), c, col, provider, clientset, castclient)
err := collect(ctx, logrus.New(), c, col, provider, castclient)

require.NoError(t, err)

Expand Down

0 comments on commit cdcf67d

Please sign in to comment.