Skip to content

Commit d954c01

Browse files
committed
feat: add gpu count option
Signed-off-by: hlts2 <[email protected]>
1 parent ea3cf19 commit d954c01

File tree

4 files changed

+102
-65
lines changed

4 files changed

+102
-65
lines changed

main.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ func run(ctx context.Context) error {
2828
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
2929
defer stop()
3030

31-
w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID, nodeDesiredGPUCount,
31+
w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID,
3232
watcher.WithRebootTimeWindowMinutes(rebootTimeWindowMinutes),
33+
watcher.WithDesiredGPUCount(nodeDesiredGPUCount),
3334
)
3435
if err != nil {
3536
return err

pkg/watcher/options.go

+16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package watcher
22

33
import (
4+
"log/slog"
45
"strconv"
56
"time"
67

@@ -13,6 +14,7 @@ type Option func(*watcher)
1314

1415
var defaultOptions = []Option{
1516
WithRebootTimeWindowMinutes("40"),
17+
WithDesiredGPUCount("0"),
1618
}
1719

1820
// WithKubernetesClient returns Option to set Kubernetes API client.
@@ -48,6 +50,20 @@ func WithRebootTimeWindowMinutes(s string) Option {
4850
n, err := strconv.Atoi(s)
4951
if err == nil && n > 0 {
5052
w.rebootTimeWindowMinutes = time.Duration(n)
53+
} else {
54+
slog.Info("RebootTimeWindowMinutes is invalid", "value", s)
55+
}
56+
}
57+
}
58+
59+
// WithDesiredGPUCount returns Option to set reboot time window.
60+
func WithDesiredGPUCount(s string) Option {
61+
return func(w *watcher) {
62+
n, err := strconv.Atoi(s)
63+
if err == nil && n >= 0 {
64+
w.nodeDesiredGPUCount = n
65+
} else {
66+
slog.Info("DesiredGPUCount is invalid", "value", s)
5167
}
5268
}
5369
}

pkg/watcher/watcher.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ type watcher struct {
4242
nodeSelector *metav1.LabelSelector
4343
}
4444

45-
func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID, nodeDesiredGPUCount string, opts ...Option) (Watcher, error) {
45+
func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID string, opts ...Option) (Watcher, error) {
4646
w := &watcher{
4747
clusterID: clusterID,
4848
apiKey: apiKey,
@@ -63,12 +63,6 @@ func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePool
6363
return nil, fmt.Errorf("CIVO_API_KEY not set")
6464
}
6565

66-
n, err := strconv.Atoi(nodeDesiredGPUCount)
67-
if err != nil {
68-
return nil, fmt.Errorf("CIVO_NODE_DESIRED_GPU_COUNT has an invalid value, %s: %w", nodeDesiredGPUCount, err)
69-
}
70-
71-
w.nodeDesiredGPUCount = n
7266
w.nodeSelector = &metav1.LabelSelector{
7367
MatchLabels: map[string]string{
7468
nodePoolLabelKey: nodePoolID,
@@ -212,7 +206,7 @@ func isNodeReady(node *corev1.Node) bool {
212206

213207
func isNodeDesiredGPU(node *corev1.Node, desired int) bool {
214208
if desired == 0 {
215-
slog.Info("Desired GPU count is set to 0", "node", node.GetName())
209+
slog.Info("Desired GPU count is set to 0, so the GPU count check was skipped", "node", node.GetName())
216210
return true
217211
}
218212

pkg/watcher/watcher_test.go

+82-56
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@ var (
2828

2929
func TestNew(t *testing.T) {
3030
type args struct {
31-
clusterID string
32-
region string
33-
apiKey string
34-
apiURL string
35-
nodePoolID string
36-
nodeDesiredGPUCount string
37-
opts []Option
31+
clusterID string
32+
region string
33+
apiKey string
34+
apiURL string
35+
nodePoolID string
36+
opts []Option
3837
}
3938
type test struct {
4039
name string
@@ -47,17 +46,15 @@ func TestNew(t *testing.T) {
4746
{
4847
name: "Returns no error when given valid input",
4948
args: args{
50-
clusterID: testClusterID,
51-
region: testRegion,
52-
apiKey: testApiKey,
53-
apiURL: testApiURL,
54-
nodePoolID: testNodePoolID,
55-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
49+
clusterID: testClusterID,
50+
region: testRegion,
51+
apiKey: testApiKey,
52+
apiURL: testApiURL,
53+
nodePoolID: testNodePoolID,
5654
opts: []Option{
5755
WithKubernetesClient(fake.NewSimpleClientset()),
5856
WithCivoClient(&FakeClient{}),
59-
WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used.
60-
WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used.
57+
WithDesiredGPUCount(testNodeDesiredGPUCount),
6158
},
6259
},
6360
checkFunc: func(w *watcher) error {
@@ -97,51 +94,66 @@ func TestNew(t *testing.T) {
9794
},
9895
},
9996
{
100-
name: "Returns an error when clusterID is missing",
97+
name: "Returns no error when input is invalid, but default value is set",
10198
args: args{
102-
region: testRegion,
103-
apiKey: testApiKey,
104-
apiURL: testApiURL,
105-
nodePoolID: testNodePoolID,
106-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
99+
clusterID: testClusterID,
100+
region: testRegion,
101+
apiKey: testApiKey,
102+
apiURL: testApiURL,
103+
nodePoolID: testNodePoolID,
107104
opts: []Option{
108105
WithKubernetesClient(fake.NewSimpleClientset()),
109106
WithCivoClient(&FakeClient{}),
107+
WithDesiredGPUCount("invalid"), // It is invalid, but the default count (0) will be used.
108+
WithDesiredGPUCount("-1"), // It is invalid, but the default count (0) will be used.
109+
WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used.
110+
WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used.
110111
},
111112
},
112-
wantErr: true,
113+
checkFunc: func(w *watcher) error {
114+
if w.nodeDesiredGPUCount != 0 {
115+
return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0)
116+
}
117+
if w.rebootTimeWindowMinutes != testRebootTimeWindowMinutes {
118+
return fmt.Errorf("w.rebootTimeWindowMinutes mismatch: got %v, want %s", w.nodeSelector, testNodePoolID)
119+
}
120+
return nil
121+
},
113122
},
114123
{
115-
name: "Returns an error when nodeDesiredGPUCount is invalid",
124+
name: "Returns no error when nodeDesiredGPUCount is 0",
116125
args: args{
117-
clusterID: testClusterID,
118-
region: testRegion,
119-
apiKey: testApiKey,
120-
apiURL: testApiURL,
121-
nodePoolID: testNodePoolID,
122-
nodeDesiredGPUCount: "invalid_number",
126+
clusterID: testClusterID,
127+
region: testRegion,
128+
apiKey: testApiKey,
129+
apiURL: testApiURL,
130+
nodePoolID: testNodePoolID,
123131
opts: []Option{
124132
WithKubernetesClient(fake.NewSimpleClientset()),
125133
WithCivoClient(&FakeClient{}),
134+
WithDesiredGPUCount("0"),
126135
},
127136
},
128-
wantErr: true,
137+
checkFunc: func(w *watcher) error {
138+
if w.nodeDesiredGPUCount != 0 {
139+
return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0)
140+
}
141+
return nil
142+
},
129143
},
130144
{
131-
name: "Returns an error when nodeDesiredGPUCount is 0",
145+
name: "Returns an error when clusterID is missing",
132146
args: args{
133-
clusterID: testClusterID,
134-
region: testRegion,
135-
apiKey: testApiKey,
136-
apiURL: testApiURL,
137-
nodePoolID: testNodePoolID,
138-
nodeDesiredGPUCount: "0",
147+
region: testRegion,
148+
apiKey: testApiKey,
149+
apiURL: testApiURL,
150+
nodePoolID: testNodePoolID,
139151
opts: []Option{
140152
WithKubernetesClient(fake.NewSimpleClientset()),
141153
WithCivoClient(&FakeClient{}),
142154
},
143155
},
144-
wantErr: false,
156+
wantErr: true,
145157
},
146158
}
147159

@@ -153,7 +165,6 @@ func TestNew(t *testing.T) {
153165
test.args.region,
154166
test.args.clusterID,
155167
test.args.nodePoolID,
156-
test.args.nodeDesiredGPUCount,
157168
test.args.opts...)
158169
if (err != nil) != test.wantErr {
159170
t.Errorf("error = %v, wantErr %v", err, test.wantErr)
@@ -177,9 +188,8 @@ func TestNew(t *testing.T) {
177188

178189
func TestRun(t *testing.T) {
179190
type args struct {
180-
opts []Option
181-
nodeDesiredGPUCount string
182-
nodePoolID string
191+
opts []Option
192+
nodePoolID string
183193
}
184194
type test struct {
185195
name string
@@ -195,9 +205,9 @@ func TestRun(t *testing.T) {
195205
opts: []Option{
196206
WithKubernetesClient(fake.NewSimpleClientset()),
197207
WithCivoClient(&FakeClient{}),
208+
WithDesiredGPUCount(testNodeDesiredGPUCount),
198209
},
199-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
200-
nodePoolID: testNodePoolID,
210+
nodePoolID: testNodePoolID,
201211
},
202212
beforeFunc: func(w *watcher) {
203213
t.Helper()
@@ -241,9 +251,9 @@ func TestRun(t *testing.T) {
241251
opts: []Option{
242252
WithKubernetesClient(fake.NewSimpleClientset()),
243253
WithCivoClient(&FakeClient{}),
254+
WithDesiredGPUCount(testNodeDesiredGPUCount),
244255
},
245-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
246-
nodePoolID: testNodePoolID,
256+
nodePoolID: testNodePoolID,
247257
},
248258
beforeFunc: func(w *watcher) {
249259
t.Helper()
@@ -298,9 +308,9 @@ func TestRun(t *testing.T) {
298308
opts: []Option{
299309
WithKubernetesClient(fake.NewSimpleClientset()),
300310
WithCivoClient(&FakeClient{}),
311+
WithDesiredGPUCount(testNodeDesiredGPUCount),
301312
},
302-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
303-
nodePoolID: testNodePoolID,
313+
nodePoolID: testNodePoolID,
304314
},
305315
beforeFunc: func(w *watcher) {
306316
t.Helper()
@@ -351,9 +361,9 @@ func TestRun(t *testing.T) {
351361
opts: []Option{
352362
WithKubernetesClient(fake.NewSimpleClientset()),
353363
WithCivoClient(&FakeClient{}),
364+
WithDesiredGPUCount(testNodeDesiredGPUCount),
354365
},
355-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
356-
nodePoolID: testNodePoolID,
366+
nodePoolID: testNodePoolID,
357367
},
358368
beforeFunc: func(w *watcher) {
359369
t.Helper()
@@ -394,9 +404,9 @@ func TestRun(t *testing.T) {
394404
opts: []Option{
395405
WithKubernetesClient(fake.NewSimpleClientset()),
396406
WithCivoClient(&FakeClient{}),
407+
WithDesiredGPUCount(testNodeDesiredGPUCount),
397408
},
398-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
399-
nodePoolID: testNodePoolID,
409+
nodePoolID: testNodePoolID,
400410
},
401411
beforeFunc: func(w *watcher) {
402412
t.Helper()
@@ -415,9 +425,9 @@ func TestRun(t *testing.T) {
415425
opts: []Option{
416426
WithKubernetesClient(fake.NewSimpleClientset()),
417427
WithCivoClient(&FakeClient{}),
428+
WithDesiredGPUCount(testNodeDesiredGPUCount),
418429
},
419-
nodeDesiredGPUCount: testNodeDesiredGPUCount,
420-
nodePoolID: testNodePoolID,
430+
nodePoolID: testNodePoolID,
421431
},
422432
beforeFunc: func(w *watcher) {
423433
t.Helper()
@@ -462,7 +472,7 @@ func TestRun(t *testing.T) {
462472
for _, test := range tests {
463473
t.Run(test.name, func(t *testing.T) {
464474
w, err := NewWatcher(t.Context(),
465-
testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.nodeDesiredGPUCount, test.args.opts...)
475+
testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.opts...)
466476
if err != nil {
467477
t.Fatal(err)
468478
}
@@ -682,6 +692,19 @@ func TestIsNodeDesiredGPU(t *testing.T) {
682692
desired: 8,
683693
want: true,
684694
},
695+
{
696+
name: "Returns true when desired GPU count is, so count check is skipped",
697+
node: &corev1.Node{
698+
ObjectMeta: metav1.ObjectMeta{
699+
Name: "node-01",
700+
},
701+
Status: corev1.NodeStatus{
702+
Allocatable: corev1.ResourceList{},
703+
},
704+
},
705+
desired: 0,
706+
want: true,
707+
},
685708
{
686709
name: "Returns false when GPU count is 0",
687710
node: &corev1.Node{
@@ -744,6 +767,7 @@ func TestRebootNode(t *testing.T) {
744767
opts: []Option{
745768
WithKubernetesClient(fake.NewSimpleClientset()),
746769
WithCivoClient(&FakeClient{}),
770+
WithDesiredGPUCount(testNodeDesiredGPUCount),
747771
},
748772
},
749773
beforeFunc: func(t *testing.T, w *watcher) {
@@ -772,6 +796,7 @@ func TestRebootNode(t *testing.T) {
772796
opts: []Option{
773797
WithKubernetesClient(fake.NewSimpleClientset()),
774798
WithCivoClient(&FakeClient{}),
799+
WithDesiredGPUCount(testNodeDesiredGPUCount),
775800
},
776801
},
777802
beforeFunc: func(t *testing.T, w *watcher) {
@@ -791,6 +816,7 @@ func TestRebootNode(t *testing.T) {
791816
opts: []Option{
792817
WithKubernetesClient(fake.NewSimpleClientset()),
793818
WithCivoClient(&FakeClient{}),
819+
WithDesiredGPUCount(testNodeDesiredGPUCount),
794820
},
795821
},
796822
beforeFunc: func(t *testing.T, w *watcher) {
@@ -818,7 +844,7 @@ func TestRebootNode(t *testing.T) {
818844
for _, test := range tests {
819845
t.Run(test.name, func(t *testing.T) {
820846
w, err := NewWatcher(t.Context(),
821-
testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, testNodeDesiredGPUCount, test.args.opts...)
847+
testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.args.opts...)
822848
if err != nil {
823849
t.Fatal(err)
824850
}

0 commit comments

Comments
 (0)