From 6b8ec2f1123cf964bafff66e40ba809c541af8b2 Mon Sep 17 00:00:00 2001 From: Jilks Smith Date: Tue, 3 Sep 2024 12:04:31 +0300 Subject: [PATCH] Update grpc.go Signed-off-by: Jilks Smith --- internal/server/grpc/grpc.go | 47 ++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go index a3ef4ff18..97f1e56c0 100644 --- a/internal/server/grpc/grpc.go +++ b/internal/server/grpc/grpc.go @@ -18,6 +18,7 @@ import ( "math/big" "net" "os" + "strings" "time" "github.com/google/go-sev-guest/client" @@ -207,35 +208,39 @@ func (s *Server) Stop() error { } func loadCertFile(certFile string) ([]byte, error) { - if certFile != "" { - return os.ReadFile(certFile) + if len(certFile) < 1000 && !strings.Contains(certFile, "\n") { + data, err := os.ReadFile(certFile) + if err == nil { + return data, nil + } } - return []byte{}, nil + return []byte(certFile), nil } -func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) { +func loadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) { var cert, key []byte var err error - if _, err = os.Stat(certfile); err == nil { - cert, err = os.ReadFile(certfile) - if err != nil { - return tls.Certificate{}, err + + readFileOrData := func(input string) ([]byte, error) { + if len(input) < 1000 && !strings.Contains(input, "\n") { + data, err := os.ReadFile(input) + if err == nil { + return data, nil + } } - } else if os.IsNotExist(err) { - cert = []byte(certfile) - } else { - return tls.Certificate{}, err + return []byte(input), nil } - if _, err := os.Stat(keyfile); err == nil { - key, err = os.ReadFile(keyfile) - if err != nil { - return tls.Certificate{}, err - } - } else if os.IsNotExist(err) { - key = []byte(keyfile) - } else { - return tls.Certificate{}, err + + cert, err = readFileOrData(certFile) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err) + } + + key, err = readFileOrData(keyFile) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err) } + return tls.X509KeyPair(cert, key) }