forked from refraction-networking/utls
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathu_handshake_client.go
163 lines (144 loc) · 5.17 KB
/
u_handshake_client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright 2022 uTLS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"compress/zlib"
"errors"
"fmt"
"io"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
)
// This function is called by (*clientHandshakeStateTLS13).readServerCertificate()
// to retrieve the certificate out of a message read by (*Conn).readHandshake()
func (hs *clientHandshakeStateTLS13) utlsReadServerCertificate(msg any) (processedMsg any, err error) {
for _, ext := range hs.uconn.Extensions {
switch ext.(type) {
case *UtlsCompressCertExtension:
// Included Compressed Certificate extension
if len(hs.uconn.certCompressionAlgs) > 0 {
compressedCertMsg, ok := msg.(*utlsCompressedCertificateMsg)
if ok {
hs.transcript.Write(compressedCertMsg.marshal())
msg, err = hs.decompressCert(*compressedCertMsg)
if err != nil {
return nil, fmt.Errorf("tls: failed to decompress certificate message: %w", err)
} else {
return msg, nil
}
}
}
default:
continue
}
}
return nil, nil
}
// called by (*clientHandshakeStateTLS13).utlsReadServerCertificate() when UtlsCompressCertExtension is used
func (hs *clientHandshakeStateTLS13) decompressCert(m utlsCompressedCertificateMsg) (*certificateMsgTLS13, error) {
var (
decompressed io.Reader
compressed = bytes.NewReader(m.compressedCertificateMessage)
c = hs.c
)
// Check to see if the peer responded with an algorithm we advertised.
supportedAlg := false
for _, alg := range hs.uconn.certCompressionAlgs {
if m.algorithm == uint16(alg) {
supportedAlg = true
}
}
if !supportedAlg {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm)
}
switch CertCompressionAlgo(m.algorithm) {
case CertCompressionBrotli:
decompressed = brotli.NewReader(compressed)
case CertCompressionZlib:
rc, err := zlib.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zlib reader: %w", err)
}
defer rc.Close()
decompressed = rc
case CertCompressionZstd:
rc, err := zstd.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zstd reader: %w", err)
}
defer rc.Close()
decompressed = rc
default:
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm)
}
rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field
rawMsg[0] = typeCertificate
rawMsg[1] = uint8(m.uncompressedLength >> 16)
rawMsg[2] = uint8(m.uncompressedLength >> 8)
rawMsg[3] = uint8(m.uncompressedLength)
n, err := decompressed.Read(rawMsg[4:])
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, err
}
if n < len(rawMsg)-4 {
// If, after decompression, the specified length does not match the actual length, the party
// receiving the invalid message MUST abort the connection with the "bad_certificate" alert.
// https://datatracker.ietf.org/doc/html/rfc8879#section-4
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength)
}
certMsg := new(certificateMsgTLS13)
if !certMsg.unmarshal(rawMsg) {
return nil, c.sendAlert(alertUnexpectedMessage)
}
return certMsg, nil
}
// to be called in (*clientHandshakeStateTLS13).handshake(),
// after hs.readServerFinished() and before hs.sendClientCertificate()
func (hs *clientHandshakeStateTLS13) serverFinishedReceived() error {
if err := hs.sendClientEncryptedExtensions(); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientEncryptedExtensions() error {
c := hs.c
clientEncryptedExtensions := new(utlsClientEncryptedExtensionsMsg)
if c.utls.hasApplicationSettings {
clientEncryptedExtensions.hasApplicationSettings = true
clientEncryptedExtensions.applicationSettings = c.utls.localApplicationSettings
hs.transcript.Write(clientEncryptedExtensions.marshal())
if _, err := c.writeRecord(recordTypeHandshake, clientEncryptedExtensions.marshal()); err != nil {
return err
}
}
return nil
}
func (hs *clientHandshakeStateTLS13) utlsReadServerParameters(encryptedExtensions *encryptedExtensionsMsg) error {
hs.c.utls.hasApplicationSettings = encryptedExtensions.utls.hasApplicationSettings
hs.c.utls.peerApplicationSettings = encryptedExtensions.utls.applicationSettings
if hs.c.utls.hasApplicationSettings {
if hs.uconn.vers < VersionTLS13 {
return errors.New("tls: server sent application settings at invalid version")
}
if len(hs.uconn.clientProtocol) == 0 {
return errors.New("tls: server sent application settings without ALPN")
}
// Check if the ALPN selected by the server exists in the client's list.
if alps, ok := hs.uconn.config.ApplicationSettings[hs.serverHello.alpnProtocol]; ok {
hs.c.utls.localApplicationSettings = alps
} else {
// return errors.New("tls: server selected ALPN doesn't match a client ALPS")
return nil // ignore if client doesn't have ALPS in use.
// TODO: is this a issue or not?
}
}
return nil
}