diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0535f056..66adcc3a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -65,12 +65,19 @@ jobs: with: go-version: "${{env.GO_VERSION}}" + - name: Generate certificates + run: make -C e2e/certs certs + - name: Build Docker image run: make update-devel-image - name: Run e2e test - run: cd e2e && make run-e2e + run: make -C e2e run-e2e + + - name: Docker Compose file + if: failure() + run: cat e2e/docker-compose.yaml - - name: Dump Logs + - name: Docker Logs if: failure() - run: cd e2e && make dump-logs + run: make -C e2e dump-logs diff --git a/.golangci.yml b/.golangci.yml index 41c56a15..0e9df4e3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -88,6 +88,7 @@ issues: linters: - bodyclose - funlen + - gochecknoglobals - gocognit - gomnd - gosec diff --git a/bind/flag.go b/bind/flag.go index 0245a7db..ee75f81c 100644 --- a/bind/flag.go +++ b/bind/flag.go @@ -46,7 +46,8 @@ func PAC(fs *pflag.FlagSet, pac **url.URL) { fs.VarP(anyflag.NewValue[*url.URL](*pac, pac, fileurl.ParseFilePathOrURL), "pac", "p", ""+ "Proxy Auto-Configuration file to use for upstream proxy selection. "+ - "It can be a local file or a URL, you can also use '-' to read from stdin. ") + "It can be a local file or a URL, you can also use '-' to read from stdin. "+ + "The data URI scheme is supported, the format is data:base64,. ") } func RequestHeaders(fs *pflag.FlagSet, headers *[]header.Header) { @@ -109,9 +110,7 @@ func HTTPTransportConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPTransportConfig) "The maximum amount of time a dial will wait for a connect to complete. "+ "With or without a timeout, the operating system may impose its own earlier timeout. For instance, TCP timeouts are often around 3 minutes. ") - fs.DurationVar(&cfg.TLSHandshakeTimeout, - "http-tls-handshake-timeout", cfg.TLSHandshakeTimeout, - "The maximum amount of time waiting to wait for a TLS handshake. Zero means no limit.") + TLSClientConfig(fs, &cfg.TLSClientConfig) fs.DurationVar(&cfg.IdleConnTimeout, "http-idle-conn-timeout", cfg.IdleConnTimeout, @@ -123,10 +122,23 @@ func HTTPTransportConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPTransportConfig) "The amount of time to wait for a server's response headers after fully writing the request (including its body, if any)."+ "This time does not include the time to read the response body. "+ "Zero means no limit. ") +} - fs.BoolVar(&cfg.TLSConfig.InsecureSkipVerify, "insecure", cfg.TLSConfig.InsecureSkipVerify, +func TLSClientConfig(fs *pflag.FlagSet, cfg *forwarder.TLSClientConfig) { + fs.DurationVar(&cfg.HandshakeTimeout, + "http-tls-handshake-timeout", cfg.HandshakeTimeout, + "The maximum amount of time waiting to wait for a TLS handshake. Zero means no limit.") + + fs.BoolVar(&cfg.InsecureSkipVerify, "insecure", cfg.InsecureSkipVerify, "Don't verify the server's certificate chain and host name. "+ "Enable to work with self-signed certificates. ") + + fs.StringSliceVar(&cfg.CAFiles, + "cacert-file", cfg.CAFiles, ""+ + "Add your own CA certificates to verify against. "+ + "The system root certificates will be used in addition to any certificates in this list. "+ + "Can be a path to a file or \"data:\" followed by base64 encoded certificate. "+ + "Use this flag multiple times to specify multiple CA certificate files. ") } func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix string, schemes ...forwarder.Scheme) { @@ -167,13 +179,7 @@ func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix "For https and h2 protocols, if TLS certificate is not specified, "+ "the server will use a self-signed certificate. ") - fs.StringVar(&cfg.CertFile, - namePrefix+"tls-cert-file", cfg.CertFile, ""+ - "TLS certificate to use if the server protocol is https or h2. ") - - fs.StringVar(&cfg.KeyFile, - namePrefix+"tls-key-file", cfg.KeyFile, ""+ - "TLS private key to use if the server protocol is https or h2. ") + TLSServerConfig(fs, &cfg.TLSServerConfig, namePrefix) } fs.DurationVar(&cfg.ReadHeaderTimeout, @@ -200,6 +206,18 @@ func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix "Setting this to none disables logging. ") } +func TLSServerConfig(fs *pflag.FlagSet, cfg *forwarder.TLSServerConfig, namePrefix string) { + fs.StringVar(&cfg.CertFile, + namePrefix+"cert-file", cfg.CertFile, ""+ + "TLS certificate to use if the server protocol is https or h2. "+ + "Can be a path to a file or \"data:\" followed by base64 encoded certificate. ") + + fs.StringVar(&cfg.KeyFile, + namePrefix+"key-file", cfg.KeyFile, ""+ + "TLS private key to use if the server protocol is https or h2. "+ + "Can be a path to a file or \"data:\" followed by base64 encoded key. ") +} + func LogConfig(fs *pflag.FlagSet, cfg *log.Config) { fs.VarP(anyflag.NewValue[*os.File](nil, &cfg.File, forwarder.OpenFileParser(os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600, 0o700)), diff --git a/cmd/forwarder/pac/eval/eval.go b/cmd/forwarder/pac/eval/eval.go index 662aeb8e..c0ecd5d1 100644 --- a/cmd/forwarder/pac/eval/eval.go +++ b/cmd/forwarder/pac/eval/eval.go @@ -34,9 +34,12 @@ func (c *command) RunE(cmd *cobra.Command, args []string) error { } resolver = r } - t := forwarder.NewHTTPTransport(c.httpTransportConfig, resolver) + t, err := forwarder.NewHTTPTransport(c.httpTransportConfig, resolver) + if err != nil { + return err + } - script, err := forwarder.ReadURL(c.pac, t) + script, err := forwarder.ReadURLString(c.pac, t) if err != nil { return fmt.Errorf("read PAC file: %w", err) } diff --git a/cmd/forwarder/pac/server/server.go b/cmd/forwarder/pac/server/server.go index b2df7dc3..2fca3610 100644 --- a/cmd/forwarder/pac/server/server.go +++ b/cmd/forwarder/pac/server/server.go @@ -36,9 +36,12 @@ func (c *command) RunE(cmd *cobra.Command, args []string) error { logger := stdlog.New(c.logConfig) logger.Debugf("configuration\n%s", config) - t := forwarder.NewHTTPTransport(c.httpTransportConfig, nil) + t, err := forwarder.NewHTTPTransport(c.httpTransportConfig, nil) + if err != nil { + return err + } - script, err := forwarder.ReadURL(c.pac, t) + script, err := forwarder.ReadURLString(c.pac, t) if err != nil { return fmt.Errorf("read PAC file: %w", err) } diff --git a/cmd/forwarder/root.go b/cmd/forwarder/root.go index b5414f77..b97f0334 100644 --- a/cmd/forwarder/root.go +++ b/cmd/forwarder/root.go @@ -56,8 +56,12 @@ func rootCommand() *cobra.Command { Prefix: []string{"dns"}, }, { - Name: "HTTP client options", - Prefix: []string{"http", "insecure"}, + Name: "HTTP client options", + Prefix: []string{ + "http", + "cacert-file", + "insecure", + }, }, { Name: "Logging options", diff --git a/cmd/forwarder/run/run.go b/cmd/forwarder/run/run.go index fec48235..51dabef2 100644 --- a/cmd/forwarder/run/run.go +++ b/cmd/forwarder/run/run.go @@ -71,12 +71,16 @@ func (c *command) RunE(cmd *cobra.Command, args []string) error { } resolver = r } - rt = forwarder.NewHTTPTransport(c.httpTransportConfig, resolver) + var err error + rt, err = forwarder.NewHTTPTransport(c.httpTransportConfig, resolver) + if err != nil { + return err + } } if c.pac != nil { var err error - script, err = forwarder.ReadURL(c.pac, rt) + script, err = forwarder.ReadURLString(c.pac, rt) if err != nil { return fmt.Errorf("read PAC file: %w", err) } diff --git a/e2e/README.md b/e2e/README.md index 7c9e3cf8..e8e104b6 100644 --- a/e2e/README.md +++ b/e2e/README.md @@ -2,8 +2,9 @@ ## Running the e2e tests with the test runner -1. Build the forwarder image `make -C ../ update-devel-image` -1. Start the test runner `make run-e2e` +1. Generate certificates `make -C certs certs`, this can be done once +1. Build the forwarder image `make -C ../ update-devel-image`, this needs to be done after each forwarder code change +2. Start the test runner `make run-e2e` 1. The test runner will run all the tests sequentially and output the results to the console 1. If one of the test fails, the procedure will stop, test output will be printed 1. Environment will not be pruned once the error occurred, remember to manually clean it up with `make down` diff --git a/e2e/certs/.gitignore b/e2e/certs/.gitignore new file mode 100755 index 00000000..8efb4468 --- /dev/null +++ b/e2e/certs/.gitignore @@ -0,0 +1,3 @@ +ca.srl +*.crt +*.key diff --git a/e2e/certs/Makefile b/e2e/certs/Makefile new file mode 100644 index 00000000..20ec7a6e --- /dev/null +++ b/e2e/certs/Makefile @@ -0,0 +1,7 @@ +.PHONY: certs +certs: + @./gen.sh + +.PHONY: test +test: + @go test -v -tags manual . diff --git a/e2e/certs/cert_test.go b/e2e/certs/cert_test.go new file mode 100755 index 00000000..c0a69c0e --- /dev/null +++ b/e2e/certs/cert_test.go @@ -0,0 +1,52 @@ +//go:build manual + +package certs_test + +import ( + "crypto/tls" + "net/http" + "testing" + + "github.com/saucelabs/forwarder" +) + +func TestCertificate(t *testing.T) { + server := http.Server{ + Addr: "127.0.0.1:8443", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, world!")) + }), + } + defer server.Close() + + go server.ListenAndServeTLS("httpbin.crt", "httpbin.key") + + tlsCfg := &tls.Config{ + ServerName: "httpbin", + } + cfg := forwarder.TLSClientConfig{ + CAFiles: []string{ + "./ca.crt", + }, + } + cfg.ConfigureTLSConfig(tlsCfg) + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = tlsCfg + + req, err := http.NewRequest("GET", "https://"+server.Addr, http.NoBody) + if err != nil { + t.Fatal(err) + } + + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + + if res.StatusCode != http.StatusOK { + t.Fatal("unexpected status code:", res.StatusCode) + } + + tr.CloseIdleConnections() +} diff --git a/e2e/certs/gen.sh b/e2e/certs/gen.sh new file mode 100755 index 00000000..dc47eae8 --- /dev/null +++ b/e2e/certs/gen.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -eu -o pipefail + +# Common variables +CA_ORGANIZATION="Sauce Labs Inc." +CA_KEY="ca.key" +CA_CERT="ca.crt" +CA_SUBJECT="/C=US/O=${CA_ORGANIZATION}" + +EC_CURVE="prime256v1" + +# Generate CA key and self-signed certificate with SHA-256 +openssl ecparam -genkey -name ${EC_CURVE} -out ${CA_KEY} +openssl req -new -x509 -sha256 -days 365 -nodes -key ${CA_KEY} -subj "${CA_SUBJECT}" -out ${CA_CERT} \ +-extensions v3_ca -config <(cat /etc/ssl/openssl.cnf - << EOF +[v3_ca] +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid:always,issuer +basicConstraints=critical,CA:true +keyUsage=critical,keyCertSign,cRLSign +EOF +) + +# Function to generate certificates for each host name +generate_certificate() { + local HOST_NAME="$1" + local KEY="${HOST_NAME}.key" + local CSR="${HOST_NAME}.csr" + local CERT="${HOST_NAME}.crt" + local SUBJECT="/C=US/O=${CA_ORGANIZATION}/CN=${HOST_NAME}" + + # Generate host key and certificate signing request (CSR) + openssl ecparam -genkey -name ${EC_CURVE} -out ${KEY} + openssl req -new -key ${KEY} -subj "${SUBJECT}" -out ${CSR} + + # Sign the CSR with the CA to generate the host certificate + openssl x509 -req -sha256 -days 365 -in ${CSR} -CA ${CA_CERT} -CAkey ${CA_KEY} -CAcreateserial -out ${CERT}\ + -extensions v3_req -extfile <(cat /etc/ssl/openssl.cnf - << EOF +[v3_req] +basicConstraints=critical,CA:FALSE +authorityKeyIdentifier=keyid,issuer +subjectAltName=@alt_names +keyUsage=digitalSignature,keyEncipherment +[ alt_names ] +DNS.1 = ${HOST_NAME} +DNS.2 = localhost +EOF + ) + + # Remove the CSR (not needed anymore) + rm ${CSR} +} + +# Generate certificates for each host name +generate_certificate "proxy" +generate_certificate "upstream-proxy" +generate_certificate "httpbin" + +chmod 644 *.key *.crt diff --git a/e2e/forwarder/service.go b/e2e/forwarder/service.go index 00567102..dd4b373b 100644 --- a/e2e/forwarder/service.go +++ b/e2e/forwarder/service.go @@ -68,13 +68,24 @@ func HttpbinService() *Service { func (s *Service) WithProtocol(protocol string) *Service { s.Environment["FORWARDER_PROTOCOL"] = protocol + + if protocol == "https" || protocol == "h2" { + s.Environment["FORWARDER_CERT_FILE"] = "/etc/forwarder/certs/" + s.Name + ".crt" + s.Environment["FORWARDER_KEY_FILE"] = "/etc/forwarder/private/" + s.Name + ".key" + s.Volumes = append(s.Volumes, + "./certs/"+s.Name+".crt:/etc/forwarder/certs/"+s.Name+".crt:ro", + "./certs/"+s.Name+".key:/etc/forwarder/private/"+s.Name+".key:ro", + ) + } + return s } func (s *Service) WithUpstream(name, protocol string) *Service { s.Environment["FORWARDER_PROXY"] = protocol + "://" + name + ":3128" if protocol == "https" { - s.Environment["FORWARDER_INSECURE"] = "true" + s.Environment["FORWARDER_CACERT_FILE"] = "/etc/forwarder/certs/ca-certificates.crt" + s.Volumes = append(s.Volumes, "./certs/ca.crt:/etc/forwarder/certs/ca-certificates.crt:ro") } return s } diff --git a/fileurl/fileurl_test.go b/fileurl/fileurl_test.go index 9ccbc97f..d54b1183 100644 --- a/fileurl/fileurl_test.go +++ b/fileurl/fileurl_test.go @@ -97,6 +97,13 @@ func TestParseFilePathOrURL(t *testing.T) { Path: "/path/to/file", }, }, + { + input: "data:text/plain;base64,U2F1Y2VMYWJzCg==", + want: url.URL{ + Scheme: "data", + Opaque: "text/plain;base64,U2F1Y2VMYWJzCg==", + }, + }, } for i := range tests { diff --git a/http_proxy.go b/http_proxy.go index 8cb50541..8a3ed39d 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -132,8 +132,10 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher, if rt == nil { log.Infof("HTTP transport not configured, using standard library default") rt = http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert // we know it's a *http.Transport - } else if _, ok := rt.(*http.Transport); !ok { + } else if tr, ok := rt.(*http.Transport); !ok { log.Debugf("using custom HTTP transport %T", rt) + } else if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil { + log.Infof("using custom root CA certificates") } hp := &HTTPProxy{ config: *cfg, @@ -160,11 +162,9 @@ func (hp *HTTPProxy) configureHTTPS() error { hp.log.Debugf("loading TLS certificate from %s and %s", hp.config.CertFile, hp.config.KeyFile) } - tlsCfg := httpsTLSConfigTemplate() - err := LoadCertificateFromTLSConfig(tlsCfg, &hp.config.TLSConfig) - hp.TLSConfig = tlsCfg + hp.TLSConfig = httpsTLSConfigTemplate() - return err + return hp.config.ConfigureTLSConfig(hp.TLSConfig) } func (hp *HTTPProxy) configureProxy() { diff --git a/http_server.go b/http_server.go index 14c7df83..b7f76616 100644 --- a/http_server.go +++ b/http_server.go @@ -77,7 +77,7 @@ func h2TLSConfigTemplate() *tls.Config { type HTTPServerConfig struct { Protocol Scheme Addr string - TLSConfig + TLSServerConfig ReadTimeout time.Duration ReadHeaderTimeout time.Duration WriteTimeout time.Duration @@ -171,12 +171,10 @@ func (hs *HTTPServer) configureHTTPS() error { hs.log.Debugf("loading TLS certificate from %s and %s", hs.config.CertFile, hs.config.KeyFile) } - tlsCfg := httpsTLSConfigTemplate() - err := LoadCertificateFromTLSConfig(tlsCfg, &hs.config.TLSConfig) - hs.srv.TLSConfig = tlsCfg + hs.srv.TLSConfig = httpsTLSConfigTemplate() hs.srv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) - return err + return hs.config.ConfigureTLSConfig(hs.srv.TLSConfig) } func (hs *HTTPServer) configureHTTP2() error { @@ -186,11 +184,9 @@ func (hs *HTTPServer) configureHTTP2() error { hs.log.Debugf("loading TLS certificate from %s and %s", hs.config.CertFile, hs.config.KeyFile) } - tlsCfg := h2TLSConfigTemplate() - err := LoadCertificateFromTLSConfig(tlsCfg, &hs.config.TLSConfig) - hs.srv.TLSConfig = tlsCfg + hs.srv.TLSConfig = h2TLSConfigTemplate() - return err + return hs.config.ConfigureTLSConfig(hs.srv.TLSConfig) } func (hs *HTTPServer) Run(ctx context.Context) error { diff --git a/http_transport.go b/http_transport.go index dbbbc9d1..d9247d0b 100644 --- a/http_transport.go +++ b/http_transport.go @@ -32,9 +32,7 @@ type HTTPTransportConfig struct { // If negative, keep-alive probes are disabled. KeepAlive time.Duration - // TLSHandshakeTimeout specifies the maximum amount of time waiting to - // wait for a TLS handshake. Zero means no timeout. - TLSHandshakeTimeout time.Duration + TLSClientConfig // MaxIdleConns controls the maximum number of idle (keep-alive) // connections across all hosts. Zero means no limit. @@ -72,42 +70,46 @@ type HTTPTransportConfig struct { // waiting for the server to approve. // This time does not include the time to send the request header. ExpectContinueTimeout time.Duration - - TLSConfig } func DefaultHTTPTransportConfig() *HTTPTransportConfig { // The default values are taken from [hashicorp/go-cleanhttp](https://github.com/hashicorp/go-cleanhttp/blob/a0807dd79fc1680a7b1f2d5a2081d92567aab97d/cleanhttp.go#L19. return &HTTPTransportConfig{ - DialTimeout: 10 * time.Second, - KeepAlive: 30 * time.Second, + DialTimeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + TLSClientConfig: TLSClientConfig{ + HandshakeTimeout: 10 * time.Second, + }, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, } } -func NewHTTPTransport(cfg *HTTPTransportConfig, r *net.Resolver) *http.Transport { +func NewHTTPTransport(cfg *HTTPTransportConfig, r *net.Resolver) (*http.Transport, error) { d := &net.Dialer{ Timeout: cfg.DialTimeout, KeepAlive: cfg.KeepAlive, Resolver: r, } + tlsCfg := new(tls.Config) + + if err := cfg.ConfigureTLSConfig(tlsCfg); err != nil { + return nil, err + } + return &http.Transport{ Proxy: nil, Dial: d.Dial, DialContext: d.DialContext, + TLSClientConfig: tlsCfg, + TLSHandshakeTimeout: cfg.TLSClientConfig.HandshakeTimeout, MaxIdleConns: cfg.MaxIdleConns, IdleConnTimeout: cfg.IdleConnTimeout, - TLSHandshakeTimeout: cfg.TLSHandshakeTimeout, ExpectContinueTimeout: cfg.ExpectContinueTimeout, ForceAttemptHTTP2: true, MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: cfg.InsecureSkipVerify, //nolint:gosec // for self-signed certificates - }, - } + }, nil } diff --git a/readurl.go b/readurl.go index 313b0a68..3f3c659c 100644 --- a/readurl.go +++ b/readurl.go @@ -7,40 +7,72 @@ package forwarder import ( + "encoding/base64" "fmt" "io" "net/http" "net/url" "os" + "strings" ) -// ReadURL can read a local file, http or https URL or stdin. -func ReadURL(u *url.URL, rt http.RoundTripper) (string, error) { +// ReadURLString can read base64 encoded data, local file, http or https URL or stdin and return it as a string. +func ReadURLString(u *url.URL, rt http.RoundTripper) (string, error) { + b, err := ReadURL(u, rt) + if err != nil { + return "", err + } + return string(b), nil +} + +// ReadURL can read base64 encoded data, local file, http or https URL or stdin. +func ReadURL(u *url.URL, rt http.RoundTripper) ([]byte, error) { switch u.Scheme { + case "data": + return readData(u) case "file": return readFile(u) case "http", "https": return readHTTP(u, rt) default: - return "", fmt.Errorf("unsupported scheme %q, supported schemes are: file, http and https", u.Scheme) + return nil, fmt.Errorf("unsupported scheme %q, supported schemes are: file, http and https", u.Scheme) + } +} + +func readData(u *url.URL) ([]byte, error) { + v := strings.TrimPrefix(u.Opaque, "//") + + idx := strings.IndexByte(v, ',') + if idx != -1 { + if v[:idx] != "base64" { + return nil, fmt.Errorf("invalid data URI, the only supported format is: data:base64,") + } + v = v[idx+1:] } + + b, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return nil, err + } + + return b, nil } -func readFile(u *url.URL) (string, error) { +func readFile(u *url.URL) ([]byte, error) { if u.Host != "" { - return "", fmt.Errorf("invalid file URL %q, host is not allowed", u.String()) + return nil, fmt.Errorf("invalid file URL %q, host is not allowed", u.String()) } if u.User != nil { - return "", fmt.Errorf("invalid file URL %q, user is not allowed", u.String()) + return nil, fmt.Errorf("invalid file URL %q, user is not allowed", u.String()) } if u.RawQuery != "" { - return "", fmt.Errorf("invalid file URL %q, query is not allowed", u.String()) + return nil, fmt.Errorf("invalid file URL %q, query is not allowed", u.String()) } if u.Fragment != "" { - return "", fmt.Errorf("invalid file URL %q, fragment is not allowed", u.String()) + return nil, fmt.Errorf("invalid file URL %q, fragment is not allowed", u.String()) } if u.Path == "" { - return "", fmt.Errorf("invalid file URL %q, path is empty", u.String()) + return nil, fmt.Errorf("invalid file URL %q, path is empty", u.String()) } if u.Path == "-" { @@ -49,42 +81,49 @@ func readFile(u *url.URL) (string, error) { f, err := os.Open(u.Path) if err != nil { - return "", err + return nil, err } return readAndClose(f) } -func readAndClose(r io.ReadCloser) (string, error) { +func readAndClose(r io.ReadCloser) ([]byte, error) { defer r.Close() - b, err := io.ReadAll(r) - if err != nil { - return "", err - } - return string(b), nil + return io.ReadAll(r) } -func readHTTP(u *url.URL, rt http.RoundTripper) (string, error) { +func readHTTP(u *url.URL, rt http.RoundTripper) ([]byte, error) { c := http.Client{ Transport: rt, } req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) //nolint:noctx // timeout is set in the transport if err != nil { - return "", err + return nil, err } resp, err := c.Do(req) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code %d", resp.StatusCode) + return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) } b, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return nil, err } - return string(b), nil + return b, nil +} + +func ReadFileOrBase64(name string) ([]byte, error) { + if strings.HasPrefix(name, "data:") { + return readData(&url.URL{ + Scheme: "data", + Opaque: name[5:], + }) + } + + return os.ReadFile(name) } diff --git a/readurl_test.go b/readurl_test.go new file mode 100644 index 00000000..b84e0e0d --- /dev/null +++ b/readurl_test.go @@ -0,0 +1,74 @@ +// Copyright 2023 Sauce Labs Inc. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package forwarder + +import ( + "net/url" + "testing" +) + +var base64Tests = []struct { + decoded, encoded string +}{ + // RFC 3548 examples + {"\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"}, + {"\x14\xfb\x9c\x03\xd9", "FPucA9k="}, + {"\x14\xfb\x9c\x03", "FPucAw=="}, + + // RFC 4648 examples + {"", ""}, + {"f", "Zg=="}, + {"fo", "Zm8="}, + {"foo", "Zm9v"}, + {"foob", "Zm9vYg=="}, + {"fooba", "Zm9vYmE="}, + {"foobar", "Zm9vYmFy"}, + + // Wikipedia examples + {"sure.", "c3VyZS4="}, + {"sure", "c3VyZQ=="}, + {"sur", "c3Vy"}, + {"su", "c3U="}, + {"leasure.", "bGVhc3VyZS4="}, + {"easure.", "ZWFzdXJlLg=="}, + {"asure.", "YXN1cmUu"}, + {"sure.", "c3VyZS4="}, +} + +func TestReadURLData(t *testing.T) { + for i := range base64Tests { + tc := base64Tests[i] + t.Run(tc.encoded, func(t *testing.T) { + u := url.URL{ + Scheme: "data", + Opaque: "//base64," + tc.encoded, + } + b, err := ReadURLString(&u, nil) + if err != nil { + t.Fatal(err) + } + if b != tc.decoded { + t.Fatalf("expected %q, got %q", tc.decoded, b) + } + }) + } +} + +func TestReadFileOrBase64(t *testing.T) { + for i := range base64Tests { + tc := base64Tests[i] + t.Run(tc.encoded, func(t *testing.T) { + b, err := ReadFileOrBase64("data:" + tc.encoded) + if err != nil { + t.Fatal(err) + } + if string(b) != tc.decoded { + t.Fatalf("expected %q, got %q", tc.decoded, b) + } + }) + } +} diff --git a/tls.go b/tls.go index 843f5ccc..b26b5544 100644 --- a/tls.go +++ b/tls.go @@ -8,11 +8,18 @@ package forwarder import ( "crypto/tls" + "crypto/x509" + "fmt" + "time" "github.com/saucelabs/forwarder/utils/certutil" ) -type TLSConfig struct { +type TLSClientConfig struct { + // HandshakeTimeout specifies the maximum amount of time waiting to + // wait for a TLS handshake. Zero means no timeout. + HandshakeTimeout time.Duration + // InsecureSkipVerify controls whether a client verifies the server's // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls // accepts any certificate presented by the server and any host name in that @@ -21,6 +28,47 @@ type TLSConfig struct { // testing or in combination with VerifyConnection or VerifyPeerCertificate. InsecureSkipVerify bool + // CAFiles is a list of paths to CA certificate files. + // If this is set, the system root CA pool will be supplemented with certificates from these files. + CAFiles []string +} + +func (c *TLSClientConfig) ConfigureTLSConfig(tlsCfg *tls.Config) error { + tlsCfg.InsecureSkipVerify = c.InsecureSkipVerify + + if err := c.loadRootCAs(tlsCfg); err != nil { + return fmt.Errorf("load CAs: %w", err) + } + + return nil +} + +func (c *TLSClientConfig) loadRootCAs(tlsCfg *tls.Config) error { + if len(c.CAFiles) == 0 { + return nil + } + + rootCAs, err := x509.SystemCertPool() + if err != nil { + return err + } + + for _, name := range c.CAFiles { + b, err := ReadFileOrBase64(name) + if err != nil { + return err + } + if !rootCAs.AppendCertsFromPEM(b) { + return fmt.Errorf("append certificate %q", name) + } + } + + tlsCfg.RootCAs = rootCAs + + return nil +} + +type TLSServerConfig struct { // CertFile is the path to the TLS certificate. CertFile string @@ -28,20 +76,40 @@ type TLSConfig struct { KeyFile string } -func LoadCertificateFromTLSConfig(dst *tls.Config, src *TLSConfig) error { +func (c *TLSServerConfig) ConfigureTLSConfig(tlsCfg *tls.Config) error { + if err := c.loadCertificate(tlsCfg); err != nil { + return fmt.Errorf("load certificate: %w", err) + } + + return nil +} + +func (c *TLSServerConfig) loadCertificate(tlsCfg *tls.Config) error { var ( cert tls.Certificate err error ) - if src.CertFile == "" && src.KeyFile == "" { + if c.CertFile == "" && c.KeyFile == "" { cert, err = certutil.RSASelfSignedCert().Gen() } else { - cert, err = tls.LoadX509KeyPair(src.CertFile, src.KeyFile) + cert, err = loadX509KeyPair(c.CertFile, c.KeyFile) } if err == nil { - dst.Certificates = append(dst.Certificates, cert) + tlsCfg.Certificates = append(tlsCfg.Certificates, cert) } return err } + +func loadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) { + certPEMBlock, err := ReadFileOrBase64(certFile) + if err != nil { + return tls.Certificate{}, err + } + keyPEMBlock, err := ReadFileOrBase64(keyFile) + if err != nil { + return tls.Certificate{}, err + } + return tls.X509KeyPair(certPEMBlock, keyPEMBlock) +}