Skip to content

Commit 584ea78

Browse files
committed
[tailscale] net: add SockTrace API
Loosely inspired by nettrace/httptrace, allows functions to be called when sockets are read from or written to. The hooks are specified via the context (with a WithSockTrace function). Only implemented for network sockets on POSIX systems. Updates tailscale/corp#9230 Updates #58 Signed-off-by: Jenny Zhang <[email protected]> (Cherry-picked from fb11c0d)
1 parent 9c10558 commit 584ea78

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

api/go1.99999.txt

+6
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
22
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
3+
pkg net, func WithSockTrace(context.Context, *SockTrace) context.Context #58
4+
pkg net, func ContextSockTrace(context.Context) *SockTrace #58
5+
pkg net, type SockTrace struct #58
6+
pkg net, type SockTrace struct, DidRead func(int) #58
7+
pkg net, type SockTrace struct, DidWrite func(int) #58
8+
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58

src/net/fd_posix.go

+47
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ type netFD struct {
2424
net string
2525
laddr Addr
2626
raddr Addr
27+
28+
// hooks (if provided) are called after successful reads or writes with the
29+
// number of bytes transferred.
30+
readHook func(int)
31+
writeHook func(int)
2732
}
2833

2934
func (fd *netFD) setAddr(laddr, raddr Addr) {
@@ -53,83 +58,125 @@ func (fd *netFD) closeWrite() error {
5358

5459
func (fd *netFD) Read(p []byte) (n int, err error) {
5560
n, err = fd.pfd.Read(p)
61+
if fd.readHook != nil && err == nil {
62+
fd.readHook(n)
63+
}
5664
runtime.KeepAlive(fd)
5765
return n, wrapSyscallError(readSyscallName, err)
5866
}
5967

6068
func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
6169
n, sa, err = fd.pfd.ReadFrom(p)
70+
if fd.readHook != nil && err == nil {
71+
fd.readHook(n)
72+
}
6273
runtime.KeepAlive(fd)
6374
return n, sa, wrapSyscallError(readFromSyscallName, err)
6475
}
6576
func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) {
6677
n, err = fd.pfd.ReadFromInet4(p, from)
78+
if fd.readHook != nil && err == nil {
79+
fd.readHook(n)
80+
}
6781
runtime.KeepAlive(fd)
6882
return n, wrapSyscallError(readFromSyscallName, err)
6983
}
7084

7185
func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) {
7286
n, err = fd.pfd.ReadFromInet6(p, from)
87+
if fd.readHook != nil && err == nil {
88+
fd.readHook(n)
89+
}
7390
runtime.KeepAlive(fd)
7491
return n, wrapSyscallError(readFromSyscallName, err)
7592
}
7693

7794
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
7895
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
96+
if fd.readHook != nil && err == nil {
97+
fd.readHook(n + oobn)
98+
}
7999
runtime.KeepAlive(fd)
80100
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
81101
}
82102

83103
func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
84104
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
105+
if fd.readHook != nil && err == nil {
106+
fd.readHook(n + oobn)
107+
}
85108
runtime.KeepAlive(fd)
86109
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
87110
}
88111

89112
func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
90113
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
114+
if fd.readHook != nil && err == nil {
115+
fd.readHook(n + oobn)
116+
}
91117
runtime.KeepAlive(fd)
92118
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
93119
}
94120

95121
func (fd *netFD) Write(p []byte) (nn int, err error) {
96122
nn, err = fd.pfd.Write(p)
123+
if fd.writeHook != nil && err == nil {
124+
fd.writeHook(nn)
125+
}
97126
runtime.KeepAlive(fd)
98127
return nn, wrapSyscallError(writeSyscallName, err)
99128
}
100129

101130
func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
102131
n, err = fd.pfd.WriteTo(p, sa)
132+
if fd.writeHook != nil && err == nil {
133+
fd.writeHook(n)
134+
}
103135
runtime.KeepAlive(fd)
104136
return n, wrapSyscallError(writeToSyscallName, err)
105137
}
106138

107139
func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
108140
n, err = fd.pfd.WriteToInet4(p, sa)
141+
if fd.writeHook != nil && err == nil {
142+
fd.writeHook(n)
143+
}
109144
runtime.KeepAlive(fd)
110145
return n, wrapSyscallError(writeToSyscallName, err)
111146
}
112147

113148
func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
114149
n, err = fd.pfd.WriteToInet6(p, sa)
150+
if fd.writeHook != nil && err == nil {
151+
fd.writeHook(n)
152+
}
115153
runtime.KeepAlive(fd)
116154
return n, wrapSyscallError(writeToSyscallName, err)
117155
}
118156

119157
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
120158
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
159+
if fd.writeHook != nil && err == nil {
160+
fd.writeHook(n + oobn)
161+
}
121162
runtime.KeepAlive(fd)
122163
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
123164
}
124165

125166
func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
126167
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
168+
if fd.writeHook != nil && err == nil {
169+
fd.writeHook(n + oobn)
170+
}
127171
runtime.KeepAlive(fd)
128172
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
129173
}
130174

131175
func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
132176
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
177+
if fd.writeHook != nil && err == nil {
178+
fd.writeHook(n + oobn)
179+
}
133180
runtime.KeepAlive(fd)
134181
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
135182
}

src/net/sock_posix.go

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
2828
poll.CloseFunc(s)
2929
return nil, err
3030
}
31+
if trace := ContextSockTrace(ctx); trace != nil {
32+
fd.readHook = trace.DidRead
33+
fd.writeHook = trace.DidWrite
34+
}
3135

3236
// This function makes a network file descriptor for the
3337
// following applications:

src/net/socktrace.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package net
6+
7+
import (
8+
"context"
9+
)
10+
11+
// SockTrace is a set of hooks to run at various operations on a network socket.
12+
// Any particular hook may be nil. Functions may be called concurrently from
13+
// different goroutines.
14+
type SockTrace struct {
15+
// DidRead is called after a successful read from the socket, where n bytes
16+
// were read.
17+
DidRead func(n int)
18+
// DidWrite is called after a successful write to the socket, where n bytes
19+
// were written.
20+
DidWrite func(n int)
21+
// WillOverwrite is called when the registered trace is overwritten by a
22+
// subsequent call to WithSockTrace. The provided trace is the new trace
23+
// that will be used.
24+
WillOverwrite func(trace *SockTrace)
25+
}
26+
27+
// WithSockTrace returns a new context based on the provided parent
28+
// ctx. Socket reads and writes made with the returned context will use
29+
// the provided trace hooks. Any previous hooks registered with ctx are
30+
// ovewritten (their WillOverwrite hook will be called).
31+
func WithSockTrace(ctx context.Context, trace *SockTrace) context.Context {
32+
if previous := ContextSockTrace(ctx); previous != nil && previous.WillOverwrite != nil {
33+
previous.WillOverwrite(trace)
34+
}
35+
return context.WithValue(ctx, sockTraceKey{}, trace)
36+
}
37+
38+
// ContextSockTrace returns the SockTrace associated with the
39+
// provided context. If none, it returns nil.
40+
func ContextSockTrace(ctx context.Context) *SockTrace {
41+
trace, _ := ctx.Value(sockTraceKey{}).(*SockTrace)
42+
return trace
43+
}
44+
45+
// unique type to prevent assignment.
46+
type sockTraceKey struct{}

0 commit comments

Comments
 (0)