Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KUBE-556: get GCP token source at startup #14

Merged
merged 8 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"net/http"
"path"
"runtime"
"time"
Expand All @@ -13,7 +12,6 @@ import (
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"

"cloud-proxy/internal/cloud/gcp"
"cloud-proxy/internal/cloud/gcp/gcpauth"
Expand All @@ -33,6 +31,13 @@ func main() {
cfg := config.Get()
logger := setupLogger(cfg)

ctx := context.Background()

tokenSource, err := gcpauth.NewTokenSource(ctx)
if err != nil {
logger.WithError(err).Panicf("Failed to create GCP credentials source")
}

dialOpts := make([]grpc.DialOption, 0)
if cfg.CastAI.DisableGRPCTLS {
// ONLY For testing purposes.
Expand Down Expand Up @@ -72,14 +77,11 @@ func main() {
}
}(conn)

ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(
"authorization", fmt.Sprintf("Token %s", cfg.CastAI.APIKey),
))
client := proxy.New(conn, gcp.New(tokenSource), logger,
cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.CastAI.APIKey, cfg.KeepAlive, cfg.KeepAliveTimeout)

go startHealthServer(logger, cfg.HealthAddress)

client := proxy.New(conn, gcp.New(gcpauth.NewCredentialsSource(), http.DefaultClient), logger,
cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.KeepAlive, cfg.KeepAliveTimeout)
err = client.Run(ctx)
if err != nil {
logger.Panicf("Failed to run client: %v", err)
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module cloud-proxy
go 1.23.1

require (
cloud.google.com/go/compute v1.28.0
cloud.google.com/go/container v1.39.0
github.com/golang/mock v1.6.0
github.com/google/uuid v1.6.0
Expand Down Expand Up @@ -56,7 +55,6 @@ require (
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
Expand Down
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U=
cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk=
cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY=
cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc=
cloud.google.com/go/compute v1.28.0 h1:OPtBxMcheSS+DWfci803qvPly3d4w7Eu5ztKBcFfzwk=
cloud.google.com/go/compute v1.28.0/go.mod h1:DEqZBtYrDnD5PvjsKwb3onnhX+qjdCVM7eshj1XdjV4=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
cloud.google.com/go/container v1.39.0 h1:Q1oW01ENxkkG3uf1oYoTmHPdvP+yhFCIuCJ4mk2RwkQ=
Expand Down Expand Up @@ -201,8 +199,6 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 h1:BulPr26Jqjnd4eYDVe+YvyR7Yc2vJGkO5/0UxD0/jZU=
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:hL97c3SYopEHblzpxRL4lSs523++l8DYxGM1FQiYmb4=
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed h1:3RgNmBoI9MZhsj3QxC+AP/qQhNwpCLOvYDYYsFrhFt0=
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
Expand Down
21 changes: 6 additions & 15 deletions internal/cloud/gcp/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,27 @@
package gcp

import (
"context"
"fmt"
"net/http"

"cloud-proxy/internal/cloud/gcp/gcpauth"
"golang.org/x/oauth2"
)

type Credentials interface {
GetToken() (string, error)
}
type Client struct {
credentials Credentials
httpClient *http.Client
httpClient *http.Client
}

func New(credentials *gcpauth.CredentialsSource, client *http.Client) *Client {
return &Client{credentials: credentials, httpClient: client}
func New(tokenSource oauth2.TokenSource) *Client {
client := oauth2.NewClient(context.Background(), tokenSource)
return &Client{httpClient: client}
}

func (c *Client) DoHTTPRequest(request *http.Request) (*http.Response, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}

token, err := c.credentials.GetToken()
if err != nil {
return nil, fmt.Errorf("credentialsSrc.GetToken: error: %w", err)
}
// Set the authorize header manually since we can't rely on mothership auth.
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

resp, err := c.httpClient.Do(request)
if err != nil {
return nil, fmt.Errorf("httpClient.Do: request %+v error: %w", request, err)
Expand Down
38 changes: 5 additions & 33 deletions internal/cloud/gcp/gcpauth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,16 @@ package gcpauth

import (
"context"
"fmt"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/serviceusage/v1"
)

func NewCredentialsSource(scopes ...string) *CredentialsSource {
func NewTokenSource(ctx context.Context, scopes ...string) (oauth2.TokenSource, error) {
if len(scopes) == 0 {
Trojan295 marked this conversation as resolved.
Show resolved Hide resolved
scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
}
return &CredentialsSource{
scopes: scopes,
}
}

type CredentialsSource struct {
scopes []string
}

// TODO: check if we should be doing it constantly; cache them; cache the token or something else.

func (src *CredentialsSource) getDefaultCredentials() (*google.Credentials, error) {
defaultCreds, err := google.FindDefaultCredentials(context.Background(), src.scopes...)
if err != nil {
return nil, fmt.Errorf("could not load default credentials: %w", err)
}
return defaultCreds, nil
}

func (src *CredentialsSource) GetToken() (string, error) {
credentials, err := src.getDefaultCredentials()
if err != nil {
return "", fmt.Errorf("cannot load GCP credentials: %w", err)
}

token, err := credentials.TokenSource.Token()
if err != nil {
return "", fmt.Errorf("cannot get access token from src (%T): %w", credentials.TokenSource, err)
scopes = []string{serviceusage.CloudPlatformScope}
}

return token.AccessToken, nil
return google.DefaultTokenSource(ctx, scopes...)
}
52 changes: 33 additions & 19 deletions internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/samber/lo"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha"
)
Expand All @@ -29,6 +30,7 @@ type CloudClient interface {

type Client struct {
grpcConn *grpc.ClientConn
apiKey string
cloudClient CloudClient
log *logrus.Logger
podName string
Expand All @@ -44,9 +46,10 @@ type Client struct {
version string
}

func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version string, keepalive, keepaliveTimeout time.Duration) *Client {
func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version, apiKey string, keepalive, keepaliveTimeout time.Duration) *Client {
c := &Client{
grpcConn: grpcConn,
apiKey: apiKey,
cloudClient: cloudClient,
log: logger,
podName: podName,
Expand All @@ -60,22 +63,26 @@ func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logg
}

func (c *Client) Run(ctx context.Context) error {
authCtx := metadata.NewOutgoingContext(ctx, metadata.Pairs(
"authorization", fmt.Sprintf("Token %s", c.apiKey),
))

t := time.NewTimer(time.Millisecond)

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-authCtx.Done():
return authCtx.Err()
case <-t.C:
c.log.Info("Starting proxy client")
stream, closeStream, err := c.getStream(ctx)
stream, closeStream, err := c.getStream(authCtx)
if err != nil {
c.log.Errorf("Could not get stream, restarting proxy client in %vs: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
continue
}

err = c.run(ctx, stream, closeStream)
err = c.run(authCtx, stream, closeStream)
if err != nil {
c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
Expand Down Expand Up @@ -134,7 +141,12 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
return fmt.Errorf("c.Connect: %w", err)
}

go c.sendKeepAlive(stream)
keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(keepAliveCh)
go c.sendKeepAlive(stream, keepAliveCh)

messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(messageRespCh)

go func() {
for {
Expand All @@ -160,7 +172,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Handling message from castai")
go c.handleMessage(in, stream)
go c.handleMessage(in, messageRespCh)
}
}()

Expand All @@ -170,6 +182,14 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
return ctx.Err()
case <-stream.Context().Done():
return fmt.Errorf("stream closed %w", stream.Context().Err())
case req := <-keepAliveCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send keep alive")
}
case req := <-messageRespCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send message response")
}
case <-time.After(time.Duration(c.keepAlive.Load())):
if !c.isAlive() {
if err := c.lastSeenError.Load(); err != nil {
Expand All @@ -181,7 +201,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}
}

func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) {
func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
if in == nil {
c.log.Error("nil message")
return
Expand All @@ -202,7 +222,7 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s
} else {
c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId())
}
err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{
respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{
Response: &cloudproxyv1alpha.ClusterResponse{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -213,9 +233,6 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s
HttpResponse: resp,
},
},
})
if err != nil {
c.log.Errorf("error sending response for msg_id=%v %v", in.GetMessageId(), err)
}
}

Expand Down Expand Up @@ -261,7 +278,7 @@ func (c *Client) isAlive() bool {
return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load()
}

func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) {
func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, sendCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
ticker := time.NewTimer(time.Duration(c.keepAlive.Load()))
defer ticker.Stop()

Expand All @@ -277,7 +294,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou
return
}
c.log.Debug("Sending keep-alive to castai")
err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{

sendCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_ClientStats{
ClientStats: &cloudproxyv1alpha.ClientStats{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -290,12 +308,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou
},
},
},
})
if err != nil {
c.lastSeen.Store(0)
c.log.Errorf("error sending keep alive message: %v", err)
return
}

ticker.Reset(time.Duration(c.keepAlive.Load()))
}
}
Expand Down
Loading
Loading