diff --git a/libs/pindialer/dialer.go b/libs/pindialer/dialer.go index b71e37320..94e4baf08 100644 --- a/libs/pindialer/dialer.go +++ b/libs/pindialer/dialer.go @@ -2,20 +2,15 @@ package pindialer import ( - "context" "crypto/sha256" "crypto/tls" + "crypto/x509" "encoding/base64" "errors" - "fmt" - "net" ) -// ContextDialer is a function connecting to the address on the named network -type ContextDialer func(ctx context.Context, network, addr string) (net.Conn, error) - -func validateChain(fingerprint string, connstate tls.ConnectionState) error { - for _, chain := range connstate.VerifiedChains { +func validateChain(fingerprint string, verifiedChains [][]*x509.Certificate) error { + for _, chain := range verifiedChains { for _, cert := range chain { hash := sha256.Sum256(cert.RawSubjectPublicKeyInfo) digest := base64.StdEncoding.EncodeToString(hash[:]) @@ -27,22 +22,13 @@ func validateChain(fingerprint string, connstate tls.ConnectionState) error { return errors.New("the server certificate was not valid") } -// MakeContextDialer returns a ContextDialer that only succeeds on connection to a TLS secured address with the pinned fingerprint -func MakeContextDialer(fingerprint string) ContextDialer { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := tls.Dial(network, addr, nil) - if err != nil { - return c, err - } - select { - case <-ctx.Done(): - return nil, fmt.Errorf("context completed") - default: - if err := validateChain(fingerprint, c.ConnectionState()); err != nil { - return nil, fmt.Errorf("failed to validate certificate chain: %w", err) - } - } - return c, nil +// Get tls.Config that validates the connection certificate chain against the +// given fingerprint. +func GetTLSConfig(fingerprint string) *tls.Config { + return &tls.Config{ + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + return validateChain(fingerprint, verifiedChains) + }, } } diff --git a/libs/wallet/provider/uphold/uphold.go b/libs/wallet/provider/uphold/uphold.go index b9cefa0ad..b99369493 100644 --- a/libs/wallet/provider/uphold/uphold.go +++ b/libs/wallet/provider/uphold/uphold.go @@ -123,11 +123,9 @@ func init() { proxy = nil } - fingerprintDialer := pindialer.MakeContextDialer(upholdCertFingerprint) - // Uphold reports HTTP 401 error when connecting with HTTP2, so disable // HTTP/2 via setting TLSNextProto to an empty map. We do not need to set - // this field on defaultHTTPClient as we set DialTLSContext without setting + // this field on defaultHTTPClient as we set TLSClientConfig without setting // ForceAttemptHTTP2 and that disables HTTP/2 also. But for clarity we // always set TLSNextProto. disableHTTP2 := make( @@ -138,9 +136,9 @@ func init() { Timeout: httpTimeout, Transport: middleware.InstrumentRoundTripper( &http.Transport{ - DialTLSContext: fingerprintDialer, - Proxy: proxy, - TLSNextProto: disableHTTP2, + Proxy: proxy, + TLSClientConfig: pindialer.GetTLSConfig(upholdCertFingerprint), + TLSNextProto: disableHTTP2, }, "uphold"), } httpClientNoFP = &http.Client{ @@ -889,12 +887,16 @@ func (resp upholdTransactionResponse) ToTransactionInfo() *walletutils.Transacti return &txInfo } -// SubmitTransaction submits the base64 encoded transaction for verification but -// does not move funds unless confirm is set to true. +// SubmitTransaction creates a transaction and, when `confirm` is true, also +// submits it. If `conform` is false, ConfirmTransaction() should be used later +// to actually submit the transaction. func (w *Wallet) SubmitTransaction(ctx context.Context, transactionB64 string, confirm bool) (*walletutils.TransactionInfo, error) { return w.submitTransaction(ctx, defaultHTTPClient, transactionB64, confirm) } +// The implementation helper for `SubmitTransaction()` that takes an extra +// client argument to use fingeprinting/non-fingerprinting client depending on +// the caller needs. func (w *Wallet) submitTransaction( ctx context.Context, client *http.Client, diff --git a/libs/wallet/provider/uphold/uphold_test.go b/libs/wallet/provider/uphold/uphold_test.go index 84d5f023f..74f39eef3 100644 --- a/libs/wallet/provider/uphold/uphold_test.go +++ b/libs/wallet/provider/uphold/uphold_test.go @@ -2,6 +2,7 @@ package uphold import ( "context" + "crypto/tls" "encoding/hex" "errors" "net/http" @@ -307,28 +308,53 @@ func TestFingerprintCheck(t *testing.T) { var proxy func(*http.Request) (*url.URL, error) wrongFingerprint := "IYSLsapSKlkofKfi6M2hmS4gzXbQKGIX/DHBWIgstw3=" + w := requireDonorWallet(t) + + req, err := w.signRegistration("randomlabel") + if err != nil { + t.Error(err) + } + + // Check fingerprint error case client := &http.Client{ Timeout: time.Second * 60, // remove middleware calls Transport: &http.Transport{ - Proxy: proxy, - DialTLSContext: pindialer.MakeContextDialer(wrongFingerprint), + Proxy: proxy, + TLSClientConfig: pindialer.GetTLSConfig(wrongFingerprint), }, } - w := requireDonorWallet(t) + _, err = client.Do(req) + assert.ErrorContains(t, err, "the server certificate was not valid") - req, err := w.signRegistration("randomlabel") - if err != nil { - t.Error(err) + // Check the fingerprint success case. + tlsConfig := pindialer.GetTLSConfig(upholdCertFingerprint) + + // VerifyConnection callback is only called after + // tlsConfig.VerifyPeerCertificate returns success. + verifyConnectionCalled := false + if tlsConfig.VerifyConnection != nil { + t.Fatalf("tlsConfig.VerifyConnection must be unset") + } + tlsConfig.VerifyConnection = func(tls.ConnectionState) error { + if verifyConnectionCalled { + t.Fatalf("Unexpected extra call to VerifyConnection") + } + verifyConnectionCalled = true + return nil } - _, err = client.Do(req) - // should fail here - if err == nil { - t.Error("unable to fail with bad cert") + client = &http.Client{ + Timeout: time.Second * 60, + Transport: &http.Transport{ + Proxy: proxy, + TLSClientConfig: tlsConfig, + }, } - assert.Equal(t, errors.Unwrap(err).Error(), "failed to validate certificate chain: the server certificate was not valid") + + _, _ = client.Do(req) + assert.Equal(t, true, verifyConnectionCalled) } func requireDonorWallet(t *testing.T) *Wallet {