From 98f04bb86000693d67ed496ee269549c53e5e864 Mon Sep 17 00:00:00 2001 From: csznet Date: Sun, 10 Mar 2024 05:10:12 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=B3=E9=97=AD=E8=B6=85=E6=97=B6tcp?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- forward/forward.go | 47 +++++++++++++++++++++++++++++++++++++++++++--- utils/utils.go | 5 +++++ web/web.go | 2 +- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/forward/forward.go b/forward/forward.go index db0276e..141e1ac 100644 --- a/forward/forward.go +++ b/forward/forward.go @@ -37,7 +37,7 @@ var bufPool = sync.Pool{ // 开启转发,负责分发具体转发 func Run(stats *ConnectionStats, wg *sync.WaitGroup) { defer wg.Done() - + defer releaseResources(stats) // 在函数返回时释放资源 var ctx, cancel = context.WithCancel(context.Background()) var innerWg sync.WaitGroup @@ -142,6 +142,10 @@ func (cs *ConnectionStats) handleTCPConnection(wg *sync.WaitGroup, clientConn ne defer wg.Done() defer clientConn.Close() + // 设置连接读写超时时间 + clientConn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second)) + clientConn.SetWriteDeadline(time.Now().Add(time.Duration(5) * time.Second)) + remoteConn, err := net.Dial("tcp", cs.RemoteAddr+":"+cs.RemotePort) if err != nil { fmt.Println("连接远程地址时发生错误:", err) @@ -150,7 +154,6 @@ func (cs *ConnectionStats) handleTCPConnection(wg *sync.WaitGroup, clientConn ne defer remoteConn.Close() cs.TCPConnections = append(cs.TCPConnections, clientConn, remoteConn) // 添加连接到列表 - var copyWG sync.WaitGroup copyWG.Add(2) @@ -254,7 +257,7 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) { // 定时打印和处理流量变化 func (cs *ConnectionStats) printStats(wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() - ticker := time.NewTicker(10 * time.Second) + ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() // 在函数结束时停止定时器 for { select { @@ -277,6 +280,16 @@ func (cs *ConnectionStats) printStats(wg *sync.WaitGroup, ctx context.Context) { } cs.TotalBytesOld = cs.TotalBytes sql.UpdateForwardBytes(cs.Id, cs.TotalBytes) + fmt.Printf("【%s】端口 %s 当前连接数: %d\n", cs.Protocol, cs.LocalPort, len(cs.TCPConnections)) + } else { + if cs.Protocol == "tcp" { + for i := len(cs.TCPConnections) - 1; i >= 0; i-- { + conn := cs.TCPConnections[i] + conn.Close() + // 从连接列表中移除关闭的连接 + cs.TCPConnections = append(cs.TCPConnections[:i], cs.TCPConnections[i+1:]...) + } + } } cs.TotalBytesLock.Unlock() //当协程退出时执行 @@ -285,3 +298,31 @@ func (cs *ConnectionStats) printStats(wg *sync.WaitGroup, ctx context.Context) { } } } + +// 关闭 TCP 连接并从切片中移除 +func closeTCPConnections(stats *ConnectionStats) { + stats.TotalBytesLock.Lock() + defer stats.TotalBytesLock.Unlock() + for _, conn := range stats.TCPConnections { + conn.Close() + } + stats.TCPConnections = nil // 清空切片 +} + +// 清理缓冲区 +func cleanupBuffer() { + // 如果有剩余的缓冲区,归还给池 + for { + buf := bufPool.Get() + if buf == nil { + break + } + bufPool.Put(buf) + } +} + +// 释放资源 +func releaseResources(stats *ConnectionStats) { + closeTCPConnections(stats) + cleanupBuffer() +} diff --git a/utils/utils.go b/utils/utils.go index ad152e7..e91644d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "sync" "csz.net/goForward/conf" @@ -10,6 +11,10 @@ import ( // 增加转发并开启 func AddForward(newF conf.ConnectionStats) bool { + fmt.Print(newF) + if newF.LocalPort == conf.WebPort { + return false + } id := sql.AddForward(newF) if id > 0 { stats := &forward.ConnectionStats{ diff --git a/web/web.go b/web/web.go index f1e1206..df173f5 100644 --- a/web/web.go +++ b/web/web.go @@ -46,7 +46,7 @@ func Run() { }) } else { c.HTML(200, "msg.tmpl", gin.H{ - "msg": "添加失败,本地端口正在转发", + "msg": "添加失败,端口已占用", "suc": false, }) }