Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use safe mutex instead usual mutex #211

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ require (
github.com/jonboulle/clockwork v0.4.0
github.com/letsencrypt/pebble/v2 v2.4.0
github.com/rekby/fastuuid v0.9.0
github.com/rekby/safemutex v0.2.0
golang.org/x/time v0.3.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ github.com/rekby/fastuuid v0.9.0 h1:iQk8V/AyqSrgQAtKRdqx/CVep+CaKwaSWeerw1yEP3Q=
github.com/rekby/fastuuid v0.9.0/go.mod h1:qP8Lh0BH2+4rNGVRDHmDpkvE/ZuLUhjmKpRWjx+WesY=
github.com/rekby/fixenv v0.3.1 h1:zOPocbQmcsxSIjiVu5U+9JAfeu6WeLN7a9ryZkGTGJY=
github.com/rekby/fixenv v0.3.1/go.mod h1:/b5LRc06BYJtslRtHKxsPWFT/ySpHV+rWvzTg+XWk4c=
github.com/rekby/safemutex v0.2.0 h1:iEfcPqsR3EApwWHwdHvp+srN9Wfna+IG8bSpN467Jmk=
github.com/rekby/safemutex v0.2.0/go.mod h1:6I/yJdmctX0RmxEp00RzYBJJXl3ona8PsBiIDqg0v+U=
github.com/rekby/zapcontext v0.0.4 h1:85600nHTteGCLcuOhGp/SzXHymm9QcCA5sn+MPKCodY=
github.com/rekby/zapcontext v0.0.4/go.mod h1:lTIxvHAwWXBZBPPfEvmAEXPbVEcTwd52VaASZWZWcxI=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
Expand Down
208 changes: 118 additions & 90 deletions internal/acme_client_manager/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/rekby/safemutex"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -38,18 +39,21 @@ type AcmeManager struct {
AgreeFunction func(tosurl string) bool
RenewAccountInterval time.Duration

ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
cache cache.Bytes
httpClient *http.Client

background sync.WaitGroup
mu safemutex.MutexWithPointers[acmeManagerSynced]
}

type acmeManagerSynced struct {
lastAccountIndex int
accounts []clientAccount
stateLoaded bool
closed bool
ctxAutorenewCompleted context.Context
cache cache.Bytes
httpClient *http.Client

background sync.WaitGroup
mu sync.Mutex
lastAccountIndex int
accounts []clientAccount
stateLoaded bool
closed bool
}

type clientAccount struct {
Expand All @@ -67,19 +71,22 @@ func New(ctx context.Context, cache cache.Bytes) *AcmeManager {
AgreeFunction: acme.AcceptTOS,
RenewAccountInterval: renewAccountInterval,
httpClient: http.DefaultClient,
lastAccountIndex: -1,
mu: safemutex.NewWithPointers(acmeManagerSynced{lastAccountIndex: -1}),
}
}

