Skip to content

Commit 5940021

Browse files
phireworkbradfitz
authored andcommitted
[tailscale] net: add TCP socket creation/close hooks to SockTrace API
Extends the hooks added by #45 to also expose when TCP sockets are created or closed (meant to allow TCP stats to be read from them). We don't do this for all socket types since stats are not available for UDP sockets, and they tend to be short-lived, thus invoking the hooks would be useless overhead. Also fixes read/write hooks to not count out-of-band data, since that's usually not sent over the wire. Updates tailscale/corp#9230 Updates #58 Signed-off-by: Jenny Zhang <[email protected]> (Cherry-picked from db4dc90) (cherry picked from commit 51a96ad) (cherry picked from commit aff7d04) Signed-off-by: Brad Fitzpatrick <[email protected]>
1 parent ab59d02 commit 5940021

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

api/go1.99999.txt

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
pkg net, func ContextSockTrace(context.Context) *SockTrace #58
12
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
23
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
4+
pkg net, func WithSockTrace(context.Context, *SockTrace) context.Context #58
5+
pkg net, type SockTrace struct #58
6+
pkg net, type SockTrace struct, DidCreateTCPConn func(syscall.RawConn) #58
7+
pkg net, type SockTrace struct, DidRead func(int) #58
8+
pkg net, type SockTrace struct, DidWrite func(int) #58
9+
pkg net, type SockTrace struct, WillCloseTCPConn func(syscall.RawConn) #58
10+
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58
311
pkg net/http, func SetRoundTripEnforcer(func(*Request) error) #55

src/net/fd_posix.go

+33
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ type netFD struct {
2424
net string
2525
laddr Addr
2626
raddr Addr
27+
28+
// hooks (if provided) are called after successful reads or writes with the
29+
// number of bytes transferred.
30+
readHook func(int)
31+
writeHook func(int)
32+
closeHook func()
2733
}
2834

2935
func (fd *netFD) setAddr(laddr, raddr Addr) {
@@ -34,6 +40,9 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
3440

3541
func (fd *netFD) Close() error {
3642
runtime.SetFinalizer(fd, nil)
43+
if fd.closeHook != nil {
44+
fd.closeHook()
45+
}
3746
return fd.pfd.Close()
3847
}
3948

@@ -44,10 +53,16 @@ func (fd *netFD) shutdown(how int) error {
4453
}
4554

4655
func (fd *netFD) closeRead() error {
56+
if fd.closeHook != nil {
57+
fd.closeHook()
58+
}
4759
return fd.shutdown(syscall.SHUT_RD)
4860
}
4961

5062
func (fd *netFD) closeWrite() error {
63+
if fd.closeHook != nil {
64+
fd.closeHook()
65+
}
5166
return fd.shutdown(syscall.SHUT_WR)
5267
}
5368

@@ -76,18 +91,27 @@ func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, er
7691

7792
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
7893
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
94+
if fd.readHook != nil && err == nil {
95+
fd.readHook(n)
96+
}
7997
runtime.KeepAlive(fd)
8098
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
8199
}
82100

83101
func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
84102
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
103+
if fd.readHook != nil && err == nil {
104+
fd.readHook(n)
105+
}
85106
runtime.KeepAlive(fd)
86107
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
87108
}
88109

89110
func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
90111
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
112+
if fd.readHook != nil && err == nil {
113+
fd.readHook(n)
114+
}
91115
runtime.KeepAlive(fd)
92116
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
93117
}
@@ -118,18 +142,27 @@ func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err e
118142

119143
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
120144
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
145+
if fd.writeHook != nil && err == nil {
146+
fd.writeHook(n)
147+
}
121148
runtime.KeepAlive(fd)
122149
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
123150
}
124151

125152
func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
126153
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
154+
if fd.writeHook != nil && err == nil {
155+
fd.writeHook(n)
156+
}
127157
runtime.KeepAlive(fd)
128158
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
129159
}
130160

131161
func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
132162
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
163+
if fd.writeHook != nil && err == nil {
164+
fd.writeHook(n)
165+
}
133166
runtime.KeepAlive(fd)
134167
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
135168
}

src/net/sock_posix.go

+15
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,21 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
2828
poll.CloseFunc(s)
2929
return nil, err
3030
}
31+
if trace := ContextSockTrace(ctx); trace != nil {
32+
fd.readHook = trace.DidRead
33+
fd.writeHook = trace.DidWrite
34+
if (trace.DidCreateTCPConn != nil || trace.WillCloseTCPConn != nil) && len(net) >= 3 && net[0:3] == "tcp" {
35+
c := newRawConn(fd)
36+
if trace.DidCreateTCPConn != nil {
37+
trace.DidCreateTCPConn(c)
38+
}
39+
if trace.WillCloseTCPConn != nil {
40+
fd.closeHook = func() {
41+
trace.WillCloseTCPConn(c)
42+
}
43+
}
44+
}
45+
}
3146

3247
// This function makes a network file descriptor for the
3348
// following applications:

src/net/socktrace.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package net
6+
7+
import (
8+
"context"
9+
"syscall"
10+
)
11+
12+
// SockTrace is a set of hooks to run at various operations on a network socket.
13+
// Any particular hook may be nil. Functions may be called concurrently from
14+
// different goroutines.
15+
type SockTrace struct {
16+
// DidOpenTCPConn is called when a TCP socket was created. The
17+
// underlying raw network connection that was created is provided.
18+
DidCreateTCPConn func(c syscall.RawConn)
19+
// DidRead is called after a successful read from the socket, where n bytes
20+
// were read.
21+
DidRead func(n int)
22+
// DidWrite is called after a successful write to the socket, where n bytes
23+
// were written.
24+
DidWrite func(n int)
25+
// WillOverwrite is called when the registered trace is overwritten by a
26+
// subsequent call to WithSockTrace. The provided trace is the new trace
27+
// that will be used.
28+
WillOverwrite func(trace *SockTrace)
29+
// WillCloseTCPConn is called when a TCP socket is about to be closed. The
30+
// underlying raw network connection that is being closed is provided.
31+
WillCloseTCPConn func(c syscall.RawConn)
32+
}
33+
34+
// WithSockTrace returns a new context based on the provided parent
35+
// ctx. Socket reads and writes made with the returned context will use
36+
// the provided trace hooks. Any previous hooks registered with ctx are
37+
// ovewritten (their WillOverwrite hook will be called).
38+
func WithSockTrace(ctx context.Context, trace *SockTrace) context.Context {
39+
if previous := ContextSockTrace(ctx); previous != nil && previous.WillOverwrite != nil {
40+
previous.WillOverwrite(trace)
41+
}
42+
return context.WithValue(ctx, sockTraceKey{}, trace)
43+
}
44+
45+
// ContextSockTrace returns the SockTrace associated with the
46+
// provided context. If none, it returns nil.
47+
func ContextSockTrace(ctx context.Context) *SockTrace {
48+
trace, _ := ctx.Value(sockTraceKey{}).(*SockTrace)
49+
return trace
50+
}
51+
52+
// unique type to prevent assignment.
53+
type sockTraceKey struct{}

0 commit comments

Comments
 (0)