Skip to content

Commit

Permalink
Merge pull request #35 from rezmoss/download-upload-progress
Browse files Browse the repository at this point in the history
feat:Progress Tracking for Uploads/Downloads
  • Loading branch information
rezmoss authored Oct 10, 2024
2 parents 2d2e034 + 9324508 commit d109af7
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 5 deletions.
55 changes: 55 additions & 0 deletions axios4go_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package axios4go

import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand Down Expand Up @@ -863,3 +866,55 @@ func TestGetByProxy(t *testing.T) {
}
})
}

func TestProgressCallbacks(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Read the request body to trigger upload progress
_, err := io.Copy(io.Discard, r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}

// Simulate a large file for download
w.Header().Set("Content-Length", "1000000")
for i := 0; i < 1000000; i++ {
_, err := w.Write([]byte("a"))
if err != nil {
t.Fatalf("Failed to write response: %v", err)
}
}
}))
defer server.Close()

uploadCalled := false
downloadCalled := false

body := bytes.NewReader([]byte(strings.Repeat("b", 500000))) // 500KB upload

_, err := Post(server.URL, body, &RequestOptions{
OnUploadProgress: func(bytesRead, totalBytes int64) {
uploadCalled = true
if bytesRead > totalBytes {
t.Errorf("Upload progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes)
}
},
OnDownloadProgress: func(bytesRead, totalBytes int64) {
downloadCalled = true
if bytesRead > totalBytes {
t.Errorf("Download progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes)
}
},
MaxContentLength: 2000000, // Set this to allow our 1MB response
})

if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if !uploadCalled {
t.Error("Upload progress callback was not called")
}
if !downloadCalled {
t.Error("Download progress callback was not called")
}
}
73 changes: 68 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ type RequestOptions struct {
ResponseType string
ResponseEncoding string
MaxRedirects int
MaxContentLength int
MaxBodyLength int
MaxContentLength int64
MaxBodyLength int64
Decompress bool
ValidateStatus func(int) bool
InterceptorOptions InterceptorOptions
Proxy *Proxy
OnUploadProgress func(bytesRead, totalBytes int64)
OnDownloadProgress func(bytesRead, totalBytes int64)
}

type Proxy struct {
Expand All @@ -74,6 +76,38 @@ type Auth struct {
Password string
}

type ProgressReader struct {
reader io.Reader
total int64
read int64
onProgress func(bytesRead, totalBytes int64)
}

type ProgressWriter struct {
writer io.Writer
total int64
written int64
onProgress func(bytesWritten, totalBytes int64)
}

func (pr *ProgressReader) Read(p []byte) (int, error) {
n, err := pr.reader.Read(p)
pr.read += int64(n)
if pr.onProgress != nil {
pr.onProgress(pr.read, pr.total)
}
return n, err
}

func (pw *ProgressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
pw.written += int64(n)
if pw.onProgress != nil {
pw.onProgress(pw.written, pw.total)
}
return n, err
}

var defaultClient = &Client{HTTPClient: &http.Client{}}

func (r *Response) JSON(v interface{}) error {
Expand Down Expand Up @@ -345,6 +379,14 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) {
if options.MaxBodyLength > 0 && bodyLength > int64(options.MaxBodyLength) {
return nil, errors.New("request body length exceeded maxBodyLength")
}

if options.Body != nil && options.OnUploadProgress != nil {
bodyReader = &ProgressReader{
reader: bodyReader,
total: bodyLength,
onProgress: options.OnUploadProgress,
}
}
}

req, err := http.NewRequest(options.Method, fullURL, bodyReader)
Expand Down Expand Up @@ -428,9 +470,24 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) {
}
}()

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
var responseBody []byte
if options.OnDownloadProgress != nil {
buf := &bytes.Buffer{}
progressWriter := &ProgressWriter{
writer: buf,
total: resp.ContentLength,
onProgress: options.OnDownloadProgress,
}
_, err = io.Copy(progressWriter, resp.Body)
if err != nil {
return nil, err
}
responseBody = buf.Bytes()
} else {
responseBody, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
}

if int64(len(responseBody)) > int64(options.MaxContentLength) {
Expand Down Expand Up @@ -504,6 +561,12 @@ func mergeOptions(dst, src *RequestOptions) {
if src.InterceptorOptions.ResponseInterceptors != nil {
dst.InterceptorOptions.ResponseInterceptors = src.InterceptorOptions.ResponseInterceptors
}
if src.OnUploadProgress != nil {
dst.OnUploadProgress = src.OnUploadProgress
}
if src.OnDownloadProgress != nil {
dst.OnDownloadProgress = src.OnDownloadProgress
}
if src.Proxy != nil {
dst.Proxy = src.Proxy
}
Expand Down
54 changes: 54 additions & 0 deletions examples/download.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package main

import (
"fmt"
"os"
"time"

"github.com/rezmoss/axios4go"
)

func main() {
url := "https://ash-speed.hetzner.com/1GB.bin"
outputPath := "1GB.bin"

startTime := time.Now()
lastPrintTime := startTime

resp, err := axios4go.Get(url, &axios4go.RequestOptions{
MaxContentLength: 5 * 1024 * 1024 * 1024, // 5GB
Timeout: 60000 * 5,
OnDownloadProgress: func(bytesRead, totalBytes int64) {
currentTime := time.Now()
if currentTime.Sub(lastPrintTime) >= time.Second || bytesRead == totalBytes {
percentage := float64(bytesRead) / float64(totalBytes) * 100
downloadedMB := float64(bytesRead) / 1024 / 1024
totalMB := float64(totalBytes) / 1024 / 1024
elapsedTime := currentTime.Sub(startTime)
speed := float64(bytesRead) / elapsedTime.Seconds() / 1024 / 1024 // MB/s

fmt.Printf("\rDownloaded %.2f%% (%.2f MB / %.2f MB) - Speed: %.2f MB/s",
percentage, downloadedMB, totalMB, speed)

lastPrintTime = currentTime
}
},
})

if err != nil {
fmt.Printf("\nError downloading file: %v\n", err)
return
}

err = writeResponseToFile(resp, outputPath)
if err != nil {
fmt.Printf("\nError writing file: %v\n", err)
return
}

fmt.Println("\nDownload completed successfully!!")
}

func writeResponseToFile(resp *axios4go.Response, outputPath string) error {
return os.WriteFile(outputPath, resp.Body, 0644)
}

0 comments on commit d109af7

Please sign in to comment.