diff --git a/go.mod b/go.mod index d84ed1b..18604df 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/jonboulle/clockwork v0.4.0 github.com/letsencrypt/pebble/v2 v2.4.0 github.com/rekby/fastuuid v0.9.0 + github.com/rekby/safemutex v0.2.0 golang.org/x/time v0.3.0 ) diff --git a/go.sum b/go.sum index 61f4d7b..1451d75 100644 --- a/go.sum +++ b/go.sum @@ -218,6 +218,8 @@ github.com/rekby/fastuuid v0.9.0 h1:iQk8V/AyqSrgQAtKRdqx/CVep+CaKwaSWeerw1yEP3Q= github.com/rekby/fastuuid v0.9.0/go.mod h1:qP8Lh0BH2+4rNGVRDHmDpkvE/ZuLUhjmKpRWjx+WesY= github.com/rekby/fixenv v0.3.1 h1:zOPocbQmcsxSIjiVu5U+9JAfeu6WeLN7a9ryZkGTGJY= github.com/rekby/fixenv v0.3.1/go.mod h1:/b5LRc06BYJtslRtHKxsPWFT/ySpHV+rWvzTg+XWk4c= +github.com/rekby/safemutex v0.2.0 h1:iEfcPqsR3EApwWHwdHvp+srN9Wfna+IG8bSpN467Jmk= +github.com/rekby/safemutex v0.2.0/go.mod h1:6I/yJdmctX0RmxEp00RzYBJJXl3ona8PsBiIDqg0v+U= github.com/rekby/zapcontext v0.0.4 h1:85600nHTteGCLcuOhGp/SzXHymm9QcCA5sn+MPKCodY= github.com/rekby/zapcontext v0.0.4/go.mod h1:lTIxvHAwWXBZBPPfEvmAEXPbVEcTwd52VaASZWZWcxI= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= diff --git a/internal/acme_client_manager/client_manager.go b/internal/acme_client_manager/client_manager.go index b655ce6..b8e1f54 100644 --- a/internal/acme_client_manager/client_manager.go +++ b/internal/acme_client_manager/client_manager.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/rekby/safemutex" "net/http" "sync" "time" @@ -38,18 +39,21 @@ type AcmeManager struct { AgreeFunction func(tosurl string) bool RenewAccountInterval time.Duration - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + cache cache.Bytes + httpClient *http.Client + + background sync.WaitGroup + mu safemutex.MutexWithPointers[acmeManagerSynced] +} + +type acmeManagerSynced struct { + lastAccountIndex int + accounts []clientAccount + stateLoaded bool + closed bool ctxAutorenewCompleted context.Context - cache cache.Bytes - httpClient *http.Client - - background sync.WaitGroup - mu sync.Mutex - lastAccountIndex int - accounts []clientAccount - stateLoaded bool - closed bool } type clientAccount struct { @@ -67,19 +71,22 @@ func New(ctx context.Context, cache cache.Bytes) *AcmeManager { AgreeFunction: acme.AcceptTOS, RenewAccountInterval: renewAccountInterval, httpClient: http.DefaultClient, - lastAccountIndex: -1, + mu: safemutex.NewWithPointers(acmeManagerSynced{lastAccountIndex: -1}), } } func (m *AcmeManager) Close() error { logger := zc.L(m.ctx) logger.Debug("Start close") - m.mu.Lock() - alreadyClosed := m.closed - ctxAutorenewCompleted := m.ctxAutorenewCompleted - m.closed = true - m.ctxCancel() - m.mu.Unlock() + var alreadyClosed bool + var ctxAutorenewCompleted context.Context + m.mu.Lock(func(value acmeManagerSynced) (newValue acmeManagerSynced) { + alreadyClosed = value.closed + ctxAutorenewCompleted = value.ctxAutorenewCompleted + value.closed = true + m.ctxCancel() + return value + }) logger.Debug("Set closed flag", zap.Any("autorenew_context", ctxAutorenewCompleted)) if alreadyClosed { @@ -95,60 +102,77 @@ func (m *AcmeManager) Close() error { return nil } -func (m *AcmeManager) GetClient(ctx context.Context) (_ *acme.Client, disableFunc func(), err error) { +func (m *AcmeManager) GetClient(ctx context.Context) (resClient *acme.Client, disableFunc func(), err error) { if ctx.Err() != nil { return nil, nil, errors.New("acme manager context closed") } - m.mu.Lock() - defer m.mu.Unlock() - - if m.closed { - return nil, nil, xerrors.Errorf("GetClient: %w", errClosed) + fail := func(resErr error) { + resClient = nil + disableFunc = nil + err = resErr } + good := func(c *acme.Client, f func()) { + resClient = c + disableFunc = f + err = nil + } + + m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced { + if synced.closed { + fail(xerrors.Errorf("GetClient: %w", errClosed)) + return synced + } - createDisableFunc := func(index int) func() { - return func() { - wasEnabled := m.disableAccountSelfSync(index) - if wasEnabled { - time.AfterFunc(disableDuration, func() { - m.accountEnableSelfSync(index) - }) + createDisableFunc := func(index int) func() { + return func() { + wasEnabled := m.disableAccountSelfSync(index) + if wasEnabled { + time.AfterFunc(disableDuration, func() { + m.accountEnableSelfSync(index) + }) + } } } - } - if !m.stateLoaded && m.cache != nil && !m.IgnoreCacheLoad { - err := m.loadFromCache(ctx) - if err != nil && err != cache.ErrCacheMiss { - return nil, nil, err + if !synced.stateLoaded && m.cache != nil && !m.IgnoreCacheLoad { + err := m.loadFromCacheLocked(ctx, &synced) + if err != nil && !errors.Is(err, cache.ErrCacheMiss) { + fail(err) + return synced + } + synced.stateLoaded = true } - m.stateLoaded = true - } - if index, ok := m.nextEnabledClientIndex(); ok { - return m.accounts[index].client, createDisableFunc(index), nil - } + if index, ok := m.nextEnabledClientIndexLocked(&synced); ok { + good(synced.accounts[index].client, createDisableFunc(index)) + return synced + } - acc, err := m.registerAccount(ctx) - m.accounts = append(m.accounts, acc) + acc, err := m.registerAccount(ctx) + synced.accounts = append(synced.accounts, acc) - m.background.Add(1) - // handlepanic: in accountRenewSelfSync - go func(index int) { - defer m.background.Done() - m.accountRenewSelfSync(index) - }(len(m.accounts) - 1) + m.background.Add(1) + // handlepanic: in accountRenewSelfSync + go func(index int) { + defer m.background.Done() + m.accountRenewSelfSync(index) + }(len(synced.accounts) - 1) - if err != nil { - return nil, nil, err - } + if err != nil { + fail(err) + return synced + } - if err = m.saveState(ctx); err != nil { - return nil, nil, err - } + if err = m.saveStateLocked(ctx, &synced); err != nil { + fail(err) + return synced + } - return acc.client, createDisableFunc(len(m.accounts) - 1), nil + good(acc.client, createDisableFunc(len(synced.accounts)-1)) + return synced + }) + return resClient, disableFunc, err } func (m *AcmeManager) accountRenewSelfSync(index int) { @@ -156,10 +180,12 @@ func (m *AcmeManager) accountRenewSelfSync(index int) { ctx, ctxCancel := context.WithCancel(m.ctx) defer ctxCancel() - m.mu.Lock() - m.ctxAutorenewCompleted = ctx - acc := m.accounts[index] - m.mu.Unlock() + var acc clientAccount + m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced { + synced.ctxAutorenewCompleted = ctx + acc = synced.accounts[index] + return synced + }) if m.ctx.Err() != nil { return @@ -183,37 +209,39 @@ func (m *AcmeManager) accountRenewSelfSync(index int) { newAccount = renewTos(m.ctx, acc.client, acc.account) }() acc.account = newAccount - m.mu.Lock() - m.accounts[index] = acc - m.mu.Unlock() + m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced { + synced.accounts[index] = acc + return synced + }) } } } func (m *AcmeManager) disableAccountSelfSync(index int) (wasEnabled bool) { - m.mu.Lock() - defer m.mu.Unlock() - - if m.accounts[index].enabled { - m.accounts[index].enabled = false - return true - } - - return false + m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced { + if synced.accounts[index].enabled { + synced.accounts[index].enabled = false + wasEnabled = true + return synced + } + wasEnabled = false + return synced + }) + return wasEnabled } func (m *AcmeManager) accountEnableSelfSync(index int) { - m.mu.Lock() - defer m.mu.Unlock() - - m.accounts[index].enabled = true + m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced { + synced.accounts[index].enabled = true + return synced + }) } func (m *AcmeManager) initClient() *acme.Client { return &acme.Client{DirectoryURL: m.DirectoryURL, HTTPClient: m.httpClient} } -func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) { +func (m *AcmeManager) loadFromCacheLocked(ctx context.Context, synced *acmeManagerSynced) (err error) { defer func() { var effectiveError error if err == cache.ErrCacheMiss { @@ -239,7 +267,7 @@ func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) { return xerrors.Errorf("no accounts in state") } - m.accounts = make([]clientAccount, 0, len(state.Accounts)) + synced.accounts = make([]clientAccount, 0, len(state.Accounts)) for index, stateAccount := range state.Accounts { client := m.initClient() client.Key = stateAccount.PrivateKey @@ -255,34 +283,34 @@ func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) { defer m.background.Done() m.accountRenewSelfSync(index) }(index) - m.accounts = append(m.accounts, acc) + synced.accounts = append(synced.accounts, acc) } return nil } -func (m *AcmeManager) nextEnabledClientIndex() (int, bool) { +func (m *AcmeManager) nextEnabledClientIndexLocked(synced *acmeManagerSynced) (int, bool) { switch { - case len(m.accounts) == 0: + case len(synced.accounts) == 0: return 0, false - case len(m.accounts) == 1 && m.accounts[0].enabled: + case len(synced.accounts) == 1 && synced.accounts[0].enabled: return 0, true default: // pass } - startIndex := m.lastAccountIndex + startIndex := synced.lastAccountIndex if startIndex < 0 { - startIndex = len(m.accounts) - 1 + startIndex = len(synced.accounts) - 1 } index := startIndex for { index++ - if index >= len(m.accounts) { + if index >= len(synced.accounts) { index = 0 } - if m.accounts[index].enabled { - m.lastAccountIndex = index + if synced.accounts[index].enabled { + synced.lastAccountIndex = index return index, true } if index == startIndex { @@ -310,11 +338,11 @@ func (m *AcmeManager) registerAccount(ctx context.Context) (clientAccount, error return acc, nil } -func (m *AcmeManager) saveState(ctx context.Context) error { +func (m *AcmeManager) saveStateLocked(ctx context.Context, synced *acmeManagerSynced) error { var state acmeManagerState - state.Accounts = make([]acmeAccountState, 0, len(m.accounts)) + state.Accounts = make([]acmeAccountState, 0, len(synced.accounts)) - for _, acc := range m.accounts { + for _, acc := range synced.accounts { state.Accounts = append(state.Accounts, acmeAccountState{PrivateKey: acc.client.Key.(*rsa.PrivateKey), AcmeAccount: acc.account}) } diff --git a/internal/acme_client_manager/client_manager_test.go b/internal/acme_client_manager/client_manager_test.go index 86e9796..170bae3 100644 --- a/internal/acme_client_manager/client_manager_test.go +++ b/internal/acme_client_manager/client_manager_test.go @@ -198,17 +198,19 @@ func TestClientManager_nextEnabledClientIndex(t *testing.T) { e, _, flush := th.NewEnv(t) defer flush() - m := AcmeManager{ - lastAccountIndex: test.lastAccountIndex, - } - - for _, enabled := range test.accountsEnabled { - m.accounts = append(m.accounts, clientAccount{enabled: enabled}) - } - - resIndex, resOk := m.nextEnabledClientIndex() - e.Cmp(resIndex, test.resIndex) - e.Cmp(resOk, test.resOk) + m := AcmeManager{} + + m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced { + synced.lastAccountIndex = test.lastAccountIndex + for _, enabled := range test.accountsEnabled { + synced.accounts = append(synced.accounts, clientAccount{enabled: enabled}) + } + + resIndex, resOk := m.nextEnabledClientIndexLocked(&synced) + e.Cmp(resIndex, test.resIndex) + e.Cmp(resOk, test.resOk) + return synced + }) }) } } diff --git a/internal/cache/value_lru.go b/internal/cache/value_lru.go index a01cbb1..f38c777 100644 --- a/internal/cache/value_lru.go +++ b/internal/cache/value_lru.go @@ -18,8 +18,7 @@ type memoryValueLRUItem struct { key string value interface{} - m sync.Mutex // sync update lastUsedTime in Get method - lastUsedTime uint64 + lastUsedTime atomic.Uint64 } type MemoryValueLRU struct { @@ -52,9 +51,7 @@ func (c *MemoryValueLRU) Get(ctx context.Context, key string) (value interface{} defer c.mu.RUnlock() if resp, exist := c.m[key]; exist { - resp.m.Lock() - resp.lastUsedTime = c.time() - resp.m.Unlock() + resp.lastUsedTime.Store(c.time()) return resp.value, nil } return nil, ErrCacheMiss @@ -67,7 +64,7 @@ func (c *MemoryValueLRU) Put(ctx context.Context, key string, value interface{}) }() c.mu.Lock() - c.m[key] = &memoryValueLRUItem{key: key, value: value, lastUsedTime: c.time()} + c.m[key] = &memoryValueLRUItem{key: key, value: value, lastUsedTime: newUint64Atomic(c.time())} if len(c.m) > c.MaxSize { // handlepanic: no external call go c.clean() @@ -104,7 +101,7 @@ func (c *MemoryValueLRU) renumberTime() { items := c.getSortedItems() for i, item := range items { - item.lastUsedTime = uint64(i) + item.lastUsedTime.Store(uint64(i)) } c.mu.Unlock() @@ -118,7 +115,7 @@ func (c *MemoryValueLRU) getSortedItems() []*memoryValueLRUItem { } sort.Slice(items, func(i, j int) bool { - return items[i].lastUsedTime < items[j].lastUsedTime + return items[i].lastUsedTime.Load() < items[j].lastUsedTime.Load() }) return items } @@ -146,3 +143,8 @@ func (c *MemoryValueLRU) clean() { delete(c.m, items[i].key) } } +func newUint64Atomic(val uint64) atomic.Uint64 { + var v atomic.Uint64 + v.Store(val) + return v +} diff --git a/internal/cache/value_lru_test.go b/internal/cache/value_lru_test.go index f4aadae..b47a6bb 100644 --- a/internal/cache/value_lru_test.go +++ b/internal/cache/value_lru_test.go @@ -116,12 +116,12 @@ func TestValueLRULimitClean(t *testing.T) { c.MaxSize = 5 c.CleanCount = 0 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 1} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 2} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 3} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 4} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 5} - c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: 6} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(1)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(2)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(3)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(4)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(5)} + c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: newUint64Atomic(6)} c.clean() td.CmpDeeply(len(c.m), 6) td.CmpDeeply(c.m["1"].value, 1) @@ -134,11 +134,11 @@ func TestValueLRULimitClean(t *testing.T) { c.MaxSize = 5 c.CleanCount = 3 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 1} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 2} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 3} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 4} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 5} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(1)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(2)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(3)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(4)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(5)} c.clean() td.CmpDeeply(len(c.m), 5) td.CmpDeeply(c.m["1"].value, 1) @@ -150,12 +150,12 @@ func TestValueLRULimitClean(t *testing.T) { c.MaxSize = 5 c.CleanCount = 2 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 1} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 2} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 3} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 4} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 5} - c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: 6} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(1)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(2)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(3)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(4)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(5)} + c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: newUint64Atomic(6)} c.clean() td.CmpDeeply(len(c.m), 4) td.Nil(c.m["1"]) @@ -169,12 +169,12 @@ func TestValueLRULimitClean(t *testing.T) { c.MaxSize = 5 c.CleanCount = 2 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 6} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 5} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 4} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 3} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 2} - c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: 1} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(6)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(5)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(4)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(3)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(2)} + c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: newUint64Atomic(1)} c.clean() td.CmpDeeply(len(c.m), 4) td.CmpDeeply(c.m["1"].value, 1) @@ -187,24 +187,24 @@ func TestValueLRULimitClean(t *testing.T) { c.MaxSize = 5 c.CleanCount = 5 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 1} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 2} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 3} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 4} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 5} - c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: 6} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(1)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(2)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(3)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(4)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(5)} + c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: newUint64Atomic(6)} c.clean() td.CmpDeeply(len(c.m), 0) c.MaxSize = 5 c.CleanCount = 6 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 1} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 2} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 3} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 4} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 5} - c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: 6} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(1)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(2)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(3)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(4)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(5)} + c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: newUint64Atomic(6)} c.clean() td.CmpDeeply(len(c.m), 0) @@ -212,12 +212,12 @@ func TestValueLRULimitClean(t *testing.T) { c.MaxSize = 5 c.CleanCount = 3 c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 1} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 2} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 3} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 4} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 5} - c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: 6} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(1)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(2)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(3)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(4)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(5)} + c.m["6"] = &memoryValueLRUItem{key: "6", value: 6, lastUsedTime: newUint64Atomic(6)} _, _ = c.Get(ctx, "6") _, _ = c.Get(ctx, "2") _, _ = c.Get(ctx, "3") @@ -242,11 +242,11 @@ func TestLimitValueRenumberItems(t *testing.T) { var c = NewMemoryValueLRU("test") c.m = make(map[string]*memoryValueLRUItem) - c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: 100} - c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: 200} - c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: 300} - c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: 400} - c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: 500} + c.m["1"] = &memoryValueLRUItem{key: "1", value: 1, lastUsedTime: newUint64Atomic(100)} + c.m["2"] = &memoryValueLRUItem{key: "2", value: 2, lastUsedTime: newUint64Atomic(200)} + c.m["3"] = &memoryValueLRUItem{key: "3", value: 3, lastUsedTime: newUint64Atomic(300)} + c.m["4"] = &memoryValueLRUItem{key: "4", value: 4, lastUsedTime: newUint64Atomic(400)} + c.m["5"] = &memoryValueLRUItem{key: "5", value: 5, lastUsedTime: newUint64Atomic(500)} c.lastTime = math.MaxUint64/2 - 1 _ = c.Put(ctx, "6", 6) @@ -255,10 +255,10 @@ func TestLimitValueRenumberItems(t *testing.T) { c.mu.RLock() defer c.mu.RLock() - td.CmpDeeply(c.m["1"].lastUsedTime, uint64(0)) - td.CmpDeeply(c.m["2"].lastUsedTime, uint64(1)) - td.CmpDeeply(c.m["3"].lastUsedTime, uint64(2)) - td.CmpDeeply(c.m["4"].lastUsedTime, uint64(3)) - td.CmpDeeply(c.m["5"].lastUsedTime, uint64(4)) - td.CmpDeeply(c.m["6"].lastUsedTime, uint64(5)) + td.CmpDeeply(c.m["1"].lastUsedTime.Load(), uint64(0)) + td.CmpDeeply(c.m["2"].lastUsedTime.Load(), uint64(1)) + td.CmpDeeply(c.m["3"].lastUsedTime.Load(), uint64(2)) + td.CmpDeeply(c.m["4"].lastUsedTime.Load(), uint64(3)) + td.CmpDeeply(c.m["5"].lastUsedTime.Load(), uint64(4)) + td.CmpDeeply(c.m["6"].lastUsedTime.Load(), uint64(5)) } diff --git a/internal/cert_manager/manager.go b/internal/cert_manager/manager.go index be9f5ff..1fa49e1 100644 --- a/internal/cert_manager/manager.go +++ b/internal/cert_manager/manager.go @@ -15,6 +15,7 @@ import ( "encoding/pem" "errors" "fmt" + "github.com/rekby/safemutex" "net/http" "reflect" "strings" @@ -113,8 +114,7 @@ type Manager struct { certForDomainAuthorize cache.Value - certStateMu sync.Mutex - certState cache.Value + certStateMu safemutex.MutexWithPointers[cache.Value] httpTokens cache.Bytes @@ -127,7 +127,7 @@ func New(acmeClientManager AcmeClientManager, c cache.Bytes, r prometheus.Regist res := Manager{} res.acmeClientManager = acmeClientManager res.certForDomainAuthorize = cache.NewMemoryValueLRU("authcert") - res.certState = cache.NewMemoryValueLRU("certstate") + res.certStateMu = safemutex.NewWithPointers[cache.Value](cache.NewMemoryValueLRU("certstate")) res.CertificateIssueTimeout = time.Minute res.httpTokens = cache.NewMemoryCache("Http validation tokens") res.Cache = c @@ -369,19 +369,23 @@ func filterDomains(ctx context.Context, checker DomainChecker, originalDomains [ } func (m *Manager) certStateGet(ctx context.Context, cd CertDescription) *certState { - m.certStateMu.Lock() - defer m.certStateMu.Unlock() + var resInterface any + m.certStateMu.Lock(func(certCache cache.Value) (newValue cache.Value) { + var err error + resInterface, err = certCache.Get(ctx, cd.String()) + if err == cache.ErrCacheMiss { + err = nil + } - resInterface, err := m.certState.Get(ctx, cd.String()) - if err == cache.ErrCacheMiss { - err = nil - } - log.DebugFatalCtx(ctx, err, "Got cert state from cache", zap.Bool("is_empty", resInterface == nil)) - if resInterface == nil { - resInterface = &certState{} - err = m.certState.Put(ctx, cd.String(), resInterface) - log.DebugFatalCtx(ctx, err, "Put empty cert state to cache") - } + log.DebugFatalCtx(ctx, err, "Got cert state from cache", zap.Bool("is_empty", resInterface == nil)) + if resInterface == nil { + resInterface = &certState{} + err = certCache.Put(ctx, cd.String(), resInterface) + log.DebugFatalCtx(ctx, err, "Put empty cert state to cache") + } + + return certCache + }) return resInterface.(*certState) } diff --git a/internal/cert_manager/manager_test.go b/internal/cert_manager/manager_test.go index 6837ae4..a73ad2f 100644 --- a/internal/cert_manager/manager_test.go +++ b/internal/cert_manager/manager_test.go @@ -10,6 +10,7 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" + "github.com/rekby/safemutex" "net" "net/http" "testing" @@ -277,7 +278,7 @@ func createManager(t *testing.T) (res testManagerContext, cancel func()) { AllowRSACert: true, AllowECDSACert: true, certForDomainAuthorize: res.certForDomainAuthorize, - certState: res.certState, + certStateMu: safemutex.NewWithPointers[cache.Value](res.certState), httpTokens: res.httpTokens, } res.manager.initMetrics(nil) diff --git a/internal/th/fixenv.go b/internal/th/fixenv.go index 3d78798..74b006b 100644 --- a/internal/th/fixenv.go +++ b/internal/th/fixenv.go @@ -2,7 +2,7 @@ package th import ( "context" - "sync" + "github.com/rekby/safemutex" "testing" "github.com/maxatome/go-testdeep" @@ -19,7 +19,9 @@ type Env struct { func NewEnv(t *testing.T) (env *Env, ctx context.Context, cancel func()) { td := testdeep.NewT(t) ctx, ctxCancel := TestContext(td) - tWrap := &testWrapper{T: td} + tWrap := &testWrapper{ + T: td, + } env = &Env{ EnvT: fixenv.NewEnv(tWrap), Ctx: ctx, @@ -41,32 +43,40 @@ type TD struct { type testWrapper struct { *testdeep.T - m sync.Mutex + m safemutex.MutexWithPointers[testWrapperSynced] +} + +type testWrapperSynced struct { cleanups []func() cleanupsStarted bool } func (t *testWrapper) Cleanup(f func()) { - t.m.Lock() - defer t.m.Unlock() - - t.cleanups = append(t.cleanups, f) + t.m.Lock(func(synced testWrapperSynced) testWrapperSynced { + synced.cleanups = append(synced.cleanups, f) + return synced + }) } func (t *testWrapper) startCleanups() { - t.m.Lock() - started := t.cleanupsStarted - if !started { - t.cleanupsStarted = true - } - t.m.Unlock() + var started bool + var cleanups []func() + + t.m.Lock(func(synced testWrapperSynced) testWrapperSynced { + started := synced.cleanupsStarted + if !started { + synced.cleanupsStarted = true + } + cleanups = synced.cleanups + return synced + }) if started { return } - for i := len(t.cleanups) - 1; i >= 0; i-- { - f := t.cleanups[i] + for i := len(cleanups) - 1; i >= 0; i-- { + f := cleanups[i] f() } } diff --git a/internal/th/fixtures.go b/internal/th/fixtures.go index eb9f8e9..777a4a5 100644 --- a/internal/th/fixtures.go +++ b/internal/th/fixtures.go @@ -1,23 +1,21 @@ package th import ( + "github.com/gojuno/minimock/v3" + "github.com/rekby/fixenv" + "github.com/rekby/safemutex" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" "io/ioutil" "log" "net" "net/http" "net/http/httptest" "os" - "sync" - - "github.com/gojuno/minimock/v3" - "github.com/rekby/fixenv" - "go.uber.org/zap" - "go.uber.org/zap/zaptest" ) var ( - freeListenAddressMutex sync.Mutex - freeListenAddressUsed = map[string]bool{} + freeListenAddressMutex = safemutex.NewWithPointers(map[string]bool{}) ) func MockController(e fixenv.Env) minimock.MockController { @@ -53,11 +51,12 @@ func NewFreeLocalTcpAddress(e fixenv.Env) *net.TCPAddr { addr := listener.Addr() addrS := addr.String() - freeListenAddressMutex.Lock() - used := freeListenAddressUsed[addrS] - freeListenAddressUsed[addrS] = true - freeListenAddressMutex.Unlock() - + var used bool + freeListenAddressMutex.Lock(func(freeListenAddressUsed map[string]bool) (newValue map[string]bool) { + used = freeListenAddressUsed[addrS] + freeListenAddressUsed[addrS] = true + return freeListenAddressUsed + }) if !used { return addr.(*net.TCPAddr) } diff --git a/vendor/github.com/egorgasay/cidranger/.gitignore b/vendor/github.com/egorgasay/cidranger/.gitignore new file mode 100644 index 0000000..cb756e9 --- /dev/null +++ b/vendor/github.com/egorgasay/cidranger/.gitignore @@ -0,0 +1,2 @@ +vendor +.idea \ No newline at end of file diff --git a/vendor/github.com/egorgasay/cidranger/.travis.yml b/vendor/github.com/egorgasay/cidranger/.travis.yml new file mode 100644 index 0000000..ecc3fb1 --- /dev/null +++ b/vendor/github.com/egorgasay/cidranger/.travis.yml @@ -0,0 +1,10 @@ +language: go +go: + - 1.13.x + - 1.14.x + - tip +before_install: + - travis_retry go get github.com/mattn/goveralls +script: + - go test -v -covermode=count -coverprofile=coverage.out ./... + - travis_retry $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/vendor/github.com/rekby/safemutex/LICENSE.txt b/vendor/github.com/rekby/safemutex/LICENSE.txt new file mode 100644 index 0000000..42ed8ca --- /dev/null +++ b/vendor/github.com/rekby/safemutex/LICENSE.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2023 Timofey Koolin, Rekby. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/rekby/safemutex/README.md b/vendor/github.com/rekby/safemutex/README.md new file mode 100644 index 0000000..edf6ac0 --- /dev/null +++ b/vendor/github.com/rekby/safemutex/README.md @@ -0,0 +1,43 @@ +[![Go Reference](https://pkg.go.dev/badge/github.com/rekby/safemutex.svg)](https://pkg.go.dev/github.com/rekby/safemutex) +[![Coverage Status](https://coveralls.io/repos/github/rekby/safe-mutex/badge.svg?branch=master)](https://coveralls.io/github/rekby/safe-mutex?branch=master) +[![GoReportCard](https://goreportcard.com/badge/github.com/rekby/safemutex)](https://goreportcard.com/report/github.com/rekby/safemutex) + +# Safe mutex + +The package inspired by [Rust mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html). + +Main idea: mutex contains guarded data and no way to use the data with unlocked mutex. + +get command: +```bash +go get github.com/rekby/safemutex +``` + +Example: +```go +package main + +import ( + "fmt" + "github.com/rekby/safemutex" +) + +type GuardedStruct struct { + Name string + Val int +} + +func main() { + simleIntMutex := safemutex.New(1) + simleIntMutex.Lock(func(synced int) int { + fmt.Println(synced) + return synced + }) + + mutexWithStruct := safemutex.New(GuardedStruct{Name: "test", Val: 1}) + mutexWithStruct.Lock(func(synced GuardedStruct) GuardedStruct { + fmt.Println(synced) + return synced + }) +} +``` diff --git a/vendor/github.com/rekby/safemutex/callback.go b/vendor/github.com/rekby/safemutex/callback.go new file mode 100644 index 0000000..1490488 --- /dev/null +++ b/vendor/github.com/rekby/safemutex/callback.go @@ -0,0 +1,4 @@ +package safemutex + +// ReadWriteCallback receive current value, saved in mutex and return new value +type ReadWriteCallback[T any] func(synced T) T diff --git a/vendor/github.com/rekby/safemutex/errors.go b/vendor/github.com/rekby/safemutex/errors.go new file mode 100644 index 0000000..e70f058 --- /dev/null +++ b/vendor/github.com/rekby/safemutex/errors.go @@ -0,0 +1,20 @@ +package safemutex + +import "errors" + +var errContainPointers = errors.New("safe mutex: value type possible to contain pointers, use MutexWithPointers for allow pointers into guarded value") + +var ErrPoisoned = errors.New("safe mutex: mutex poisoned (exit from callback with panic), use NewWithOptions for allow use poisoned value") + +// errWrap need for deny direct compare with returned errors +type errWrap struct { + err error +} + +func (e errWrap) Error() string { + return e.err.Error() +} + +func (e errWrap) Unwrap() error { + return e.err +} diff --git a/vendor/github.com/rekby/safemutex/mutex.go b/vendor/github.com/rekby/safemutex/mutex.go new file mode 100644 index 0000000..bbee804 --- /dev/null +++ b/vendor/github.com/rekby/safemutex/mutex.go @@ -0,0 +1,96 @@ +package safemutex + +import ( + "reflect" +) + +// Mutex contains guarded value inside, access to value allowed inside callbacks only +// it allow to guarantee not access to the value without lock the mutex +// zero value is usable as mutex with default options and zero value of guarded type +// Mutex deny to save value with any type of pointers, which allow accidentally change internal state. +// it will panic if T contains any pointer. +type Mutex[T any] struct { + mutexBase[T] + initialized bool +} + +// New create Mutex with initial value and default options. +// New call internal checks for T and panic if checks failed, see MutexOptions for details +func New[T any](value T) Mutex[T] { + res := Mutex[T]{ + mutexBase: mutexBase[T]{ + value: value, + }, + } + + res.validateLocked() + + //nolint:govet + //goland:noinspection GoVetCopyLock + return res +} + +// Lock - call f within locked mutex. +// it will panic if value type not pass internal checks +// it will panic with ErrPoisoned if previous call exited without return value: +// with panic or runtime.Goexit() +func (m *Mutex[T]) Lock(f ReadWriteCallback[T]) { + m.m.Lock() + defer m.m.Unlock() + + m.validateLocked() + m.callLocked(f) +} + +func (m *Mutex[T]) validateLocked() { + m.baseValidateLocked() + + if m.initialized { + return + } + + // check pointers + if checkTypeCanContainPointers(reflect.TypeOf(m.value)) { + panic(errContainPointers) + } + + m.initialized = true +} + +// checkTypeCanContainPointers check the value for potential contain pointers +// return true only of value guaranteed without any pointers and false in other cases (has pointers or unknown) +func checkTypeCanContainPointers(t reflectType) bool { + if t == nil { + return true + } + switch t.Kind() { + case + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Bool, reflect.Complex64, reflect.Complex128, reflect.Float32, reflect.Float64, + reflect.String: + return false + case reflect.Struct: + for i := 0; i < t.NumField(); i++ { + structField := t.Field(i) + if checkTypeCanContainPointers(structField.Type) { + return true + } + } + return false + case reflect.Array: + return checkTypeCanContainPointers(t.Elem()) + case reflect.Pointer, reflect.UnsafePointer, reflect.Slice, reflect.Map, reflect.Chan, reflect.Interface, + reflect.Func: + return true + default: + return true + } +} + +type reflectType interface { + Kind() reflect.Kind + NumField() int + Field(i int) reflect.StructField + Elem() reflect.Type +} diff --git a/vendor/github.com/rekby/safemutex/mutex_allow_pointers.go b/vendor/github.com/rekby/safemutex/mutex_allow_pointers.go new file mode 100644 index 0000000..e841ddb --- /dev/null +++ b/vendor/github.com/rekby/safemutex/mutex_allow_pointers.go @@ -0,0 +1,33 @@ +package safemutex + +// MutexWithPointers contains guarded value inside, access to value allowed inside callbacks only +// it allow to guarantee not access to the value without lock the mutex +// zero value is usable as mutex with default options and zero value of guarded type +type MutexWithPointers[T any] struct { + mutexBase[T] +} + +// NewWithPointers create Mutex with initial value and default options. +// NewWithPointers call internal checks for T and panic if checks failed, see MutexOptions for details +func NewWithPointers[T any](value T) MutexWithPointers[T] { + res := MutexWithPointers[T]{ + mutexBase: mutexBase[T]{ + value: value, + }, + } + + //nolint:govet + //goland:noinspection GoVetCopyLock + return res +} + +// Lock - call f within locked mutex. +// it will panic with ErrPoisoned if previous call exited without return value: +// with panic or runtime.Goexit() +func (m *MutexWithPointers[T]) Lock(f ReadWriteCallback[T]) { + m.m.Lock() + defer m.m.Unlock() + + m.baseValidateLocked() + m.callLocked(f) +} diff --git a/vendor/github.com/rekby/safemutex/mutex_base.go b/vendor/github.com/rekby/safemutex/mutex_base.go new file mode 100644 index 0000000..3be55f6 --- /dev/null +++ b/vendor/github.com/rekby/safemutex/mutex_base.go @@ -0,0 +1,21 @@ +package safemutex + +import "sync" + +type mutexBase[T any] struct { + m sync.Mutex + value T + errWrap errWrap +} + +func (m *mutexBase[T]) baseValidateLocked() { + if m.errWrap.err != nil { + panic(m.errWrap) + } +} + +func (m *mutexBase[T]) callLocked(f ReadWriteCallback[T]) { + m.errWrap.err = ErrPoisoned + m.value = f(m.value) + m.errWrap.err = nil +} diff --git a/vendor/github.com/rekby/safemutex/mutex_try_lock.go b/vendor/github.com/rekby/safemutex/mutex_try_lock.go new file mode 100644 index 0000000..be8fccf --- /dev/null +++ b/vendor/github.com/rekby/safemutex/mutex_try_lock.go @@ -0,0 +1,43 @@ +//go:build go1.19 +// +build go1.19 + +package safemutex + +// TryLock - call f within locked mutex if locked successfully. +// returned true if locked successfully +// return true if mutex already locked +// it will panic if value type not pass internal checks +// it will panic with ErrPoisoned if locked successfully and previous call exited without return value: +// with panic or runtime.Goexit() +// +// Available since go 1.19 only +func (m *Mutex[T]) TryLock(f ReadWriteCallback[T]) bool { + locked := m.m.TryLock() + if !locked { + return false + } + defer m.m.Unlock() + + m.validateLocked() + m.callLocked(f) + return true +} + +// TryLock - call f within locked mutex if locked successfully. +// returned true if locked successfully +// return true if mutex already locked +// it will panic with ErrPoisoned if locked successfully and previous call exited without return value: +// with panic or runtime.Goexit() +// +// Available since go 1.19 only +func (m *MutexWithPointers[T]) TryLock(f ReadWriteCallback[T]) bool { + locked := m.m.TryLock() + if !locked { + return false + } + defer m.m.Unlock() + + m.baseValidateLocked() + m.callLocked(f) + return true +} diff --git a/vendor/github.com/rekby/safemutex/safe-mutex.iml b/vendor/github.com/rekby/safemutex/safe-mutex.iml new file mode 100644 index 0000000..19c8f5e --- /dev/null +++ b/vendor/github.com/rekby/safemutex/safe-mutex.iml @@ -0,0 +1,13 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/vendor/modules.txt b/vendor/modules.txt index cb9753a..47c59ff 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -139,6 +139,9 @@ github.com/rekby/fastuuid/internal/ibytes # github.com/rekby/fixenv v0.3.1 ## explicit; go 1.18 github.com/rekby/fixenv +# github.com/rekby/safemutex v0.2.0 +## explicit; go 1.18 +github.com/rekby/safemutex # github.com/rekby/zapcontext v0.0.4 ## explicit github.com/rekby/zapcontext