diff --git a/go/test/endtoend/topotest/consul/main_test.go b/go/test/endtoend/topotest/consul/main_test.go index 1c278864ced..0f6fa6ce554 100644 --- a/go/test/endtoend/topotest/consul/main_test.go +++ b/go/test/endtoend/topotest/consul/main_test.go @@ -24,7 +24,9 @@ import ( "testing" "time" + topoutils "vitess.io/vitess/go/test/endtoend/topotest/utils" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/topo" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -140,6 +142,87 @@ func TestTopoRestart(t *testing.T) { } } +// TestShardLocking tests that shard locking works as intended. +func TestShardLocking(t *testing.T) { + // create topo server connection + ts, err := topo.OpenServer(*clusterInstance.TopoFlavorString(), clusterInstance.VtctlProcess.TopoGlobalAddress, clusterInstance.VtctlProcess.TopoGlobalRoot) + require.NoError(t, err) + + // Acquire a shard lock. + ctx, unlock, err := ts.LockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + require.NoError(t, err) + // Check that we can't reacquire it from the same context. + _, _, err = ts.LockShard(ctx, KeyspaceName, "0", "TestShardLocking") + require.ErrorContains(t, err, "lock for shard customer/0 is already held") + // Also check that TryLockShard is non-blocking and returns an error. + _, _, err = ts.TryLockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + require.ErrorContains(t, err, "node already exists: lock already exists at path keyspaces/customer/shards/0") + // Check that CheckShardLocked doesn't return an error. + err = topo.CheckShardLocked(ctx, KeyspaceName, "0") + require.NoError(t, err) + + // We'll now try to acquire the lock from a different thread. + secondThreadLockAcquired := false + go func() { + _, unlock, err := ts.LockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + defer unlock(&err) + require.NoError(t, err) + secondThreadLockAcquired = true + }() + + // Wait for some time and ensure that the second acquiring of lock shard is blocked. + time.Sleep(100 * time.Millisecond) + require.False(t, secondThreadLockAcquired) + + // Unlock the shard. + unlock(&err) + // Check that we no longer have shard lock acquired. + err = topo.CheckShardLocked(ctx, KeyspaceName, "0") + require.ErrorContains(t, err, "shard customer/0 is not locked (no lockInfo in map)") + + // Wait to see that the second thread was able to acquire the shard lock. + topoutils.WaitForBoolValue(t, &secondThreadLockAcquired, true) +} + +// TestKeyspaceLocking tests that keyspace locking works as intended. +func TestKeyspaceLocking(t *testing.T) { + // create topo server connection + ts, err := topo.OpenServer(*clusterInstance.TopoFlavorString(), clusterInstance.VtctlProcess.TopoGlobalAddress, clusterInstance.VtctlProcess.TopoGlobalRoot) + require.NoError(t, err) + + // Acquire a keyspace lock. + ctx, unlock, err := ts.LockKeyspace(context.Background(), KeyspaceName, "TestKeyspaceLocking") + require.NoError(t, err) + // Check that we can't reacquire it from the same context. + _, _, err = ts.LockKeyspace(ctx, KeyspaceName, "TestKeyspaceLocking") + require.ErrorContains(t, err, "lock for keyspace customer is already held") + // Check that CheckKeyspaceLocked doesn't return an error. + err = topo.CheckKeyspaceLocked(ctx, KeyspaceName) + require.NoError(t, err) + + // We'll now try to acquire the lock from a different thread. + secondThreadLockAcquired := false + go func() { + _, unlock, err := ts.LockKeyspace(context.Background(), KeyspaceName, "TestKeyspaceLocking") + defer unlock(&err) + require.NoError(t, err) + secondThreadLockAcquired = true + }() + + // Wait for some time and ensure that the second acquiring of lock shard is blocked. + time.Sleep(100 * time.Millisecond) + require.False(t, secondThreadLockAcquired) + + // Unlock the keyspace. + unlock(&err) + // Check that we no longer have keyspace lock acquired. + err = topo.CheckKeyspaceLocked(ctx, KeyspaceName) + require.ErrorContains(t, err, "keyspace customer is not locked (no lockInfo in map)") + + // Wait to see that the second thread was able to acquire the shard lock. + topoutils.WaitForBoolValue(t, &secondThreadLockAcquired, true) +} + func execute(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result { t.Helper() qr, err := conn.ExecuteFetch(query, 1000, true) diff --git a/go/test/endtoend/topotest/etcd2/main_test.go b/go/test/endtoend/topotest/etcd2/main_test.go index db34bd2ee86..747f2721cdc 100644 --- a/go/test/endtoend/topotest/etcd2/main_test.go +++ b/go/test/endtoend/topotest/etcd2/main_test.go @@ -23,7 +23,9 @@ import ( "testing" "time" + topoutils "vitess.io/vitess/go/test/endtoend/topotest/utils" "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/log" @@ -111,10 +113,94 @@ func TestTopoDownServingQuery(t *testing.T) { execMulti(t, conn, `insert into t1(c1, c2, c3, c4) values (300,100,300,'abc'); ;; insert into t1(c1, c2, c3, c4) values (301,101,301,'abcd');;`) utils.AssertMatches(t, conn, `select c1,c2,c3 from t1`, `[[INT64(300) INT64(100) INT64(300)] [INT64(301) INT64(101) INT64(301)]]`) clusterInstance.TopoProcess.TearDown(clusterInstance.Cell, clusterInstance.OriginalVTDATAROOT, clusterInstance.CurrentVTDATAROOT, true, *clusterInstance.TopoFlavorString()) + defer func() { + _ = clusterInstance.TopoProcess.SetupEtcd() + }() time.Sleep(3 * time.Second) utils.AssertMatches(t, conn, `select c1,c2,c3 from t1`, `[[INT64(300) INT64(100) INT64(300)] [INT64(301) INT64(101) INT64(301)]]`) } +// TestShardLocking tests that shard locking works as intended. +func TestShardLocking(t *testing.T) { + // create topo server connection + ts, err := topo.OpenServer(*clusterInstance.TopoFlavorString(), clusterInstance.VtctlProcess.TopoGlobalAddress, clusterInstance.VtctlProcess.TopoGlobalRoot) + require.NoError(t, err) + + // Acquire a shard lock. + ctx, unlock, err := ts.LockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + require.NoError(t, err) + // Check that we can't reacquire it from the same context. + _, _, err = ts.LockShard(ctx, KeyspaceName, "0", "TestShardLocking") + require.ErrorContains(t, err, "lock for shard customer/0 is already held") + // Also check that TryLockShard is non-blocking and returns an error. + _, _, err = ts.TryLockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + require.ErrorContains(t, err, "node already exists: lock already exists at path keyspaces/customer/shards/0") + // Check that CheckShardLocked doesn't return an error. + err = topo.CheckShardLocked(ctx, KeyspaceName, "0") + require.NoError(t, err) + + // We'll now try to acquire the lock from a different thread. + secondThreadLockAcquired := false + go func() { + _, unlock, err := ts.LockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + defer unlock(&err) + require.NoError(t, err) + secondThreadLockAcquired = true + }() + + // Wait for some time and ensure that the second acquiring of lock shard is blocked. + time.Sleep(100 * time.Millisecond) + require.False(t, secondThreadLockAcquired) + + // Unlock the shard. + unlock(&err) + // Check that we no longer have shard lock acquired. + err = topo.CheckShardLocked(ctx, KeyspaceName, "0") + require.ErrorContains(t, err, "shard customer/0 is not locked (no lockInfo in map)") + + // Wait to see that the second thread was able to acquire the shard lock. + topoutils.WaitForBoolValue(t, &secondThreadLockAcquired, true) +} + +// TestKeyspaceLocking tests that keyspace locking works as intended. +func TestKeyspaceLocking(t *testing.T) { + // create topo server connection + ts, err := topo.OpenServer(*clusterInstance.TopoFlavorString(), clusterInstance.VtctlProcess.TopoGlobalAddress, clusterInstance.VtctlProcess.TopoGlobalRoot) + require.NoError(t, err) + + // Acquire a keyspace lock. + ctx, unlock, err := ts.LockKeyspace(context.Background(), KeyspaceName, "TestKeyspaceLocking") + require.NoError(t, err) + // Check that we can't reacquire it from the same context. + _, _, err = ts.LockKeyspace(ctx, KeyspaceName, "TestKeyspaceLocking") + require.ErrorContains(t, err, "lock for keyspace customer is already held") + // Check that CheckKeyspaceLocked doesn't return an error. + err = topo.CheckKeyspaceLocked(ctx, KeyspaceName) + require.NoError(t, err) + + // We'll now try to acquire the lock from a different thread. + secondThreadLockAcquired := false + go func() { + _, unlock, err := ts.LockKeyspace(context.Background(), KeyspaceName, "TestKeyspaceLocking") + defer unlock(&err) + require.NoError(t, err) + secondThreadLockAcquired = true + }() + + // Wait for some time and ensure that the second acquiring of lock shard is blocked. + time.Sleep(100 * time.Millisecond) + require.False(t, secondThreadLockAcquired) + + // Unlock the keyspace. + unlock(&err) + // Check that we no longer have keyspace lock acquired. + err = topo.CheckKeyspaceLocked(ctx, KeyspaceName) + require.ErrorContains(t, err, "keyspace customer is not locked (no lockInfo in map)") + + // Wait to see that the second thread was able to acquire the shard lock. + topoutils.WaitForBoolValue(t, &secondThreadLockAcquired, true) +} + func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result diff --git a/go/test/endtoend/topotest/utils/utils.go b/go/test/endtoend/topotest/utils/utils.go new file mode 100644 index 00000000000..6b8433b6a7f --- /dev/null +++ b/go/test/endtoend/topotest/utils/utils.go @@ -0,0 +1,41 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// WaitForBoolValue takes a pointer to a boolean and waits for it to reach a certain value. +func WaitForBoolValue(t *testing.T, val *bool, waitFor bool) { + timeout := time.After(15 * time.Second) + for { + select { + case <-timeout: + require.Failf(t, "Failed waiting for bool value", "Timed out waiting for the boolean to become %v", waitFor) + return + default: + if *val == waitFor { + return + } + time.Sleep(100 * time.Millisecond) + } + } +} diff --git a/go/test/endtoend/topotest/zk2/main_test.go b/go/test/endtoend/topotest/zk2/main_test.go index 816bbc72d72..48636331747 100644 --- a/go/test/endtoend/topotest/zk2/main_test.go +++ b/go/test/endtoend/topotest/zk2/main_test.go @@ -23,7 +23,9 @@ import ( "testing" "time" + topoutils "vitess.io/vitess/go/test/endtoend/topotest/utils" "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/log" @@ -116,6 +118,87 @@ func TestTopoDownServingQuery(t *testing.T) { utils.AssertMatches(t, conn, `select c1,c2,c3 from t1`, `[[INT64(300) INT64(100) INT64(300)] [INT64(301) INT64(101) INT64(301)]]`) } +// TestShardLocking tests that shard locking works as intended. +func TestShardLocking(t *testing.T) { + // create topo server connection + ts, err := topo.OpenServer(*clusterInstance.TopoFlavorString(), clusterInstance.VtctlProcess.TopoGlobalAddress, clusterInstance.VtctlProcess.TopoGlobalRoot) + require.NoError(t, err) + + // Acquire a shard lock. + ctx, unlock, err := ts.LockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + require.NoError(t, err) + // Check that we can't reacquire it from the same context. + _, _, err = ts.LockShard(ctx, KeyspaceName, "0", "TestShardLocking") + require.ErrorContains(t, err, "lock for shard customer/0 is already held") + // Also check that TryLockShard is non-blocking and returns an error. + _, _, err = ts.TryLockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + require.ErrorContains(t, err, "node already exists: lock already exists at path keyspaces/customer/shards/0") + // Check that CheckShardLocked doesn't return an error. + err = topo.CheckShardLocked(ctx, KeyspaceName, "0") + require.NoError(t, err) + + // We'll now try to acquire the lock from a different thread. + secondThreadLockAcquired := false + go func() { + _, unlock, err := ts.LockShard(context.Background(), KeyspaceName, "0", "TestShardLocking") + defer unlock(&err) + require.NoError(t, err) + secondThreadLockAcquired = true + }() + + // Wait for some time and ensure that the second acquiring of lock shard is blocked. + time.Sleep(100 * time.Millisecond) + require.False(t, secondThreadLockAcquired) + + // Unlock the shard. + unlock(&err) + // Check that we no longer have shard lock acquired. + err = topo.CheckShardLocked(ctx, KeyspaceName, "0") + require.ErrorContains(t, err, "shard customer/0 is not locked (no lockInfo in map)") + + // Wait to see that the second thread was able to acquire the shard lock. + topoutils.WaitForBoolValue(t, &secondThreadLockAcquired, true) +} + +// TestKeyspaceLocking tests that keyspace locking works as intended. +func TestKeyspaceLocking(t *testing.T) { + // create topo server connection + ts, err := topo.OpenServer(*clusterInstance.TopoFlavorString(), clusterInstance.VtctlProcess.TopoGlobalAddress, clusterInstance.VtctlProcess.TopoGlobalRoot) + require.NoError(t, err) + + // Acquire a keyspace lock. + ctx, unlock, err := ts.LockKeyspace(context.Background(), KeyspaceName, "TestKeyspaceLocking") + require.NoError(t, err) + // Check that we can't reacquire it from the same context. + _, _, err = ts.LockKeyspace(ctx, KeyspaceName, "TestKeyspaceLocking") + require.ErrorContains(t, err, "lock for keyspace customer is already held") + // Check that CheckKeyspaceLocked doesn't return an error. + err = topo.CheckKeyspaceLocked(ctx, KeyspaceName) + require.NoError(t, err) + + // We'll now try to acquire the lock from a different thread. + secondThreadLockAcquired := false + go func() { + _, unlock, err := ts.LockKeyspace(context.Background(), KeyspaceName, "TestKeyspaceLocking") + defer unlock(&err) + require.NoError(t, err) + secondThreadLockAcquired = true + }() + + // Wait for some time and ensure that the second acquiring of lock shard is blocked. + time.Sleep(100 * time.Millisecond) + require.False(t, secondThreadLockAcquired) + + // Unlock the keyspace. + unlock(&err) + // Check that we no longer have keyspace lock acquired. + err = topo.CheckKeyspaceLocked(ctx, KeyspaceName) + require.ErrorContains(t, err, "keyspace customer is not locked (no lockInfo in map)") + + // Wait to see that the second thread was able to acquire the shard lock. + topoutils.WaitForBoolValue(t, &secondThreadLockAcquired, true) +} + func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result diff --git a/go/vt/schemamanager/tablet_executor.go b/go/vt/schemamanager/tablet_executor.go index 592c64e7073..ab58f9f8463 100644 --- a/go/vt/schemamanager/tablet_executor.go +++ b/go/vt/schemamanager/tablet_executor.go @@ -387,7 +387,7 @@ func (exec *TabletExecutor) Execute(ctx context.Context, sqls []string) *Execute } for index, sql := range sqls { // Attempt to renew lease: - if err := rl.Do(func() error { return topo.CheckKeyspaceLockedAndRenew(ctx, exec.keyspace) }); err != nil { + if err := rl.Do(func() error { return topo.CheckKeyspaceLocked(ctx, exec.keyspace) }); err != nil { return errorExecResult(vterrors.Wrapf(err, "CheckKeyspaceLocked in ApplySchemaKeyspace %v", exec.keyspace)) } execResult.CurSQLIndex = index diff --git a/go/vt/topo/keyspace_lock.go b/go/vt/topo/keyspace_lock.go new file mode 100644 index 00000000000..7df1b2ee64f --- /dev/null +++ b/go/vt/topo/keyspace_lock.go @@ -0,0 +1,58 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package topo + +import ( + "context" + "path" +) + +type keyspaceLock struct { + keyspace string +} + +var _ iTopoLock = (*keyspaceLock)(nil) + +func (s *keyspaceLock) Type() string { + return "keyspace" +} + +func (s *keyspaceLock) ResourceName() string { + return s.keyspace +} + +func (s *keyspaceLock) Path() string { + return path.Join(KeyspacesPath, s.keyspace) +} + +// LockKeyspace will lock the keyspace, and return: +// - a context with a locksInfo structure for future reference. +// - an unlock method +// - an error if anything failed. +func (ts *Server) LockKeyspace(ctx context.Context, keyspace, action string) (context.Context, func(*error), error) { + return ts.internalLock(ctx, &keyspaceLock{ + keyspace: keyspace, + }, action, true) +} + +// CheckKeyspaceLocked can be called on a context to make sure we have the lock +// for a given keyspace. +func CheckKeyspaceLocked(ctx context.Context, keyspace string) error { + return checkLocked(ctx, &keyspaceLock{ + keyspace: keyspace, + }) +} diff --git a/go/vt/topo/keyspace_lock_test.go b/go/vt/topo/keyspace_lock_test.go new file mode 100644 index 00000000000..6d0a34de554 --- /dev/null +++ b/go/vt/topo/keyspace_lock_test.go @@ -0,0 +1,84 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package topo_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/topo" + "vitess.io/vitess/go/vt/topo/memorytopo" +) + +// TestTopoKeyspaceLock tests keyspace lock operations. +func TestTopoKeyspaceLock(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ts := memorytopo.NewServer(ctx, "zone1") + defer ts.Close() + + currentTopoLockTimeout := topo.LockTimeout + topo.LockTimeout = testLockTimeout + defer func() { + topo.LockTimeout = currentTopoLockTimeout + }() + + ks1 := "ks1" + ks2 := "ks2" + err := ts.CreateKeyspace(ctx, ks1, &topodatapb.Keyspace{}) + require.NoError(t, err) + err = ts.CreateKeyspace(ctx, ks2, &topodatapb.Keyspace{}) + require.NoError(t, err) + + origCtx := ctx + ctx, unlock, err := ts.LockKeyspace(origCtx, ks1, "ks1") + require.NoError(t, err) + + // locking the same key again, without unlocking, should return an error + _, _, err2 := ts.LockKeyspace(ctx, ks1, "ks1") + require.ErrorContains(t, err2, "already held") + + // Check that we have the keyspace lock shouldn't return an error + err = topo.CheckKeyspaceLocked(ctx, ks1) + require.NoError(t, err) + + // Check that we have the keyspace lock for the other keyspace should return an error + err = topo.CheckKeyspaceLocked(ctx, ks2) + require.ErrorContains(t, err, "keyspace ks2 is not locked") + + // Check we can acquire a keyspace lock for the other keyspace + ctx2, unlock2, err := ts.LockKeyspace(ctx, ks2, "ks2") + require.NoError(t, err) + defer unlock2(&err) + + // Unlock the first keyspace + unlock(&err) + + // Check keyspace locked output for both keyspaces + err = topo.CheckKeyspaceLocked(ctx2, ks1) + require.ErrorContains(t, err, "keyspace ks1 is not locked") + err = topo.CheckKeyspaceLocked(ctx2, ks2) + require.NoError(t, err) + + // confirm that the lock can be re-acquired after unlocking + _, unlock, err = ts.LockKeyspace(origCtx, ks1, "ks1") + require.NoError(t, err) + defer unlock(&err) +} diff --git a/go/vt/topo/locks.go b/go/vt/topo/locks.go index 6325124c429..fee2b53df1a 100644 --- a/go/vt/topo/locks.go +++ b/go/vt/topo/locks.go @@ -21,13 +21,11 @@ import ( "encoding/json" "os" "os/user" - "path" "sync" "time" "github.com/spf13/pflag" - _flag "vitess.io/vitess/go/internal/flag" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/proto/vtrpc" @@ -35,8 +33,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -// This file contains utility methods and definitions to lock -// keyspaces and shards. +// This file contains utility methods and definitions to lock resources using topology server. var ( // LockTimeout is the maximum duration for which a @@ -123,131 +120,37 @@ type locksKeyType int var locksKey locksKeyType -// LockKeyspace will lock the keyspace, and return: -// - a context with a locksInfo structure for future reference. -// - an unlock method -// - an error if anything failed. -func (ts *Server) LockKeyspace(ctx context.Context, keyspace, action string) (context.Context, func(*error), error) { - i, ok := ctx.Value(locksKey).(*locksInfo) - if !ok { - i = &locksInfo{ - info: make(map[string]*lockInfo), - } - ctx = context.WithValue(ctx, locksKey, i) - } - i.mu.Lock() - defer i.mu.Unlock() - - // check that we're not already locked - if _, ok = i.info[keyspace]; ok { - return nil, nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "lock for keyspace %v is already held", keyspace) - } - - // lock - l := newLock(action) - lockDescriptor, err := l.lockKeyspace(ctx, ts, keyspace) - if err != nil { - return nil, nil, err - } - - // and update our structure - i.info[keyspace] = &lockInfo{ - lockDescriptor: lockDescriptor, - actionNode: l, - } - return ctx, func(finalErr *error) { - i.mu.Lock() - defer i.mu.Unlock() - - if _, ok := i.info[keyspace]; !ok { - if *finalErr != nil { - log.Errorf("trying to unlock keyspace %v multiple times", keyspace) - } else { - *finalErr = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "trying to unlock keyspace %v multiple times", keyspace) - } - return - } - - err := l.unlockKeyspace(ctx, ts, keyspace, lockDescriptor, *finalErr) - if *finalErr != nil { - if err != nil { - // both error are set, just log the unlock error - log.Errorf("unlockKeyspace(%v) failed: %v", keyspace, err) - } - } else { - *finalErr = err - } - delete(i.info, keyspace) - }, nil +// iTopoLock is the interface for knowing the resource that is being locked. +// It allows for better controlling nuances for different lock types and log messages. +type iTopoLock interface { + Type() string + ResourceName() string + Path() string } -// CheckKeyspaceLocked can be called on a context to make sure we have the lock -// for a given keyspace. -func CheckKeyspaceLocked(ctx context.Context, keyspace string) error { - // extract the locksInfo pointer - i, ok := ctx.Value(locksKey).(*locksInfo) - if !ok { - return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "keyspace %v is not locked (no locksInfo)", keyspace) - } - i.mu.Lock() - defer i.mu.Unlock() - - // find the individual entry - _, ok = i.info[keyspace] - if !ok { - return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "keyspace %v is not locked (no lockInfo in map)", keyspace) - } - - // TODO(alainjobart): check the lock server implementation - // still holds the lock. Will need to look at the lockInfo struct. - - // and we're good for now. - return nil -} - -// CheckKeyspaceLockedAndRenew can be called on a context to make sure we have the lock -// for a given keyspace. The function also attempts to renew the lock. -func CheckKeyspaceLockedAndRenew(ctx context.Context, keyspace string) error { - // extract the locksInfo pointer - i, ok := ctx.Value(locksKey).(*locksInfo) - if !ok { - return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "keyspace %v is not locked (no locksInfo)", keyspace) - } - i.mu.Lock() - defer i.mu.Unlock() +// perform the topo lock operation +func (l *Lock) lock(ctx context.Context, ts *Server, lt iTopoLock, isBlocking bool) (LockDescriptor, error) { + log.Infof("Locking %v %v for action %v", lt.Type(), lt.ResourceName(), l.Action) - // find the individual entry - entry, ok := i.info[keyspace] - if !ok { - return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "keyspace %v is not locked (no lockInfo in map)", keyspace) - } - // try renewing lease: - return entry.lockDescriptor.Check(ctx) -} - -// lockKeyspace will lock the keyspace in the topology server. -// unlockKeyspace should be called if this returns no error. -func (l *Lock) lockKeyspace(ctx context.Context, ts *Server, keyspace string) (LockDescriptor, error) { - log.Infof("Locking keyspace %v for action %v", keyspace, l.Action) - - ctx, cancel := context.WithTimeout(ctx, getLockTimeout()) + ctx, cancel := context.WithTimeout(ctx, LockTimeout) defer cancel() - - span, ctx := trace.NewSpan(ctx, "TopoServer.LockKeyspaceForAction") + span, ctx := trace.NewSpan(ctx, "TopoServer.Lock") span.Annotate("action", l.Action) - span.Annotate("keyspace", keyspace) + span.Annotate("path", lt.Path()) defer span.Finish() - keyspacePath := path.Join(KeyspacesPath, keyspace) j, err := l.ToJSON() if err != nil { return nil, err } - return ts.globalCell.Lock(ctx, keyspacePath, j) + if isBlocking { + return ts.globalCell.Lock(ctx, lt.Path(), j) + } + return ts.globalCell.TryLock(ctx, lt.Path(), j) } -// unlockKeyspace unlocks a previously locked keyspace. -func (l *Lock) unlockKeyspace(ctx context.Context, ts *Server, keyspace string, lockDescriptor LockDescriptor, actionError error) error { +// unlock unlocks a previously locked key. +func (l *Lock) unlock(ctx context.Context, lt iTopoLock, lockDescriptor LockDescriptor, actionError error) error { // Detach from the parent timeout, but copy the trace span. // We need to still release the lock even if the parent // context timed out. @@ -255,70 +158,23 @@ func (l *Lock) unlockKeyspace(ctx context.Context, ts *Server, keyspace string, ctx, cancel := context.WithTimeout(ctx, RemoteOperationTimeout) defer cancel() - span, ctx := trace.NewSpan(ctx, "TopoServer.UnlockKeyspaceForAction") + span, ctx := trace.NewSpan(ctx, "TopoServer.Unlock") span.Annotate("action", l.Action) - span.Annotate("keyspace", keyspace) + span.Annotate("path", lt.Path()) defer span.Finish() // first update the actionNode if actionError != nil { - log.Infof("Unlocking keyspace %v for action %v with error %v", keyspace, l.Action, actionError) + log.Infof("Unlocking %v %v for action %v with error %v", lt.Type(), lt.ResourceName(), l.Action, actionError) l.Status = "Error: " + actionError.Error() } else { - log.Infof("Unlocking keyspace %v for successful action %v", keyspace, l.Action) + log.Infof("Unlocking %v %v for successful action %v", lt.Type(), lt.ResourceName(), l.Action) l.Status = "Done" } return lockDescriptor.Unlock(ctx) } -// LockShard will lock the shard, and return: -// - a context with a locksInfo structure for future reference. -// - an unlock method -// - an error if anything failed. -// -// We are currently only using this method to lock actions that would -// impact each-other. Most changes of the Shard object are done by -// UpdateShardFields, which is not locking the shard object. The -// current list of actions that lock a shard are: -// * all Vitess-controlled re-parenting operations: -// - InitShardPrimary -// - PlannedReparentShard -// - EmergencyReparentShard -// -// * any vtorc recovery e.g -// - RecoverDeadPrimary -// - ElectNewPrimary -// - FixPrimary -// -// * before any replication repair from replication manager -// -// * operations that we don't want to conflict with re-parenting: -// - DeleteTablet when it's the shard's current primary -func (ts *Server) LockShard(ctx context.Context, keyspace, shard, action string) (context.Context, func(*error), error) { - return ts.internalLockShard(ctx, keyspace, shard, action, true) -} - -// TryLockShard will lock the shard, and return: -// - a context with a locksInfo structure for future reference. -// - an unlock method -// - an error if anything failed. -// -// `TryLockShard` is different from `LockShard`. If there is already a lock on given shard, -// then unlike `LockShard` instead of waiting and blocking the client it returns with -// `Lock already exists` error. With current implementation it may not be able to fail-fast -// for some scenarios. For example there is a possibility that a thread checks for lock for -// a given shard but by the time it acquires the lock, some other thread has already acquired it, -// in this case the client will block until the other caller releases the lock or the -// client call times out (just like standard `LockShard' implementation). In short the lock checking -// and acquiring is not under the same mutex in current implementation of `TryLockShard`. -// -// We are currently using `TryLockShard` during tablet discovery in Vtorc recovery -func (ts *Server) TryLockShard(ctx context.Context, keyspace, shard, action string) (context.Context, func(*error), error) { - return ts.internalLockShard(ctx, keyspace, shard, action, false) -} - -// internalLockShard is used to indicate whether the call should fail-fast or not. -func (ts *Server) internalLockShard(ctx context.Context, keyspace, shard, action string, isBlocking bool) (context.Context, func(*error), error) { +func (ts *Server) internalLock(ctx context.Context, lt iTopoLock, action string, isBlocking bool) (context.Context, func(*error), error) { i, ok := ctx.Value(locksKey).(*locksInfo) if !ok { i = &locksInfo{ @@ -328,28 +184,19 @@ func (ts *Server) internalLockShard(ctx context.Context, keyspace, shard, action } i.mu.Lock() defer i.mu.Unlock() - - // check that we're not already locked - mapKey := keyspace + "/" + shard - if _, ok = i.info[mapKey]; ok { - return nil, nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "lock for shard %v/%v is already held", keyspace, shard) + // check that we are not already locked + if _, ok := i.info[lt.ResourceName()]; ok { + return nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "lock for %v %v is already held", lt.Type(), lt.ResourceName()) } - // lock + // lock it l := newLock(action) - var lockDescriptor LockDescriptor - var err error - if isBlocking { - lockDescriptor, err = l.lockShard(ctx, ts, keyspace, shard) - } else { - lockDescriptor, err = l.tryLockShard(ctx, ts, keyspace, shard) - } + lockDescriptor, err := l.lock(ctx, ts, lt, isBlocking) if err != nil { return nil, nil, err } - // and update our structure - i.info[mapKey] = &lockInfo{ + i.info[lt.ResourceName()] = &lockInfo{ lockDescriptor: lockDescriptor, actionNode: l, } @@ -357,118 +204,45 @@ func (ts *Server) internalLockShard(ctx context.Context, keyspace, shard, action i.mu.Lock() defer i.mu.Unlock() - if _, ok := i.info[mapKey]; !ok { + if _, ok := i.info[lt.ResourceName()]; !ok { if *finalErr != nil { - log.Errorf("trying to unlock shard %v/%v multiple times", keyspace, shard) + log.Errorf("trying to unlock %v %v multiple times", lt.Type(), lt.ResourceName()) } else { - *finalErr = vterrors.Errorf(vtrpc.Code_INTERNAL, "trying to unlock shard %v/%v multiple times", keyspace, shard) + *finalErr = vterrors.Errorf(vtrpc.Code_INTERNAL, "trying to unlock %v %v multiple times", lt.Type(), lt.ResourceName()) } return } - err := l.unlockShard(ctx, ts, keyspace, shard, lockDescriptor, *finalErr) + err := l.unlock(ctx, lt, lockDescriptor, *finalErr) + // if we have an error, we log it, but we still want to delete the lock if *finalErr != nil { if err != nil { // both error are set, just log the unlock error - log.Warningf("unlockShard(%s/%s) failed: %v", keyspace, shard, err) + log.Warningf("unlock %v %v failed: %v", lt.Type(), lt.ResourceName(), err) } } else { *finalErr = err } - delete(i.info, mapKey) + delete(i.info, lt.ResourceName()) }, nil } -// CheckShardLocked can be called on a context to make sure we have the lock -// for a given shard. -func CheckShardLocked(ctx context.Context, keyspace, shard string) error { +// checkLocked checks that the given resource is locked. +func checkLocked(ctx context.Context, lt iTopoLock) error { // extract the locksInfo pointer i, ok := ctx.Value(locksKey).(*locksInfo) if !ok { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "shard %v/%v is not locked (no locksInfo)", keyspace, shard) + return vterrors.Errorf(vtrpc.Code_INTERNAL, "%v %v is not locked (no locksInfo)", lt.Type(), lt.ResourceName()) } i.mu.Lock() defer i.mu.Unlock() - // func the individual entry - mapKey := keyspace + "/" + shard - li, ok := i.info[mapKey] + // find the individual entry + li, ok := i.info[lt.ResourceName()] if !ok { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "shard %v/%v is not locked (no lockInfo in map)", keyspace, shard) + return vterrors.Errorf(vtrpc.Code_INTERNAL, "%v %v is not locked (no lockInfo in map)", lt.Type(), lt.ResourceName()) } // Check the lock server implementation still holds the lock. return li.lockDescriptor.Check(ctx) } - -// lockShard will lock the shard in the topology server. -// UnlockShard should be called if this returns no error. -func (l *Lock) lockShard(ctx context.Context, ts *Server, keyspace, shard string) (LockDescriptor, error) { - return l.internalLockShard(ctx, ts, keyspace, shard, true) -} - -// tryLockShard will lock the shard in the topology server but unlike `lockShard` it fail-fast if not able to get lock -// UnlockShard should be called if this returns no error. -func (l *Lock) tryLockShard(ctx context.Context, ts *Server, keyspace, shard string) (LockDescriptor, error) { - return l.internalLockShard(ctx, ts, keyspace, shard, false) -} - -func (l *Lock) internalLockShard(ctx context.Context, ts *Server, keyspace, shard string, isBlocking bool) (LockDescriptor, error) { - log.Infof("Locking shard %v/%v for action %v", keyspace, shard, l.Action) - - ctx, cancel := context.WithTimeout(ctx, getLockTimeout()) - defer cancel() - - span, ctx := trace.NewSpan(ctx, "TopoServer.LockShardForAction") - span.Annotate("action", l.Action) - span.Annotate("keyspace", keyspace) - span.Annotate("shard", shard) - defer span.Finish() - - shardPath := path.Join(KeyspacesPath, keyspace, ShardsPath, shard) - j, err := l.ToJSON() - if err != nil { - return nil, err - } - if isBlocking { - return ts.globalCell.Lock(ctx, shardPath, j) - } - return ts.globalCell.TryLock(ctx, shardPath, j) -} - -// unlockShard unlocks a previously locked shard. -func (l *Lock) unlockShard(ctx context.Context, ts *Server, keyspace, shard string, lockDescriptor LockDescriptor, actionError error) error { - // Detach from the parent timeout, but copy the trace span. - // We need to still release the lock even if the parent context timed out. - ctx = trace.CopySpan(context.TODO(), ctx) - ctx, cancel := context.WithTimeout(ctx, RemoteOperationTimeout) - defer cancel() - - span, ctx := trace.NewSpan(ctx, "TopoServer.UnlockShardForAction") - span.Annotate("action", l.Action) - span.Annotate("keyspace", keyspace) - span.Annotate("shard", shard) - defer span.Finish() - - // first update the actionNode - if actionError != nil { - log.Infof("Unlocking shard %v/%v for action %v with error %v", keyspace, shard, l.Action, actionError) - l.Status = "Error: " + actionError.Error() - } else { - log.Infof("Unlocking shard %v/%v for successful action %v", keyspace, shard, l.Action) - l.Status = "Done" - } - return lockDescriptor.Unlock(ctx) -} - -// getLockTimeout is shim code used for backward compatibility with v15 -// This code can be removed in v17+ and LockTimeout can be used directly -func getLockTimeout() time.Duration { - if _flag.IsFlagProvided("lock-timeout") { - return LockTimeout - } - if _flag.IsFlagProvided("remote_operation_timeout") { - return RemoteOperationTimeout - } - return LockTimeout -} diff --git a/go/vt/topo/locks_test.go b/go/vt/topo/locks_test.go deleted file mode 100644 index c4d2019676e..00000000000 --- a/go/vt/topo/locks_test.go +++ /dev/null @@ -1,101 +0,0 @@ -/* -Copyright 2022 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package topo - -import ( - "os" - "testing" - "time" - - "github.com/spf13/pflag" - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/internal/flag" -) - -// TestGetLockTimeout tests the behaviour of -// getLockTimeout function in different situations where -// the two flags `remote_operation_timeout` and `lock-timeout` are -// provided or not. -func TestGetLockTimeout(t *testing.T) { - tests := []struct { - description string - lockTimeoutValue string - remoteOperationTimeoutValue string - expectedLockTimeout time.Duration - }{ - { - description: "no flags specified", - lockTimeoutValue: "", - remoteOperationTimeoutValue: "", - expectedLockTimeout: 45 * time.Second, - }, { - description: "lock-timeout flag specified", - lockTimeoutValue: "33s", - remoteOperationTimeoutValue: "", - expectedLockTimeout: 33 * time.Second, - }, { - description: "remote operation timeout flag specified", - lockTimeoutValue: "", - remoteOperationTimeoutValue: "33s", - expectedLockTimeout: 33 * time.Second, - }, { - description: "both flags specified", - lockTimeoutValue: "33s", - remoteOperationTimeoutValue: "22s", - expectedLockTimeout: 33 * time.Second, - }, { - description: "remote operation timeout flag specified to the default", - lockTimeoutValue: "", - remoteOperationTimeoutValue: "15s", - expectedLockTimeout: 15 * time.Second, - }, { - description: "lock-timeout flag specified to the default", - lockTimeoutValue: "45s", - remoteOperationTimeoutValue: "33s", - expectedLockTimeout: 45 * time.Second, - }, - } - - for _, tt := range tests { - t.Run(tt.description, func(t *testing.T) { - oldLockTimeout := LockTimeout - oldRemoteOpsTimeout := RemoteOperationTimeout - defer func() { - LockTimeout = oldLockTimeout - RemoteOperationTimeout = oldRemoteOpsTimeout - }() - var args []string - if tt.lockTimeoutValue != "" { - args = append(args, "--lock-timeout", tt.lockTimeoutValue) - } - if tt.remoteOperationTimeoutValue != "" { - args = append(args, "--remote_operation_timeout", tt.remoteOperationTimeoutValue) - } - os.Args = os.Args[0:1] - os.Args = append(os.Args, args...) - - fs := pflag.NewFlagSet("test", pflag.ExitOnError) - registerTopoLockFlags(fs) - flag.Parse(fs) - - val := getLockTimeout() - require.Equal(t, tt.expectedLockTimeout, val) - }) - } - -} diff --git a/go/vt/topo/routing_rules_lock.go b/go/vt/topo/routing_rules_lock.go index db4fa63bc9b..c45ddb738c9 100644 --- a/go/vt/topo/routing_rules_lock.go +++ b/go/vt/topo/routing_rules_lock.go @@ -18,20 +18,30 @@ package topo import ( "context" - "fmt" ) -// RoutingRulesLock is a wrapper over TopoLock, to serialize updates to routing rules. -type RoutingRulesLock struct { - *TopoLock +type routingRules struct{} + +var _ iTopoLock = (*routingRules)(nil) + +func (s *routingRules) Type() string { + return RoutingRulesPath +} + +func (s *routingRules) ResourceName() string { + return RoutingRulesPath +} + +func (s *routingRules) Path() string { + return RoutingRulesPath +} + +// LockRoutingRules acquires a lock for routing rules. +func (ts *Server) LockRoutingRules(ctx context.Context, action string) (context.Context, func(*error), error) { + return ts.internalLock(ctx, &routingRules{}, action, true) } -func NewRoutingRulesLock(ctx context.Context, ts *Server, name string) (*RoutingRulesLock, error) { - return &RoutingRulesLock{ - TopoLock: &TopoLock{ - Path: RoutingRulesPath, - Name: fmt.Sprintf("RoutingRules::%s", name), - ts: ts, - }, - }, nil +// CheckRoutingRulesLocked checks if a lock for routing rules is still possessed. +func CheckRoutingRulesLocked(ctx context.Context) error { + return checkLocked(ctx, &routingRules{}) } diff --git a/go/vt/topo/routing_rules_lock_test.go b/go/vt/topo/routing_rules_lock_test.go index 23027517019..2627ea8e984 100644 --- a/go/vt/topo/routing_rules_lock_test.go +++ b/go/vt/topo/routing_rules_lock_test.go @@ -19,6 +19,7 @@ package topo_test import ( "context" "testing" + "time" "github.com/stretchr/testify/require" @@ -28,6 +29,61 @@ import ( vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) +// lower the lock timeout for testing +const testLockTimeout = 3 * time.Second + +// TestTopoLockTimeout tests that the lock times out after the specified duration. +func TestTopoLockTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ts := memorytopo.NewServer(ctx, "zone1") + defer ts.Close() + + err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{}) + require.NoError(t, err) + + currentTopoLockTimeout := topo.LockTimeout + topo.LockTimeout = testLockTimeout + defer func() { + topo.LockTimeout = currentTopoLockTimeout + }() + + // acquire the lock + origCtx := ctx + _, unlock, err := ts.LockRoutingRules(origCtx, "ks1") + require.NoError(t, err) + defer unlock(&err) + + // re-acquiring the lock should fail + _, _, err2 := ts.LockRoutingRules(origCtx, "ks1") + require.Errorf(t, err2, "deadline exceeded") +} + +// TestTopoLockBasic tests basic lock operations. +func TestTopoLockBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ts := memorytopo.NewServer(ctx, "zone1") + defer ts.Close() + + err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{}) + require.NoError(t, err) + + origCtx := ctx + ctx, unlock, err := ts.LockRoutingRules(origCtx, "ks1") + require.NoError(t, err) + + // locking the same key again, without unlocking, should return an error + _, _, err2 := ts.LockRoutingRules(ctx, "ks1") + require.ErrorContains(t, err2, "already held") + + // confirm that the lock can be re-acquired after unlocking + unlock(&err) + _, unlock, err = ts.LockRoutingRules(origCtx, "ks1") + require.NoError(t, err) + defer unlock(&err) +} + // TestKeyspaceRoutingRulesLock tests that the lock is acquired and released correctly. func TestKeyspaceRoutingRulesLock(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -44,18 +100,16 @@ func TestKeyspaceRoutingRulesLock(t *testing.T) { err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{}) require.NoError(t, err) - lock, err := topo.NewRoutingRulesLock(ctx, ts, "ks1") - require.NoError(t, err) - _, unlock, err := lock.Lock(ctx) + _, unlock, err := ts.LockRoutingRules(ctx, "ks1") require.NoError(t, err) // re-acquiring the lock should fail - _, _, err = lock.Lock(ctx) + _, _, err = ts.LockRoutingRules(ctx, "ks1") require.Error(t, err) unlock(&err) // re-acquiring the lock should succeed - _, _, err = lock.Lock(ctx) + _, _, err = ts.LockRoutingRules(ctx, "ks1") require.NoError(t, err) } diff --git a/go/vt/topo/shard_lock.go b/go/vt/topo/shard_lock.go new file mode 100644 index 00000000000..72d0b1c8ca4 --- /dev/null +++ b/go/vt/topo/shard_lock.go @@ -0,0 +1,98 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package topo + +import ( + "context" + "path" +) + +type shardLock struct { + keyspace, shard string +} + +var _ iTopoLock = (*shardLock)(nil) + +func (s *shardLock) Type() string { + return "shard" +} + +func (s *shardLock) ResourceName() string { + return s.keyspace + "/" + s.shard +} + +func (s *shardLock) Path() string { + return path.Join(KeyspacesPath, s.keyspace, ShardsPath, s.shard) +} + +// LockShard will lock the shard, and return: +// - a context with a locksInfo structure for future reference. +// - an unlock method +// - an error if anything failed. +// +// We are currently only using this method to lock actions that would +// impact each-other. Most changes of the Shard object are done by +// UpdateShardFields, which is not locking the shard object. The +// current list of actions that lock a shard are: +// * all Vitess-controlled re-parenting operations: +// - PlannedReparentShard +// - EmergencyReparentShard +// +// * any vtorc recovery e.g +// - RecoverDeadPrimary +// - ElectNewPrimary +// - FixPrimary +// +// * operations that we don't want to conflict with re-parenting: +// - DeleteTablet when it's the shard's current primary +func (ts *Server) LockShard(ctx context.Context, keyspace, shard, action string) (context.Context, func(*error), error) { + return ts.internalLock(ctx, &shardLock{ + keyspace: keyspace, + shard: shard, + }, action, true) +} + +// TryLockShard will lock the shard, and return: +// - a context with a locksInfo structure for future reference. +// - an unlock method +// - an error if anything failed. +// +// `TryLockShard` is different from `LockShard`. If there is already a lock on given shard, +// then unlike `LockShard` instead of waiting and blocking the client it returns with +// `Lock already exists` error. With current implementation it may not be able to fail-fast +// for some scenarios. For example there is a possibility that a thread checks for lock for +// a given shard but by the time it acquires the lock, some other thread has already acquired it, +// in this case the client will block until the other caller releases the lock or the +// client call times out (just like standard `LockShard' implementation). In short the lock checking +// and acquiring is not under the same mutex in current implementation of `TryLockShard`. +// +// We are currently using `TryLockShard` during tablet discovery in Vtorc recovery +func (ts *Server) TryLockShard(ctx context.Context, keyspace, shard, action string) (context.Context, func(*error), error) { + return ts.internalLock(ctx, &shardLock{ + keyspace: keyspace, + shard: shard, + }, action, false) +} + +// CheckShardLocked can be called on a context to make sure we have the lock +// for a given shard. +func CheckShardLocked(ctx context.Context, keyspace, shard string) error { + return checkLocked(ctx, &shardLock{ + keyspace: keyspace, + shard: shard, + }) +} diff --git a/go/vt/topo/topo_lock_test.go b/go/vt/topo/shard_lock_test.go similarity index 57% rename from go/vt/topo/topo_lock_test.go rename to go/vt/topo/shard_lock_test.go index c378c05a9ff..dd37335c4ca 100644 --- a/go/vt/topo/topo_lock_test.go +++ b/go/vt/topo/shard_lock_test.go @@ -19,71 +19,66 @@ package topo_test import ( "context" "testing" - "time" "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" - - vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) -// lower the lock timeout for testing -const testLockTimeout = 3 * time.Second - -// TestTopoLockTimeout tests that the lock times out after the specified duration. -func TestTopoLockTimeout(t *testing.T) { +// TestTopoShardLock tests shard lock operations. +func TestTopoShardLock(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := memorytopo.NewServer(ctx, "zone1") defer ts.Close() - err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{}) - require.NoError(t, err) - lock, err := topo.NewRoutingRulesLock(ctx, ts, "ks1") - require.NoError(t, err) - currentTopoLockTimeout := topo.LockTimeout topo.LockTimeout = testLockTimeout defer func() { topo.LockTimeout = currentTopoLockTimeout }() - // acquire the lock - origCtx := ctx - _, unlock, err := lock.Lock(origCtx) - require.NoError(t, err) - defer unlock(&err) - - // re-acquiring the lock should fail - _, _, err2 := lock.Lock(origCtx) - require.Errorf(t, err2, "deadline exceeded") -} - -// TestTopoLockBasic tests basic lock operations. -func TestTopoLockBasic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ts := memorytopo.NewServer(ctx, "zone1") - defer ts.Close() - - err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{}) + ks := "ks" + shard1 := "80-" + shard2 := "-80" + _, err := ts.GetOrCreateShard(ctx, ks, shard1) require.NoError(t, err) - lock, err := topo.NewRoutingRulesLock(ctx, ts, "ks1") + _, err = ts.GetOrCreateShard(ctx, ks, shard2) require.NoError(t, err) origCtx := ctx - ctx, unlock, err := lock.Lock(origCtx) + ctx, unlock, err := ts.LockShard(origCtx, ks, shard1, "ks80-") require.NoError(t, err) // locking the same key again, without unlocking, should return an error - _, _, err2 := lock.Lock(ctx) + _, _, err2 := ts.LockShard(ctx, ks, shard1, "ks80-") require.ErrorContains(t, err2, "already held") - // confirm that the lock can be re-acquired after unlocking + // Check that we have the shard lock shouldn't return an error + err = topo.CheckShardLocked(ctx, ks, shard1) + require.NoError(t, err) + + // Check that we have the shard lock for the other shard should return an error + err = topo.CheckShardLocked(ctx, ks, shard2) + require.ErrorContains(t, err, "shard ks/-80 is not locked") + + // Check we can acquire a shard lock for the other shard + ctx2, unlock2, err := ts.LockShard(ctx, ks, shard2, "ks-80") + require.NoError(t, err) + defer unlock2(&err) + + // Unlock the first shard unlock(&err) - _, unlock, err = lock.Lock(origCtx) + + // Check shard locked output for both shards + err = topo.CheckShardLocked(ctx2, ks, shard1) + require.ErrorContains(t, err, "shard ks/80- is not locked") + err = topo.CheckShardLocked(ctx2, ks, shard2) + require.NoError(t, err) + + // confirm that the lock can be re-acquired after unlocking + _, unlock, err = ts.TryLockShard(origCtx, ks, shard1, "ks80-") require.NoError(t, err) defer unlock(&err) } diff --git a/go/vt/topo/shard_test.go b/go/vt/topo/shard_test.go index ccef80944a9..b1de279cb1c 100644 --- a/go/vt/topo/shard_test.go +++ b/go/vt/topo/shard_test.go @@ -77,12 +77,29 @@ func TestRemoveCellsFromList(t *testing.T) { } } +// fakeLockDescriptor implements the topo.LockDescriptor interface +type fakeLockDescriptor struct{} + +// Check implements the topo.LockDescriptor interface +func (f fakeLockDescriptor) Check(ctx context.Context) error { + return nil +} + +// Unlock implements the topo.LockDescriptor interface +func (f fakeLockDescriptor) Unlock(ctx context.Context) error { + return nil +} + +var _ LockDescriptor = (*fakeLockDescriptor)(nil) + func lockedKeyspaceContext(keyspace string) context.Context { ctx := context.Background() return context.WithValue(ctx, locksKey, &locksInfo{ info: map[string]*lockInfo{ // An empty entry is good enough for this. - keyspace: {}, + keyspace: { + lockDescriptor: fakeLockDescriptor{}, + }, }, }) } diff --git a/go/vt/topo/topo_lock.go b/go/vt/topo/topo_lock.go deleted file mode 100644 index ffd732fff36..00000000000 --- a/go/vt/topo/topo_lock.go +++ /dev/null @@ -1,169 +0,0 @@ -/* -Copyright 2024 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package topo - -import ( - "context" - "fmt" - - "vitess.io/vitess/go/trace" - "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" -) - -// ITopoLock is the interface for a lock that can be used to lock a key in the topology server. -// The lock is associated with a context and can be unlocked by calling the returned function. -// Note that we don't need an Unlock method on the interface, as the Lock() function -// returns a function that can be used to unlock the lock. -type ITopoLock interface { - Lock(ctx context.Context) (context.Context, func(*error), error) -} - -type TopoLock struct { - Path string // topo path to lock - Name string // name, for logging purposes - - ts *Server -} - -var _ ITopoLock = (*TopoLock)(nil) - -func (ts *Server) NewTopoLock(path, name string) *TopoLock { - return &TopoLock{ - ts: ts, - Path: path, - Name: name, - } -} - -func (tl *TopoLock) String() string { - return fmt.Sprintf("TopoLock{Path: %v, Name: %v}", tl.Path, tl.Name) -} - -// perform the topo lock operation -func (l *Lock) lock(ctx context.Context, ts *Server, path string) (LockDescriptor, error) { - ctx, cancel := context.WithTimeout(ctx, LockTimeout) - defer cancel() - span, ctx := trace.NewSpan(ctx, "TopoServer.Lock") - span.Annotate("action", l.Action) - span.Annotate("path", path) - defer span.Finish() - - j, err := l.ToJSON() - if err != nil { - return nil, err - } - return ts.globalCell.Lock(ctx, path, j) -} - -// unlock unlocks a previously locked key. -func (l *Lock) unlock(ctx context.Context, path string, lockDescriptor LockDescriptor, actionError error) error { - // Detach from the parent timeout, but copy the trace span. - // We need to still release the lock even if the parent - // context timed out. - ctx = trace.CopySpan(context.TODO(), ctx) - ctx, cancel := context.WithTimeout(ctx, RemoteOperationTimeout) - defer cancel() - - span, ctx := trace.NewSpan(ctx, "TopoServer.Unlock") - span.Annotate("action", l.Action) - span.Annotate("path", path) - defer span.Finish() - - // first update the actionNode - if actionError != nil { - l.Status = "Error: " + actionError.Error() - } else { - l.Status = "Done" - } - return lockDescriptor.Unlock(ctx) -} - -// Lock adds lock information to the context, checks that the lock is not already held, and locks it. -// It returns a new context with the lock information and a function to unlock the lock. -func (tl TopoLock) Lock(ctx context.Context) (context.Context, func(*error), error) { - i, ok := ctx.Value(locksKey).(*locksInfo) - if !ok { - i = &locksInfo{ - info: make(map[string]*lockInfo), - } - ctx = context.WithValue(ctx, locksKey, i) - } - i.mu.Lock() - defer i.mu.Unlock() - // check that we are not already locked - if _, ok := i.info[tl.Path]; ok { - return nil, nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "lock for %v is already held", tl.Path) - } - - // lock it - l := newLock(fmt.Sprintf("lock for %s", tl.Name)) - lockDescriptor, err := l.lock(ctx, tl.ts, tl.Path) - if err != nil { - return nil, nil, err - } - // and update our structure - i.info[tl.Path] = &lockInfo{ - lockDescriptor: lockDescriptor, - actionNode: l, - } - return ctx, func(finalErr *error) { - i.mu.Lock() - defer i.mu.Unlock() - - if _, ok := i.info[tl.Path]; !ok { - if *finalErr != nil { - log.Errorf("trying to unlock %v multiple times", tl.Path) - } else { - *finalErr = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "trying to unlock %v multiple times", tl.Path) - } - return - } - - err := l.unlock(ctx, tl.Path, lockDescriptor, *finalErr) - // if we have an error, we log it, but we still want to delete the lock - if *finalErr != nil { - if err != nil { - // both error are set, just log the unlock error - log.Errorf("unlock(%v) failed: %v", tl.Path, err) - } - } else { - *finalErr = err - } - delete(i.info, tl.Path) - }, nil -} - -func CheckLocked(ctx context.Context, keyPath string) error { - // extract the locksInfo pointer - i, ok := ctx.Value(locksKey).(*locksInfo) - if !ok { - return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "%s is not locked (no locksInfo)", keyPath) - } - i.mu.Lock() - defer i.mu.Unlock() - - // find the individual entry - _, ok = i.info[keyPath] - if !ok { - return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "%s is not locked (no lockInfo in map)", keyPath) - } - - // and we're good for now. - return nil -} diff --git a/go/vt/topotools/routing_rules.go b/go/vt/topotools/routing_rules.go index a3bc5a8a957..5e423f8f55d 100644 --- a/go/vt/topotools/routing_rules.go +++ b/go/vt/topotools/routing_rules.go @@ -166,7 +166,7 @@ func buildKeyspaceRoutingRules(rules *map[string]string) *vschemapb.KeyspaceRout // saveKeyspaceRoutingRulesLocked saves the keyspace routing rules in the topo server. It expects the caller to // have acquired a RoutingRulesLock. func saveKeyspaceRoutingRulesLocked(ctx context.Context, ts *topo.Server, rules map[string]string) error { - if err := topo.CheckLocked(ctx, topo.RoutingRulesPath); err != nil { + if err := topo.CheckRoutingRulesLocked(ctx); err != nil { return err } return ts.SaveKeyspaceRoutingRules(ctx, buildKeyspaceRoutingRules(&rules)) @@ -180,12 +180,7 @@ func saveKeyspaceRoutingRulesLocked(ctx context.Context, ts *topo.Server, rules // then modify the keyspace routing rules in-place. func UpdateKeyspaceRoutingRules(ctx context.Context, ts *topo.Server, reason string, update func(ctx context.Context, rules *map[string]string) error) (err error) { - var lock *topo.RoutingRulesLock - lock, err = topo.NewRoutingRulesLock(ctx, ts, reason) - if err != nil { - return err - } - lockCtx, unlock, lockErr := lock.Lock(ctx) + lockCtx, unlock, lockErr := ts.LockRoutingRules(ctx, reason) if lockErr != nil { // If the key does not yet exist then let's create it. if !topo.IsErrType(lockErr, topo.NoNode) { diff --git a/go/vt/topotools/routing_rules_test.go b/go/vt/topotools/routing_rules_test.go index 2d4d9feacd1..6a33bbfff70 100644 --- a/go/vt/topotools/routing_rules_test.go +++ b/go/vt/topotools/routing_rules_test.go @@ -150,9 +150,7 @@ func TestSaveKeyspaceRoutingRulesLocked(t *testing.T) { }) // declare and acquire lock - lock, err := topo.NewRoutingRulesLock(ctx, ts, "test") - require.NoError(t, err) - lockCtx, unlock, err := lock.Lock(ctx) + lockCtx, unlock, err := ts.LockRoutingRules(ctx, "test") require.NoError(t, err) defer unlock(&err)