Skip to content

Commit

Permalink
Support storage
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jun 6, 2024
1 parent a1c1ef7 commit 4ad541f
Show file tree
Hide file tree
Showing 6 changed files with 1,140 additions and 346 deletions.
25 changes: 25 additions & 0 deletions cmd/crproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ import (
"strings"
"time"

"github.com/distribution/distribution/v3/registry/storage/driver/factory"
"github.com/gorilla/handlers"
"github.com/spf13/pflag"
"github.com/wzshiming/geario"

_ "github.com/distribution/distribution/v3/registry/storage/driver/azure"
_ "github.com/distribution/distribution/v3/registry/storage/driver/gcs"
_ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws"
_ "github.com/wzshiming/crproxy/storage/driver/oss"

"github.com/wzshiming/crproxy"
)

Expand All @@ -27,6 +33,9 @@ var (
blockImageList []string
retry int
retryInterval time.Duration
storageDriver string
storageParameters map[string]string
linkExpires time.Duration
)

func init() {
Expand All @@ -38,6 +47,9 @@ func init() {
pflag.StringSliceVar(&blockImageList, "block-image-list", nil, "block image list")
pflag.IntVar(&retry, "retry", 0, "retry times")
pflag.DurationVar(&retryInterval, "retry-interval", 0, "retry interval")
pflag.StringVar(&storageDriver, "storage-driver", "", "storage driver")
pflag.StringToStringVar(&storageParameters, "storage-parameters", nil, "storage parameters")
pflag.DurationVar(&linkExpires, "link-expires", 0, "link expires")
pflag.Parse()
}

Expand Down Expand Up @@ -105,6 +117,19 @@ func main() {
crproxy.WithDisableKeepAlives(disableKeepAlives),
}

if storageDriver != "" {
parameters := map[string]interface{}{}
for k, v := range storageParameters {
parameters[k] = v
}
sd, err := factory.Create(storageDriver, parameters)
if err != nil {
logger.Println("create storage driver failed:", err)
os.Exit(1)
}
opts = append(opts, crproxy.WithStorageDriver(sd))
}

if len(blockImageList) != 0 {
opts = append(opts, crproxy.WithBlockFunc(func(info *crproxy.PathInfo) bool {
image := info.Host + "/" + info.Image
Expand Down
181 changes: 175 additions & 6 deletions crproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ import (
"io"
"net/http"
"net/textproto"
"path"
"strings"
"sync"
"time"
"crypto/sha256"
"encoding/hex"

"github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/distribution/distribution/v3/registry/client/auth"
"github.com/distribution/distribution/v3/registry/client/auth/challenge"
"github.com/distribution/distribution/v3/registry/client/transport"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/wzshiming/geario"
"github.com/wzshiming/httpseek"
"github.com/wzshiming/lru"
Expand All @@ -41,18 +45,33 @@ type CRProxy struct {
domainAlias map[string]string
userAndPass map[string]Userpass
basicCredentials *basicCredentials
mut sync.Mutex
mutClientset sync.Mutex
bytesPool sync.Pool
logger Logger
totalBlobsSpeedLimit *geario.Gear
blobsSpeedLimit *geario.B
blockFunc func(*PathInfo) bool
retry int
retryInterval time.Duration
storageDriver storagedriver.StorageDriver
linkExpires time.Duration
mutCache sync.Map
}

type Option func(c *CRProxy)

func WithLinkExpires(d time.Duration) Option {
return func(c *CRProxy) {
c.linkExpires = d
}
}

func WithStorageDriver(storageDriver storagedriver.StorageDriver) Option {
return func(c *CRProxy) {
c.storageDriver = storageDriver
}
}

func WithBlobsSpeedLimit(limit geario.B) Option {
return func(c *CRProxy) {
c.blobsSpeedLimit = &limit
Expand Down Expand Up @@ -163,8 +182,8 @@ func (c *CRProxy) getScheme(host string) string {
}

func (c *CRProxy) getClientset(host string, image string) *http.Client {
c.mut.Lock()
defer c.mut.Unlock()
c.mutClientset.Lock()
defer c.mutClientset.Unlock()
if c.clientset[host] != nil {
client, ok := c.clientset[host].Get(image)
if ok {
Expand Down Expand Up @@ -249,8 +268,8 @@ func (c *CRProxy) disableKeepAlives(rt http.RoundTripper) http.RoundTripper {
}

func (c *CRProxy) ping(host string) error {
c.mut.Lock()
defer c.mut.Unlock()
c.mutClientset.Lock()
defer c.mutClientset.Unlock()

if c.logger != nil {
c.logger.Println("ping", host)
Expand Down Expand Up @@ -336,7 +355,7 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}
if !strings.HasPrefix(oriPath, prefix) {
http.NotFound(rw, r)
c.notFoundResponse(rw, r)
return
}
if oriPath == catalog {
Expand Down Expand Up @@ -374,6 +393,14 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
r.URL.Scheme = c.getScheme(info.Host)
r.URL.Path = path

if c.storageDriver != nil && info.Blobs != "" {
c.cacheBlobResponse(rw, r, info)
return
}
c.directResponse(rw, r, info)
}

func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) {
cli := c.getClientset(info.Host, info.Image)
resp, err := c.doWithAuth(cli, r, info.Host)
if err != nil {
Expand Down Expand Up @@ -418,6 +445,148 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
}

func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) {
ctx := r.Context()

blob := strings.TrimPrefix(info.Blobs, "sha256:")
blobPath := path.Join("/docker/registry/v2/blobs/sha256", blob[:2], blob, "data")

closeValue, loaded := c.mutCache.LoadOrStore(blobPath, make(chan struct{}))
closeCh := closeValue.(chan struct{})
for loaded {
select {
case <-ctx.Done():
err := ctx.Err().Error()
if c.logger != nil {
c.logger.Println(err)
}
http.Error(rw, err, http.StatusInternalServerError)
return
case <-closeCh:
}
closeValue, loaded = c.mutCache.LoadOrStore(blobPath, make(chan struct{}))
closeCh = closeValue.(chan struct{})
}

doneCache := func() {
c.mutCache.Delete(blobPath)
close(closeCh)
}

_, err := c.storageDriver.Stat(ctx, blobPath)
if err == nil {
err = c.redirect(rw, r, blobPath)
if err == nil {
doneCache()
return
}
c.errorResponse(rw, r, ctx.Err())
return
} else {
if c.logger != nil {
c.logger.Println("Cache miss", blobPath)
}
}

errCh := make(chan error, 1)

go func() {
defer doneCache()
err = c.cacheBlobContent(r, blobPath, info)
errCh <- err
}()

select {
case <-ctx.Done():
c.errorResponse(rw, r, ctx.Err())
return
case err := <-errCh:
if err != nil {
c.errorResponse(rw, r, err)
return
}
err = c.redirect(rw, r, blobPath)
if err != nil {
if c.logger != nil {
c.logger.Println("failed to redirect", blobPath, err)
}
}
return
}
}

func (c *CRProxy) cacheBlobContent(r *http.Request, blobPath string, info *PathInfo) error {
cli := c.getClientset(info.Host, info.Image)
resp, err := c.doWithAuth(cli, r, info.Host)
if err != nil {
return err
}
defer func() {
resp.Body.Close()
}()

buf := c.bytesPool.Get().([]byte)
defer c.bytesPool.Put(buf)

fw, err := c.storageDriver.Writer(r.Context(), blobPath, false)
if err != nil {
return err
}

h := sha256.New()
n, err := io.CopyBuffer(fw, io.TeeReader(resp.Body, h), buf)
if err != nil {
fw.Cancel()
return err
}

if n != resp.ContentLength {
fw.Cancel()
return fmt.Errorf("expected %d bytes, got %d", resp.ContentLength, n)
}

hash := hex.EncodeToString(h.Sum(nil)[:])
if info.Blobs[7:] != hash {
fw.Cancel()
return fmt.Errorf("expected %s hash, got %s", info.Blobs[7:], hash)
}

return fw.Commit()
}

func (c *CRProxy) errorResponse(rw http.ResponseWriter, r *http.Request, err error) {
if err != nil {
e := err.Error()
if c.logger != nil {
c.logger.Println(e)
}
}
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
}

func (c *CRProxy) notFoundResponse(rw http.ResponseWriter, r *http.Request) {
http.NotFound(rw, r)
}

func (c *CRProxy) redirect(rw http.ResponseWriter, r *http.Request, blobPath string) error {
options := map[string]interface{}{
"method": http.MethodGet,
}
linkExpires := c.linkExpires
if linkExpires > 0 {
options["expiry"] = time.Now().Add(linkExpires)
}
u, err := c.storageDriver.URLFor(r.Context(), blobPath, options)
if err != nil {
return err
}
if c.logger != nil {
c.logger.Println("Cache hit", blobPath, u)
}
http.Redirect(rw, r, u, http.StatusTemporaryRedirect)
return nil
}

func (c *CRProxy) getDomainAlias(host string) string {
if c.domainAlias == nil {
return host
Expand Down
57 changes: 43 additions & 14 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/wzshiming/crproxy
go 1.21

require (
github.com/denverdino/aliyungo v0.0.0-20230411124812-ab98a9173ace
github.com/distribution/distribution/v3 v3.0.0-20220907155224-78b9c98c5c31
github.com/gorilla/handlers v1.5.1
github.com/spf13/pflag v1.0.3
Expand All @@ -12,23 +13,51 @@ require (
)

require (
cloud.google.com/go/compute v1.23.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/Azure/azure-sdk-for-go v56.3.0+incompatible // indirect
github.com/Azure/go-autorest v14.2.0+incompatible // indirect
github.com/Azure/go-autorest/autorest v0.11.24 // indirect
github.com/Azure/go-autorest/autorest/adal v0.9.18 // indirect
github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect
github.com/Azure/go-autorest/logger v0.2.1 // indirect
github.com/Azure/go-autorest/tracing v0.6.0 // indirect
github.com/aws/aws-sdk-go v1.48.10 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/docker/go-metrics v0.0.1 // indirect
github.com/felixge/httpsnoop v1.0.1 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.6 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gofrs/uuid v4.0.0+incompatible // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/s2a-go v0.1.4 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.11.0 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect
github.com/prometheus/client_golang v1.12.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/stretchr/testify v1.8.0 // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
google.golang.org/protobuf v1.27.1 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/prometheus/client_golang v1.17.0 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/stretchr/testify v1.8.4 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/net v0.18.0 // indirect
golang.org/x/oauth2 v0.11.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/api v0.126.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/cloud v0.0.0-20151119220103-975617b05ea8 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect
google.golang.org/grpc v1.59.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
)
Loading

0 comments on commit 4ad541f

Please sign in to comment.