From c70e39ccfb622c9201cfdf90a8402bd7d6d3622c Mon Sep 17 00:00:00 2001 From: csznet Date: Tue, 21 May 2024 15:43:38 +0800 Subject: [PATCH] fixed issues/17 --- forward/forward.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/forward/forward.go b/forward/forward.go index e20d969..e0140fe 100644 --- a/forward/forward.go +++ b/forward/forward.go @@ -125,7 +125,7 @@ func Run(stats *ConnectionStats) { } innerWg.Add(1) go func() { - stats.handleTCPConnection(clientConn, ctx) + stats.handleTCPConnection(clientConn, ctx, cancel) innerWg.Done() }() } @@ -134,7 +134,7 @@ func Run(stats *ConnectionStats) { } // TCP转发 -func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context) { +func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context, cancel context.CancelFunc) { defer clientConn.Close() remoteConn, err := net.Dial("tcp", cs.RemoteAddr+":"+cs.RemotePort) if err != nil { @@ -147,11 +147,17 @@ func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context. copyWG.Add(2) go func() { defer copyWG.Done() - cs.copyBytes(clientConn, remoteConn) + if err := cs.copyBytes(clientConn, remoteConn); err != nil { + log.Println("复制字节时发生错误:", err) + cancel() // Assuming `cancel` is the cancel function from the context + } }() go func() { defer copyWG.Done() - cs.copyBytes(remoteConn, clientConn) + if err := cs.copyBytes(remoteConn, clientConn); err != nil { + log.Println("复制字节时发生错误:", err) + cancel() // Assuming `cancel` is the cancel function from the context + } }() for { select { @@ -204,7 +210,7 @@ func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr } -func (cs *ConnectionStats) copyBytes(dst, src net.Conn) { +func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error { buf := bufPool.Get().([]byte) defer bufPool.Put(buf) for { @@ -216,7 +222,7 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) { _, err := dst.Write(buf[:n]) if err != nil { log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err) - break + return err } } if err == io.EOF { @@ -230,6 +236,7 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) { // 关闭连接 dst.Close() src.Close() + return nil } // 定时打印和处理流量变化