func (m *AcmeManager) Close() error {
logger := zc.L(m.ctx)
logger.Debug("Start close")
m.mu.Lock()
alreadyClosed := m.closed
ctxAutorenewCompleted := m.ctxAutorenewCompleted
m.closed = true
m.ctxCancel()
m.mu.Unlock()
var alreadyClosed bool
var ctxAutorenewCompleted context.Context
m.mu.Lock(func(value acmeManagerSynced) (newValue acmeManagerSynced) {
alreadyClosed = value.closed
ctxAutorenewCompleted = value.ctxAutorenewCompleted
value.closed = true
m.ctxCancel()
return value
})
logger.Debug("Set closed flag", zap.Any("autorenew_context", ctxAutorenewCompleted))

if alreadyClosed {
Expand All @@ -95,71 +102,90 @@ func (m *AcmeManager) Close() error {
return nil
}

func (m *AcmeManager) GetClient(ctx context.Context) (_ *acme.Client, disableFunc func(), err error) {
func (m *AcmeManager) GetClient(ctx context.Context) (resClient *acme.Client, disableFunc func(), err error) {
if ctx.Err() != nil {
return nil, nil, errors.New("acme manager context closed")
}

m.mu.Lock()
defer m.mu.Unlock()

if m.closed {
return nil, nil, xerrors.Errorf("GetClient: %w", errClosed)
fail := func(resErr error) {
resClient = nil
disableFunc = nil
err = resErr
}
good := func(c *acme.Client, f func()) {
resClient = c
disableFunc = f
err = nil
}

m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
if synced.closed {
fail(xerrors.Errorf("GetClient: %w", errClosed))
return synced
}

createDisableFunc := func(index int) func() {
return func() {
wasEnabled := m.disableAccountSelfSync(index)
if wasEnabled {
time.AfterFunc(disableDuration, func() {
m.accountEnableSelfSync(index)
})
createDisableFunc := func(index int) func() {
return func() {
wasEnabled := m.disableAccountSelfSync(index)
if wasEnabled {
time.AfterFunc(disableDuration, func() {
m.accountEnableSelfSync(index)
})
}
}
}
}

if !m.stateLoaded && m.cache != nil && !m.IgnoreCacheLoad {
err := m.loadFromCache(ctx)
if err != nil && err != cache.ErrCacheMiss {
return nil, nil, err
if !synced.stateLoaded && m.cache != nil && !m.IgnoreCacheLoad {
err := m.loadFromCacheLocked(ctx, &synced)
if err != nil && !errors.Is(err, cache.ErrCacheMiss) {
fail(err)
return synced
}
synced.stateLoaded = true
}
m.stateLoaded = true
}

if index, ok := m.nextEnabledClientIndex(); ok {
return m.accounts[index].client, createDisableFunc(index), nil
}
if index, ok := m.nextEnabledClientIndexLocked(&synced); ok {
good(synced.accounts[index].client, createDisableFunc(index))
return synced
}

acc, err := m.registerAccount(ctx)
m.accounts = append(m.accounts, acc)
acc, err := m.registerAccount(ctx)
synced.accounts = append(synced.accounts, acc)

m.background.Add(1)
// handlepanic: in accountRenewSelfSync
go func(index int) {
defer m.background.Done()
m.accountRenewSelfSync(index)
}(len(m.accounts) - 1)
m.background.Add(1)
// handlepanic: in accountRenewSelfSync
go func(index int) {
defer m.background.Done()
m.accountRenewSelfSync(index)
}(len(synced.accounts) - 1)

if err != nil {
return nil, nil, err
}
if err != nil {
fail(err)
return synced
}

if err = m.saveState(ctx); err != nil {
return nil, nil, err
}
if err = m.saveStateLocked(ctx, &synced); err != nil {
fail(err)
return synced
}

return acc.client, createDisableFunc(len(m.accounts) - 1), nil
good(acc.client, createDisableFunc(len(synced.accounts)-1))
return synced
})
return resClient, disableFunc, err
}

func (m *AcmeManager) accountRenewSelfSync(index int) {
logger := zc.L(m.ctx)
ctx, ctxCancel := context.WithCancel(m.ctx)
defer ctxCancel()

m.mu.Lock()
m.ctxAutorenewCompleted = ctx
acc := m.accounts[index]
m.mu.Unlock()
var acc clientAccount
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.ctxAutorenewCompleted = ctx
acc = synced.accounts[index]
return synced
})

if m.ctx.Err() != nil {
return
Expand All @@ -183,37 +209,39 @@ func (m *AcmeManager) accountRenewSelfSync(index int) {
newAccount = renewTos(m.ctx, acc.client, acc.account)
}()
acc.account = newAccount
m.mu.Lock()
m.accounts[index] = acc
m.mu.Unlock()
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.accounts[index] = acc
return synced
})
}
}
}

func (m *AcmeManager) disableAccountSelfSync(index int) (wasEnabled bool) {
m.mu.Lock()
defer m.mu.Unlock()

if m.accounts[index].enabled {
m.accounts[index].enabled = false
return true
}

return false
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
if synced.accounts[index].enabled {
synced.accounts[index].enabled = false
wasEnabled = true
return synced
}
wasEnabled = false
return synced
})
return wasEnabled
}

