diff --git a/axios4go_test.go b/axios4go_test.go index 00b241b..62056e9 100644 --- a/axios4go_test.go +++ b/axios4go_test.go @@ -1,9 +1,12 @@ package axios4go import ( + "bytes" "encoding/json" + "io" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -863,3 +866,55 @@ func TestGetByProxy(t *testing.T) { } }) } + +func TestProgressCallbacks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read the request body to trigger upload progress + _, err := io.Copy(io.Discard, r.Body) + if err != nil { + t.Fatalf("Failed to read request body: %v", err) + } + + // Simulate a large file for download + w.Header().Set("Content-Length", "1000000") + for i := 0; i < 1000000; i++ { + _, err := w.Write([]byte("a")) + if err != nil { + t.Fatalf("Failed to write response: %v", err) + } + } + })) + defer server.Close() + + uploadCalled := false + downloadCalled := false + + body := bytes.NewReader([]byte(strings.Repeat("b", 500000))) // 500KB upload + + _, err := Post(server.URL, body, &RequestOptions{ + OnUploadProgress: func(bytesRead, totalBytes int64) { + uploadCalled = true + if bytesRead > totalBytes { + t.Errorf("Upload progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes) + } + }, + OnDownloadProgress: func(bytesRead, totalBytes int64) { + downloadCalled = true + if bytesRead > totalBytes { + t.Errorf("Download progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes) + } + }, + MaxContentLength: 2000000, // Set this to allow our 1MB response + }) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if !uploadCalled { + t.Error("Upload progress callback was not called") + } + if !downloadCalled { + t.Error("Download progress callback was not called") + } +} diff --git a/client.go b/client.go index 91e61f5..4778b6c 100644 --- a/client.go +++ b/client.go @@ -54,12 +54,14 @@ type RequestOptions struct { ResponseType string ResponseEncoding string MaxRedirects int - MaxContentLength int - MaxBodyLength int + MaxContentLength int64 + MaxBodyLength int64 Decompress bool ValidateStatus func(int) bool InterceptorOptions InterceptorOptions Proxy *Proxy + OnUploadProgress func(bytesRead, totalBytes int64) + OnDownloadProgress func(bytesRead, totalBytes int64) } type Proxy struct { @@ -74,6 +76,38 @@ type Auth struct { Password string } +type ProgressReader struct { + reader io.Reader + total int64 + read int64 + onProgress func(bytesRead, totalBytes int64) +} + +type ProgressWriter struct { + writer io.Writer + total int64 + written int64 + onProgress func(bytesWritten, totalBytes int64) +} + +func (pr *ProgressReader) Read(p []byte) (int, error) { + n, err := pr.reader.Read(p) + pr.read += int64(n) + if pr.onProgress != nil { + pr.onProgress(pr.read, pr.total) + } + return n, err +} + +func (pw *ProgressWriter) Write(p []byte) (int, error) { + n, err := pw.writer.Write(p) + pw.written += int64(n) + if pw.onProgress != nil { + pw.onProgress(pw.written, pw.total) + } + return n, err +} + var defaultClient = &Client{HTTPClient: &http.Client{}} func (r *Response) JSON(v interface{}) error { @@ -345,6 +379,14 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) { if options.MaxBodyLength > 0 && bodyLength > int64(options.MaxBodyLength) { return nil, errors.New("request body length exceeded maxBodyLength") } + + if options.Body != nil && options.OnUploadProgress != nil { + bodyReader = &ProgressReader{ + reader: bodyReader, + total: bodyLength, + onProgress: options.OnUploadProgress, + } + } } req, err := http.NewRequest(options.Method, fullURL, bodyReader) @@ -428,9 +470,24 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) { } }() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err + var responseBody []byte + if options.OnDownloadProgress != nil { + buf := &bytes.Buffer{} + progressWriter := &ProgressWriter{ + writer: buf, + total: resp.ContentLength, + onProgress: options.OnDownloadProgress, + } + _, err = io.Copy(progressWriter, resp.Body) + if err != nil { + return nil, err + } + responseBody = buf.Bytes() + } else { + responseBody, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } } if int64(len(responseBody)) > int64(options.MaxContentLength) { @@ -504,6 +561,12 @@ func mergeOptions(dst, src *RequestOptions) { if src.InterceptorOptions.ResponseInterceptors != nil { dst.InterceptorOptions.ResponseInterceptors = src.InterceptorOptions.ResponseInterceptors } + if src.OnUploadProgress != nil { + dst.OnUploadProgress = src.OnUploadProgress + } + if src.OnDownloadProgress != nil { + dst.OnDownloadProgress = src.OnDownloadProgress + } if src.Proxy != nil { dst.Proxy = src.Proxy } diff --git a/examples/download.go b/examples/download.go new file mode 100644 index 0000000..82a5a75 --- /dev/null +++ b/examples/download.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/rezmoss/axios4go" +) + +func main() { + url := "https://ash-speed.hetzner.com/1GB.bin" + outputPath := "1GB.bin" + + startTime := time.Now() + lastPrintTime := startTime + + resp, err := axios4go.Get(url, &axios4go.RequestOptions{ + MaxContentLength: 5 * 1024 * 1024 * 1024, // 5GB + Timeout: 60000 * 5, + OnDownloadProgress: func(bytesRead, totalBytes int64) { + currentTime := time.Now() + if currentTime.Sub(lastPrintTime) >= time.Second || bytesRead == totalBytes { + percentage := float64(bytesRead) / float64(totalBytes) * 100 + downloadedMB := float64(bytesRead) / 1024 / 1024 + totalMB := float64(totalBytes) / 1024 / 1024 + elapsedTime := currentTime.Sub(startTime) + speed := float64(bytesRead) / elapsedTime.Seconds() / 1024 / 1024 // MB/s + + fmt.Printf("\rDownloaded %.2f%% (%.2f MB / %.2f MB) - Speed: %.2f MB/s", + percentage, downloadedMB, totalMB, speed) + + lastPrintTime = currentTime + } + }, + }) + + if err != nil { + fmt.Printf("\nError downloading file: %v\n", err) + return + } + + err = writeResponseToFile(resp, outputPath) + if err != nil { + fmt.Printf("\nError writing file: %v\n", err) + return + } + + fmt.Println("\nDownload completed successfully!!") +} + +func writeResponseToFile(resp *axios4go.Response, outputPath string) error { + return os.WriteFile(outputPath, resp.Body, 0644) +}