diff --git a/chotki.go b/chotki.go index c1f4ca4..a08d076 100644 --- a/chotki.go +++ b/chotki.go @@ -395,37 +395,37 @@ func (cho *Chotki) ObjectMapper() *ORM { return NewORM(cho, cho.db.NewSnapshot()) } -func (cho *Chotki) RestoreNet(ctx context.Context) error { +func (cho *Chotki) RestoreNet() error { i := cho.db.NewIter(&pebble.IterOptions{}) defer i.Close() for i.SeekGE([]byte{'l'}); i.Valid() && i.Key()[0] == 'L'; i.Next() { address := string(i.Key()[1:]) - _ = cho.net.Listen(ctx, address) + _ = cho.net.Listen(address) } for i.SeekGE([]byte{'c'}); i.Valid() && i.Key()[0] == 'C'; i.Next() { address := string(i.Key()[1:]) - _ = cho.net.Connect(ctx, address) + _ = cho.net.Connect(address) } return nil } -func (cho *Chotki) Listen(ctx context.Context, addr string) error { - return cho.net.Listen(ctx, addr) +func (cho *Chotki) Listen(addr string) error { + return cho.net.Listen(addr) } func (cho *Chotki) Unlisten(addr string) error { return cho.net.Unlisten(addr) } -func (cho *Chotki) Connect(ctx context.Context, addr string) error { - return cho.net.Connect(ctx, addr) +func (cho *Chotki) Connect(addr string) error { + return cho.net.Connect(addr) } -func (cho *Chotki) ConnectPool(ctx context.Context, name string, addrs []string) error { - return cho.net.ConnectPool(ctx, name, addrs) +func (cho *Chotki) ConnectPool(name string, addrs []string) error { + return cho.net.ConnectPool(name, addrs) } func (cho *Chotki) Disconnect(addr string) error { diff --git a/protocol/net.go b/protocol/net.go index 73c5cbb..22bbe53 100644 --- a/protocol/net.go +++ b/protocol/net.go @@ -9,7 +9,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "time" "github.com/drpcorg/chotki/utils" @@ -44,15 +43,15 @@ type DestroyCallback func(name string, p Traced) // HTTP/RPC server as, for example, we cannot let one slow receiver delay // event transmission to all the other receivers. type Net struct { - closed atomic.Bool - wg sync.WaitGroup log utils.Logger onInstall InstallCallback onDestroy DestroyCallback - conns *xsync.MapOf[string, *Peer] - listens *xsync.MapOf[string, net.Listener] + conns *xsync.MapOf[string, *Peer] + listens *xsync.MapOf[string, net.Listener] + ctx context.Context + cancelCtx context.CancelFunc tlsConfig *tls.Config readBufferTcpSize int @@ -97,8 +96,11 @@ func (opt *TcpBufferSizeOpt) Apply(n *Net) { } func NewNet(log utils.Logger, install InstallCallback, destroy DestroyCallback, opts ...NetOpt) *Net { + ctx, cancel := context.WithCancel(context.Background()) net := &Net{ log: log, + cancelCtx: cancel, + ctx: ctx, conns: xsync.NewMapOf[string, *Peer](), listens: xsync.NewMapOf[string, net.Listener](), onInstall: install, @@ -128,7 +130,7 @@ func (n *Net) GetStats() NetStats { } func (n *Net) Close() error { - n.closed.Store(true) + n.cancelCtx() n.listens.Range(func(_ string, v net.Listener) bool { v.Close() @@ -149,11 +151,11 @@ func (n *Net) Close() error { return nil } -func (n *Net) Connect(ctx context.Context, addr string) (err error) { - return n.ConnectPool(ctx, addr, []string{addr}) +func (n *Net) Connect(addr string) (err error) { + return n.ConnectPool(addr, []string{addr}) } -func (n *Net) ConnectPool(ctx context.Context, name string, addrs []string) (err error) { +func (n *Net) ConnectPool(name string, addrs []string) (err error) { // nil is needed so that Connect cannot be called // while KeepConnecting is connects if _, ok := n.conns.LoadOrStore(name, nil); ok { @@ -162,7 +164,7 @@ func (n *Net) ConnectPool(ctx context.Context, name string, addrs []string) (err n.wg.Add(1) go func() { - n.KeepConnecting(ctx, fmt.Sprintf("connect:%s", name), addrs) + n.KeepConnecting(fmt.Sprintf("connect:%s", name), addrs) n.wg.Done() }() @@ -179,14 +181,14 @@ func (de *Net) Disconnect(name string) (err error) { return nil } -func (n *Net) Listen(ctx context.Context, addr string) error { +func (n *Net) Listen(addr string) error { // nil is needed so that Listen cannot be called // while creating listener if _, ok := n.listens.LoadOrStore(addr, nil); ok { return ErrAddressDuplicated } - listener, err := n.createListener(ctx, addr) + listener, err := n.createListener(addr) if err != nil { n.listens.Delete(addr) return err @@ -197,7 +199,7 @@ func (n *Net) Listen(ctx context.Context, addr string) error { n.wg.Add(1) go func() { - n.KeepListening(ctx, addr) + n.KeepListening(addr) n.wg.Done() }() @@ -213,21 +215,13 @@ func (de *Net) Unlisten(addr string) error { return listener.Close() } -func (n *Net) KeepConnecting(ctx context.Context, name string, addrs []string) { +func (n *Net) KeepConnecting(name string, addrs []string) { connBackoff := MIN_RETRY_PERIOD - - for !n.closed.Load() { - select { - case <-ctx.Done(): - break - default: - // continue - } - + for n.ctx.Err() == nil { var err error var conn net.Conn for _, addr := range addrs { - conn, err = n.createConn(ctx, addr) + conn, err = n.createConn(addr) if err == nil { break } @@ -238,18 +232,18 @@ func (n *Net) KeepConnecting(ctx context.Context, name string, addrs []string) { select { case <-time.After(connBackoff): - case <-ctx.Done(): + case <-n.ctx.Done(): break } connBackoff = min(MAX_RETRY_PERIOD, connBackoff*2) continue } - n.setTCPBuffersSize(n.log.WithDefaultArgs(ctx, "name", name), conn) + n.setTCPBuffersSize(n.log.WithDefaultArgs(context.Background(), "name", name), conn) n.log.Info("net: connected", "name", name) connBackoff = MIN_RETRY_PERIOD - n.keepPeer(ctx, name, conn) + n.keepPeer(name, conn) } } @@ -277,15 +271,8 @@ func (n *Net) setTCPBuffersSize(ctx context.Context, conn net.Conn) { } } -func (n *Net) KeepListening(ctx context.Context, addr string) { - for !n.closed.Load() { - select { - case <-ctx.Done(): - break - default: - // continue - } - +func (n *Net) KeepListening(addr string) { + for n.ctx.Err() == nil { listener, ok := n.listens.Load(addr) if !ok { break @@ -304,10 +291,10 @@ func (n *Net) KeepListening(ctx context.Context, addr string) { remoteAddr := conn.RemoteAddr().String() n.log.Info("net: accept connection", "addr", addr, "remoteAddr", remoteAddr) - n.setTCPBuffersSize(n.log.WithDefaultArgs(ctx, "addr", addr, "remoteAdds", remoteAddr), conn) + n.setTCPBuffersSize(n.log.WithDefaultArgs(context.Background(), "addr", addr, "remoteAdds", remoteAddr), conn) n.wg.Add(1) go func() { - n.keepPeer(ctx, fmt.Sprintf("listen:%s:%s", uuid.Must(uuid.NewV7()).String(), remoteAddr), conn) + n.keepPeer(fmt.Sprintf("listen:%s:%s", uuid.Must(uuid.NewV7()).String(), remoteAddr), conn) defer n.wg.Done() }() } @@ -321,7 +308,7 @@ func (n *Net) KeepListening(ctx context.Context, addr string) { n.log.Info("net: listener closed", "addr", addr) } -func (n *Net) keepPeer(ctx context.Context, name string, conn net.Conn) { +func (n *Net) keepPeer(name string, conn net.Conn) { peer := &Peer{ inout: n.onInstall(name), conn: conn, @@ -331,7 +318,7 @@ func (n *Net) keepPeer(ctx context.Context, name string, conn net.Conn) { } n.conns.Store(name, peer) - readErr, writeErr, closeErr := peer.Keep(ctx) + readErr, writeErr, closeErr := peer.Keep(n.ctx) if readErr != nil { n.log.Error("net: couldn't read from peer", "name", name, "err", readErr, "trace_id", peer.GetTraceId()) } @@ -346,7 +333,7 @@ func (n *Net) keepPeer(ctx context.Context, name string, conn net.Conn) { n.onDestroy(name, peer) } -func (n *Net) createListener(ctx context.Context, addr string) (net.Listener, error) { +func (n *Net) createListener(addr string) (net.Listener, error) { connType, address, err := parseAddr(addr) if err != nil { return nil, err @@ -356,13 +343,13 @@ func (n *Net) createListener(ctx context.Context, addr string) (net.Listener, er switch connType { case TCP: config := net.ListenConfig{} - if listener, err = config.Listen(ctx, "tcp", address); err != nil { + if listener, err = config.Listen(n.ctx, "tcp", address); err != nil { return nil, err } case TLS: config := net.ListenConfig{} - if listener, err = config.Listen(ctx, "tcp", address); err != nil { + if listener, err = config.Listen(n.ctx, "tcp", address); err != nil { return nil, err } @@ -375,7 +362,7 @@ func (n *Net) createListener(ctx context.Context, addr string) (net.Listener, er return listener, nil } -func (n *Net) createConn(ctx context.Context, addr string) (net.Conn, error) { +func (n *Net) createConn(addr string) (net.Conn, error) { connType, address, err := parseAddr(addr) if err != nil { return nil, err @@ -385,14 +372,14 @@ func (n *Net) createConn(ctx context.Context, addr string) (net.Conn, error) { switch connType { case TCP: d := net.Dialer{Timeout: time.Minute} - if conn, err = d.DialContext(ctx, "tcp", address); err != nil { + if conn, err = d.DialContext(n.ctx, "tcp", address); err != nil { return nil, err } case TLS: d := tls.Dialer{Config: n.tlsConfig} - if conn, err = d.DialContext(ctx, "tcp", address); err != nil { + if conn, err = d.DialContext(n.ctx, "tcp", address); err != nil { return nil, err } diff --git a/protocol/net_test.go b/protocol/net_test.go index 1cf767d..22d2f2d 100644 --- a/protocol/net_test.go +++ b/protocol/net_test.go @@ -82,7 +82,7 @@ func TestTCPDepot_Connect(t *testing.T) { return &TracedQueue[Records, []byte]{lCon} }, func(_ string, t Traced) { lCon.Close() }, &NetTlsConfigOpt{tlsConfig("a.chotki.local")}) - err := l.Listen(context.Background(), loop) + err := l.Listen(loop) assert.Nil(t, err) cCon := utils.NewFDQueue[Records](16, time.Millisecond, 0) @@ -90,7 +90,7 @@ func TestTCPDepot_Connect(t *testing.T) { return &TracedQueue[Records, []byte]{cCon} }, func(_ string, t Traced) { cCon.Close() }, &NetTlsConfigOpt{tlsConfig("b.chotki.local")}) - err = c.Connect(context.Background(), loop) + err = c.Connect(loop) assert.Nil(t, err) time.Sleep(time.Second) // Wait connection, todo use events @@ -139,7 +139,7 @@ func TestTCPDepot_ConnectFailed(t *testing.T) { return &TracedQueue[Records, []byte]{cCon} }, func(_ string, t Traced) { cCon.Close() }, &NetTlsConfigOpt{tlsConfig("b.chotki.local")}) - err := c.Connect(context.Background(), loop) + err := c.Connect(loop) assert.Nil(t, err) time.Sleep(time.Second) // Wait connection, todo use events diff --git a/repl/commands.go b/repl/commands.go index adf5c7f..4a9d2a2 100644 --- a/repl/commands.go +++ b/repl/commands.go @@ -421,7 +421,7 @@ func (repl *REPL) CommandListen(arg *rdx.RDX) (id rdx.ID, err error) { } addr := rdx.Snative(rdx.Sparse(string(arg.Text))) if err == nil { - err = repl.Host.Listen(context.Background(), addr) + err = repl.Host.Listen(addr) } return } @@ -434,7 +434,7 @@ func (repl *REPL) CommandConnect(arg *rdx.RDX) (id rdx.ID, err error) { } addr := rdx.Snative(rdx.Sparse(string(arg.Text))) if err == nil { - err = repl.Host.Connect(context.Background(), addr) + err = repl.Host.Connect(addr) } return }