diff --git a/prometheus/prometheus.go b/prometheus/prometheus.go index a21b40c..1768e7a 100644 --- a/prometheus/prometheus.go +++ b/prometheus/prometheus.go @@ -31,7 +31,11 @@ func (bat authTransport) RoundTrip(req *http.Request) (*http.Response, error) { if bat.username != "" { req.SetBasicAuth(bat.username, bat.password) } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bat.token)) + + if bat.token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bat.token)) + } + return bat.Transport.RoundTrip(req) } diff --git a/prometheus/prometheus_test.go b/prometheus/prometheus_test.go index a919a76..69f2c36 100644 --- a/prometheus/prometheus_test.go +++ b/prometheus/prometheus_test.go @@ -1,6 +1,7 @@ package prometheus import ( + "encoding/base64" "fmt" "net/http" "time" @@ -33,6 +34,56 @@ var _ = Describe("Tests for Prometheus", func() { Expect(count).To(BeEquivalentTo(0)) Expect(err).To(BeNil()) }) + + It("Test2 bearer header is used when token is provided", func() { + url := "https://example.com/api" + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + fmt.Println("Failed to create request:", err) + return + } + _, err = bat.RoundTrip(req) + Expect(req.Header.Get("Authorization")).To(Equal("Bearer someRandomToken")) + //Asserting no of times mocks are called + Expect(count).To(BeEquivalentTo(0)) + Expect(err).To(BeNil()) + }) + + It("Test3 basic auth header is used when no token is provided", func() { + bat.token = "" + url := "https://example.com/api" + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + fmt.Println("Failed to create request:", err) + return + } + _, err = bat.RoundTrip(req) + + encodedAuthHeader := base64.StdEncoding.EncodeToString([]byte("someRandomUsername:someRandomPassword")) + + Expect(req.Header.Get("Authorization")).To(Equal("Basic " + encodedAuthHeader)) + //Asserting no of times mocks are called + Expect(count).To(BeEquivalentTo(0)) + Expect(err).To(BeNil()) + }) + + It("Test4 no auth header set when auth details are omitted", func() { + bat.token = "" + bat.username = "" + bat.password = "" + url := "https://example.com/api" + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + fmt.Println("Failed to create request:", err) + return + } + _, err = bat.RoundTrip(req) + + Expect(req.Header.Get("Authorization")).To(Equal("")) + //Asserting no of times mocks are called + Expect(count).To(BeEquivalentTo(0)) + Expect(err).To(BeNil()) + }) }) Context("Tests for NewClient()", func() {