Skip to content

Commit 73e5f34

Browse files
phireworkawly
authored andcommitted
[tailscale] net: add TCP socket creation/close hooks to SockTrace API
Extends the hooks added by #45 to also expose when TCP sockets are created or closed (meant to allow TCP stats to be read from them). We don't do this for all socket types since stats are not available for UDP sockets, and they tend to be short-lived, thus invoking the hooks would be useless overhead. Also fixes read/write hooks to not count out-of-band data, since that's usually not sent over the wire. Updates tailscale/corp#9230 Updates #58 Signed-off-by: Jenny Zhang <[email protected]> (Cherry-picked from db4dc90)
1 parent fa171a9 commit 73e5f34

File tree

4 files changed

+40
-6
lines changed

4 files changed

+40
-6
lines changed

api/go1.99999.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ pkg net, type SockTrace struct #58
66
pkg net, type SockTrace struct, DidRead func(int) #58
77
pkg net, type SockTrace struct, DidWrite func(int) #58
88
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58
9+
pkg net, type SockTrace struct, DidCreateTCPConn func(syscall.RawConn) #58
10+
pkg net, type SockTrace struct, WillCloseTCPConn func(syscall.RawConn) #58

src/net/fd_posix.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type netFD struct {
2929
// number of bytes transferred.
3030
readHook func(int)
3131
writeHook func(int)
32+
closeHook func()
3233
}
3334

3435
func (fd *netFD) setAddr(laddr, raddr Addr) {
@@ -39,6 +40,9 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
3940

4041
func (fd *netFD) Close() error {
4142
runtime.SetFinalizer(fd, nil)
43+
if fd.closeHook != nil {
44+
fd.closeHook()
45+
}
4246
return fd.pfd.Close()
4347
}
4448

@@ -49,10 +53,16 @@ func (fd *netFD) shutdown(how int) error {
4953
}
5054

5155
func (fd *netFD) closeRead() error {
56+
if fd.closeHook != nil {
57+
fd.closeHook()
58+
}
5259
return fd.shutdown(syscall.SHUT_RD)
5360
}
5461

5562
func (fd *netFD) closeWrite() error {
63+
if fd.closeHook != nil {
64+
fd.closeHook()
65+
}
5666
return fd.shutdown(syscall.SHUT_WR)
5767
}
5868

@@ -94,7 +104,7 @@ func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, er
94104
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
95105
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
96106
if fd.readHook != nil && err == nil {
97-
fd.readHook(n + oobn)
107+
fd.readHook(n)
98108
}
99109
runtime.KeepAlive(fd)
100110
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
@@ -103,7 +113,7 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int
103113
func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
104114
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
105115
if fd.readHook != nil && err == nil {
106-
fd.readHook(n + oobn)
116+
fd.readHook(n)
107117
}
108118
runtime.KeepAlive(fd)
109119
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
@@ -112,7 +122,7 @@ func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.Socka
112122
func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
113123
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
114124
if fd.readHook != nil && err == nil {
115-
fd.readHook(n + oobn)
125+
fd.readHook(n)
116126
}
117127
runtime.KeepAlive(fd)
118128
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
@@ -157,7 +167,7 @@ func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err e
157167
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
158168
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
159169
if fd.writeHook != nil && err == nil {
160-
fd.writeHook(n + oobn)
170+
fd.writeHook(n)
161171
}
162172
runtime.KeepAlive(fd)
163173
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
@@ -166,7 +176,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
166176
func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
167177
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
168178
if fd.writeHook != nil && err == nil {
169-
fd.writeHook(n + oobn)
179+
fd.writeHook(n)
170180
}
171181
runtime.KeepAlive(fd)
172182
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
@@ -175,7 +185,7 @@ func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4)
175185
func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
176186
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
177187
if fd.writeHook != nil && err == nil {
178-
fd.writeHook(n + oobn)
188+
fd.writeHook(n)
179189
}
180190
runtime.KeepAlive(fd)
181191
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)

src/net/sock_posix.go

+15
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
3131
if trace := ContextSockTrace(ctx); trace != nil {
3232
fd.readHook = trace.DidRead
3333
fd.writeHook = trace.DidWrite
34+
if (trace.DidCreateTCPConn != nil || trace.WillCloseTCPConn != nil) && len(net) >= 3 && net[0:3] == "tcp" {
35+
// Ignore newRawConn errors (they're not possible in the current
36+
// implementation, but even if they were, we don't want to
37+
// affect socket operations for a trace hook invocation).
38+
if c, err := newRawConn(fd); err == nil {
39+
if trace.DidCreateTCPConn != nil {
40+
trace.DidCreateTCPConn(c)
41+
}
42+
if trace.WillCloseTCPConn != nil {
43+
fd.closeHook = func() {
44+
trace.WillCloseTCPConn(c)
45+
}
46+
}
47+
}
48+
}
3449
}
3550

3651
// This function makes a network file descriptor for the

src/net/socktrace.go

+7
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ package net
66

77
import (
88
"context"
9+
"syscall"
910
)
1011

1112
// SockTrace is a set of hooks to run at various operations on a network socket.
1213
// Any particular hook may be nil. Functions may be called concurrently from
1314
// different goroutines.
1415
type SockTrace struct {
16+
// DidOpenTCPConn is called when a TCP socket was created. The
17+
// underlying raw network connection that was created is provided.
18+
DidCreateTCPConn func(c syscall.RawConn)
1519
// DidRead is called after a successful read from the socket, where n bytes
1620
// were read.
1721
DidRead func(n int)
@@ -22,6 +26,9 @@ type SockTrace struct {
2226
// subsequent call to WithSockTrace. The provided trace is the new trace
2327
// that will be used.
2428
WillOverwrite func(trace *SockTrace)
29+
// WillCloseTCPConn is called when a TCP socket is about to be closed. The
30+
// underlying raw network connection that is being closed is provided.
31+
WillCloseTCPConn func(c syscall.RawConn)
2532
}
2633

2734
// WithSockTrace returns a new context based on the provided parent

0 commit comments

Comments
 (0)