Skip to content

Commit

Permalink
refactor closing in net & peer, remove ctx from some functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Termina1 committed Nov 22, 2024
1 parent 8a4113d commit 116b653
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 60 deletions.
18 changes: 9 additions & 9 deletions chotki.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,37 +395,37 @@ func (cho *Chotki) ObjectMapper() *ORM {
return NewORM(cho, cho.db.NewSnapshot())
}

func (cho *Chotki) RestoreNet(ctx context.Context) error {
func (cho *Chotki) RestoreNet() error {
i := cho.db.NewIter(&pebble.IterOptions{})
defer i.Close()

for i.SeekGE([]byte{'l'}); i.Valid() && i.Key()[0] == 'L'; i.Next() {
address := string(i.Key()[1:])
_ = cho.net.Listen(ctx, address)
_ = cho.net.Listen(address)
}

for i.SeekGE([]byte{'c'}); i.Valid() && i.Key()[0] == 'C'; i.Next() {
address := string(i.Key()[1:])
_ = cho.net.Connect(ctx, address)
_ = cho.net.Connect(address)
}

return nil
}

func (cho *Chotki) Listen(ctx context.Context, addr string) error {
return cho.net.Listen(ctx, addr)
func (cho *Chotki) Listen(addr string) error {
return cho.net.Listen(addr)
}

func (cho *Chotki) Unlisten(addr string) error {
return cho.net.Unlisten(addr)
}

func (cho *Chotki) Connect(ctx context.Context, addr string) error {
return cho.net.Connect(ctx, addr)
func (cho *Chotki) Connect(addr string) error {
return cho.net.Connect(addr)
}

func (cho *Chotki) ConnectPool(ctx context.Context, name string, addrs []string) error {
return cho.net.ConnectPool(ctx, name, addrs)
func (cho *Chotki) ConnectPool(name string, addrs []string) error {
return cho.net.ConnectPool(name, addrs)
}

func (cho *Chotki) Disconnect(addr string) error {
Expand Down
79 changes: 33 additions & 46 deletions protocol/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/drpcorg/chotki/utils"
Expand Down Expand Up @@ -44,15 +43,15 @@ type DestroyCallback func(name string, p Traced)
// HTTP/RPC server as, for example, we cannot let one slow receiver delay
// event transmission to all the other receivers.
type Net struct {
closed atomic.Bool

wg sync.WaitGroup
log utils.Logger
onInstall InstallCallback
onDestroy DestroyCallback

conns *xsync.MapOf[string, *Peer]
listens *xsync.MapOf[string, net.Listener]
conns *xsync.MapOf[string, *Peer]
listens *xsync.MapOf[string, net.Listener]
ctx context.Context
cancelCtx context.CancelFunc

tlsConfig *tls.Config
readBufferTcpSize int
Expand Down Expand Up @@ -97,8 +96,11 @@ func (opt *TcpBufferSizeOpt) Apply(n *Net) {
}

func NewNet(log utils.Logger, install InstallCallback, destroy DestroyCallback, opts ...NetOpt) *Net {
ctx, cancel := context.WithCancel(context.Background())
net := &Net{
log: log,
cancelCtx: cancel,
ctx: ctx,
conns: xsync.NewMapOf[string, *Peer](),
listens: xsync.NewMapOf[string, net.Listener](),
onInstall: install,
Expand Down Expand Up @@ -128,7 +130,7 @@ func (n *Net) GetStats() NetStats {
}

func (n *Net) Close() error {
n.closed.Store(true)
n.cancelCtx()

n.listens.Range(func(_ string, v net.Listener) bool {
v.Close()
Expand All @@ -149,11 +151,11 @@ func (n *Net) Close() error {
return nil
}

func (n *Net) Connect(ctx context.Context, addr string) (err error) {
return n.ConnectPool(ctx, addr, []string{addr})
func (n *Net) Connect(addr string) (err error) {
return n.ConnectPool(addr, []string{addr})
}

func (n *Net) ConnectPool(ctx context.Context, name string, addrs []string) (err error) {
func (n *Net) ConnectPool(name string, addrs []string) (err error) {
// nil is needed so that Connect cannot be called
// while KeepConnecting is connects
if _, ok := n.conns.LoadOrStore(name, nil); ok {
Expand All @@ -162,7 +164,7 @@ func (n *Net) ConnectPool(ctx context.Context, name string, addrs []string) (err

n.wg.Add(1)
go func() {
n.KeepConnecting(ctx, fmt.Sprintf("connect:%s", name), addrs)
n.KeepConnecting(fmt.Sprintf("connect:%s", name), addrs)
n.wg.Done()
}()

Expand All @@ -179,14 +181,14 @@ func (de *Net) Disconnect(name string) (err error) {
return nil
}

func (n *Net) Listen(ctx context.Context, addr string) error {
func (n *Net) Listen(addr string) error {
// nil is needed so that Listen cannot be called
// while creating listener
if _, ok := n.listens.LoadOrStore(addr, nil); ok {
return ErrAddressDuplicated
}

listener, err := n.createListener(ctx, addr)
listener, err := n.createListener(addr)
if err != nil {
n.listens.Delete(addr)
return err
Expand All @@ -197,7 +199,7 @@ func (n *Net) Listen(ctx context.Context, addr string) error {

n.wg.Add(1)
go func() {
n.KeepListening(ctx, addr)
n.KeepListening(addr)
n.wg.Done()
}()

Expand All @@ -213,21 +215,13 @@ func (de *Net) Unlisten(addr string) error {
return listener.Close()
}

func (n *Net) KeepConnecting(ctx context.Context, name string, addrs []string) {
func (n *Net) KeepConnecting(name string, addrs []string) {
connBackoff := MIN_RETRY_PERIOD

for !n.closed.Load() {
select {
case <-ctx.Done():
break
default:
// continue
}

for n.ctx.Err() == nil {
var err error
var conn net.Conn
for _, addr := range addrs {
conn, err = n.createConn(ctx, addr)
conn, err = n.createConn(addr)
if err == nil {
break
}
Expand All @@ -238,18 +232,18 @@ func (n *Net) KeepConnecting(ctx context.Context, name string, addrs []string) {

select {
case <-time.After(connBackoff):
case <-ctx.Done():
case <-n.ctx.Done():
break
}
connBackoff = min(MAX_RETRY_PERIOD, connBackoff*2)

continue
}
n.setTCPBuffersSize(n.log.WithDefaultArgs(ctx, "name", name), conn)
n.setTCPBuffersSize(n.log.WithDefaultArgs(context.Background(), "name", name), conn)
n.log.Info("net: connected", "name", name)

connBackoff = MIN_RETRY_PERIOD
n.keepPeer(ctx, name, conn)
n.keepPeer(name, conn)
}
}

Expand Down Expand Up @@ -277,15 +271,8 @@ func (n *Net) setTCPBuffersSize(ctx context.Context, conn net.Conn) {
}
}

func (n *Net) KeepListening(ctx context.Context, addr string) {
for !n.closed.Load() {
select {
case <-ctx.Done():
break
default:
// continue
}

func (n *Net) KeepListening(addr string) {
for n.ctx.Err() == nil {
listener, ok := n.listens.Load(addr)
if !ok {
break
Expand All @@ -304,10 +291,10 @@ func (n *Net) KeepListening(ctx context.Context, addr string) {

remoteAddr := conn.RemoteAddr().String()
n.log.Info("net: accept connection", "addr", addr, "remoteAddr", remoteAddr)
n.setTCPBuffersSize(n.log.WithDefaultArgs(ctx, "addr", addr, "remoteAdds", remoteAddr), conn)
n.setTCPBuffersSize(n.log.WithDefaultArgs(context.Background(), "addr", addr, "remoteAdds", remoteAddr), conn)
n.wg.Add(1)
go func() {
n.keepPeer(ctx, fmt.Sprintf("listen:%s:%s", uuid.Must(uuid.NewV7()).String(), remoteAddr), conn)
n.keepPeer(fmt.Sprintf("listen:%s:%s", uuid.Must(uuid.NewV7()).String(), remoteAddr), conn)
defer n.wg.Done()
}()
}
Expand All @@ -321,7 +308,7 @@ func (n *Net) KeepListening(ctx context.Context, addr string) {
n.log.Info("net: listener closed", "addr", addr)
}

func (n *Net) keepPeer(ctx context.Context, name string, conn net.Conn) {
func (n *Net) keepPeer(name string, conn net.Conn) {
peer := &Peer{
inout: n.onInstall(name),
conn: conn,
Expand All @@ -331,7 +318,7 @@ func (n *Net) keepPeer(ctx context.Context, name string, conn net.Conn) {
}
n.conns.Store(name, peer)

readErr, writeErr, closeErr := peer.Keep(ctx)
readErr, writeErr, closeErr := peer.Keep(n.ctx)
if readErr != nil {
n.log.Error("net: couldn't read from peer", "name", name, "err", readErr, "trace_id", peer.GetTraceId())
}
Expand All @@ -346,7 +333,7 @@ func (n *Net) keepPeer(ctx context.Context, name string, conn net.Conn) {
n.onDestroy(name, peer)
}

func (n *Net) createListener(ctx context.Context, addr string) (net.Listener, error) {
func (n *Net) createListener(addr string) (net.Listener, error) {
connType, address, err := parseAddr(addr)
if err != nil {
return nil, err
Expand All @@ -356,13 +343,13 @@ func (n *Net) createListener(ctx context.Context, addr string) (net.Listener, er
switch connType {
case TCP:
config := net.ListenConfig{}
if listener, err = config.Listen(ctx, "tcp", address); err != nil {
if listener, err = config.Listen(n.ctx, "tcp", address); err != nil {
return nil, err
}

case TLS:
config := net.ListenConfig{}
if listener, err = config.Listen(ctx, "tcp", address); err != nil {
if listener, err = config.Listen(n.ctx, "tcp", address); err != nil {
return nil, err
}

Expand All @@ -375,7 +362,7 @@ func (n *Net) createListener(ctx context.Context, addr string) (net.Listener, er
return listener, nil
}

func (n *Net) createConn(ctx context.Context, addr string) (net.Conn, error) {
func (n *Net) createConn(addr string) (net.Conn, error) {
connType, address, err := parseAddr(addr)
if err != nil {
return nil, err
Expand All @@ -385,14 +372,14 @@ func (n *Net) createConn(ctx context.Context, addr string) (net.Conn, error) {
switch connType {
case TCP:
d := net.Dialer{Timeout: time.Minute}
if conn, err = d.DialContext(ctx, "tcp", address); err != nil {
if conn, err = d.DialContext(n.ctx, "tcp", address); err != nil {
return nil, err
}

case TLS:
d := tls.Dialer{Config: n.tlsConfig}

if conn, err = d.DialContext(ctx, "tcp", address); err != nil {
if conn, err = d.DialContext(n.ctx, "tcp", address); err != nil {
return nil, err
}

Expand Down
6 changes: 3 additions & 3 deletions protocol/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ func TestTCPDepot_Connect(t *testing.T) {
return &TracedQueue[Records, []byte]{lCon}
}, func(_ string, t Traced) { lCon.Close() }, &NetTlsConfigOpt{tlsConfig("a.chotki.local")})

err := l.Listen(context.Background(), loop)
err := l.Listen(loop)
assert.Nil(t, err)

cCon := utils.NewFDQueue[Records](16, time.Millisecond, 0)
c := NewNet(log, func(_ string) FeedDrainCloserTraced {
return &TracedQueue[Records, []byte]{cCon}
}, func(_ string, t Traced) { cCon.Close() }, &NetTlsConfigOpt{tlsConfig("b.chotki.local")})

err = c.Connect(context.Background(), loop)
err = c.Connect(loop)
assert.Nil(t, err)
time.Sleep(time.Second) // Wait connection, todo use events

Expand Down Expand Up @@ -139,7 +139,7 @@ func TestTCPDepot_ConnectFailed(t *testing.T) {
return &TracedQueue[Records, []byte]{cCon}
}, func(_ string, t Traced) { cCon.Close() }, &NetTlsConfigOpt{tlsConfig("b.chotki.local")})

err := c.Connect(context.Background(), loop)
err := c.Connect(loop)
assert.Nil(t, err)
time.Sleep(time.Second) // Wait connection, todo use events

Expand Down
4 changes: 2 additions & 2 deletions repl/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func (repl *REPL) CommandListen(arg *rdx.RDX) (id rdx.ID, err error) {
}
addr := rdx.Snative(rdx.Sparse(string(arg.Text)))
if err == nil {
err = repl.Host.Listen(context.Background(), addr)
err = repl.Host.Listen(addr)
}
return
}
Expand All @@ -434,7 +434,7 @@ func (repl *REPL) CommandConnect(arg *rdx.RDX) (id rdx.ID, err error) {
}
addr := rdx.Snative(rdx.Sparse(string(arg.Text)))
if err == nil {
err = repl.Host.Connect(context.Background(), addr)
err = repl.Host.Connect(addr)
}
return
}
Expand Down

0 comments on commit 116b653

Please sign in to comment.