Skip to content

Commit

Permalink
KUBE-556: get GCP token source at startup (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trojan295 authored Sep 25, 2024
1 parent 0e2b0d8 commit 84a1cde
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 179 deletions.
16 changes: 9 additions & 7 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"net/http"
"path"
"runtime"
"time"
Expand All @@ -13,7 +12,6 @@ import (
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"

"cloud-proxy/internal/cloud/gcp"
"cloud-proxy/internal/cloud/gcp/gcpauth"
Expand All @@ -33,6 +31,13 @@ func main() {
cfg := config.Get()
logger := setupLogger(cfg)

ctx := context.Background()

tokenSource, err := gcpauth.NewTokenSource(ctx)
if err != nil {
logger.WithError(err).Panicf("Failed to create GCP credentials source")
}

dialOpts := make([]grpc.DialOption, 0)
if cfg.CastAI.DisableGRPCTLS {
// ONLY For testing purposes.
Expand Down Expand Up @@ -72,14 +77,11 @@ func main() {
}
}(conn)

ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(
"authorization", fmt.Sprintf("Token %s", cfg.CastAI.APIKey),
))
client := proxy.New(conn, gcp.New(tokenSource), logger,
cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.CastAI.APIKey, cfg.KeepAlive, cfg.KeepAliveTimeout)

go startHealthServer(logger, cfg.HealthAddress)

client := proxy.New(conn, gcp.New(gcpauth.NewCredentialsSource(), http.DefaultClient), logger,
cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.KeepAlive, cfg.KeepAliveTimeout)
err = client.Run(ctx)
if err != nil {
logger.Panicf("Failed to run client: %v", err)
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module cloud-proxy
go 1.23.1

require (
cloud.google.com/go/compute v1.28.0
cloud.google.com/go/container v1.39.0
github.com/golang/mock v1.6.0
github.com/google/uuid v1.6.0
Expand Down Expand Up @@ -56,7 +55,6 @@ require (
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
Expand Down
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U=
cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk=
cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY=
cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc=
cloud.google.com/go/compute v1.28.0 h1:OPtBxMcheSS+DWfci803qvPly3d4w7Eu5ztKBcFfzwk=
cloud.google.com/go/compute v1.28.0/go.mod h1:DEqZBtYrDnD5PvjsKwb3onnhX+qjdCVM7eshj1XdjV4=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
cloud.google.com/go/container v1.39.0 h1:Q1oW01ENxkkG3uf1oYoTmHPdvP+yhFCIuCJ4mk2RwkQ=
Expand Down Expand Up @@ -201,8 +199,6 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 h1:BulPr26Jqjnd4eYDVe+YvyR7Yc2vJGkO5/0UxD0/jZU=
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:hL97c3SYopEHblzpxRL4lSs523++l8DYxGM1FQiYmb4=
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed h1:3RgNmBoI9MZhsj3QxC+AP/qQhNwpCLOvYDYYsFrhFt0=
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
Expand Down
21 changes: 6 additions & 15 deletions internal/cloud/gcp/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,27 @@
package gcp

import (
"context"
"fmt"
"net/http"

"cloud-proxy/internal/cloud/gcp/gcpauth"
"golang.org/x/oauth2"
)

type Credentials interface {
GetToken() (string, error)
}
type Client struct {
credentials Credentials
httpClient *http.Client
httpClient *http.Client
}

func New(credentials *gcpauth.CredentialsSource, client *http.Client) *Client {
return &Client{credentials: credentials, httpClient: client}
func New(tokenSource oauth2.TokenSource) *Client {
client := oauth2.NewClient(context.Background(), tokenSource)
return &Client{httpClient: client}
}

func (c *Client) DoHTTPRequest(request *http.Request) (*http.Response, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}

token, err := c.credentials.GetToken()
if err != nil {
return nil, fmt.Errorf("credentialsSrc.GetToken: error: %w", err)
}
// Set the authorize header manually since we can't rely on mothership auth.
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

resp, err := c.httpClient.Do(request)
if err != nil {
return nil, fmt.Errorf("httpClient.Do: request %+v error: %w", request, err)
Expand Down
38 changes: 5 additions & 33 deletions internal/cloud/gcp/gcpauth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,16 @@ package gcpauth

import (
"context"
"fmt"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/serviceusage/v1"
)

func NewCredentialsSource(scopes ...string) *CredentialsSource {
func NewTokenSource(ctx context.Context, scopes ...string) (oauth2.TokenSource, error) {
if len(scopes) == 0 {
scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
}
return &CredentialsSource{
scopes: scopes,
}
}

type CredentialsSource struct {
scopes []string
}

// TODO: check if we should be doing it constantly; cache them; cache the token or something else.

func (src *CredentialsSource) getDefaultCredentials() (*google.Credentials, error) {
defaultCreds, err := google.FindDefaultCredentials(context.Background(), src.scopes...)
if err != nil {
return nil, fmt.Errorf("could not load default credentials: %w", err)
}
return defaultCreds, nil
}

func (src *CredentialsSource) GetToken() (string, error) {
credentials, err := src.getDefaultCredentials()
if err != nil {
return "", fmt.Errorf("cannot load GCP credentials: %w", err)
}

token, err := credentials.TokenSource.Token()
if err != nil {
return "", fmt.Errorf("cannot get access token from src (%T): %w", credentials.TokenSource, err)
scopes = []string{serviceusage.CloudPlatformScope}
}

return token.AccessToken, nil
return google.DefaultTokenSource(ctx, scopes...)
}
52 changes: 33 additions & 19 deletions internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/samber/lo"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha"
)
Expand All @@ -29,6 +30,7 @@ type CloudClient interface {

type Client struct {
grpcConn *grpc.ClientConn
apiKey string
cloudClient CloudClient
log *logrus.Logger
podName string
Expand All @@ -44,9 +46,10 @@ type Client struct {
version string
}

func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version string, keepalive, keepaliveTimeout time.Duration) *Client {
func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version, apiKey string, keepalive, keepaliveTimeout time.Duration) *Client {
c := &Client{
grpcConn: grpcConn,
apiKey: apiKey,
cloudClient: cloudClient,
log: logger,
podName: podName,
Expand All @@ -60,22 +63,26 @@ func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logg
}

func (c *Client) Run(ctx context.Context) error {
authCtx := metadata.NewOutgoingContext(ctx, metadata.Pairs(
"authorization", fmt.Sprintf("Token %s", c.apiKey),
))

t := time.NewTimer(time.Millisecond)

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-authCtx.Done():
return authCtx.Err()
case <-t.C:
c.log.Info("Starting proxy client")
stream, closeStream, err := c.getStream(ctx)
stream, closeStream, err := c.getStream(authCtx)
if err != nil {
c.log.Errorf("Could not get stream, restarting proxy client in %vs: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
continue
}

err = c.run(ctx, stream, closeStream)
err = c.run(authCtx, stream, closeStream)
if err != nil {
c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
Expand Down Expand Up @@ -134,7 +141,12 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
return fmt.Errorf("c.Connect: %w", err)
}

go c.sendKeepAlive(stream)
keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(keepAliveCh)
go c.sendKeepAlive(stream, keepAliveCh)

messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(messageRespCh)

go func() {
for {
Expand All @@ -160,7 +172,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Handling message from castai")
go c.handleMessage(in, stream)
go c.handleMessage(in, messageRespCh)
}
}()

Expand All @@ -170,6 +182,14 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
return ctx.Err()
case <-stream.Context().Done():
return fmt.Errorf("stream closed %w", stream.Context().Err())
case req := <-keepAliveCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send keep alive")
}
case req := <-messageRespCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send message response")
}
case <-time.After(time.Duration(c.keepAlive.Load())):
if !c.isAlive() {
if err := c.lastSeenError.Load(); err != nil {
Expand All @@ -181,7 +201,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}
}

func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) {
func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
if in == nil {
c.log.Error("nil message")
return
Expand All @@ -202,7 +222,7 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s
} else {
c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId())
}
err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{
respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{
Response: &cloudproxyv1alpha.ClusterResponse{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -213,9 +233,6 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s
HttpResponse: resp,
},
},
})
if err != nil {
c.log.Errorf("error sending response for msg_id=%v %v", in.GetMessageId(), err)
}
}

Expand Down Expand Up @@ -261,7 +278,7 @@ func (c *Client) isAlive() bool {
return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load()
}

func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) {
func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, sendCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
ticker := time.NewTimer(time.Duration(c.keepAlive.Load()))
defer ticker.Stop()

Expand All @@ -277,7 +294,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou
return
}
c.log.Debug("Sending keep-alive to castai")
err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{

sendCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_ClientStats{
ClientStats: &cloudproxyv1alpha.ClientStats{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -290,12 +308,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou
},
},
},
})
if err != nil {
c.lastSeen.Store(0)
c.log.Errorf("error sending keep alive message: %v", err)
return
}

ticker.Reset(time.Duration(c.keepAlive.Load()))
}
}
Expand Down
Loading

0 comments on commit 84a1cde

Please sign in to comment.