diff --git a/speedtest/internal/README.md b/docs/measurement_method.md similarity index 100% rename from speedtest/internal/README.md rename to docs/measurement_method.md diff --git a/docs/speedtest_protocol_specifications.md b/docs/speedtest_protocol_specifications.md new file mode 100644 index 0000000..3ac629f --- /dev/null +++ b/docs/speedtest_protocol_specifications.md @@ -0,0 +1,46 @@ +# SpeedTest Specifications +This document records some of the interfaces defined in speedtest for reference only. + +## Native Socket Interfaces + +The protocol uses a plain text data stream and ends each message with '\n'. +And '\n' and the operators are included in the total bytes. + +| Method | Protocol | Describe | +|--------|----------|---------------------------------------------------| +| Greet | TCP | Say Hello and get the server version information. | +| PING | TCP | Echo with the server. | +| Loss | TCP+UDP | Conduct UDP packet loss test. | +| Down | TCP | Sending data to the server. | +| Up | TCP | Receive data from the server. | + +### Great +```shell +Clinet: HI\n +Server: HELLO [Major].[Minor] ([Major].[Minor].[Patch]) [YYYY]-[MM]-[DD].[LTSCCode].[GitHash]\n +``` + +### PING +```shell +Clinet: PING [Local Timestamp]\n +Server: PONG [Remote Timestamp]\n +``` + +### Loss +Please see https://github.com/showwin/speedtest-go/issues/169 + +### Down +```shell +Clinet: DOWNLOAD [Size]\n +Server: DOWNLOAD [Random Data]\n +``` + +### Up +```shell +Clinet: UPLOAD [Size]\n +Clinet: [Random Data] +Server: OK [Size] [Timestamp] +``` + +## References +[1] Reverse Engineering the Speedtest.net Protocol, Gökberk Yaltıraklı https://gist.github.com/sdstrowes/411fca9d900a846a704f68547941eb97 diff --git a/speedtest.go b/speedtest.go index cad3ef9..d257e1a 100644 --- a/speedtest.go +++ b/speedtest.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/showwin/speedtest-go/speedtest/control" "github.com/showwin/speedtest-go/speedtest/transport" "gopkg.in/alecthomas/kingpin.v2" "io" @@ -37,7 +38,8 @@ var ( userAgent = kingpin.Flag("ua", "Set the user-agent header for the speedtest.").String() noDownload = kingpin.Flag("no-download", "Disable download test.").Bool() noUpload = kingpin.Flag("no-upload", "Disable upload test.").Bool() - pingMode = kingpin.Flag("ping-mode", "Select a method for Ping (support icmp/tcp/http).").Default("http").String() + pingMode = kingpin.Flag("ping-mode", "Select a method for Ping (options: icmp/tcp/http).").Default("http").String() + protocol = kingpin.Flag("protocol", "Select a protocol (default: http) for speedtest (options: tcp/http).").Default("http").Short('p').String() unit = kingpin.Flag("unit", "Set human-readable and auto-scaled rate units for output (options: decimal-bits/decimal-bytes/binary-bits/binary-bytes).").Short('u').String() debug = kingpin.Flag("debug", "Enable debug mode.").Short('d').Bool() ) @@ -70,12 +72,13 @@ func main() { Source: *source, DnsBindSource: *dnsBindSource, Debug: *debug, - PingMode: parseProto(*pingMode), // TCP as default + PingMode: control.ParseProto(*pingMode), // TCP as default SavingMode: *savingMode, MaxConnections: *thread, CityFlag: *city, LocationFlag: *location, Keyword: *search, + Protocol: control.ParseProto(*protocol), })) if *showCityList { @@ -171,43 +174,43 @@ func main() { accEcho := newAccompanyEcho(server, time.Millisecond*500) taskManager.RunWithTrigger(!*noDownload, "Download", func(task *Task) { accEcho.Run() - speedtestClient.SetCallbackDownload(func(downRate speedtest.ByteRate) { + callback := func(downRate float64) { lc := accEcho.CurrentLatency() if lc == 0 { - task.Updatef("Download: %s (Latency: --)", downRate) + task.Updatef("Download: %s (Latency: --)", speedtest.ByteRate(downRate)) } else { - task.Updatef("Download: %s (Latency: %dms)", downRate, lc/1000000) + task.Updatef("Download: %s (Latency: %dms)", speedtest.ByteRate(downRate), lc/1000000) } - }) + } if *multi { - task.CheckError(server.MultiDownloadTestContext(context.Background(), servers)) + task.CheckError(server.MultiDownloadTestContext(context.Background(), servers, callback)) } else { - task.CheckError(server.DownloadTest()) + task.CheckError(server.DownloadTest(callback)) } accEcho.Stop() mean, _, std, minL, maxL := speedtest.StandardDeviation(accEcho.Latencies()) - task.Printf("Download: %s (Used: %.2fMB) (Latency: %dms Jitter: %dms Min: %dms Max: %dms)", server.DLSpeed, float64(server.Context.Manager.GetTotalDownload())/1000/1000, mean/1000000, std/1000000, minL/1000000, maxL/1000000) + task.Printf("Download: %s (Used: %.2fMB) (Latency: %dms Jitter: %dms Min: %dms Max: %dms)", server.DLSpeed, float64(server.Received)/1000/1000, mean/1000000, std/1000000, minL/1000000, maxL/1000000) task.Complete() }) taskManager.RunWithTrigger(!*noUpload, "Upload", func(task *Task) { accEcho.Run() - speedtestClient.SetCallbackUpload(func(upRate speedtest.ByteRate) { + callback := func(upRate float64) { lc := accEcho.CurrentLatency() if lc == 0 { - task.Updatef("Upload: %s (Latency: --)", upRate) + task.Updatef("Upload: %s (Latency: --)", speedtest.ByteRate(upRate)) } else { - task.Updatef("Upload: %s (Latency: %dms)", upRate, lc/1000000) + task.Updatef("Upload: %s (Latency: %dms)", speedtest.ByteRate(upRate), lc/1000000) } - }) + } if *multi { - task.CheckError(server.MultiUploadTestContext(context.Background(), servers)) + task.CheckError(server.MultiUploadTestContext(context.Background(), servers, callback)) } else { - task.CheckError(server.UploadTest()) + task.CheckError(server.UploadTest(callback)) } accEcho.Stop() mean, _, std, minL, maxL := speedtest.StandardDeviation(accEcho.Latencies()) - task.Printf("Upload: %s (Used: %.2fMB) (Latency: %dms Jitter: %dms Min: %dms Max: %dms)", server.ULSpeed, float64(server.Context.Manager.GetTotalUpload())/1000/1000, mean/1000000, std/1000000, minL/1000000, maxL/1000000) + task.Printf("Upload: %s (Used: %.2fMB) (Latency: %dms Jitter: %dms Min: %dms Max: %dms)", server.ULSpeed, float64(server.Sent)/1000/1000, mean/1000000, std/1000000, minL/1000000, maxL/1000000) task.Complete() }) @@ -220,7 +223,6 @@ func main() { taskManager.Println(server.PacketLoss.String()) } taskManager.Reset() - speedtestClient.Manager.Reset() } taskManager.Stop() @@ -309,17 +311,6 @@ func parseUnit(str string) speedtest.UnitType { } } -func parseProto(str string) speedtest.Proto { - str = strings.ToLower(str) - if str == "icmp" { - return speedtest.ICMP - } else if str == "tcp" { - return speedtest.TCP - } else { - return speedtest.HTTP - } -} - func AppInfo() { if !*jsonOutput { fmt.Println() diff --git a/speedtest/control/chunk.go b/speedtest/control/chunk.go new file mode 100644 index 0000000..3651a0d --- /dev/null +++ b/speedtest/control/chunk.go @@ -0,0 +1,35 @@ +package control + +import ( + "errors" + "io" + "sync" + "time" +) + +const DefaultReadChunkSize = 1024 // 1 KBytes with higher frequency rate feedback + +var ( + ErrDuplicateCall = errors.New("multiple calls to the same chunk handler are not allowed") +) + +type Chunk interface { + UploadHandler(size int64) Chunk + DownloadHandler(r io.Reader) error + + Rate() float64 + Duration() time.Duration + + Type() Proto + + Len() int64 + + Read(b []byte) (n int, err error) +} + +var BlackHole = sync.Pool{ + New: func() any { + b := make([]byte, 8192) + return &b + }, +} diff --git a/speedtest/control/controller.go b/speedtest/control/controller.go new file mode 100644 index 0000000..401e3a0 --- /dev/null +++ b/speedtest/control/controller.go @@ -0,0 +1,12 @@ +package control + +type Controller interface { + // Get Reference counter volume + Get() int64 + // Add Reference counter increment + Add(delta int64) + // Repeat Pointing to duplicate memory space + Repeat() []byte + // Done Notification processing completed + Done() <-chan struct{} +} diff --git a/speedtest/control/lb.go b/speedtest/control/lb.go new file mode 100644 index 0000000..53fc545 --- /dev/null +++ b/speedtest/control/lb.go @@ -0,0 +1,61 @@ +package control + +import ( + "math" + "sync" +) + +type Task func() error + +type TaskItem struct { + fn func() error + SlothIndex int64 + Currents int64 +} + +// LoadBalancer The implementation of Least-Connections Load Balancer with Failure Drop. +type LoadBalancer struct { + TaskQueue []*TaskItem + sync.Mutex +} + +func NewLoadBalancer() *LoadBalancer { + return &LoadBalancer{} +} + +func (lb *LoadBalancer) Len() int { + return len(lb.TaskQueue) +} + +// Add a new task to the [LoadBalancer] +// @param priority The smaller the value, the higher the priority. +func (lb *LoadBalancer) Add(task Task, priority int64) { + if task == nil { + panic("empty task is not allowed") + } + lb.TaskQueue = append(lb.TaskQueue, &TaskItem{fn: task, SlothIndex: priority, Currents: 0}) +} + +func (lb *LoadBalancer) Dispatch() { + var candidate *TaskItem + lb.Lock() + var minWeighted int64 = math.MaxInt64 + for i := 0; i < lb.Len(); i++ { + weighted := lb.TaskQueue[i].Currents * lb.TaskQueue[i].SlothIndex + if weighted < minWeighted { + minWeighted = weighted + candidate = lb.TaskQueue[i] + } + } + if candidate == nil || candidate.fn == nil { + return + } + candidate.Currents++ + lb.Unlock() + err := candidate.fn() + lb.Lock() + defer lb.Unlock() + if err == nil { + candidate.Currents-- + } +} diff --git a/speedtest/control/lb_test.go b/speedtest/control/lb_test.go new file mode 100644 index 0000000..4fc9299 --- /dev/null +++ b/speedtest/control/lb_test.go @@ -0,0 +1,117 @@ +package control + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestSLB(t *testing.T) { + lb := NewLoadBalancer() + var a int64 = 0 + + lb.Add(func() error { + atomic.AddInt64(&a, 1) + time.Sleep(time.Second * 2) + return errors.New("error") + }, 2) + + go func() { + for { + fmt.Printf("a:%d\n", a) + time.Sleep(time.Second) + } + }() + + wg := sync.WaitGroup{} + + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + for { + lb.Dispatch() + } + }() + } + + wg.Wait() +} + +func TestLB(t *testing.T) { + lb := NewLoadBalancer() + var a int64 = 0 + var b int64 = 0 + var c int64 = 0 + var d int64 = 0 + + lb.Add(func() error { + atomic.AddInt64(&a, 1) + time.Sleep(time.Second * 2) + return nil + }, 2) + + lb.Add(func() error { + atomic.AddInt64(&b, 1) + time.Sleep(time.Second * 2) + return nil + }, 1) + + lb.Add(func() error { + atomic.AddInt64(&c, 1) + time.Sleep(time.Second * 2) + fmt.Println("error") + return errors.New("error") + }, 1) + + lb.Add(func() error { + atomic.AddInt64(&d, 1) + time.Sleep(time.Second * 2) + return nil + }, 5) + + wg := sync.WaitGroup{} + + go func() { + for { + fmt.Printf("a:%d, b:%d, c:%d, d:%d\n", a, b, c, d) + time.Sleep(time.Second) + } + }() + + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + for { + lb.Dispatch() + } + }() + } + + wg.Wait() +} + +func BenchmarkDP(b *testing.B) { + lb := NewLoadBalancer() + lb.Add(func() error { + return nil + }, 1) + lb.Add(func() error { + return nil + }, 1) + lb.Add(func() error { + return nil + }, 1) + lb.Add(func() error { + return nil + }, 1) + lb.Add(func() error { + return nil + }, 1) + + for i := 0; i < b.N; i++ { + lb.Dispatch() + } +} diff --git a/speedtest/control/manager.go b/speedtest/control/manager.go new file mode 100644 index 0000000..13d0d09 --- /dev/null +++ b/speedtest/control/manager.go @@ -0,0 +1,22 @@ +package control + +import ( + "time" +) + +type Manager interface { + SetSamplingPeriod(duration time.Duration) Manager + SetSamplingDuration(duration time.Duration) Manager + + History() *Tracer + + // SetNThread This function name is confusing. + // Deprecated: Replaced by [DataManager.SetMaxConnections]. + SetNThread(n int) Manager + SetMaxConnections(n int) Manager + + GetSamplingPeriod() time.Duration + GetSamplingDuration() time.Duration + + GetMaxConnections() int +} diff --git a/speedtest/control/trace.go b/speedtest/control/trace.go new file mode 100644 index 0000000..5baebae --- /dev/null +++ b/speedtest/control/trace.go @@ -0,0 +1,39 @@ +package control + +const DefaultMaxTraceSize = 10 + +type Trace []Chunk + +type Tracer struct { + ts []*Trace + maxSize int +} + +func NewHistoryTracer(size int) *Tracer { + return &Tracer{ + ts: make([]*Trace, 0, size), + maxSize: size, + } +} + +func (rs *Tracer) Push(value Trace) { + if len(rs.ts) == rs.maxSize { + rs.ts = rs.ts[1:] + } + rs.ts = append(rs.ts, &value) +} + +func (rs *Tracer) Latest() Trace { + if len(rs.ts) > 0 { + return *rs.ts[len(rs.ts)-1] + } + return nil +} + +func (rs *Tracer) All() []*Trace { + return rs.ts +} + +func (rs *Tracer) Clean() { + rs.ts = make([]*Trace, 0, rs.maxSize) +} diff --git a/speedtest/control/type.go b/speedtest/control/type.go new file mode 100644 index 0000000..2bbaea5 --- /dev/null +++ b/speedtest/control/type.go @@ -0,0 +1,33 @@ +package control + +import ( + "strings" +) + +type Proto int32 + +const TypeChunkUndefined = 0 + +// control protocol and test type +const ( + TypeDownload Proto = 1 << iota + TypeUpload + TypeHTTP + TypeTCP + TypeICMP +) + +func ParseProto(str string) Proto { + str = strings.ToLower(str) + if str == "icmp" { + return TypeICMP + } else if str == "tcp" { + return TypeTCP + } else { + return TypeHTTP + } +} + +func (p Proto) Assert(u32 Proto) bool { + return p&u32 == u32 +} diff --git a/speedtest/data_manager.go b/speedtest/data_manager.go index 7ae2a48..4c34b7b 100644 --- a/speedtest/data_manager.go +++ b/speedtest/data_manager.go @@ -1,610 +1,98 @@ package speedtest import ( - "bytes" "context" "errors" - "github.com/showwin/speedtest-go/speedtest/internal" - "io" - "math" + "github.com/showwin/speedtest-go/speedtest/control" "runtime" - "sync" - "sync/atomic" "time" ) -type Manager interface { - SetRateCaptureFrequency(duration time.Duration) Manager - SetCaptureTime(duration time.Duration) Manager - - NewChunk() Chunk - - GetTotalDownload() int64 - GetTotalUpload() int64 - AddTotalDownload(value int64) - AddTotalUpload(value int64) - - GetAvgDownloadRate() float64 - GetAvgUploadRate() float64 - - GetEWMADownloadRate() float64 - GetEWMAUploadRate() float64 - - SetCallbackDownload(callback func(downRate ByteRate)) - SetCallbackUpload(callback func(upRate ByteRate)) - - RegisterDownloadHandler(fn func()) *TestDirection - RegisterUploadHandler(fn func()) *TestDirection - - // Wait for the upload or download task to end to avoid errors caused by core occupation - Wait() - Reset() - Snapshots() *Snapshots - - SetNThread(n int) Manager -} - -type Chunk interface { - UploadHandler(size int64) Chunk - DownloadHandler(r io.Reader) error - - GetRate() float64 - GetDuration() time.Duration - GetParent() Manager - - Read(b []byte) (n int, err error) -} - -const readChunkSize = 1024 // 1 KBytes with higher frequency rate feedback - -type DataType int32 - -const ( - typeEmptyChunk = iota - typeDownload - typeUpload -) - var ( ErrorUninitializedManager = errors.New("uninitialized manager") ) -type funcGroup struct { - fns []func() -} - -func (f *funcGroup) Add(fn func()) { - f.fns = append(f.fns, fn) -} - type DataManager struct { - SnapshotStore *Snapshots - Snapshot *Snapshot - sync.Mutex + // protocol indicates the transport that the manager should use. + // Optional: + // [control.TypeTCP] + // [control.TypeHTTP] + protocol control.Proto - repeatByte *[]byte + // estTimeout refers to the timeout threshold when establishing a connection. + // By default, we consider the connection timed out when the handshake takes + // more than 4 seconds. + estTimeout time.Duration - captureTime time.Duration - rateCaptureFrequency time.Duration - nThread int + // samplingPeriod indicates the sampling period of the sampler. + samplingPeriod time.Duration - running bool - runningRW sync.RWMutex + // samplingDuration indicates the maximum sampling duration of the sampler. + samplingDuration time.Duration - download *TestDirection - upload *TestDirection + // maxConnections refers to the maximum number of connections, the default + // is the number of logical cores of the device. + // It is recommended to set maxConnections = 8. + maxConnections int + Tracer *control.Tracer } -type TestDirection struct { - TestType int // test type - manager *DataManager // manager - totalDataVolume int64 // total send/receive data volume - RateSequence []int64 // rate history sequence - welford *internal.Welford // std/EWMA/mean - captureCallback func(realTimeRate ByteRate) // user callback - closeFunc func() // close func - *funcGroup // actually exec function -} - -func (dm *DataManager) NewDataDirection(testType int) *TestDirection { - return &TestDirection{ - TestType: testType, - manager: dm, - funcGroup: &funcGroup{}, +func NewDataManager(protocol control.Proto) *DataManager { + return &DataManager{ + maxConnections: runtime.NumCPU(), + estTimeout: time.Second * 4, + samplingPeriod: time.Second * 15, + samplingDuration: time.Millisecond * 50, + Tracer: control.NewHistoryTracer(control.DefaultMaxTraceSize), + protocol: protocol, } } -func NewDataManager() *DataManager { - r := bytes.Repeat([]byte{0xAA}, readChunkSize) // uniformly distributed sequence of bits - ret := &DataManager{ - nThread: runtime.NumCPU(), - captureTime: time.Second * 15, - rateCaptureFrequency: time.Millisecond * 50, - Snapshot: &Snapshot{}, - repeatByte: &r, - } - ret.download = ret.NewDataDirection(typeDownload) - ret.upload = ret.NewDataDirection(typeUpload) - ret.SnapshotStore = newHistorySnapshots(maxSnapshotSize) - return ret +// NewDirection +// @param ctx indicates the deadline of the sampler. +// timeout should not be greater than [DataManager.samplingDuration]. +func (dm *DataManager) NewDirection(ctx context.Context, testDirection control.Proto) *TestDirection { + direction := NewDataDirection(dm, dm.protocol|testDirection) + dm.Tracer.Push(direction.Trace()) + direction.ctx, direction.testCancel = context.WithCancel(ctx) + return direction } -func (dm *DataManager) SetCallbackDownload(callback func(downRate ByteRate)) { - if dm.download != nil { - dm.download.captureCallback = callback - } +func (dm *DataManager) GetMaxConnections() int { + return dm.maxConnections } -func (dm *DataManager) SetCallbackUpload(callback func(upRate ByteRate)) { - if dm.upload != nil { - dm.upload.captureCallback = callback - } -} - -func (dm *DataManager) Wait() { - oldDownTotal := dm.GetTotalDownload() - oldUpTotal := dm.GetTotalUpload() - for { - time.Sleep(dm.rateCaptureFrequency) - newDownTotal := dm.GetTotalDownload() - newUpTotal := dm.GetTotalUpload() - deltaDown := newDownTotal - oldDownTotal - deltaUp := newUpTotal - oldUpTotal - oldDownTotal = newDownTotal - oldUpTotal = newUpTotal - if deltaDown == 0 && deltaUp == 0 { - return - } - } -} - -func (dm *DataManager) RegisterUploadHandler(fn func()) *TestDirection { - if len(dm.upload.fns) < dm.nThread { - dm.upload.Add(fn) - } - return dm.upload -} - -func (dm *DataManager) RegisterDownloadHandler(fn func()) *TestDirection { - if len(dm.download.fns) < dm.nThread { - dm.download.Add(fn) - } - return dm.download -} - -func (td *TestDirection) GetTotalDataVolume() int64 { - return atomic.LoadInt64(&td.totalDataVolume) -} - -func (td *TestDirection) AddTotalDataVolume(delta int64) int64 { - return atomic.AddInt64(&td.totalDataVolume, delta) -} - -func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerIndex int) { - if len(td.fns) == 0 { - panic("empty task stack") - } - if mainRequestHandlerIndex > len(td.fns)-1 { - mainRequestHandlerIndex = 0 - } - mainLoadFactor := 0.1 - // When the number of processor cores is equivalent to the processing program, - // the processing efficiency reaches the highest level (VT is not considered). - mainN := int(mainLoadFactor * float64(len(td.fns))) - if mainN == 0 { - mainN = 1 - } - if len(td.fns) == 1 { - mainN = td.manager.nThread - } - auxN := td.manager.nThread - mainN - dbg.Printf("Available fns: %d\n", len(td.fns)) - dbg.Printf("mainN: %d\n", mainN) - dbg.Printf("auxN: %d\n", auxN) - wg := sync.WaitGroup{} - td.manager.running = true - stopCapture := td.rateCapture() - - // refresh once function - once := sync.Once{} - td.closeFunc = func() { - once.Do(func() { - stopCapture <- true - close(stopCapture) - td.manager.runningRW.Lock() - td.manager.running = false - td.manager.runningRW.Unlock() - cancel() - dbg.Println("FuncGroup: Stop") - }) - } - - time.AfterFunc(td.manager.captureTime, td.closeFunc) - for i := 0; i < mainN; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for { - td.manager.runningRW.RLock() - running := td.manager.running - td.manager.runningRW.RUnlock() - if !running { - return - } - td.fns[mainRequestHandlerIndex]() - } - }() - } - for j := 0; j < auxN; { - for i := range td.fns { - if j == auxN { - break - } - if i == mainRequestHandlerIndex { - continue - } - wg.Add(1) - t := i - go func() { - defer wg.Done() - for { - td.manager.runningRW.RLock() - running := td.manager.running - td.manager.runningRW.RUnlock() - if !running { - return - } - td.fns[t]() - } - }() - j++ - } - } - wg.Wait() -} - -func (td *TestDirection) rateCapture() chan bool { - ticker := time.NewTicker(td.manager.rateCaptureFrequency) - var prevTotalDataVolume int64 = 0 - stopCapture := make(chan bool) - td.welford = internal.NewWelford(5*time.Second, td.manager.rateCaptureFrequency) - sTime := time.Now() - go func(t *time.Ticker) { - defer t.Stop() - for { - select { - case <-t.C: - newTotalDataVolume := td.GetTotalDataVolume() - deltaDataVolume := newTotalDataVolume - prevTotalDataVolume - prevTotalDataVolume = newTotalDataVolume - if deltaDataVolume != 0 { - td.RateSequence = append(td.RateSequence, deltaDataVolume) - } - // anyway we update the measuring instrument - globalAvg := (float64(td.GetTotalDataVolume())) / float64(time.Since(sTime).Milliseconds()) * 1000 - if td.welford.Update(globalAvg, float64(deltaDataVolume)) { - go td.closeFunc() - } - // reports the current rate at the given rate - if td.captureCallback != nil { - td.captureCallback(ByteRate(td.welford.EWMA())) - } - case stop := <-stopCapture: - if stop { - return - } - } - } - }(ticker) - return stopCapture -} - -func (dm *DataManager) NewChunk() Chunk { - var dc DataChunk - dc.manager = dm - dm.Lock() - *dm.Snapshot = append(*dm.Snapshot, &dc) - dm.Unlock() - return &dc -} - -func (dm *DataManager) AddTotalDownload(value int64) { - dm.download.AddTotalDataVolume(value) -} - -func (dm *DataManager) AddTotalUpload(value int64) { - dm.upload.AddTotalDataVolume(value) +func (dm *DataManager) SetSamplingPeriod(duration time.Duration) control.Manager { + dm.samplingDuration = duration + return dm } -func (dm *DataManager) GetTotalDownload() int64 { - return dm.download.GetTotalDataVolume() +func (dm *DataManager) SetSamplingDuration(duration time.Duration) control.Manager { + dm.samplingPeriod = duration + return dm } -func (dm *DataManager) GetTotalUpload() int64 { - return dm.upload.GetTotalDataVolume() +func (dm *DataManager) GetSamplingPeriod() time.Duration { + return dm.samplingDuration } -func (dm *DataManager) SetRateCaptureFrequency(duration time.Duration) Manager { - dm.rateCaptureFrequency = duration - return dm +func (dm *DataManager) GetSamplingDuration() time.Duration { + return dm.samplingPeriod } -func (dm *DataManager) SetCaptureTime(duration time.Duration) Manager { - dm.captureTime = duration - return dm +func (dm *DataManager) SetNThread(n int) control.Manager { + return dm.SetMaxConnections(n) } -func (dm *DataManager) SetNThread(n int) Manager { +func (dm *DataManager) SetMaxConnections(n int) control.Manager { if n < 1 { - dm.nThread = runtime.NumCPU() + dm.maxConnections = runtime.NumCPU() } else { - dm.nThread = n + dm.maxConnections = n } return dm } -func (dm *DataManager) Snapshots() *Snapshots { - return dm.SnapshotStore -} - -func (dm *DataManager) Reset() { - dm.SnapshotStore.push(dm.Snapshot) - dm.Snapshot = &Snapshot{} - dm.download = dm.NewDataDirection(typeDownload) - dm.upload = dm.NewDataDirection(typeUpload) -} - -func (dm *DataManager) GetAvgDownloadRate() float64 { - unit := float64(dm.captureTime / time.Millisecond) - return float64(dm.download.GetTotalDataVolume()*8/1000) / unit -} - -func (dm *DataManager) GetEWMADownloadRate() float64 { - if dm.download.welford != nil { - return dm.download.welford.EWMA() - } - return 0 -} - -func (dm *DataManager) GetAvgUploadRate() float64 { - unit := float64(dm.captureTime / time.Millisecond) - return float64(dm.upload.GetTotalDataVolume()*8/1000) / unit -} - -func (dm *DataManager) GetEWMAUploadRate() float64 { - if dm.upload.welford != nil { - return dm.upload.welford.EWMA() - } - return 0 -} - -type DataChunk struct { - manager *DataManager - dateType DataType - startTime time.Time - endTime time.Time - err error - ContentLength int64 - remainOrDiscardSize int64 -} - -var blackHolePool = sync.Pool{ - New: func() any { - b := make([]byte, 8192) - return &b - }, -} - -func (dc *DataChunk) GetDuration() time.Duration { - return dc.endTime.Sub(dc.startTime) -} - -func (dc *DataChunk) GetRate() float64 { - if dc.dateType == typeDownload { - return float64(dc.remainOrDiscardSize) / dc.GetDuration().Seconds() - } else if dc.dateType == typeUpload { - return float64(dc.ContentLength-dc.remainOrDiscardSize) * 8 / 1000 / 1000 / dc.GetDuration().Seconds() - } - return 0 -} - -// DownloadHandler No value will be returned here, because the error will interrupt the test. -// The error chunk is generally caused by the remote server actively closing the connection. -func (dc *DataChunk) DownloadHandler(r io.Reader) error { - if dc.dateType != typeEmptyChunk { - dc.err = errors.New("multiple calls to the same chunk handler are not allowed") - return dc.err - } - dc.dateType = typeDownload - dc.startTime = time.Now() - defer func() { - dc.endTime = time.Now() - }() - bufP := blackHolePool.Get().(*[]byte) - defer blackHolePool.Put(bufP) - readSize := 0 - for { - dc.manager.runningRW.RLock() - running := dc.manager.running - dc.manager.runningRW.RUnlock() - if !running { - return nil - } - readSize, dc.err = r.Read(*bufP) - rs := int64(readSize) - - dc.remainOrDiscardSize += rs - dc.manager.download.AddTotalDataVolume(rs) - if dc.err != nil { - if dc.err == io.EOF { - return nil - } - return dc.err - } - } -} - -func (dc *DataChunk) UploadHandler(size int64) Chunk { - if dc.dateType != typeEmptyChunk { - dc.err = errors.New("multiple calls to the same chunk handler are not allowed") - } - - if size <= 0 { - panic("the size of repeated bytes should be > 0") - } - - dc.ContentLength = size - dc.remainOrDiscardSize = size - dc.dateType = typeUpload - dc.startTime = time.Now() - return dc -} - -func (dc *DataChunk) GetParent() Manager { - return dc.manager -} - -// WriteTo Used to hook all traffic. -func (dc *DataChunk) WriteTo(w io.Writer) (written int64, err error) { - nw := 0 - nr := readChunkSize - for { - dc.manager.runningRW.RLock() - running := dc.manager.running - dc.manager.runningRW.RUnlock() - if !running || dc.remainOrDiscardSize <= 0 { - dc.endTime = time.Now() - return written, io.EOF - } - if dc.remainOrDiscardSize < readChunkSize { - nr = int(dc.remainOrDiscardSize) - nw, err = w.Write((*dc.manager.repeatByte)[:nr]) - } else { - nw, err = w.Write(*dc.manager.repeatByte) - } - if err != nil { - return - } - n64 := int64(nw) - written += n64 - dc.remainOrDiscardSize -= n64 - dc.manager.AddTotalUpload(n64) - if nr != nw { - return written, io.ErrShortWrite - } - } -} - -// Please don't call it, only used to wrapped by [io.NopCloser] -// We use [DataChunk.WriteTo] that implements [io.WriterTo] to bypass this function. -func (dc *DataChunk) Read(b []byte) (n int, err error) { - panic("unexpected call: only used to implement the io.Reader") -} - -// calcMAFilter Median-Averaging Filter -func _(list []int64) float64 { - if len(list) == 0 { - return 0 - } - var sum int64 = 0 - n := len(list) - if n == 0 { - return 0 - } - length := len(list) - for i := 0; i < length-1; i++ { - for j := i + 1; j < length; j++ { - if list[i] > list[j] { - list[i], list[j] = list[j], list[i] - } - } - } - for i := 1; i < n-1; i++ { - sum += list[i] - } - return float64(sum) / float64(n-2) -} - -func pautaFilter(vector []int64) []int64 { - dbg.Println("Per capture unit") - dbg.Printf("Raw Sequence len: %d\n", len(vector)) - dbg.Printf("Raw Sequence: %v\n", vector) - if len(vector) == 0 { - return vector - } - mean, _, std, _, _ := sampleVariance(vector) - var retVec []int64 - for _, value := range vector { - if math.Abs(float64(value-mean)) < float64(3*std) { - retVec = append(retVec, value) - } - } - dbg.Printf("Raw average: %dByte\n", mean) - dbg.Printf("Pauta Sequence len: %d\n", len(retVec)) - dbg.Printf("Pauta Sequence: %v\n", retVec) - return retVec -} - -// sampleVariance sample Variance -func sampleVariance(vector []int64) (mean, variance, stdDev, min, max int64) { - if len(vector) == 0 { - return 0, 0, 0, 0, 0 - } - var sumNum, accumulate int64 - min = math.MaxInt64 - max = math.MinInt64 - for _, value := range vector { - sumNum += value - if min > value { - min = value - } - if max < value { - max = value - } - } - mean = sumNum / int64(len(vector)) - for _, value := range vector { - accumulate += (value - mean) * (value - mean) - } - variance = accumulate / int64(len(vector)-1) // Bessel's correction - stdDev = int64(math.Sqrt(float64(variance))) - return -} - -const maxSnapshotSize = 10 - -type Snapshot []*DataChunk - -type Snapshots struct { - sp []*Snapshot - maxSize int -} - -func newHistorySnapshots(size int) *Snapshots { - return &Snapshots{ - sp: make([]*Snapshot, 0, size), - maxSize: size, - } -} - -func (rs *Snapshots) push(value *Snapshot) { - if len(rs.sp) == rs.maxSize { - rs.sp = rs.sp[1:] - } - rs.sp = append(rs.sp, value) -} - -func (rs *Snapshots) Latest() *Snapshot { - if len(rs.sp) > 0 { - return rs.sp[len(rs.sp)-1] - } - return nil -} - -func (rs *Snapshots) All() []*Snapshot { - return rs.sp -} - -func (rs *Snapshots) Clean() { - rs.sp = make([]*Snapshot, 0, rs.maxSize) +func (dm *DataManager) History() *control.Tracer { + return dm.Tracer } diff --git a/speedtest/data_manager_test.go b/speedtest/data_manager_test.go index dbd34d3..7a3118c 100644 --- a/speedtest/data_manager_test.go +++ b/speedtest/data_manager_test.go @@ -43,7 +43,7 @@ func TestDataManager_AddTotalDownload(t *testing.T) { func TestDataManager_GetAvgDownloadRate(t *testing.T) { dm := NewDataManager() dm.download.totalDataVolume = 3000000 - dm.captureTime = time.Second * 10 + dm.samplingPeriod = time.Second * 10 result := dm.GetAvgDownloadRate() if result != 2.4 { @@ -81,7 +81,7 @@ func TestDynamicRate(t *testing.T) { //t.Error(err) } - server.Context.Manager.Wait() + // server.Context.Manager.Wait() err = server.UploadTest() if err != nil { diff --git a/speedtest/direction.go b/speedtest/direction.go new file mode 100644 index 0000000..4816300 --- /dev/null +++ b/speedtest/direction.go @@ -0,0 +1,184 @@ +package speedtest + +import ( + "bytes" + "context" + "github.com/showwin/speedtest-go/speedtest/control" + "github.com/showwin/speedtest-go/speedtest/http" + "github.com/showwin/speedtest-go/speedtest/transport" + + "github.com/showwin/speedtest-go/speedtest/internal" + "sync" + "sync/atomic" + "time" +) + +type TestDirection struct { + ctx context.Context + testCancel context.CancelFunc + proto control.Proto // see [Proto] + RateSequence []int64 // rate history sequence + manager control.Manager // manager + totalDataVolume int64 // total sent/received data volume + welford *internal.Welford // std/EWMA/mean + samplingCallback func(realTimeRate float64) // sampling callback + trace control.Trace // detailed chunk data tracing + loadBalancer *control.LoadBalancer + repeatBytes []byte + Duration time.Duration + sync.Mutex +} + +func NewDataDirection(m control.Manager, proto control.Proto) *TestDirection { + r := bytes.Repeat([]byte{0xAA}, control.DefaultReadChunkSize) // uniformly distributed sequence of bits + return &TestDirection{ + proto: proto, + manager: m, + repeatBytes: r, + loadBalancer: control.NewLoadBalancer(), + ctx: context.TODO(), + } +} + +func (td *TestDirection) NewChunk() control.Chunk { + var chunk control.Chunk + if td.proto.Assert(control.TypeTCP) { + chunk = transport.NewChunk(td) + } else { + chunk = http.NewChunk(td) // using HTTP as default protocol + } + td.Lock() + defer td.Unlock() + td.trace = append(td.trace, chunk) + return chunk +} + +// Trace returns tracing data +func (td *TestDirection) Trace() control.Trace { + td.Lock() + defer td.Unlock() + return td.trace +} + +// Avg Get the overall average speed in the test direction. +func (td *TestDirection) Avg() float64 { + unit := float64(td.manager.GetSamplingDuration() / time.Millisecond) + return float64(td.GetTotalDataVolume()*8/1000) / unit +} + +// EWMA Get real-time EWMA and average weighted values. +func (td *TestDirection) EWMA() float64 { + if td.welford != nil { + return td.welford.EWMA() + } + internal.DBG().Println("warning: empty td.welford") + return 0 +} + +// GetTotalDataVolume Read the data volume in the current direction. +func (td *TestDirection) GetTotalDataVolume() int64 { + return atomic.LoadInt64(&td.totalDataVolume) +} + +// AddTotalDataVolume Add the data volume in the current direction. +func (td *TestDirection) AddTotalDataVolume(delta int64) int64 { + return atomic.AddInt64(&td.totalDataVolume, delta) +} + +func (td *TestDirection) Add(delta int64) { + td.AddTotalDataVolume(delta) +} + +func (td *TestDirection) Get() int64 { + return td.GetTotalDataVolume() +} + +func (td *TestDirection) Done() <-chan struct{} { + return td.ctx.Done() +} + +func (td *TestDirection) Repeat() []byte { + return td.repeatBytes +} + +// SetSamplingCallback Sets an optional periodic sampling callback function +// for the TestDirection. +func (td *TestDirection) SetSamplingCallback(callback func(rate float64)) *TestDirection { + td.samplingCallback = callback + return td +} + +// RegisterHandler Add a test function for TestDirection that sequences a +// size that depends on the maximum number of connections. +func (td *TestDirection) RegisterHandler(task control.Task, priority int64) *TestDirection { + if td.loadBalancer.Len() < td.manager.GetMaxConnections() { + td.loadBalancer.Add(task, priority) + } + return td +} + +// Start The Load balancer for TestDirection. +func (td *TestDirection) Start() { + if td.loadBalancer == nil { + panic("loadBalancer is nil") + } + if td.loadBalancer.Len() == 0 { + panic("empty task stack") + } + // sampling + td.rateSampling() + wg := sync.WaitGroup{} + start := time.Now() + for i := 0; i < td.manager.GetMaxConnections(); i++ { + wg.Add(1) + go func() { + for { + select { + case <-td.Done(): + wg.Done() + return + default: + td.loadBalancer.Dispatch() + } + } + }() + } + wg.Wait() + td.Duration = time.Since(start) +} + +func (td *TestDirection) rateSampling() { + ticker := time.NewTicker(td.manager.GetSamplingPeriod()) + var prevTotalDataVolume int64 = 0 + td.welford = internal.NewWelford(5*time.Second, td.manager.GetSamplingPeriod()) + sTime := time.Now() + go func(t *time.Ticker) { + defer t.Stop() + for { + select { + case <-td.Done(): + internal.DBG().Println("RateSampler: ctx.Done from another goroutine") + return + case <-t.C: + newTotalDataVolume := td.GetTotalDataVolume() + deltaDataVolume := newTotalDataVolume - prevTotalDataVolume + prevTotalDataVolume = newTotalDataVolume + if deltaDataVolume != 0 { + td.RateSequence = append(td.RateSequence, deltaDataVolume) + } + // anyway we update the measuring instrument + globalAvg := (float64(td.GetTotalDataVolume())) / float64(time.Since(sTime).Milliseconds()) * 1000 + if td.welford.Update(globalAvg, float64(deltaDataVolume)) { + // direction canceled early. + td.testCancel() + internal.DBG().Println("RateSampler: terminate due to early stop") + return + } + // reports the current rate by callback. + if td.samplingCallback != nil { + td.samplingCallback(td.welford.EWMA()) + } + } + } + }(ticker) +} diff --git a/speedtest/http/chunk.go b/speedtest/http/chunk.go new file mode 100644 index 0000000..c2c5b19 --- /dev/null +++ b/speedtest/http/chunk.go @@ -0,0 +1,115 @@ +package http + +import ( + "github.com/showwin/speedtest-go/speedtest/control" + "io" + "time" +) + +// DataChunk The Speedtest's I/O implementation of Hypertext Transfer Protocol. +type DataChunk struct { + dateType control.Proto + startTime time.Time + endTime time.Time + err error + ContentLength int64 + remainOrDiscardSize int64 + control.Controller +} + +func NewChunk(controller control.Controller) control.Chunk { + return &DataChunk{Controller: controller} +} + +func (dc *DataChunk) Len() int64 { + return dc.ContentLength +} + +// Duration Get chunk duration (start -> end) +func (dc *DataChunk) Duration() time.Duration { + return dc.endTime.Sub(dc.startTime) +} + +// Rate Get chunk avg rate +func (dc *DataChunk) Rate() float64 { + if dc.dateType.Assert(control.TypeDownload) { + return float64(dc.remainOrDiscardSize) / dc.Duration().Seconds() + } else if dc.dateType.Assert(control.TypeUpload) { + return float64(dc.ContentLength-dc.remainOrDiscardSize) * 8 / 1000 / 1000 / dc.Duration().Seconds() + } + return 0 +} + +func (dc *DataChunk) Type() control.Proto { + return dc.dateType +} + +// DownloadHandler No value will be returned here, because the error will interrupt the test. +// The error chunk is generally caused by the remote server actively closing the connection. +// @Related HTTP Download +func (dc *DataChunk) DownloadHandler(r io.Reader) error { + if dc.dateType != control.TypeChunkUndefined { + dc.err = control.ErrDuplicateCall + return dc.err + } + dc.dateType = control.TypeDownload | control.TypeHTTP + dc.startTime = time.Now() + defer func() { + dc.endTime = time.Now() + }() + bufP := control.BlackHole.Get().(*[]byte) + defer control.BlackHole.Put(bufP) + readSize := 0 + for { + select { + case <-dc.Done(): + return nil + default: + readSize, dc.err = r.Read(*bufP) + rs := int64(readSize) + + dc.remainOrDiscardSize += rs + dc.Add(rs) + if dc.err != nil { + if dc.err == io.EOF { + return nil + } + return dc.err + } + } + } +} + +// UploadHandler Create an upload handler +// @Related HTTP UPLOAD +func (dc *DataChunk) UploadHandler(size int64) control.Chunk { + if dc.dateType != control.TypeChunkUndefined { + dc.err = control.ErrDuplicateCall + } + + if size <= 0 { + panic("the size of repeated bytes should be > 0") + } + + dc.ContentLength = size + dc.remainOrDiscardSize = size + dc.dateType = control.TypeUpload | control.TypeHTTP + dc.startTime = time.Now() + return dc +} + +func (dc *DataChunk) Read(b []byte) (n int, err error) { + if dc.remainOrDiscardSize < control.DefaultReadChunkSize { + if dc.remainOrDiscardSize <= 0 { + dc.endTime = time.Now() + return n, io.EOF + } + n = copy(b, dc.Repeat()[:dc.remainOrDiscardSize]) + } else { + n = copy(b, dc.Repeat()) + } + n64 := int64(n) + dc.remainOrDiscardSize -= n64 + dc.Add(n64) + return +} diff --git a/speedtest/debug.go b/speedtest/internal/debug.go similarity index 88% rename from speedtest/debug.go rename to speedtest/internal/debug.go index 4903c07..4f87473 100644 --- a/speedtest/debug.go +++ b/speedtest/internal/debug.go @@ -1,4 +1,4 @@ -package speedtest +package internal import ( "log" @@ -30,4 +30,8 @@ func (d *Debug) Printf(format string, v ...any) { } } +func DBG() *Debug { + return dbg +} + var dbg = NewDebug() diff --git a/speedtest/internal/utils.go b/speedtest/internal/utils.go new file mode 100644 index 0000000..ab28ca7 --- /dev/null +++ b/speedtest/internal/utils.go @@ -0,0 +1,88 @@ +package internal + +import ( + "crypto/rand" + "fmt" + "math" +) + +func GenerateUUID() (string, error) { + randUUID := make([]byte, 16) + _, err := rand.Read(randUUID) + if err != nil { + return "", err + } + randUUID[8] = randUUID[8]&^0xc0 | 0x80 + randUUID[6] = randUUID[6]&^0xf0 | 0x40 + return fmt.Sprintf("%x-%x-%x-%x-%x", randUUID[0:4], randUUID[4:6], randUUID[6:8], randUUID[8:10], randUUID[10:]), nil +} + +// calcMAFilter Median-Averaging Filter +func _(list []int64) float64 { + if len(list) == 0 { + return 0 + } + var sum int64 = 0 + n := len(list) + if n == 0 { + return 0 + } + length := len(list) + for i := 0; i < length-1; i++ { + for j := i + 1; j < length; j++ { + if list[i] > list[j] { + list[i], list[j] = list[j], list[i] + } + } + } + for i := 1; i < n-1; i++ { + sum += list[i] + } + return float64(sum) / float64(n-2) +} + +func pautaFilter(vector []int64) []int64 { + DBG().Println("Per capture unit") + DBG().Printf("Raw Sequence len: %d\n", len(vector)) + DBG().Printf("Raw Sequence: %v\n", vector) + if len(vector) == 0 { + return vector + } + mean, _, std, _, _ := sampleVariance(vector) + var retVec []int64 + for _, value := range vector { + if math.Abs(float64(value-mean)) < float64(3*std) { + retVec = append(retVec, value) + } + } + DBG().Printf("Raw average: %dByte\n", mean) + DBG().Printf("Pauta Sequence len: %d\n", len(retVec)) + DBG().Printf("Pauta Sequence: %v\n", retVec) + return retVec +} + +// sampleVariance sample Variance +func sampleVariance(vector []int64) (mean, variance, stdDev, min, max int64) { + if len(vector) == 0 { + return 0, 0, 0, 0, 0 + } + var sumNum, accumulate int64 + min = math.MaxInt64 + max = math.MinInt64 + for _, value := range vector { + sumNum += value + if min > value { + min = value + } + if max < value { + max = value + } + } + mean = sumNum / int64(len(vector)) + for _, value := range vector { + accumulate += (value - mean) * (value - mean) + } + variance = accumulate / int64(len(vector)-1) // Bessel's correction + stdDev = int64(math.Sqrt(float64(variance))) + return +} diff --git a/speedtest/request.go b/speedtest/request.go index f188d56..2eee565 100644 --- a/speedtest/request.go +++ b/speedtest/request.go @@ -4,20 +4,21 @@ import ( "context" "errors" "fmt" + "github.com/showwin/speedtest-go/speedtest/control" + "github.com/showwin/speedtest-go/speedtest/internal" "github.com/showwin/speedtest-go/speedtest/transport" "io" "math" + "net" "net/http" "net/url" "path" "strings" - "sync/atomic" "time" ) type ( - downloadFunc func(context.Context, *Server, int) error - uploadFunc func(context.Context, *Server, int) error + testFunc func(context.Context, *TestDirection, *Server, int) error ) var ( @@ -29,139 +30,103 @@ var ( ErrConnectTimeout = errors.New("server connect timeout") ) -func (s *Server) MultiDownloadTestContext(ctx context.Context, servers Servers) error { - ss := servers.Available() - if ss.Len() == 0 { - return errors.New("not found available servers") +func (s *Server) pullTest( + ctx context.Context, + directionType control.Proto, + testFn testFunc, + callback func(rate float64), + servers Servers, +) (*TestDirection, error) { + var availableServers *Servers + if servers == nil { + availableServers = &Servers{s} + } else { + availableServers = servers.Available() } - mainIDIndex := 0 - var td *TestDirection - _context, cancel := context.WithCancel(ctx) - defer cancel() - var errorTimes int64 = 0 - var requestTimes int64 = 0 - for i, server := range *ss { + + if availableServers.Len() == 0 { + return nil, errors.New("not found available servers") + } + direction := s.Context.NewDirection(ctx, directionType).SetSamplingCallback(callback) + for _, server := range *availableServers { + var priority int64 = 2 if server.ID == s.ID { - mainIDIndex = i + priority = 1 } + internal.DBG().Printf("[%d] Register Handler: %s\n", directionType, server.URL) sp := server - dbg.Printf("Register Download Handler: %s\n", sp.URL) - td = server.Context.RegisterDownloadHandler(func() { - atomic.AddInt64(&requestTimes, 1) - if err := downloadRequest(_context, sp, 3); err != nil { - atomic.AddInt64(&errorTimes, 1) - } - }) - } - if td == nil { - return ErrorUninitializedManager - } - td.Start(cancel, mainIDIndex) // block here - s.DLSpeed = ByteRate(td.manager.GetEWMADownloadRate()) - if s.DLSpeed == 0 && float64(errorTimes)/float64(requestTimes) > 0.1 { - s.DLSpeed = -1 // N/A + direction.RegisterHandler(func() error { + connectContext, cancel := context.WithTimeout(context.Background(), s.Context.estTimeout) + defer cancel() + //fmt.Println(sp.Host) + return testFn(connectContext, direction, sp, 3) + }, priority) + } + direction.Start() + return direction, nil +} + +func (s *Server) MultiDownloadTestContext(ctx context.Context, servers Servers, callback func(rate float64)) error { + direction, err := s.pullTest(ctx, control.TypeDownload, downloadRequest, callback, servers) + if err != nil { + return err } + s.DLSpeed = ByteRate(direction.EWMA()) + s.TestDuration.Download = &direction.Duration + s.testDurationTotalCount() return nil } -func (s *Server) MultiUploadTestContext(ctx context.Context, servers Servers) error { - ss := servers.Available() - if ss.Len() == 0 { - return errors.New("not found available servers") - } - mainIDIndex := 0 - var td *TestDirection - _context, cancel := context.WithCancel(ctx) - defer cancel() - var errorTimes int64 = 0 - var requestTimes int64 = 0 - for i, server := range *ss { - if server.ID == s.ID { - mainIDIndex = i - } - sp := server - dbg.Printf("Register Upload Handler: %s\n", sp.URL) - td = server.Context.RegisterUploadHandler(func() { - atomic.AddInt64(&requestTimes, 1) - if err := uploadRequest(_context, sp, 3); err != nil { - atomic.AddInt64(&errorTimes, 1) - } - }) - } - if td == nil { - return ErrorUninitializedManager - } - td.Start(cancel, mainIDIndex) // block here - s.ULSpeed = ByteRate(td.manager.GetEWMAUploadRate()) - if s.ULSpeed == 0 && float64(errorTimes)/float64(requestTimes) > 0.1 { - s.ULSpeed = -1 // N/A +func (s *Server) MultiUploadTestContext(ctx context.Context, servers Servers, callback func(rate float64)) error { + direction, err := s.pullTest(ctx, control.TypeUpload, uploadRequest, callback, servers) + if err != nil { + return err } + s.ULSpeed = ByteRate(direction.EWMA()) + s.TestDuration.Download = &direction.Duration + s.testDurationTotalCount() return nil } // DownloadTest executes the test to measure download speed -func (s *Server) DownloadTest() error { - return s.downloadTestContext(context.Background(), downloadRequest) +func (s *Server) DownloadTest(callback func(rate float64)) error { + // usually, the connections handled by speedtest server only alive time < 1 minute. + // we set it 30 seconds. + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + return s.DownloadTestContext(ctx, callback) } // DownloadTestContext executes the test to measure download speed, observing the given context. -func (s *Server) DownloadTestContext(ctx context.Context) error { - return s.downloadTestContext(ctx, downloadRequest) -} - -func (s *Server) downloadTestContext(ctx context.Context, downloadRequest downloadFunc) error { - var errorTimes int64 = 0 - var requestTimes int64 = 0 - start := time.Now() - _context, cancel := context.WithCancel(ctx) - s.Context.RegisterDownloadHandler(func() { - atomic.AddInt64(&requestTimes, 1) - if err := downloadRequest(_context, s, 3); err != nil { - atomic.AddInt64(&errorTimes, 1) - } - }).Start(cancel, 0) - duration := time.Since(start) - s.DLSpeed = ByteRate(s.Context.GetEWMADownloadRate()) - if s.DLSpeed == 0 && float64(errorTimes)/float64(requestTimes) > 0.1 { - s.DLSpeed = -1 // N/A +func (s *Server) DownloadTestContext(ctx context.Context, callback func(rate float64)) error { + direction, err := s.pullTest(ctx, control.TypeDownload, downloadRequest, callback, nil) + if err != nil { + return err } - s.TestDuration.Download = &duration + s.DLSpeed = ByteRate(direction.EWMA()) + s.TestDuration.Download = &direction.Duration s.testDurationTotalCount() return nil } // UploadTest executes the test to measure upload speed -func (s *Server) UploadTest() error { - return s.uploadTestContext(context.Background(), uploadRequest) +func (s *Server) UploadTest(callback func(rate float64)) error { + return s.UploadTestContext(context.Background(), callback) } // UploadTestContext executes the test to measure upload speed, observing the given context. -func (s *Server) UploadTestContext(ctx context.Context) error { - return s.uploadTestContext(ctx, uploadRequest) -} - -func (s *Server) uploadTestContext(ctx context.Context, uploadRequest uploadFunc) error { - var errorTimes int64 = 0 - var requestTimes int64 = 0 - start := time.Now() - _context, cancel := context.WithCancel(ctx) - s.Context.RegisterUploadHandler(func() { - atomic.AddInt64(&requestTimes, 1) - if err := uploadRequest(_context, s, 4); err != nil { - atomic.AddInt64(&errorTimes, 1) - } - }).Start(cancel, 0) - duration := time.Since(start) - s.ULSpeed = ByteRate(s.Context.GetEWMAUploadRate()) - if s.ULSpeed == 0 && float64(errorTimes)/float64(requestTimes) > 0.1 { - s.ULSpeed = -1 // N/A +func (s *Server) UploadTestContext(ctx context.Context, callback func(rate float64)) error { + direction, err := s.pullTest(ctx, control.TypeUpload, uploadRequest, callback, nil) + if err != nil { + return err } - s.TestDuration.Upload = &duration + s.ULSpeed = ByteRate(direction.EWMA()) + s.TestDuration.Upload = &direction.Duration s.testDurationTotalCount() return nil } -func downloadRequest(ctx context.Context, s *Server, w int) error { +func downloadRequest(ctx context.Context, direction *TestDirection, s *Server, w int) error { size := dlSizes[w] u, err := url.Parse(s.URL) if err != nil { @@ -169,35 +134,94 @@ func downloadRequest(ctx context.Context, s *Server, w int) error { } u.Path = path.Dir(u.Path) xdlURL := u.JoinPath(fmt.Sprintf("random%dx%d.jpg", size, size)).String() - dbg.Printf("XdlURL: %s\n", xdlURL) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, xdlURL, nil) - if err != nil { - return err - } + internal.DBG().Printf("XdlURL: %s\n", xdlURL) - resp, err := s.Context.doer.Do(req) - if err != nil { - return err + chunk := direction.NewChunk() + + if direction.proto.Assert(control.TypeTCP) { + dialer := &net.Dialer{} + client, err1 := transport.NewClient(dialer) + if err1 != nil { + return err1 + } + err = client.Connect(context.TODO(), s.Host) + if err != nil { + return err + } + connReader, err1 := client.RegisterDownload(int64(size)) + if err1 != nil { + return err1 + } + return chunk.DownloadHandler(connReader) + } else { + // set est deadline + // TODO: tmp usage, we must split speedtest config and speedtest result. + estContext, cancel := context.WithTimeout(context.Background(), s.Context.estTimeout) + defer cancel() + req, err := http.NewRequestWithContext(estContext, http.MethodGet, xdlURL, nil) + if err != nil { + return err + } + req.Header.Set("Connection", "Keep-Alive") + resp, err := s.Context.doer.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + return chunk.DownloadHandler(resp.Body) } - defer resp.Body.Close() - return s.Context.NewChunk().DownloadHandler(resp.Body) } -func uploadRequest(ctx context.Context, s *Server, w int) error { +func uploadRequest(ctx context.Context, direction *TestDirection, s *Server, w int) error { size := ulSizes[w] - dc := s.Context.NewChunk().UploadHandler(int64(size*100-51) * 10) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.URL, io.NopCloser(dc)) - if err != nil { + chunk := direction.NewChunk() + + if direction.proto.Assert(control.TypeTCP) { + var chunkScale int64 = 1 + chunkSize := 1000 * 1000 * chunkScale + dialer := &net.Dialer{} + client, err := transport.NewClient(dialer) + if err != nil { + fmt.Println(err.Error()) + return err + } + //fmt.Println(s.Host) + err = client.Connect(context.TODO(), "speedtestd.kpn.com:8080") // TODO: NEED fix + if err != nil { + fmt.Println(err.Error()) + return err + } + remainSize, err := client.RegisterUpload(chunkSize) + if err != nil { + fmt.Println(err.Error()) + return err + } + dc := chunk.UploadHandler(remainSize) + rc := io.NopCloser(dc) + _, err = client.Upload(rc) + return err - } - dbg.Printf("Len=%d, XulURL: %s\n", req.ContentLength, s.URL) - req.Header.Set("Content-Type", "application/octet-stream") - resp, err := s.Context.doer.Do(req) - if err != nil { + } else { + dc := chunk.UploadHandler(int64(size*100-51) * 10) + // set est deadline + // TODO: tmp usage, we must split speedtest config and speedtest result. + estContext, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + req, err := http.NewRequestWithContext(estContext, http.MethodPost, s.URL, dc) + if err != nil { + return err + } + req.ContentLength = dc.Len() + req.Header.Set("Content-Type", "application/octet-stream") + internal.DBG().Printf("Len=%d, XulURL: %s\n", dc.Len(), s.URL) + resp, err := s.Context.doer.Do(req) + if err != nil { + return err + } + _, _ = io.Copy(io.Discard, resp.Body) + defer resp.Body.Close() return err } - defer resp.Body.Close() - return err } // PingTest executes test to measure latency @@ -209,9 +233,9 @@ func (s *Server) PingTest(callback func(latency time.Duration)) error { func (s *Server) PingTestContext(ctx context.Context, callback func(latency time.Duration)) (err error) { start := time.Now() var vectorPingResult []int64 - if s.Context.config.PingMode == TCP { + if s.Context.config.PingMode.Assert(control.TypeTCP) { vectorPingResult, err = s.TCPPing(ctx, 10, time.Millisecond*200, callback) - } else if s.Context.config.PingMode == ICMP { + } else if s.Context.config.PingMode.Assert(control.TypeICMP) { vectorPingResult, err = s.ICMPPing(ctx, time.Second*4, 10, time.Millisecond*200, callback) } else { vectorPingResult, err = s.HTTPPing(ctx, 10, time.Millisecond*200, callback) @@ -219,7 +243,7 @@ func (s *Server) PingTestContext(ctx context.Context, callback func(latency time if err != nil || len(vectorPingResult) == 0 { return err } - dbg.Printf("Before StandardDeviation: %v\n", vectorPingResult) + internal.DBG().Printf("Before StandardDeviation: %v\n", vectorPingResult) mean, _, std, minLatency, maxLatency := StandardDeviation(vectorPingResult) duration := time.Since(start) s.Latency = time.Duration(mean) * time.Nanosecond @@ -237,11 +261,11 @@ func (s *Server) TestAll() error { if err != nil { return err } - err = s.DownloadTest() + err = s.DownloadTest(nil) if err != nil { return err } - return s.UploadTest() + return s.UploadTest(nil) } func (s *Server) TCPPing( @@ -300,7 +324,7 @@ func (s *Server) HTTPPing( } u.Path = path.Dir(u.Path) pingDst := u.JoinPath("latency.txt").String() - dbg.Printf("Echo: %s\n", pingDst) + internal.DBG().Printf("Echo: %s\n", pingDst) failTimes := 0 req, err := http.NewRequestWithContext(ctx, http.MethodGet, pingDst, nil) if err != nil { @@ -327,7 +351,7 @@ func (s *Server) HTTPPing( if i > 0 { latency := endTime.Nanoseconds() latencies = append(latencies, latency) - dbg.Printf("RTT: %d\n", latency) + internal.DBG().Printf("RTT: %d\n", latency) if callback != nil { callback(endTime) } @@ -361,7 +385,7 @@ func (s *Server) ICMPPing( if err != nil || len(u.Host) == 0 { return nil, err } - dbg.Printf("Echo: %s\n", strings.Split(u.Host, ":")[0]) + internal.DBG().Printf("Echo: %s\n", strings.Split(u.Host, ":")[0]) dialContext, err := s.Context.ipDialer.DialContext(ctx, "ip:icmp", strings.Split(u.Host, ":")[0]) if err != nil { return nil, err @@ -411,7 +435,7 @@ func (s *Server) ICMPPing( } endTime := time.Since(sTime) latencies = append(latencies, endTime.Nanoseconds()) - dbg.Printf("1RTT: %s\n", endTime) + internal.DBG().Printf("1RTT: %s\n", endTime) if callback != nil { callback(endTime) } diff --git a/speedtest/request_test.go b/speedtest/request_test.go index 54b7190..7d417c4 100644 --- a/speedtest/request_test.go +++ b/speedtest/request_test.go @@ -3,6 +3,7 @@ package speedtest import ( "context" "fmt" + "github.com/showwin/speedtest-go/speedtest/control" "runtime" "testing" "time" @@ -60,7 +61,7 @@ func TestUploadTestContext(t *testing.T) { if err != nil { t.Errorf(err.Error()) } - value := server.Context.Manager.GetAvgUploadRate() + value := server.Context.Manager.GetUploadAvgRate() if value < idealSpeed*(1-delta) || idealSpeed*(1+delta) < value { t.Errorf("got unexpected server.ULSpeed '%v', expected between %v and %v", value, idealSpeed*(1-delta), idealSpeed*(1+delta)) } @@ -90,7 +91,7 @@ func TestPautaFilter(t *testing.T) { t.Fail() } - result := pautaFilter(vector1) + result := control.pautaFilter(vector1) if len(result) != 10 { t.Fail() } diff --git a/speedtest/result.go b/speedtest/result.go new file mode 100644 index 0000000..909588d --- /dev/null +++ b/speedtest/result.go @@ -0,0 +1 @@ +package speedtest diff --git a/speedtest/server.go b/speedtest/server.go index 7ab64bb..3a9c502 100644 --- a/speedtest/server.go +++ b/speedtest/server.go @@ -6,6 +6,8 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/showwin/speedtest-go/speedtest/control" + "github.com/showwin/speedtest-go/speedtest/internal" "github.com/showwin/speedtest-go/speedtest/transport" "math" "net/http" @@ -51,6 +53,8 @@ type Server struct { Jitter time.Duration `json:"jitter"` DLSpeed ByteRate `json:"dl_speed"` ULSpeed ByteRate `json:"ul_speed"` + Sent int64 `json:"sent"` + Received int64 `json:"received"` TestDuration TestDuration `json:"test_duration"` PacketLoss transport.PLoss `json:"packet_loss"` @@ -222,7 +226,7 @@ func (s *Speedtest) FetchServerListContext(ctx context.Context) (Servers, error) query.Set("lon", strconv.FormatFloat(s.config.Location.Lon, 'f', -1, 64)) } u.RawQuery = query.Encode() - dbg.Printf("Retrieving servers: %s\n", u.String()) + internal.DBG().Printf("Retrieving servers: %s\n", u.String()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { return Servers{}, err @@ -277,7 +281,7 @@ func (s *Speedtest) FetchServerListContext(ctx context.Context) (Servers, error) return servers, errors.New("response payload decoding not implemented") } - dbg.Printf("Servers Num: %d\n", len(servers)) + internal.DBG().Printf("Servers Num: %d\n", len(servers)) // set doer of server for _, server := range servers { server.Context = s @@ -286,15 +290,15 @@ func (s *Speedtest) FetchServerListContext(ctx context.Context) (Servers, error) // ping once var wg sync.WaitGroup pCtx, fc := context.WithTimeout(context.Background(), time.Second*4) - dbg.Println("Echo each server...") + internal.DBG().Println("Echo each server...") for _, server := range servers { wg.Add(1) go func(gs *Server) { var latency []int64 var errPing error - if s.config.PingMode == TCP { + if s.config.PingMode.Assert(control.TypeTCP) { latency, errPing = gs.TCPPing(pCtx, 1, time.Millisecond, nil) - } else if s.config.PingMode == ICMP { + } else if s.config.PingMode.Assert(control.TypeICMP) { latency, errPing = gs.ICMPPing(pCtx, 4*time.Second, 1, time.Millisecond, nil) } else { latency, errPing = gs.HTTPPing(pCtx, 1, time.Millisecond, nil) diff --git a/speedtest/speedtest.go b/speedtest/speedtest.go index 2844fc7..14799b1 100644 --- a/speedtest/speedtest.go +++ b/speedtest/speedtest.go @@ -3,6 +3,8 @@ package speedtest import ( "context" "fmt" + "github.com/showwin/speedtest-go/speedtest/control" + "github.com/showwin/speedtest-go/speedtest/internal" "net" "net/http" "net/url" @@ -13,22 +15,14 @@ import ( ) var ( - version = "1.7.8" + version = "1.8.0-beta.1" DefaultUserAgent = fmt.Sprintf("showwin/speedtest-go %s", version) ) -type Proto int - -const ( - HTTP Proto = iota - TCP - ICMP -) - // Speedtest is a speedtest client. type Speedtest struct { User *User - Manager + *DataManager doer *http.Client config *UserConfig @@ -41,10 +35,11 @@ type UserConfig struct { UserAgent string Proxy string Source string + Protocol control.Proto DnsBindSource bool DialerControl func(network, address string, c syscall.RawConn) error Debug bool - PingMode Proto + PingMode control.Proto SavingMode bool MaxConnections int @@ -66,26 +61,26 @@ func parseAddr(addr string) (string, string) { func (s *Speedtest) NewUserConfig(uc *UserConfig) { if uc.Debug { - dbg.Enable() + internal.DBG().Enable() } if uc.SavingMode { uc.MaxConnections = 1 // Set the number of concurrent connections to 1 } - s.SetNThread(uc.MaxConnections) + s.SetMaxConnections(uc.MaxConnections) if len(uc.CityFlag) > 0 { var err error uc.Location, err = GetLocation(uc.CityFlag) if err != nil { - dbg.Printf("Warning: skipping command line arguments: --city. err: %v\n", err.Error()) + internal.DBG().Printf("Warning: skipping command line arguments: --city. err: %v\n", err.Error()) } } if len(uc.LocationFlag) > 0 { var err error uc.Location, err = ParseLocation(uc.CityFlag, uc.LocationFlag) if err != nil { - dbg.Printf("Warning: skipping command line arguments: --location. err: %v\n", err.Error()) + internal.DBG().Printf("Warning: skipping command line arguments: --location. err: %v\n", err.Error()) } } @@ -102,13 +97,13 @@ func (s *Speedtest) NewUserConfig(uc *UserConfig) { if err == nil { tcpSource = addr0 } else { - dbg.Printf("Warning: skipping parse the source address. err: %s\n", err.Error()) + internal.DBG().Printf("Warning: skipping parse the source address. err: %s\n", err.Error()) } addr1, err := net.ResolveIPAddr("ip", address) // dynamic tcp port if err == nil { icmpSource = addr1 } else { - dbg.Printf("Warning: skipping parse the source address. err: %s\n", err.Error()) + internal.DBG().Printf("Warning: skipping parse the source address. err: %s\n", err.Error()) } if uc.DnsBindSource { net.DefaultResolver.Dial = func(ctx context.Context, network, dnsServer string) (net.Conn, error) { @@ -132,7 +127,7 @@ func (s *Speedtest) NewUserConfig(uc *UserConfig) { if len(uc.Proxy) > 0 { if parse, err := url.Parse(uc.Proxy); err != nil { - dbg.Printf("Warning: skipping parse the proxy host. err: %s\n", err.Error()) + internal.DBG().Printf("Warning: skipping parse the proxy host. err: %s\n", err.Error()) } else { proxy = func(_ *http.Request) (*url.URL, error) { return parse, err @@ -153,10 +148,13 @@ func (s *Speedtest) NewUserConfig(uc *UserConfig) { KeepAlive: 30 * time.Second, Control: uc.DialerControl, } - s.config.T = &http.Transport{ - Proxy: proxy, - DialContext: s.tcpDialer.DialContext, + Proxy: proxy, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := s.tcpDialer.DialContext(ctx, network, address) + fmt.Println(conn.LocalAddr().String()) + return conn, err + }, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -189,20 +187,22 @@ func WithDoer(doer *http.Client) Option { func WithUserConfig(userConfig *UserConfig) Option { return func(s *Speedtest) { s.NewUserConfig(userConfig) - dbg.Printf("Source: %s\n", s.config.Source) - dbg.Printf("Proxy: %s\n", s.config.Proxy) - dbg.Printf("SavingMode: %v\n", s.config.SavingMode) - dbg.Printf("Keyword: %v\n", s.config.Keyword) - dbg.Printf("PingType: %v\n", s.config.PingMode) - dbg.Printf("OS: %s, ARCH: %s, NumCPU: %d\n", runtime.GOOS, runtime.GOARCH, runtime.NumCPU()) + internal.DBG().Printf("Source: %s\n", s.config.Source) + internal.DBG().Printf("Proxy: %s\n", s.config.Proxy) + internal.DBG().Printf("SavingMode: %v\n", s.config.SavingMode) + internal.DBG().Printf("Keyword: %v\n", s.config.Keyword) + internal.DBG().Printf("PingType: %v\n", s.config.PingMode) + internal.DBG().Printf("Protocol: %v\n", s.config.Protocol) + internal.DBG().Printf("OS: %s, ARCH: %s, NumCPU: %d\n", runtime.GOOS, runtime.GOARCH, runtime.NumCPU()) } } // New creates a new speedtest client. +// TODO: need refactor func New(opts ...Option) *Speedtest { s := &Speedtest{ - doer: http.DefaultClient, - Manager: NewDataManager(), + doer: http.DefaultClient, + DataManager: NewDataManager(control.TypeHTTP), } // load default config s.NewUserConfig(&UserConfig{UserAgent: DefaultUserAgent}) @@ -210,6 +210,9 @@ func New(opts ...Option) *Speedtest { for _, opt := range opts { opt(s) } + // init protocol + // // TODO tmp usage + s.DataManager.protocol = s.config.Protocol return s } diff --git a/speedtest/speedtest_test.go b/speedtest/speedtest_test.go index c7f51d4..81c2334 100644 --- a/speedtest/speedtest_test.go +++ b/speedtest/speedtest_test.go @@ -1,6 +1,7 @@ package speedtest import ( + "github.com/showwin/speedtest-go/speedtest/internal" "net/http" "net/http/httptest" "testing" @@ -14,7 +15,7 @@ func BenchmarkLogSpeed(b *testing.B) { } WithUserConfig(config)(s) for i := 0; i < b.N; i++ { - dbg.Printf("hello %s\n", "s20080123") // ~1ns/op + internal.dbg.Printf("hello %s\n", "s20080123") // ~1ns/op } } diff --git a/speedtest/transport/chunk.go b/speedtest/transport/chunk.go new file mode 100644 index 0000000..c7f31a9 --- /dev/null +++ b/speedtest/transport/chunk.go @@ -0,0 +1,130 @@ +package transport + +import ( + "github.com/showwin/speedtest-go/speedtest/control" + "io" + "time" +) + +// DataChunk The Speedtest's I/O implementation of Transmission Control Protocol. +type DataChunk struct { + dateType control.Proto + startTime time.Time + endTime time.Time + err error + ContentLength int64 + remainOrDiscardSize int64 + ctrl control.Controller +} + +func NewChunk(controller control.Controller) control.Chunk { + return &DataChunk{ctrl: controller} +} + +// UploadHandler Create an upload handler +// @Related TCP UPLOAD +func (dc *DataChunk) UploadHandler(size int64) control.Chunk { + if dc.dateType != control.TypeChunkUndefined { + dc.err = control.ErrDuplicateCall + } + dc.ContentLength = size + dc.remainOrDiscardSize = size + dc.dateType = control.TypeUpload | control.TypeTCP + dc.startTime = time.Now() + return dc +} + +func (dc *DataChunk) DownloadHandler(r io.Reader) error { + if dc.dateType != control.TypeChunkUndefined { + dc.err = control.ErrDuplicateCall + return dc.err + } + dc.dateType = control.TypeDownload | control.TypeTCP + dc.startTime = time.Now() + defer func() { + dc.endTime = time.Now() + }() + bufP := control.BlackHole.Get().(*[]byte) + defer control.BlackHole.Put(bufP) + readSize := 0 + for { + select { + case <-dc.ctrl.Done(): + return nil + default: + readSize, dc.err = r.Read(*bufP) + rs := int64(readSize) + + dc.remainOrDiscardSize += rs + dc.ctrl.Add(rs) + if dc.err != nil { + if dc.err == io.EOF { + return nil + } + return dc.err + } + } + } +} + +func (dc *DataChunk) Rate() float64 { + if dc.dateType.Assert(control.TypeDownload) { + return float64(dc.remainOrDiscardSize) / dc.Duration().Seconds() + } else if dc.dateType.Assert(control.TypeUpload) { + return float64(dc.ContentLength-dc.remainOrDiscardSize) * 8 / 1000 / 1000 / dc.Duration().Seconds() + } + return 0 +} + +func (dc *DataChunk) Duration() time.Duration { + return dc.endTime.Sub(dc.startTime) +} + +func (dc *DataChunk) Type() control.Proto { + return dc.dateType +} + +func (dc *DataChunk) Len() int64 { + return dc.ContentLength +} + +// WriteTo Used to hook body traffic. +// @Related TCP UPLOAD +func (dc *DataChunk) WriteTo(w io.Writer) (written int64, err error) { + nw := 0 + nr := control.DefaultReadChunkSize + for { + select { + case <-dc.ctrl.Done(): + dc.endTime = time.Now() + return written, io.EOF + default: + if dc.remainOrDiscardSize <= 0 { + dc.endTime = time.Now() + return written, io.EOF + } + if dc.remainOrDiscardSize < control.DefaultReadChunkSize { + nr = int(dc.remainOrDiscardSize) + nw, err = w.Write(dc.ctrl.Repeat()[:nr]) + } else { + nw, err = w.Write(dc.ctrl.Repeat()) + } + if err != nil { + return + } + n64 := int64(nw) + written += n64 + dc.remainOrDiscardSize -= n64 + dc.ctrl.Add(n64) + if nr != nw { + return written, io.ErrShortWrite + } + } + } +} + +// @Related TCP UPLOAD +func (dc *DataChunk) Read(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} diff --git a/speedtest/transport/scalar.go b/speedtest/transport/scalar.go new file mode 100644 index 0000000..4b24715 --- /dev/null +++ b/speedtest/transport/scalar.go @@ -0,0 +1,20 @@ +package transport + +const MaxBufferSize = 32 * 1000 * 1000 // 32 KB + +type bufferScalar struct { + factor int64 + connections int64 +} + +func (s *bufferScalar) update(rate int64) int64 { + ret := rate / s.factor / s.connections + if ret > MaxBufferSize { + return MaxBufferSize + } + return ret +} + +func newBufAllocator(factor, connections int64) *bufferScalar { + return &bufferScalar{factor: factor, connections: connections} +} diff --git a/speedtest/transport/tcp.go b/speedtest/transport/tcp.go index 16219a0..6f8077f 100644 --- a/speedtest/transport/tcp.go +++ b/speedtest/transport/tcp.go @@ -6,19 +6,21 @@ import ( "context" "errors" "fmt" + "github.com/showwin/speedtest-go/speedtest/internal" + "io" "net" "strconv" "time" ) var ( - pingPrefix = []byte{0x50, 0x49, 0x4e, 0x47, 0x20} - // downloadPrefix = []byte{0x44, 0x4F, 0x57, 0x4E, 0x4C, 0x4F, 0x41, 0x44, 0x20} - // uploadPrefix = []byte{0x55, 0x50, 0x4C, 0x4F, 0x41, 0x44, 0x20} - initPacket = []byte{0x49, 0x4e, 0x49, 0x54, 0x50, 0x4c, 0x4f, 0x53, 0x53} - packetLoss = []byte{0x50, 0x4c, 0x4f, 0x53, 0x53} - hiFormat = []byte{0x48, 0x49} - quitFormat = []byte{0x51, 0x55, 0x49, 0x54} + pingPrefix = []byte{0x50, 0x49, 0x4e, 0x47, 0x20} + downloadPrefix = []byte{0x44, 0x4F, 0x57, 0x4E, 0x4C, 0x4F, 0x41, 0x44, 0x20} + uploadPrefix = []byte{0x55, 0x50, 0x4C, 0x4F, 0x41, 0x44, 0x20} + initPacket = []byte{0x49, 0x4e, 0x49, 0x54, 0x50, 0x4c, 0x4f, 0x53, 0x53} + packetLoss = []byte{0x50, 0x4c, 0x4f, 0x53, 0x53} + hiFormat = []byte{0x48, 0x49} + quitFormat = []byte{0x51, 0x55, 0x49, 0x54} ) var ( @@ -28,10 +30,6 @@ var ( ErrUninitializedPacketLossInst = errors.New("uninitialized packet loss inst") ) -func pingFormat(locTime int64) []byte { - return strconv.AppendInt(pingPrefix, locTime, 10) -} - type Client struct { id string conn net.Conn @@ -44,7 +42,7 @@ type Client struct { } func NewClient(dialer *net.Dialer) (*Client, error) { - uuid, err := generateUUID() + uuid, err := internal.GenerateUUID() if err != nil { return nil, err } @@ -68,6 +66,10 @@ func (client *Client) Connect(ctx context.Context, host string) (err error) { return nil } +func (client *Client) Local() net.Addr { + return client.conn.LocalAddr() +} + func (client *Client) Disconnect() (err error) { _, _ = client.conn.Write(quitFormat) client.conn = nil @@ -76,7 +78,7 @@ func (client *Client) Disconnect() (err error) { return } -func (client *Client) Write(data []byte) (err error) { +func (client *Client) WriteToConn(data []byte) (err error) { if client.conn == nil { return ErrEmptyConn } @@ -84,7 +86,7 @@ func (client *Client) Write(data []byte) (err error) { return } -func (client *Client) Read() ([]byte, error) { +func (client *Client) ReadFromConn() ([]byte, error) { if client.conn == nil { return nil, ErrEmptyConn } @@ -93,9 +95,9 @@ func (client *Client) Read() ([]byte, error) { func (client *Client) Version() string { if len(client.version) == 0 { - err := client.Write(hiFormat) + err := client.WriteToConn(hiFormat) if err == nil { - message, err := client.Read() + message, err := client.ReadFromConn() if err != nil || len(message) < 8 { return "unknown" } @@ -121,11 +123,11 @@ func (client *Client) PingContext(ctx context.Context) (int64, error) { go func() { for i := 0; i < 2; i++ { t0 := time.Now().UnixNano() - if err := client.Write(pingFormat(t0)); err != nil { + if err := client.WriteToConn(embedFormat(pingPrefix, t0)); err != nil { resultChan <- err return } - data, err := client.Read() + data, err := client.ReadFromConn() t2 := time.Now().UnixNano() if err != nil { resultChan <- err @@ -164,11 +166,11 @@ func (client *Client) InitPacketLoss() error { id := client.id payload := append(hiFormat, 0x20) payload = append(payload, []byte(id)...) - err := client.Write(payload) + err := client.WriteToConn(payload) if err != nil { return err } - return client.Write(initPacket) + return client.WriteToConn(initPacket) } // PLoss Packet loss statistics @@ -205,11 +207,11 @@ func (p PLoss) LossPercent() float64 { } func (client *Client) PacketLoss() (*PLoss, error) { - err := client.Write(packetLoss) + err := client.WriteToConn(packetLoss) if err != nil { return nil, err } - result, err := client.Read() + result, err := client.ReadFromConn() if err != nil { return nil, err } @@ -236,10 +238,35 @@ func (client *Client) PacketLoss() (*PLoss, error) { }, nil } -func (client *Client) Download() { - panic("Unimplemented method: Client.Download()") +func (client *Client) RegisterDownload(chunkSize int64) (io.Reader, error) { + initPacketOperations := embedFormat(downloadPrefix, chunkSize) + err := client.WriteToConn(initPacketOperations) + if err != nil { + return nil, err + } + return client.conn, err +} + +func (client *Client) RegisterUpload(chunkSize int64) (int64, error) { + initPacketOperations := embedFormat(uploadPrefix, chunkSize) + remain := chunkSize - int64(len(initPacketOperations)) - 1 + return remain, client.WriteToConn(initPacketOperations) +} + +// Upload Perform upload operation +// @endpoint data input source, We can collect the upload info here. +func (client *Client) Upload(endpoint io.Reader) (int64, error) { + if wt, ok := endpoint.(io.WriterTo); ok { + return wt.WriteTo(client) + } else { + panic("endpoint is not implement io.WriterTo") + } +} + +func (client *Client) Write(p []byte) (n int, err error) { + return client.conn.Write(p) } -func (client *Client) Upload() { - panic("Unimplemented method: Client.Upload()") +func embedFormat(prefix []byte, packetSize int64) []byte { + return strconv.AppendInt(prefix, packetSize, 10) } diff --git a/speedtest/transport/udp.go b/speedtest/transport/udp.go index 79e8174..ff22786 100644 --- a/speedtest/transport/udp.go +++ b/speedtest/transport/udp.go @@ -3,7 +3,6 @@ package transport import ( "bytes" "context" - "crypto/rand" "fmt" mrand "math/rand" "net" @@ -52,14 +51,3 @@ func (ps *PacketLossSender) Send(order int) error { _, err := ps.conn.Write(payload) return err } - -func generateUUID() (string, error) { - randUUID := make([]byte, 16) - _, err := rand.Read(randUUID) - if err != nil { - return "", err - } - randUUID[8] = randUUID[8]&^0xc0 | 0x80 - randUUID[6] = randUUID[6]&^0xf0 | 0x40 - return fmt.Sprintf("%x-%x-%x-%x-%x", randUUID[0:4], randUUID[4:6], randUUID[6:8], randUUID[8:10], randUUID[10:]), nil -} diff --git a/speedtest/user.go b/speedtest/user.go index 1bb0ae0..e2b5c21 100644 --- a/speedtest/user.go +++ b/speedtest/user.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/showwin/speedtest-go/speedtest/internal" "net/http" ) @@ -35,7 +36,7 @@ func FetchUserInfo() (*User, error) { // FetchUserInfoContext returns information about caller determined by speedtest.net, observing the given context. func (s *Speedtest) FetchUserInfoContext(ctx context.Context) (*User, error) { - dbg.Printf("Retrieving user info: %s\n", speedTestConfigUrl) + internal.DBG().Printf("Retrieving user info: %s\n", speedTestConfigUrl) req, err := http.NewRequestWithContext(ctx, http.MethodGet, speedTestConfigUrl, nil) if err != nil { return nil, err