Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:Progress Tracking for Uploads/Downloads #35

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions axios4go_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package axios4go

import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand Down Expand Up @@ -863,3 +866,55 @@ func TestGetByProxy(t *testing.T) {
}
})
}

func TestProgressCallbacks(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Read the request body to trigger upload progress
_, err := io.Copy(io.Discard, r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}

// Simulate a large file for download
w.Header().Set("Content-Length", "1000000")
for i := 0; i < 1000000; i++ {
_, err := w.Write([]byte("a"))
if err != nil {
t.Fatalf("Failed to write response: %v", err)
}
}
}))
defer server.Close()

uploadCalled := false
downloadCalled := false

body := bytes.NewReader([]byte(strings.Repeat("b", 500000))) // 500KB upload

_, err := Post(server.URL, body, &RequestOptions{
OnUploadProgress: func(bytesRead, totalBytes int64) {
uploadCalled = true
if bytesRead > totalBytes {
t.Errorf("Upload progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes)
}
},
OnDownloadProgress: func(bytesRead, totalBytes int64) {
downloadCalled = true
if bytesRead > totalBytes {
t.Errorf("Download progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes)
}
},
MaxContentLength: 2000000, // Set this to allow our 1MB response
})

if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if !uploadCalled {
t.Error("Upload progress callback was not called")
}
if !downloadCalled {
t.Error("Download progress callback was not called")
}
}
73 changes: 68 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ type RequestOptions struct {
ResponseType string
ResponseEncoding string
MaxRedirects int
MaxContentLength int
MaxBodyLength int
MaxContentLength int64
MaxBodyLength int64
Decompress bool
ValidateStatus func(int) bool
InterceptorOptions InterceptorOptions
Proxy *Proxy
OnUploadProgress func(bytesRead, totalBytes int64)
OnDownloadProgress func(bytesRead, totalBytes int64)
}

type Proxy struct {
Expand All @@ -74,6 +76,38 @@ type Auth struct {
Password string
}

type ProgressReader struct {
reader io.Reader
total int64
read int64
onProgress func(bytesRead, totalBytes int64)
}

type ProgressWriter struct {
writer io.Writer
total int64
written int64
onProgress func(bytesWritten, totalBytes int64)
}

func (pr *ProgressReader) Read(p []byte) (int, error) {
n, err := pr.reader.Read(p)
pr.read += int64(n)
if pr.onProgress != nil {
pr.onProgress(pr.read, pr.total)
}
return n, err
}

func (pw *ProgressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
pw.written += int64(n)
if pw.onProgress != nil {
pw.onProgress(pw.written, pw.total)
}
return n, err
}

var defaultClient = &Client{HTTPClient: &http.Client{}}

func (r *Response) JSON(v interface{}) error {
Expand Down Expand Up @@ -345,6 +379,14 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) {
if options.MaxBodyLength > 0 && bodyLength > int64(options.MaxBodyLength) {
return nil, errors.New("request body length exceeded maxBodyLength")
}

if options.Body != nil && options.OnUploadProgress != nil {
bodyReader = &ProgressReader{
reader: bodyReader,
total: bodyLength,
onProgress: options.OnUploadProgress,
}
}
}

req, err := http.NewRequest(options.Method, fullURL, bodyReader)
Expand Down Expand Up @@ -428,9 +470,24 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) {
}
}()

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
var responseBody []byte
if options.OnDownloadProgress != nil {
buf := &bytes.Buffer{}
progressWriter := &ProgressWriter{
writer: buf,
total: resp.ContentLength,
onProgress: options.OnDownloadProgress,
}
_, err = io.Copy(progressWriter, resp.Body)
if err != nil {
return nil, err
}
responseBody = buf.Bytes()
} else {
responseBody, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
}

if int64(len(responseBody)) > int64(options.MaxContentLength) {
Expand Down Expand Up @@ -504,6 +561,12 @@ func mergeOptions(dst, src *RequestOptions) {
if src.InterceptorOptions.ResponseInterceptors != nil {
dst.InterceptorOptions.ResponseInterceptors = src.InterceptorOptions.ResponseInterceptors
}
if src.OnUploadProgress != nil {
dst.OnUploadProgress = src.OnUploadProgress
}
if src.OnDownloadProgress != nil {
dst.OnDownloadProgress = src.OnDownloadProgress
}
if src.Proxy != nil {
dst.Proxy = src.Proxy
}
Expand Down
54 changes: 54 additions & 0 deletions examples/download.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package main

import (
"fmt"
"os"
"time"

"github.com/rezmoss/axios4go"
)

func main() {
url := "https://ash-speed.hetzner.com/1GB.bin"
outputPath := "1GB.bin"

startTime := time.Now()
lastPrintTime := startTime

resp, err := axios4go.Get(url, &axios4go.RequestOptions{
MaxContentLength: 5 * 1024 * 1024 * 1024, // 5GB
Timeout: 60000 * 5,
OnDownloadProgress: func(bytesRead, totalBytes int64) {
currentTime := time.Now()
if currentTime.Sub(lastPrintTime) >= time.Second || bytesRead == totalBytes {
percentage := float64(bytesRead) / float64(totalBytes) * 100
downloadedMB := float64(bytesRead) / 1024 / 1024
totalMB := float64(totalBytes) / 1024 / 1024
elapsedTime := currentTime.Sub(startTime)
speed := float64(bytesRead) / elapsedTime.Seconds() / 1024 / 1024 // MB/s

fmt.Printf("\rDownloaded %.2f%% (%.2f MB / %.2f MB) - Speed: %.2f MB/s",
percentage, downloadedMB, totalMB, speed)

lastPrintTime = currentTime
}
},
})

if err != nil {
fmt.Printf("\nError downloading file: %v\n", err)
return
}

err = writeResponseToFile(resp, outputPath)
if err != nil {
fmt.Printf("\nError writing file: %v\n", err)
return
}

fmt.Println("\nDownload completed successfully!!")
}

func writeResponseToFile(resp *axios4go.Response, outputPath string) error {
return os.WriteFile(outputPath, resp.Body, 0644)
}
Loading