Skip to content

Commit 779c293

Browse files
authored
fix(driver): implement canceling and updating progress for putting for some drivers (#7847)
* fix(driver): additionally implement canceling and updating progress for putting for some drivers * refactor: add driver archive api into template * fix(123): use built-in MD5 to avoid caching full * . * fix build failed
1 parent b9f397d commit 779c293

File tree

35 files changed

+457
-256
lines changed

35 files changed

+457
-256
lines changed

drivers/115/driver.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,12 @@ func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
215215
var uploadResult *UploadResult
216216
// 闪传失败,上传
217217
if stream.GetSize() <= 10*utils.MB { // 文件大小小于10MB,改用普通模式上传
218-
if uploadResult, err = d.UploadByOSS(&fastInfo.UploadOSSParams, stream, dirID); err != nil {
218+
if uploadResult, err = d.UploadByOSS(ctx, &fastInfo.UploadOSSParams, stream, dirID, up); err != nil {
219219
return nil, err
220220
}
221221
} else {
222222
// 分片上传
223-
if uploadResult, err = d.UploadByMultipart(&fastInfo.UploadOSSParams, stream.GetSize(), stream, dirID); err != nil {
223+
if uploadResult, err = d.UploadByMultipart(ctx, &fastInfo.UploadOSSParams, stream.GetSize(), stream, dirID, up); err != nil {
224224
return nil, err
225225
}
226226
}

drivers/115/util.go

+26-5
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@ package _115
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/md5"
67
"crypto/tls"
78
"encoding/hex"
89
"encoding/json"
910
"fmt"
11+
"github.com/alist-org/alist/v3/internal/driver"
12+
"github.com/alist-org/alist/v3/internal/stream"
1013
"io"
1114
"net/http"
1215
"net/url"
1316
"strconv"
1417
"strings"
1518
"sync"
19+
"sync/atomic"
1620
"time"
1721

1822
"github.com/alist-org/alist/v3/internal/conf"
@@ -271,7 +275,7 @@ func UploadDigestRange(stream model.FileStreamer, rangeSpec string) (result stri
271275
}
272276

273277
// UploadByOSS use aliyun sdk to upload
274-
func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dirID string) (*UploadResult, error) {
278+
func (c *Pan115) UploadByOSS(ctx context.Context, params *driver115.UploadOSSParams, s model.FileStreamer, dirID string, up driver.UpdateProgress) (*UploadResult, error) {
275279
ossToken, err := c.client.GetOSSToken()
276280
if err != nil {
277281
return nil, err
@@ -286,6 +290,13 @@ func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dir
286290
}
287291

288292
var bodyBytes []byte
293+
r := &stream.ReaderWithCtx{
294+
Reader: &stream.ReaderUpdatingProgress{
295+
Reader: s,
296+
UpdateProgress: up,
297+
},
298+
Ctx: ctx,
299+
}
289300
if err = bucket.PutObject(params.Object, r, append(
290301
driver115.OssOption(params, ossToken),
291302
oss.CallbackResult(&bodyBytes),
@@ -301,7 +312,8 @@ func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dir
301312
}
302313

303314
// UploadByMultipart upload by mutipart blocks
304-
func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize int64, stream model.FileStreamer, dirID string, opts ...driver115.UploadMultipartOption) (*UploadResult, error) {
315+
func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.UploadOSSParams, fileSize int64, s model.FileStreamer,
316+
dirID string, up driver.UpdateProgress, opts ...driver115.UploadMultipartOption) (*UploadResult, error) {
305317
var (
306318
chunks []oss.FileChunk
307319
parts []oss.UploadPart
@@ -313,7 +325,7 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i
313325
err error
314326
)
315327

316-
tmpF, err := stream.CacheFullInTempFile()
328+
tmpF, err := s.CacheFullInTempFile()
317329
if err != nil {
318330
return nil, err
319331
}
@@ -372,6 +384,7 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i
372384
quit <- struct{}{}
373385
}()
374386

387+
completedNum := atomic.Int32{}
375388
// consumers
376389
for i := 0; i < options.ThreadsNum; i++ {
377390
go func(threadId int) {
@@ -384,6 +397,8 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i
384397
var part oss.UploadPart // 出现错误就继续尝试,共尝试3次
385398
for retry := 0; retry < 3; retry++ {
386399
select {
400+
case <-ctx.Done():
401+
break
387402
case <-ticker.C:
388403
if ossToken, err = d.client.GetOSSToken(); err != nil { // 到时重新获取ossToken
389404
errCh <- errors.Wrap(err, "刷新token时出现错误")
@@ -396,12 +411,18 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i
396411
continue
397412
}
398413

399-
if part, err = bucket.UploadPart(imur, bytes.NewBuffer(buf), chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil {
414+
if part, err = bucket.UploadPart(imur, &stream.ReaderWithCtx{
415+
Reader: bytes.NewBuffer(buf),
416+
Ctx: ctx,
417+
}, chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil {
400418
break
401419
}
402420
}
403421
if err != nil {
404-
errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", stream.GetName(), chunk.Number, err))
422+
errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", s.GetName(), chunk.Number, err))
423+
} else {
424+
num := completedNum.Add(1)
425+
up(float64(num) * 100.0 / float64(len(chunks)))
405426
}
406427
UploadedPartsCh <- part
407428
}

drivers/123/driver.go

+34-24
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/base64"
77
"encoding/hex"
88
"fmt"
9+
"github.com/alist-org/alist/v3/internal/stream"
910
"io"
1011
"net/http"
1112
"net/url"
@@ -185,32 +186,35 @@ func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error {
185186
}
186187
}
187188

188-
func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
189-
// const DEFAULT int64 = 10485760
190-
h := md5.New()
191-
// need to calculate md5 of the full content
192-
tempFile, err := stream.CacheFullInTempFile()
193-
if err != nil {
194-
return err
195-
}
196-
defer func() {
197-
_ = tempFile.Close()
198-
}()
199-
if _, err = utils.CopyWithBuffer(h, tempFile); err != nil {
200-
return err
201-
}
202-
_, err = tempFile.Seek(0, io.SeekStart)
203-
if err != nil {
204-
return err
189+
func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
190+
etag := file.GetHash().GetHash(utils.MD5)
191+
if len(etag) < utils.MD5.Width {
192+
// const DEFAULT int64 = 10485760
193+
h := md5.New()
194+
// need to calculate md5 of the full content
195+
tempFile, err := file.CacheFullInTempFile()
196+
if err != nil {
197+
return err
198+
}
199+
defer func() {
200+
_ = tempFile.Close()
201+
}()
202+
if _, err = utils.CopyWithBuffer(h, tempFile); err != nil {
203+
return err
204+
}
205+
_, err = tempFile.Seek(0, io.SeekStart)
206+
if err != nil {
207+
return err
208+
}
209+
etag = hex.EncodeToString(h.Sum(nil))
205210
}
206-
etag := hex.EncodeToString(h.Sum(nil))
207211
data := base.Json{
208212
"driveId": 0,
209213
"duplicate": 2, // 2->覆盖 1->重命名 0->默认
210214
"etag": etag,
211-
"fileName": stream.GetName(),
215+
"fileName": file.GetName(),
212216
"parentFileId": dstDir.GetID(),
213-
"size": stream.GetSize(),
217+
"size": file.GetSize(),
214218
"type": 0,
215219
}
216220
var resp UploadResp
@@ -225,7 +229,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
225229
return nil
226230
}
227231
if resp.Data.AccessKeyId == "" || resp.Data.SecretAccessKey == "" || resp.Data.SessionToken == "" {
228-
err = d.newUpload(ctx, &resp, stream, tempFile, up)
232+
err = d.newUpload(ctx, &resp, file, up)
229233
return err
230234
} else {
231235
cfg := &aws.Config{
@@ -239,15 +243,21 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
239243
return err
240244
}
241245
uploader := s3manager.NewUploader(s)
242-
if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize {
243-
uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1)
246+
if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize {
247+
uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1)
244248
}
245249
input := &s3manager.UploadInput{
246250
Bucket: &resp.Data.Bucket,
247251
Key: &resp.Data.Key,
248-
Body: tempFile,
252+
Body: &stream.ReaderUpdatingProgress{
253+
Reader: file,
254+
UpdateProgress: up,
255+
},
249256
}
250257
_, err = uploader.UploadWithContext(ctx, input)
258+
if err != nil {
259+
return err
260+
}
251261
}
252262
_, err = d.Request(UploadComplete, http.MethodPost, func(req *resty.Request) {
253263
req.SetBody(base.Json{

drivers/123/upload.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F
6969
return err
7070
}
7171

72-
func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, reader io.Reader, up driver.UpdateProgress) error {
72+
func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error {
7373
chunkSize := int64(1024 * 1024 * 16)
7474
// fetch s3 pre signed urls
7575
chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize)))
@@ -103,7 +103,7 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi
103103
if j == chunkCount {
104104
curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize
105105
}
106-
err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(reader, chunkSize), curSize, false, getS3UploadUrl)
106+
err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(file, chunkSize), curSize, false, getS3UploadUrl)
107107
if err != nil {
108108
return err
109109
}

drivers/alist_v3/driver.go

+11-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package alist_v3
33
import (
44
"context"
55
"fmt"
6+
"github.com/alist-org/alist/v3/internal/stream"
67
"io"
78
"net/http"
89
"path"
@@ -181,25 +182,28 @@ func (d *AListV3) Remove(ctx context.Context, obj model.Obj) error {
181182
return err
182183
}
183184

184-
func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
185-
req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", stream)
185+
func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error {
186+
req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", &stream.ReaderUpdatingProgress{
187+
Reader: s,
188+
UpdateProgress: up,
189+
})
186190
if err != nil {
187191
return err
188192
}
189193
req.Header.Set("Authorization", d.Token)
190-
req.Header.Set("File-Path", path.Join(dstDir.GetPath(), stream.GetName()))
194+
req.Header.Set("File-Path", path.Join(dstDir.GetPath(), s.GetName()))
191195
req.Header.Set("Password", d.MetaPassword)
192-
if md5 := stream.GetHash().GetHash(utils.MD5); len(md5) > 0 {
196+
if md5 := s.GetHash().GetHash(utils.MD5); len(md5) > 0 {
193197
req.Header.Set("X-File-Md5", md5)
194198
}
195-
if sha1 := stream.GetHash().GetHash(utils.SHA1); len(sha1) > 0 {
199+
if sha1 := s.GetHash().GetHash(utils.SHA1); len(sha1) > 0 {
196200
req.Header.Set("X-File-Sha1", sha1)
197201
}
198-
if sha256 := stream.GetHash().GetHash(utils.SHA256); len(sha256) > 0 {
202+
if sha256 := s.GetHash().GetHash(utils.SHA256); len(sha256) > 0 {
199203
req.Header.Set("X-File-Sha256", sha256)
200204
}
201205

202-
req.ContentLength = stream.GetSize()
206+
req.ContentLength = s.GetSize()
203207
// client := base.NewHttpClient()
204208
// client.Timeout = time.Hour * 6
205209
res, err := base.HttpClient.Do(req)

drivers/chaoxing/driver.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"github.com/alist-org/alist/v3/internal/stream"
910
"io"
1011
"mime/multipart"
1112
"net/http"
@@ -215,7 +216,7 @@ func (d *ChaoXing) Remove(ctx context.Context, obj model.Obj) error {
215216
return nil
216217
}
217218

218-
func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
219+
func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
219220
var resp UploadDataRsp
220221
_, err := d.request("https://noteyd.chaoxing.com/pc/files/getUploadConfig", http.MethodGet, func(req *resty.Request) {
221222
}, &resp)
@@ -227,11 +228,11 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileS
227228
}
228229
body := &bytes.Buffer{}
229230
writer := multipart.NewWriter(body)
230-
filePart, err := writer.CreateFormFile("file", stream.GetName())
231+
filePart, err := writer.CreateFormFile("file", file.GetName())
231232
if err != nil {
232233
return err
233234
}
234-
_, err = utils.CopyWithBuffer(filePart, stream)
235+
_, err = utils.CopyWithBuffer(filePart, file)
235236
if err != nil {
236237
return err
237238
}
@@ -248,7 +249,14 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileS
248249
if err != nil {
249250
return err
250251
}
251-
req, err := http.NewRequest("POST", "https://pan-yz.chaoxing.com/upload", body)
252+
r := &stream.ReaderUpdatingProgress{
253+
Reader: &stream.SimpleReaderWithSize{
254+
Reader: body,
255+
Size: int64(body.Len()),
256+
},
257+
UpdateProgress: up,
258+
}
259+
req, err := http.NewRequestWithContext(ctx, "POST", "https://pan-yz.chaoxing.com/upload", r)
252260
if err != nil {
253261
return err
254262
}

drivers/ftp/driver.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ftp
22

33
import (
44
"context"
5+
"github.com/alist-org/alist/v3/internal/stream"
56
stdpath "path"
67

78
"github.com/alist-org/alist/v3/internal/driver"
@@ -114,13 +115,18 @@ func (d *FTP) Remove(ctx context.Context, obj model.Obj) error {
114115
}
115116
}
116117

117-
func (d *FTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
118+
func (d *FTP) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error {
118119
if err := d.login(); err != nil {
119120
return err
120121
}
121-
// TODO: support cancel
122-
path := stdpath.Join(dstDir.GetPath(), stream.GetName())
123-
return d.conn.Stor(encode(path, d.Encoding), stream)
122+
path := stdpath.Join(dstDir.GetPath(), s.GetName())
123+
return d.conn.Stor(encode(path, d.Encoding), &stream.ReaderWithCtx{
124+
Reader: &stream.ReaderUpdatingProgress{
125+
Reader: s,
126+
UpdateProgress: up,
127+
},
128+
Ctx: ctx,
129+
})
124130
}
125131

126132
var _ driver.Driver = (*FTP)(nil)

drivers/github/driver.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/alist-org/alist/v3/internal/driver"
1717
"github.com/alist-org/alist/v3/internal/errs"
1818
"github.com/alist-org/alist/v3/internal/model"
19+
"github.com/alist-org/alist/v3/internal/stream"
1920
"github.com/alist-org/alist/v3/pkg/utils"
2021
"github.com/go-resty/resty/v2"
2122
log "github.com/sirupsen/logrus"
@@ -649,15 +650,15 @@ func (d *Github) createGitKeep(path, message string) error {
649650
return nil
650651
}
651652

652-
func (d *Github) putBlob(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress) (string, error) {
653+
func (d *Github) putBlob(ctx context.Context, s model.FileStreamer, up driver.UpdateProgress) (string, error) {
653654
beforeContent := "{\"encoding\":\"base64\",\"content\":\""
654655
afterContent := "\"}"
655-
length := int64(len(beforeContent)) + calculateBase64Length(stream.GetSize()) + int64(len(afterContent))
656+
length := int64(len(beforeContent)) + calculateBase64Length(s.GetSize()) + int64(len(afterContent))
656657
beforeContentReader := strings.NewReader(beforeContent)
657658
contentReader, contentWriter := io.Pipe()
658659
go func() {
659660
encoder := base64.NewEncoder(base64.StdEncoding, contentWriter)
660-
if _, err := utils.CopyWithBuffer(encoder, stream); err != nil {
661+
if _, err := utils.CopyWithBuffer(encoder, s); err != nil {
661662
_ = contentWriter.CloseWithError(err)
662663
return
663664
}
@@ -667,10 +668,12 @@ func (d *Github) putBlob(ctx context.Context, stream model.FileStreamer, up driv
667668
afterContentReader := strings.NewReader(afterContent)
668669
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
669670
fmt.Sprintf("https://api.github.com/repos/%s/%s/git/blobs", d.Owner, d.Repo),
670-
&ReaderWithProgress{
671-
Reader: io.MultiReader(beforeContentReader, contentReader, afterContentReader),
672-
Length: length,
673-
Progress: up,
671+
&stream.ReaderUpdatingProgress{
672+
Reader: &stream.SimpleReaderWithSize{
673+
Reader: io.MultiReader(beforeContentReader, contentReader, afterContentReader),
674+
Size: length,
675+
},
676+
UpdateProgress: up,
674677
})
675678
if err != nil {
676679
return "", err

0 commit comments

Comments
 (0)