Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: apply PCR AOST time stamp for authentication flow #135929

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pkg/sql/catalog/descs/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,22 @@ func (tc *Collection) GetIndexComment(
return tc.GetComment(catalogkeys.MakeCommentKey(uint32(tableID), uint32(indexID), catalogkeys.IndexCommentType))
}

// MaybeSetReplicationSafeTS modifies a txn to apply the replication safe timestamp,
// if we are executing against a PCR reader catalog.
func (tc *Collection) MaybeSetReplicationSafeTS(ctx context.Context, txn *kv.Txn) error {
now := txn.DB().Clock().Now()
desc, err := tc.leased.lm.Acquire(ctx, now, keys.SystemDatabaseID)
if err != nil {
return err
}
defer desc.Release(ctx)

if desc.Underlying().(catalog.DatabaseDescriptor).GetReplicatedPCRVersion() == 0 {
return nil
}
return txn.SetFixedTimestamp(ctx, tc.leased.lm.GetSafeReplicationTS())
}

// GetConstraintComment implements the scdecomp.CommentGetter interface.
func (tc *Collection) GetConstraintComment(
tableID descpb.ID, constraintID catid.ConstraintID,
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/catalog/descs/leased_descriptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type LeaseManager interface {
IncGaugeAfterLeaseDuration(
gaugeType lease.AfterLeaseDurationGauge,
) (decrAfterWait func())

GetSafeReplicationTS() hlc.Timestamp
}

type deadlineHolder interface {
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/catalog/replication/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ go_test(
"//pkg/util/randutil",
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_stretchr_testify//require",
],
)
48 changes: 44 additions & 4 deletions pkg/sql/catalog/replication/reader_catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/errors"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -270,6 +271,8 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
srcRunner := sqlutils.MakeSQLRunner(srcConn)

ddlToExec := []string{
"CREATE USER bob password 'bob'",
"GRANT ADMIN TO bob;",
"CREATE SEQUENCE sq1;",
"CREATE TYPE IF NOT EXISTS status AS ENUM ('open', 'closed', 'inactive');",
"CREATE TABLE t1(j int default nextval('sq1'), val status);",
Expand Down Expand Up @@ -297,6 +300,9 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
// Connect only after the reader catalog is setup, so the connection
// executor is aware.
destConn := destTenant.SQLConn(t)
destURL, destURLCleanup := destTenant.PGUrl(t, serverutils.UserPassword("bob", "bob"), serverutils.ClientCerts(false))
defer destURLCleanup()
require.NoError(t, err)
destRunner := sqlutils.MakeSQLRunner(destConn)

check := func(query string, isEqual bool) {
Expand All @@ -322,6 +328,16 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
} else {
require.NotEqualValues(t, srcRes, destRes)
}

// Sanity: Execute the same query as prepared statement inside the reader
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm, were these unit tests with prepared stmts failing before your patch to the source code?

// catalog .
destPgxConn, err := pgx.Connect(ctx, destURL.String())
_, err = destPgxConn.Prepare(ctx, query, query)
require.NoError(t, err)
rows, err := destPgxConn.Query(ctx, query)
require.NoError(t, err)
defer rows.Close()
require.NoError(t, destPgxConn.Close(ctx))
}

compareEqual := func(query string) {
Expand All @@ -333,6 +349,10 @@ func TestReaderCatalogTSAdvance(t *testing.T) {

var newTS hlc.Timestamp
descriptorRefreshHookEnabled.Store(true)
existingPgxConn, err := pgx.Connect(ctx, destURL.String())
require.NoError(t, err)
_, err = existingPgxConn.Prepare(ctx, "basic select", "SELECT * FROM t1, v1, t2")
require.NoError(t, err)
for _, useAOST := range []bool{false, true} {
if useAOST {
closeWaitForRefresh()
Expand Down Expand Up @@ -385,7 +405,9 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
destRunner.Exec(t, "SET bypass_pcr_reader_catalog_aost='on'")
}
iterationsDone := false
uniqueIdx := 0
for !iterationsDone {
uniqueIdx++
if !useAOST {
select {
case waitForRefresh <- struct{}{}:
Expand All @@ -397,8 +419,27 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
case <-iterationsDoneCh:
iterationsDone = true
default:
// Prepare on an existing connection.
rows, err := existingPgxConn.Query(ctx, "SELECT * FROM t1, v1, t2")
require.NoError(t, err)
rows.Close()
uniqueQuery := fmt.Sprintf("SELECT a.j + %d FROM t1 as a, v1 as b, t2 as c ", uniqueIdx)
_, err = existingPgxConn.Prepare(ctx, fmt.Sprintf("q%d", uniqueIdx), uniqueQuery)
require.NoError(t, err)
rows, err = existingPgxConn.Query(ctx, uniqueQuery)
require.NoError(t, err)
rows.Close()
// Open new connections.
newPgxConn, err := pgx.Connect(ctx, destURL.String())
require.NoError(t, err)
_, err = newPgxConn.Prepare(ctx, "basic select", "SELECT * FROM t1, v1, t2")
require.NoError(t, err)
rows, err = newPgxConn.Query(ctx, "SELECT * FROM t1, v1, t2")
require.NoError(t, err)
require.NoError(t, newPgxConn.Close(ctx))

tx := destRunner.Begin(t)
_, err := tx.Exec("SELECT * FROM t1")
_, err = tx.Exec("SELECT * FROM t1")
checkAOSTError(err)
_, err = tx.Exec("SELECT * FROM v1")
checkAOSTError(err)
Expand All @@ -414,13 +455,13 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
checkAOSTError(err)
}
}

// Finally ensure the queries actually match.
require.NoError(t, grp.Wait())
// Check if the error was detected.
require.Equalf(t, !useAOST, errorDetected,
"error was detected unexpectedly (AOST = %t on connection)", useAOST)
}
require.NoError(t, existingPgxConn.Close(ctx))
now = newTS
compareEqual("SELECT * FROM t1 ORDER BY j")
compareEqual("SELECT * FROM v1 ORDER BY 1")
Expand All @@ -432,7 +473,7 @@ func TestReaderCatalogTSAdvance(t *testing.T) {
// with PCR even if a long running txn is running.
func TestReaderCatalogTSAdvanceWithLongTxn(t *testing.T) {
defer leaktest.AfterTest(t)()
skip.UnderDuress(t)
// skip.UnderDuress(t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: intended?


ctx := context.Background()
ts := serverutils.StartServerOnly(t, base.TestServerArgs{
Expand All @@ -459,7 +500,6 @@ func TestReaderCatalogTSAdvanceWithLongTxn(t *testing.T) {

ddlToExec := []string{
"CREATE USER roacher WITH CREATEROLE;",
"GRANT ADMIN TO roacher;",
"ALTER USER roacher SET timezone='America/New_York';",
"CREATE SEQUENCE sq1;",
"CREATE TABLE t1(n int default nextval('sq1'), val TEXT);",
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/regions/region_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ func (f fakeLeaseManager) IncGaugeAfterLeaseDuration(gauge lease.AfterLeaseDurat
return func() {}
}

func (f fakeLeaseManager) GetSafeReplicationTS() hlc.Timestamp {
return hlc.Timestamp{}
}

var _ descs.LeaseManager = (*fakeLeaseManager)(nil)

type fakeSystemDatabase struct {
Expand Down
3 changes: 3 additions & 0 deletions pkg/sql/sessioninit/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func (a *Cache) GetAuthInfo(
err = db.DescsTxn(ctx, func(
ctx context.Context, txn descs.Txn,
) error {
if err := txn.Descriptors().MaybeSetReplicationSafeTS(ctx, txn.KV()); err != nil {
return err
}
_, usersTableDesc, err = descs.PrefixAndTable(ctx, txn.Descriptors().ByNameWithLeased(txn.KV()).Get(), UsersTableName)
if err != nil {
return err
Expand Down
157 changes: 83 additions & 74 deletions pkg/sql/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ func GetUserSessionInitInfo(
return execCfg.InternalDB.DescsTxn(ctx, func(
ctx context.Context, txn descs.Txn,
) error {
if err := txn.Descriptors().MaybeSetReplicationSafeTS(ctx, txn.KV()); err != nil {
return err
}
memberships, err := MemberOfWithAdminOption(ctx, execCfg, txn, user)
if err != nil {
return err
Expand Down Expand Up @@ -288,92 +291,98 @@ func retrieveAuthInfo(
// we should always look up the latest data.
const getHashedPassword = `SELECT "hashedPassword" FROM system.public.users ` +
`WHERE username=$1`
ie := f.Executor()
values, err := ie.QueryRowEx(
ctx, "get-hashed-pwd", nil, /* txn */
sessiondata.NodeUserSessionDataOverride,
getHashedPassword, user)

if err != nil {
return aInfo, errors.Wrapf(err, "error looking up user %s", user)
}
var hashedPassword []byte
if values != nil {
aInfo.UserExists = true
if v := values[0]; v != tree.DNull {
hashedPassword = []byte(*(v.(*tree.DBytes)))
err := f.DescsTxn(ctx, func(ctx context.Context, txn descs.Txn) error {
if err := txn.Descriptors().MaybeSetReplicationSafeTS(ctx, txn.KV()); err != nil {
return err
}
values, err := txn.QueryRowEx(
ctx, "get-hashed-pwd", txn.KV(), /* txn */
sessiondata.NodeUserSessionDataOverride,
getHashedPassword, user)
if err != nil {
return err
}
}
aInfo.HashedPassword = password.LoadPasswordHash(ctx, hashedPassword)

if !aInfo.UserExists {
return aInfo, nil
}
var hashedPassword []byte
if values != nil {
aInfo.UserExists = true
if v := values[0]; v != tree.DNull {
hashedPassword = []byte(*(v.(*tree.DBytes)))
}
}

// None of the rest of the role options are relevant for root.
if user.IsRootUser() {
return aInfo, nil
}
aInfo.HashedPassword = password.LoadPasswordHash(ctx, hashedPassword)

// Use fully qualified table name to avoid looking up "".system.role_options.
const getLoginDependencies = `SELECT option, value FROM system.public.role_options ` +
`WHERE username=$1 AND option IN ('NOLOGIN', 'VALID UNTIL', 'NOSQLLOGIN', 'REPLICATION', 'SUBJECT')`
if !aInfo.UserExists {
return nil
}

roleOptsIt, err := ie.QueryIteratorEx(
ctx, "get-login-dependencies", nil, /* txn */
sessiondata.NodeUserSessionDataOverride,
getLoginDependencies,
user,
)
// None of the rest of the role options are relevant for root.
if user.IsRootUser() {
return nil
}

if err != nil {
return aInfo, errors.Wrapf(err, "error looking up user %s", user)
}
// We have to make sure to close the iterator since we might return from
// the for loop early (before Next() returns false).
defer func() { retErr = errors.CombineErrors(retErr, roleOptsIt.Close()) }()
// Use fully qualified table name to avoid looking up "".system.role_options.
const getLoginDependencies = `SELECT option, value FROM system.public.role_options ` +
`WHERE username=$1 AND option IN ('NOLOGIN', 'VALID UNTIL', 'NOSQLLOGIN', 'REPLICATION', 'SUBJECT')`

// To support users created before 20.1, allow all USERS/ROLES to login
// if NOLOGIN is not found.
aInfo.CanLoginSQLRoleOpt = true
aInfo.CanLoginDBConsoleRoleOpt = true
var ok bool
roleOptsIt, err := txn.QueryIteratorEx(
ctx, "get-login-dependencies", txn.KV(), /* txn */
sessiondata.NodeUserSessionDataOverride,
getLoginDependencies,
user,
)

for ok, err = roleOptsIt.Next(ctx); ok; ok, err = roleOptsIt.Next(ctx) {
row := roleOptsIt.Cur()
option := string(tree.MustBeDString(row[0]))
switch option {
case "NOLOGIN":
aInfo.CanLoginSQLRoleOpt = false
aInfo.CanLoginDBConsoleRoleOpt = false
case "NOSQLLOGIN":
aInfo.CanLoginSQLRoleOpt = false
case "REPLICATION":
aInfo.CanUseReplicationRoleOpt = true
case "VALID UNTIL":
if row[1] != tree.DNull {
ts := string(tree.MustBeDString(row[1]))
// This is okay because the VALID UNTIL is stored as a string
// representation of a TimestampTZ which has the same underlying
// representation in the table as a Timestamp (UTC time).
timeCtx := tree.NewParseContext(timeutil.Now())
aInfo.ValidUntil, _, err = tree.ParseDTimestamp(timeCtx, ts, time.Microsecond)
if err != nil {
return aInfo, errors.Wrap(err,
"error trying to parse timestamp while retrieving password valid until value")
if err != nil {
return errors.Wrapf(err, "error looking up user %s", user)
}
// We have to make sure to close the iterator since we might return from
// the for loop early (before Next() returns false).
defer func() { retErr = errors.CombineErrors(retErr, roleOptsIt.Close()) }()

// To support users created before 20.1, allow all USERS/ROLES to login
// if NOLOGIN is not found.
aInfo.CanLoginSQLRoleOpt = true
aInfo.CanLoginDBConsoleRoleOpt = true
var ok bool

for ok, err = roleOptsIt.Next(ctx); ok; ok, err = roleOptsIt.Next(ctx) {
row := roleOptsIt.Cur()
option := string(tree.MustBeDString(row[0]))
switch option {
case "NOLOGIN":
aInfo.CanLoginSQLRoleOpt = false
aInfo.CanLoginDBConsoleRoleOpt = false
case "NOSQLLOGIN":
aInfo.CanLoginSQLRoleOpt = false
case "REPLICATION":
aInfo.CanUseReplicationRoleOpt = true
case "VALID UNTIL":
if row[1] != tree.DNull {
ts := string(tree.MustBeDString(row[1]))
// This is okay because the VALID UNTIL is stored as a string
// representation of a TimestampTZ which has the same underlying
// representation in the table as a Timestamp (UTC time).
timeCtx := tree.NewParseContext(timeutil.Now())
aInfo.ValidUntil, _, err = tree.ParseDTimestamp(timeCtx, ts, time.Microsecond)
if err != nil {
return errors.Wrap(err,
"error trying to parse timestamp while retrieving password valid until value")
}
}
}
case "SUBJECT":
if row[1] != tree.DNull {
subjectStr := string(tree.MustBeDString(row[1]))
dn, err := distinguishedname.ParseDN(subjectStr)
if err != nil {
return aInfo, err
case "SUBJECT":
if row[1] != tree.DNull {
subjectStr := string(tree.MustBeDString(row[1]))
dn, err := distinguishedname.ParseDN(subjectStr)
if err != nil {
return err
}
aInfo.Subject = dn
}
aInfo.Subject = dn
}
}
}
return nil
})

return aInfo, err
}
Expand Down
Loading