func (m *AcmeManager) accountEnableSelfSync(index int) {
m.mu.Lock()
defer m.mu.Unlock()

m.accounts[index].enabled = true
m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.accounts[index].enabled = true
return synced
})
}

func (m *AcmeManager) initClient() *acme.Client {
return &acme.Client{DirectoryURL: m.DirectoryURL, HTTPClient: m.httpClient}
}

func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) {
func (m *AcmeManager) loadFromCacheLocked(ctx context.Context, synced *acmeManagerSynced) (err error) {
defer func() {
var effectiveError error
if err == cache.ErrCacheMiss {
Expand All @@ -239,7 +267,7 @@ func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) {
return xerrors.Errorf("no accounts in state")
}

m.accounts = make([]clientAccount, 0, len(state.Accounts))
synced.accounts = make([]clientAccount, 0, len(state.Accounts))
for index, stateAccount := range state.Accounts {
client := m.initClient()
client.Key = stateAccount.PrivateKey
Expand All @@ -255,34 +283,34 @@ func (m *AcmeManager) loadFromCache(ctx context.Context) (err error) {
defer m.background.Done()
m.accountRenewSelfSync(index)
}(index)
m.accounts = append(m.accounts, acc)
synced.accounts = append(synced.accounts, acc)
}

return nil
}

func (m *AcmeManager) nextEnabledClientIndex() (int, bool) {
func (m *AcmeManager) nextEnabledClientIndexLocked(synced *acmeManagerSynced) (int, bool) {
switch {
case len(m.accounts) == 0:
case len(synced.accounts) == 0:
return 0, false
case len(m.accounts) == 1 && m.accounts[0].enabled:
case len(synced.accounts) == 1 && synced.accounts[0].enabled:
return 0, true
default:
// pass
}

startIndex := m.lastAccountIndex
startIndex := synced.lastAccountIndex
if startIndex < 0 {
startIndex = len(m.accounts) - 1
startIndex = len(synced.accounts) - 1
}
index := startIndex
for {
index++
if index >= len(m.accounts) {
if index >= len(synced.accounts) {
index = 0
}
if m.accounts[index].enabled {
m.lastAccountIndex = index
if synced.accounts[index].enabled {
synced.lastAccountIndex = index
return index, true
}
if index == startIndex {
Expand Down Expand Up @@ -310,11 +338,11 @@ func (m *AcmeManager) registerAccount(ctx context.Context) (clientAccount, error
return acc, nil
}

func (m *AcmeManager) saveState(ctx context.Context) error {
func (m *AcmeManager) saveStateLocked(ctx context.Context, synced *acmeManagerSynced) error {
var state acmeManagerState
state.Accounts = make([]acmeAccountState, 0, len(m.accounts))
state.Accounts = make([]acmeAccountState, 0, len(synced.accounts))

for _, acc := range m.accounts {
for _, acc := range synced.accounts {
state.Accounts = append(state.Accounts, acmeAccountState{PrivateKey: acc.client.Key.(*rsa.PrivateKey), AcmeAccount: acc.account})
}

Expand Down
24 changes: 13 additions & 11 deletions internal/acme_client_manager/client_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,19 @@ func TestClientManager_nextEnabledClientIndex(t *testing.T) {
e, _, flush := th.NewEnv(t)
defer flush()

m := AcmeManager{
lastAccountIndex: test.lastAccountIndex,
}

for _, enabled := range test.accountsEnabled {
m.accounts = append(m.accounts, clientAccount{enabled: enabled})
}

resIndex, resOk := m.nextEnabledClientIndex()
e.Cmp(resIndex, test.resIndex)
e.Cmp(resOk, test.resOk)
m := AcmeManager{}

m.mu.Lock(func(synced acmeManagerSynced) acmeManagerSynced {
synced.lastAccountIndex = test.lastAccountIndex
for _, enabled := range test.accountsEnabled {
synced.accounts = append(synced.accounts, clientAccount{enabled: enabled})
}

resIndex, resOk := m.nextEnabledClientIndexLocked(&synced)
e.Cmp(resIndex, test.resIndex)
e.Cmp(resOk, test.resOk)
return synced
})
})
}
}
Expand Down
Loading