From 4a2a4893dc457896b61c866016c939fde65c377f Mon Sep 17 00:00:00 2001 From: Baohe Zhang Date: Wed, 25 Jan 2023 19:53:11 -0800 Subject: [PATCH] Support SET SESSION AUTHORIZATION and RESET SESSION AUTHORIZATION --- .../src/main/java/io/trino/cli/Console.java | 12 + .../src/main/java/io/trino/cli/Query.java | 10 + .../java/io/trino/client/ClientSession.java | 18 ++ .../java/io/trino/client/ProtocolHeaders.java | 21 ++ .../java/io/trino/client/StatementClient.java | 4 + .../io/trino/client/StatementClientV1.java | 32 +- .../java/io/trino/jdbc/TrinoConnection.java | 17 + .../java/io/trino/jdbc/TestTrinoDriver.java | 54 ++++ .../io/trino/jdbc/TestTrinoResultSet.java | 12 + .../antlr4/io/trino/grammar/sql/SqlBase.g4 | 7 + .../src/main/java/io/trino/Session.java | 23 ++ .../java/io/trino/SessionRepresentation.java | 28 ++ .../io/trino/dispatcher/DispatchManager.java | 1 + .../trino/dispatcher/FailedDispatchQuery.java | 2 + .../java/io/trino/event/QueryMonitor.java | 1 + .../java/io/trino/execution/QueryInfo.java | 19 ++ .../io/trino/execution/QueryStateMachine.java | 19 ++ .../ResetSessionAuthorizationTask.java | 64 ++++ .../SetSessionAuthorizationTask.java | 91 ++++++ .../HttpRequestSessionContextFactory.java | 28 +- .../server/QueryExecutionFactoryModule.java | 6 + .../io/trino/server/QuerySessionSupplier.java | 18 +- .../java/io/trino/server/SessionContext.java | 8 + .../protocol/ExecutingStatementResource.java | 4 + .../java/io/trino/server/protocol/Query.java | 12 + .../server/protocol/QueryResultsResponse.java | 3 + .../security/InsecureAuthenticator.java | 7 +- .../security/PasswordAuthenticator.java | 13 +- .../trino/sql/analyzer/StatementAnalyzer.java | 15 + .../io/trino/testing/LocalQueryRunner.java | 1 + .../java/io/trino/testing/TestingSession.java | 1 + .../java/io/trino/util/StatementUtils.java | 6 + .../execution/MockManagedQueryExecution.java | 2 + .../io/trino/execution/TestQueryInfo.java | 2 + .../TestSetSessionAuthorizationTask.java | 125 ++++++++ .../io/trino/server/TestBasicQueryInfo.java | 2 + .../io/trino/server/TestQueryStateInfo.java | 2 + .../server/TestSessionPropertyDefaults.java | 1 + .../server/security/TestResourceSecurity.java | 32 ++ .../main/java/io/trino/sql/SqlFormatter.java | 17 + .../java/io/trino/sql/parser/AstBuilder.java | 19 ++ .../java/io/trino/sql/tree/AstVisitor.java | 10 + .../sql/tree/ResetSessionAuthorization.java | 74 +++++ .../sql/tree/SetSessionAuthorization.java | 89 ++++++ .../io/trino/sql/parser/TestSqlParser.java | 24 ++ .../sql/parser/TestStatementBuilder.java | 3 + .../trino/spi/eventlistener/QueryContext.java | 9 + .../main/sphinx/develop/client-protocol.md | 9 + docs/src/main/sphinx/sql.md | 2 + .../sql/reset-session-authorization.rst | 22 ++ .../sphinx/sql/set-session-authorization.rst | 47 +++ .../httpquery/TestHttpEventListener.java | 1 + .../mysql/TestMysqlEventListener.java | 2 + ...stDbSessionPropertyManagerIntegration.java | 1 + .../AbstractTestEngineOnlyQueries.java | 1 + .../trino/testing/TestingSessionContext.java | 1 + .../trino/testing/TestTestingTrinoClient.java | 1 + .../TestSetSessionAuthorization.java | 300 ++++++++++++++++++ ...set_session_authorization_permissions.json | 16 + 59 files changed, 1360 insertions(+), 11 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java create mode 100644 core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java create mode 100644 core/trino-main/src/test/java/io/trino/execution/TestSetSessionAuthorizationTask.java create mode 100644 core/trino-parser/src/main/java/io/trino/sql/tree/ResetSessionAuthorization.java create mode 100644 core/trino-parser/src/main/java/io/trino/sql/tree/SetSessionAuthorization.java create mode 100644 docs/src/main/sphinx/sql/reset-session-authorization.rst create mode 100644 docs/src/main/sphinx/sql/set-session-authorization.rst create mode 100644 testing/trino-tests/src/test/java/io/trino/execution/TestSetSessionAuthorization.java create mode 100644 testing/trino-tests/src/test/resources/set_session_authorization_permissions.json diff --git a/client/trino-cli/src/main/java/io/trino/cli/Console.java b/client/trino-cli/src/main/java/io/trino/cli/Console.java index 2e791c4631e1e..c6ba90b6d46bb 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Console.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Console.java @@ -14,6 +14,7 @@ package io.trino.cli; import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.ByteStreams; import io.airlift.units.Duration; @@ -388,6 +389,17 @@ private static boolean process( builder = builder.path(query.getSetPath().get()); } + // update authorization user if present + if (query.getSetAuthorizationUser().isPresent()) { + builder = builder.authorizationUser(query.getSetAuthorizationUser()); + builder = builder.roles(ImmutableMap.of()); + } + + if (query.isResetAuthorizationUser()) { + builder = builder.authorizationUser(Optional.empty()); + builder = builder.roles(ImmutableMap.of()); + } + // update session properties if present if (!query.getSetSessionProperties().isEmpty() || !query.getResetSessionProperties().isEmpty()) { Map sessionProperties = new HashMap<>(session.getProperties()); diff --git a/client/trino-cli/src/main/java/io/trino/cli/Query.java b/client/trino-cli/src/main/java/io/trino/cli/Query.java index cb306b31cc967..4ee49a073197e 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Query.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Query.java @@ -91,6 +91,16 @@ public Optional getSetPath() return client.getSetPath(); } + public Optional getSetAuthorizationUser() + { + return client.getSetAuthorizationUser(); + } + + public boolean isResetAuthorizationUser() + { + return client.isResetAuthorizationUser(); + } + public Map getSetSessionProperties() { return client.getSetSessionProperties(); diff --git a/client/trino-client/src/main/java/io/trino/client/ClientSession.java b/client/trino-client/src/main/java/io/trino/client/ClientSession.java index f5c83f58c8b95..36f280cd15d8a 100644 --- a/client/trino-client/src/main/java/io/trino/client/ClientSession.java +++ b/client/trino-client/src/main/java/io/trino/client/ClientSession.java @@ -36,6 +36,7 @@ public class ClientSession private final URI server; private final Optional principal; private final Optional user; + private final Optional authorizationUser; private final String source; private final Optional traceToken; private final Set clientTags; @@ -75,6 +76,7 @@ private ClientSession( URI server, Optional principal, Optional user, + Optional authorizationUser, String source, Optional traceToken, Set clientTags, @@ -96,6 +98,7 @@ private ClientSession( this.server = requireNonNull(server, "server is null"); this.principal = requireNonNull(principal, "principal is null"); this.user = requireNonNull(user, "user is null"); + this.authorizationUser = requireNonNull(authorizationUser, "authorizationUser is null"); this.source = source; this.traceToken = requireNonNull(traceToken, "traceToken is null"); this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null")); @@ -158,6 +161,11 @@ public Optional getUser() return user; } + public Optional getAuthorizationUser() + { + return authorizationUser; + } + public String getSource() { return source; @@ -258,6 +266,7 @@ public String toString() .add("server", server) .add("principal", principal) .add("user", user) + .add("authorizationUser", authorizationUser) .add("clientTags", clientTags) .add("clientInfo", clientInfo) .add("catalog", catalog) @@ -277,6 +286,7 @@ public static final class Builder private URI server; private Optional principal = Optional.empty(); private Optional user = Optional.empty(); + private Optional authorizationUser = Optional.empty(); private String source; private Optional traceToken = Optional.empty(); private Set clientTags = ImmutableSet.of(); @@ -303,6 +313,7 @@ private Builder(ClientSession clientSession) server = clientSession.getServer(); principal = clientSession.getPrincipal(); user = clientSession.getUser(); + authorizationUser = clientSession.getAuthorizationUser(); source = clientSession.getSource(); traceToken = clientSession.getTraceToken(); clientTags = clientSession.getClientTags(); @@ -334,6 +345,12 @@ public Builder user(Optional user) return this; } + public Builder authorizationUser(Optional authorizationUser) + { + this.authorizationUser = authorizationUser; + return this; + } + public Builder principal(Optional principal) { this.principal = principal; @@ -448,6 +465,7 @@ public ClientSession build() server, principal, user, + authorizationUser, source, traceToken, clientTags, diff --git a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java index d78a0b2562873..e09555d847557 100644 --- a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java +++ b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java @@ -26,6 +26,7 @@ public final class ProtocolHeaders private final String name; private final String requestUser; + private final String requestOriginalUser; private final String requestSource; private final String requestCatalog; private final String requestSchema; @@ -52,6 +53,8 @@ public final class ProtocolHeaders private final String responseDeallocatedPrepare; private final String responseStartedTransactionId; private final String responseClearTransactionId; + private final String responseSetAuthorizationUser; + private final String responseResetAuthorizationUser; public static ProtocolHeaders createProtocolHeaders(String name) { @@ -69,6 +72,7 @@ private ProtocolHeaders(String name) this.name = name; String prefix = "X-" + name + "-"; requestUser = prefix + "User"; + requestOriginalUser = prefix + "Original-User"; requestSource = prefix + "Source"; requestCatalog = prefix + "Catalog"; requestSchema = prefix + "Schema"; @@ -95,6 +99,8 @@ private ProtocolHeaders(String name) responseDeallocatedPrepare = prefix + "Deallocated-Prepare"; responseStartedTransactionId = prefix + "Started-Transaction-Id"; responseClearTransactionId = prefix + "Clear-Transaction-Id"; + responseSetAuthorizationUser = prefix + "Set-Authorization-User"; + responseResetAuthorizationUser = prefix + "Reset-Authorization-User"; } public String getProtocolName() @@ -107,6 +113,11 @@ public String requestUser() return requestUser; } + public String requestOriginalUser() + { + return requestOriginalUser; + } + public String requestSource() { return requestSource; @@ -237,6 +248,16 @@ public String responseClearTransactionId() return responseClearTransactionId; } + public String responseSetAuthorizationUser() + { + return responseSetAuthorizationUser; + } + + public String responseResetAuthorizationUser() + { + return responseResetAuthorizationUser; + } + public static ProtocolHeaders detectProtocol(Optional alternateHeaderName, Set headerNames) throws ProtocolDetectionException { diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClient.java b/client/trino-client/src/main/java/io/trino/client/StatementClient.java index f927bde0716b3..841c1b296f5b7 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClient.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClient.java @@ -50,6 +50,10 @@ public interface StatementClient Optional getSetPath(); + Optional getSetAuthorizationUser(); + + boolean isResetAuthorizationUser(); + Map getSetSessionProperties(); Set getResetSessionProperties(); diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index d563dde99b094..96b9ef42acff8 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -79,6 +79,8 @@ class StatementClientV1 private final AtomicReference setCatalog = new AtomicReference<>(); private final AtomicReference setSchema = new AtomicReference<>(); private final AtomicReference setPath = new AtomicReference<>(); + private final AtomicReference setAuthorizationUser = new AtomicReference<>(); + private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean(); private final Map setSessionProperties = new ConcurrentHashMap<>(); private final Set resetSessionProperties = Sets.newConcurrentHashSet(); private final Map setRoles = new ConcurrentHashMap<>(); @@ -89,6 +91,7 @@ class StatementClientV1 private final ZoneId timeZone; private final Duration requestTimeoutNanos; private final Optional user; + private final Optional originalUser; private final String clientCapabilities; private final boolean compressionDisabled; @@ -104,7 +107,11 @@ public StatementClientV1(Call.Factory httpCallFactory, ClientSession session, St this.timeZone = session.getTimeZone(); this.query = query; this.requestTimeoutNanos = session.getClientRequestTimeout(); - this.user = Stream.of(session.getUser(), session.getPrincipal()) + this.user = Stream.of(session.getAuthorizationUser(), session.getUser(), session.getPrincipal()) + .filter(Optional::isPresent) + .map(Optional::get) + .findFirst(); + this.originalUser = Stream.of(session.getUser(), session.getPrincipal()) .filter(Optional::isPresent) .map(Optional::get) .findFirst(); @@ -270,6 +277,18 @@ public Optional getSetPath() return Optional.ofNullable(setPath.get()); } + @Override + public Optional getSetAuthorizationUser() + { + return Optional.ofNullable(setAuthorizationUser.get()); + } + + @Override + public boolean isResetAuthorizationUser() + { + return resetAuthorizationUser.get(); + } + @Override public Map getSetSessionProperties() { @@ -319,6 +338,7 @@ private Request.Builder prepareRequest(HttpUrl url) .addHeader(USER_AGENT, USER_AGENT_VALUE) .url(url); user.ifPresent(requestUser -> builder.addHeader(TRINO_HEADERS.requestUser(), requestUser)); + originalUser.ifPresent(originalUser -> builder.addHeader(TRINO_HEADERS.requestOriginalUser(), originalUser)); if (compressionDisabled) { builder.header(ACCEPT_ENCODING, "identity"); } @@ -399,6 +419,16 @@ private void processResponse(Headers headers, QueryResults results) setSchema.set(headers.get(TRINO_HEADERS.responseSetSchema())); setPath.set(headers.get(TRINO_HEADERS.responseSetPath())); + String setAuthorizationUser = headers.get(TRINO_HEADERS.responseSetAuthorizationUser()); + if (setAuthorizationUser != null) { + this.setAuthorizationUser.set(setAuthorizationUser); + } + + String resetAuthorizationUser = headers.get(TRINO_HEADERS.responseResetAuthorizationUser()); + if (resetAuthorizationUser != null) { + this.resetAuthorizationUser.set(Boolean.parseBoolean(resetAuthorizationUser)); + } + for (String setSession : headers.values(TRINO_HEADERS.responseSetSession())) { List keyValue = COLLECTION_HEADER_SPLITTER.splitToList(setSession); if (keyValue.size() != 2) { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index bedcf5e1c8321..5babd5b2525cb 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -89,6 +89,7 @@ public class TrinoConnection private final AtomicReference catalog = new AtomicReference<>(); private final AtomicReference schema = new AtomicReference<>(); private final AtomicReference path = new AtomicReference<>(); + private final AtomicReference authorizationUser = new AtomicReference<>(); private final AtomicReference timeZoneId = new AtomicReference<>(); private final AtomicReference locale = new AtomicReference<>(); private final AtomicReference networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2))); @@ -746,6 +747,7 @@ StatementClient startQuery(String sql, Map sessionPropertiesOver .server(httpUri) .principal(user) .user(sessionUser.get()) + .authorizationUser(Optional.ofNullable(authorizationUser.get())) .source(source) .traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN))) .clientTags(ImmutableSet.copyOf(clientTags)) @@ -781,6 +783,15 @@ void updateSession(StatementClient client) client.getSetSchema().ifPresent(schema::set); client.getSetPath().ifPresent(path::set); + if (client.getSetAuthorizationUser().isPresent()) { + authorizationUser.set(client.getSetAuthorizationUser().get()); + roles.clear(); + } + if (client.isResetAuthorizationUser()) { + authorizationUser.set(null); + roles.clear(); + } + if (client.getStartedTransactionId() != null) { transactionId.set(client.getStartedTransactionId()); } @@ -810,6 +821,12 @@ int activeStatements() return statements.size(); } + @VisibleForTesting + String getAuthorizationUser() + { + return authorizationUser.get(); + } + private void checkOpen() throws SQLException { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java index 2396a84706bbd..4533b229213bd 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java @@ -1105,6 +1105,47 @@ public void testCustomDnsResolver() } } + @Test(timeOut = 10000) + public void testResetSessionAuthorization() + throws Exception + { + try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + assertEquals(connection.getAuthorizationUser(), null); + assertEquals(getCurrentUser(connection), "test"); + statement.execute("SET SESSION AUTHORIZATION john"); + assertEquals(connection.getAuthorizationUser(), "john"); + assertEquals(getCurrentUser(connection), "john"); + statement.execute("SET SESSION AUTHORIZATION bob"); + assertEquals(connection.getAuthorizationUser(), "bob"); + assertEquals(getCurrentUser(connection), "bob"); + statement.execute("RESET SESSION AUTHORIZATION"); + assertEquals(connection.getAuthorizationUser(), null); + assertEquals(getCurrentUser(connection), "test"); + } + } + + @Test(timeOut = 10000) + public void testSetRoleAfterSetSessionAuthorization() + throws Exception + { + try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + statement.execute("SET SESSION AUTHORIZATION john"); + assertEquals(connection.getAuthorizationUser(), "john"); + statement.execute("SET ROLE ALL"); + assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.ALL, Optional.empty()))); + statement.execute("SET SESSION AUTHORIZATION bob"); + assertEquals(connection.getAuthorizationUser(), "bob"); + assertEquals(connection.getRoles(), ImmutableMap.of()); + statement.execute("SET ROLE NONE"); + assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.NONE, Optional.empty()))); + statement.execute("RESET SESSION AUTHORIZATION"); + assertEquals(connection.getAuthorizationUser(), null); + assertEquals(connection.getRoles(), ImmutableMap.of()); + } + } + private QueryState getQueryState(String queryId) throws SQLException { @@ -1166,6 +1207,19 @@ private static Properties toProperties(Map map) return properties; } + private static String getCurrentUser(Connection connection) + throws SQLException + { + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("SELECT current_user")) { + while (rs.next()) { + return rs.getString(1); + } + } + + throw new RuntimeException("Failed to get CURRENT_USER"); + } + public static class TestingDnsResolver implements DnsResolver { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java index 6975a09a98ac7..a7d0676b3d80e 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java @@ -206,6 +206,18 @@ public Optional getSetPath() throw new UnsupportedOperationException(); } + @Override + public Optional getSetAuthorizationUser() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isResetAuthorizationUser() + { + throw new UnsupportedOperationException(); + } + @Override public Map getSetSessionProperties() { diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 9827015525a96..27a7c8c15af01 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -166,6 +166,8 @@ statement (LIKE pattern=string (ESCAPE escape=string)?)? #showFunctions | SHOW SESSION (LIKE pattern=string (ESCAPE escape=string)?)? #showSession + | SET SESSION AUTHORIZATION authorizationUser #setSessionAuthorization + | RESET SESSION AUTHORIZATION #resetSessionAuthorization | SET SESSION qualifiedName EQ expression #setSession | RESET SESSION qualifiedName #resetSession | START TRANSACTION (transactionMode (',' transactionMode)*)? #startTransaction @@ -886,6 +888,11 @@ number | MINUS? INTEGER_VALUE #integerLiteral ; +authorizationUser + : identifier #identifierUser + | string #stringUser + ; + nonReserved // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved : ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 6dc3db14ea381..19969b1f17f64 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -65,6 +65,7 @@ public final class Session private final Optional transactionId; private final boolean clientTransactionSupport; private final Identity identity; + private final Identity originalIdentity; private final Optional source; private final Optional catalog; private final Optional schema; @@ -93,6 +94,7 @@ public Session( Optional transactionId, boolean clientTransactionSupport, Identity identity, + Identity originalIdentity, Optional source, Optional catalog, Optional schema, @@ -119,6 +121,7 @@ public Session( this.transactionId = requireNonNull(transactionId, "transactionId is null"); this.clientTransactionSupport = clientTransactionSupport; this.identity = requireNonNull(identity, "identity is null"); + this.originalIdentity = requireNonNull(originalIdentity, "originalIdentity is null"); this.source = requireNonNull(source, "source is null"); this.catalog = requireNonNull(catalog, "catalog is null"); this.schema = requireNonNull(schema, "schema is null"); @@ -169,6 +172,11 @@ public Identity getIdentity() return identity; } + public Identity getOriginalIdentity() + { + return originalIdentity; + } + public Optional getSource() { return source; @@ -347,6 +355,7 @@ public Session beginTransactionId(TransactionId transactionId, TransactionManage Identity.from(identity) .withConnectorRoles(connectorRoles.buildOrThrow()) .build(), + originalIdentity, source, catalog, schema, @@ -395,6 +404,7 @@ public Session withDefaultProperties(Map systemPropertyDefaults, transactionId, clientTransactionSupport, identity, + originalIdentity, source, catalog, schema, @@ -426,6 +436,7 @@ public Session withExchangeEncryption(Slice encryptionKey) transactionId, clientTransactionSupport, identity, + originalIdentity, source, catalog, schema, @@ -475,7 +486,9 @@ public SessionRepresentation toSessionRepresentation() transactionId, clientTransactionSupport, identity.getUser(), + originalIdentity.getUser(), identity.getGroups(), + originalIdentity.getGroups(), identity.getPrincipal().map(Principal::toString), identity.getEnabledRoles(), source, @@ -588,6 +601,7 @@ public static class SessionBuilder private TransactionId transactionId; private boolean clientTransactionSupport; private Identity identity; + private Identity originalIdentity; private String source; private String catalog; private String schema; @@ -622,6 +636,7 @@ private SessionBuilder(Session session) this.transactionId = session.transactionId.orElse(null); this.clientTransactionSupport = session.clientTransactionSupport; this.identity = session.identity; + this.originalIdentity = session.originalIdentity; this.source = session.source.orElse(null); this.catalog = session.catalog.orElse(null); this.path = session.path; @@ -783,6 +798,13 @@ public SessionBuilder setIdentity(Identity identity) return this; } + @CanIgnoreReturnValue + public SessionBuilder setOriginalIdentity(Identity originalIdentity) + { + this.originalIdentity = originalIdentity; + return this; + } + @CanIgnoreReturnValue public SessionBuilder setUserAgent(String userAgent) { @@ -889,6 +911,7 @@ public Session build() Optional.ofNullable(transactionId), clientTransactionSupport, identity, + originalIdentity, Optional.ofNullable(source), Optional.ofNullable(catalog), Optional.ofNullable(schema), diff --git a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java index dd95c26f3e579..649b3cdbb3e85 100644 --- a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java @@ -47,7 +47,9 @@ public final class SessionRepresentation private final Optional transactionId; private final boolean clientTransactionSupport; private final String user; + private final String originalUser; private final Set groups; + private final Set originalUserGroups; private final Optional principal; private final Set enabledRoles; private final Optional source; @@ -77,7 +79,9 @@ public SessionRepresentation( @JsonProperty("transactionId") Optional transactionId, @JsonProperty("clientTransactionSupport") boolean clientTransactionSupport, @JsonProperty("user") String user, + @JsonProperty("originalUser") String originalUser, @JsonProperty("groups") Set groups, + @JsonProperty("originalUserGroups") Set originalUserGroups, @JsonProperty("principal") Optional principal, @JsonProperty("enabledRoles") Set enabledRoles, @JsonProperty("source") Optional source, @@ -105,7 +109,9 @@ public SessionRepresentation( this.transactionId = requireNonNull(transactionId, "transactionId is null"); this.clientTransactionSupport = clientTransactionSupport; this.user = requireNonNull(user, "user is null"); + this.originalUser = requireNonNull(originalUser, "originalUser is null"); this.groups = requireNonNull(groups, "groups is null"); + this.originalUserGroups = requireNonNull(originalUserGroups, "originalUserGroups is null"); this.principal = requireNonNull(principal, "principal is null"); this.enabledRoles = ImmutableSet.copyOf(requireNonNull(enabledRoles, "enabledRoles is null")); this.source = requireNonNull(source, "source is null"); @@ -164,12 +170,24 @@ public String getUser() return user; } + @JsonProperty + public String getOriginalUser() + { + return originalUser; + } + @JsonProperty public Set getGroups() { return groups; } + @JsonProperty + public Set getOriginalUserGroups() + { + return originalUserGroups; + } + @JsonProperty public Optional getPrincipal() { @@ -318,6 +336,15 @@ public Identity toIdentity(Map extraCredentials) .build(); } + public Identity toOriginalIdentity(Map extraCredentials) + { + return Identity.forUser(originalUser) + .withGroups(originalUserGroups) + .withPrincipal(principal.map(BasicPrincipal::new)) + .withExtraCredentials(extraCredentials) + .build(); + } + public Session toSession(SessionPropertyManager sessionPropertyManager) { return toSession(sessionPropertyManager, emptyMap(), Optional.empty()); @@ -331,6 +358,7 @@ public Session toSession(SessionPropertyManager sessionPropertyManager, Map void createQueryInternal(QueryId queryId, Span querySpan, Slug slug, session = Session.builder(sessionPropertyManager) .setQueryId(queryId) .setIdentity(sessionContext.getIdentity()) + .setOriginalIdentity(sessionContext.getOriginalIdentity()) .setSource(sessionContext.getSource().orElse(null)) .build(); } diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java b/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java index 05b9e596a7a61..36fa9270cba33 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java @@ -226,6 +226,8 @@ private static QueryInfo immediateFailureQueryInfo( Optional.empty(), Optional.empty(), Optional.empty(), + Optional.empty(), + false, ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java index bb258df0bd691..7951fc9d80f3b 100644 --- a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java +++ b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java @@ -343,6 +343,7 @@ private QueryContext createQueryContext(SessionRepresentation session, Optional< { return new QueryContext( session.getUser(), + session.getOriginalUser(), session.getPrincipal(), session.getGroups(), session.getTraceToken(), diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java b/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java index 65f1ed3f6ef02..0c276b6ab8cfe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java @@ -60,6 +60,8 @@ public class QueryInfo private final Optional setCatalog; private final Optional setSchema; private final Optional setPath; + private final Optional setAuthorizationUser; + private final boolean resetAuthorizationUser; private final Map setSessionProperties; private final Set resetSessionProperties; private final Map setRoles; @@ -97,6 +99,8 @@ public QueryInfo( @JsonProperty("setCatalog") Optional setCatalog, @JsonProperty("setSchema") Optional setSchema, @JsonProperty("setPath") Optional setPath, + @JsonProperty("setAuthorizationUser") Optional setAuthorizationUser, + @JsonProperty("resetAuthorizationUser") boolean resetAuthorizationUser, @JsonProperty("setSessionProperties") Map setSessionProperties, @JsonProperty("resetSessionProperties") Set resetSessionProperties, @JsonProperty("setRoles") Map setRoles, @@ -129,6 +133,7 @@ public QueryInfo( requireNonNull(setCatalog, "setCatalog is null"); requireNonNull(setSchema, "setSchema is null"); requireNonNull(setPath, "setPath is null"); + requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); requireNonNull(setSessionProperties, "setSessionProperties is null"); requireNonNull(resetSessionProperties, "resetSessionProperties is null"); requireNonNull(addedPreparedStatements, "addedPreparedStatements is null"); @@ -158,6 +163,8 @@ public QueryInfo( this.setCatalog = setCatalog; this.setSchema = setSchema; this.setPath = setPath; + this.setAuthorizationUser = setAuthorizationUser; + this.resetAuthorizationUser = resetAuthorizationUser; this.setSessionProperties = ImmutableMap.copyOf(setSessionProperties); this.resetSessionProperties = ImmutableSet.copyOf(resetSessionProperties); this.setRoles = ImmutableMap.copyOf(setRoles); @@ -268,6 +275,18 @@ public Optional getSetPath() return setPath; } + @JsonProperty + public Optional getSetAuthorizationUser() + { + return setAuthorizationUser; + } + + @JsonProperty + public boolean isResetAuthorizationUser() + { + return resetAuthorizationUser; + } + @JsonProperty public Map getSetSessionProperties() { diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 62aeb985b8628..f23197db8d3fd 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -151,6 +151,9 @@ public class QueryStateMachine private final AtomicReference setSchema = new AtomicReference<>(); private final AtomicReference setPath = new AtomicReference<>(); + private final AtomicReference setAuthorizationUser = new AtomicReference<>(); + private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean(); + private final Map setSessionProperties = new ConcurrentHashMap<>(); private final Set resetSessionProperties = Sets.newConcurrentHashSet(); @@ -530,6 +533,8 @@ QueryInfo getQueryInfo(Optional rootStage) Optional.ofNullable(setCatalog.get()), Optional.ofNullable(setSchema.get()), Optional.ofNullable(setPath.get()), + Optional.ofNullable(setAuthorizationUser.get()), + resetAuthorizationUser.get(), setSessionProperties, resetSessionProperties, setRoles, @@ -923,6 +928,18 @@ public String getSetPath() return setPath.get(); } + public void setSetAuthorizationUser(String authorizationUser) + { + checkState(authorizationUser != null && !authorizationUser.isEmpty(), "Authorization user cannot be null or empty"); + setAuthorizationUser.set(authorizationUser); + } + + public void resetAuthorizationUser() + { + checkArgument(setAuthorizationUser.get() == null, "Cannot set and reset the authorization user in the same request"); + resetAuthorizationUser.set(true); + } + public void addSetSessionProperties(String key, String value) { setSessionProperties.put(requireNonNull(key, "key is null"), requireNonNull(value, "value is null")); @@ -1306,6 +1323,8 @@ public void pruneQueryInfo() queryInfo.getSetCatalog(), queryInfo.getSetSchema(), queryInfo.getSetPath(), + queryInfo.getSetAuthorizationUser(), + queryInfo.isResetAuthorizationUser(), queryInfo.getSetSessionProperties(), queryInfo.getResetSessionProperties(), queryInfo.getSetRoles(), diff --git a/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java new file mode 100644 index 0000000000000..6bc1b2e668678 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/ResetSessionAuthorizationTask.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.inject.Inject; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.spi.TrinoException; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.ResetSessionAuthorization; +import io.trino.transaction.TransactionManager; + +import java.util.List; + +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static java.util.Objects.requireNonNull; + +public class ResetSessionAuthorizationTask + implements DataDefinitionTask +{ + private final TransactionManager transactionManager; + + @Inject + public ResetSessionAuthorizationTask(TransactionManager transactionManager) + { + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + } + + @Override + public String getName() + { + return "RESET SESSION AUTHORIZATION"; + } + + @Override + public ListenableFuture execute( + ResetSessionAuthorization statement, + QueryStateMachine stateMachine, + List parameters, + WarningCollector warningCollector) + { + Session session = stateMachine.getSession(); + session.getTransactionId().ifPresent(transactionId -> { + if (!transactionManager.getTransactionInfo(transactionId).isAutoCommitContext()) { + throw new TrinoException(GENERIC_USER_ERROR, "Can't reset authorization user in the middle of a transaction"); + } + }); + stateMachine.resetAuthorizationUser(); + return immediateFuture(null); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java new file mode 100644 index 0000000000000..99afede9118cf --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.inject.Inject; +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.security.AccessControl; +import io.trino.spi.TrinoException; +import io.trino.spi.security.Identity; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.SetSessionAuthorization; +import io.trino.sql.tree.StringLiteral; +import io.trino.transaction.TransactionManager; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static java.util.Objects.requireNonNull; + +public class SetSessionAuthorizationTask + implements DataDefinitionTask +{ + private final AccessControl accessControl; + private final TransactionManager transactionManager; + + @Inject + public SetSessionAuthorizationTask(AccessControl accessControl, TransactionManager transactionManager) + { + this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + } + + @Override + public String getName() + { + return "SET SESSION AUTHORIZATION"; + } + + @Override + public ListenableFuture execute( + SetSessionAuthorization statement, + QueryStateMachine stateMachine, + List parameters, + WarningCollector warningCollector) + { + Session session = stateMachine.getSession(); + Identity originalIdentity = session.getOriginalIdentity(); + // Set authorization user in the middle of a transaction is disallowed by the SQL spec + session.getTransactionId().ifPresent(transactionId -> { + if (!transactionManager.getTransactionInfo(transactionId).isAutoCommitContext()) { + throw new TrinoException(GENERIC_USER_ERROR, "Can't set authorization user in the middle of a transaction"); + } + }); + + String user; + Expression userExpression = statement.getUser(); + if (userExpression instanceof Identifier identifier) { + user = identifier.getValue(); + } + else if (userExpression instanceof StringLiteral stringLiteral) { + user = stringLiteral.getValue(); + } + else { + throw new IllegalArgumentException("Unsupported user expression: " + userExpression.getClass().getName()); + } + checkState(user != null && !user.isEmpty(), "Authorization user cannot be null or empty"); + + if (!originalIdentity.getUser().equals(user)) { + accessControl.checkCanSetUser(originalIdentity.getPrincipal(), user); + accessControl.checkCanImpersonateUser(originalIdentity, user); + } + stateMachine.setSetAuthorizationUser(user); + return immediateFuture(null); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java index b3a3234fb134e..58209aae8e46b 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java @@ -107,6 +107,7 @@ public SessionContext createSessionContext( requireNonNull(authenticatedIdentity, "authenticatedIdentity is null"); Identity identity = buildSessionIdentity(authenticatedIdentity, protocolHeaders, headers); + Identity originalIdentity = buildSessionOriginalIdentity(identity, protocolHeaders, headers); SelectedRole selectedRole = parseSystemRoleHeaders(protocolHeaders, headers); Optional source = Optional.ofNullable(headers.getFirst(protocolHeaders.requestSource())); @@ -165,6 +166,7 @@ else if (nameParts.size() == 2) { path, authenticatedIdentity, identity, + originalIdentity, selectedRole, source, traceToken, @@ -209,21 +211,27 @@ public Identity extractAuthorizedIdentity( } Identity identity = buildSessionIdentity(optionalAuthenticatedIdentity, protocolHeaders, headers); + Identity originalIdentity = buildSessionOriginalIdentity(identity, protocolHeaders, headers); - accessControl.checkCanSetUser(identity.getPrincipal(), identity.getUser()); + accessControl.checkCanSetUser(originalIdentity.getPrincipal(), originalIdentity.getUser()); // authenticated may not present for HTTP or if authentication is not setup optionalAuthenticatedIdentity.ifPresent(authenticatedIdentity -> { // only check impersonation if authenticated user is not the same as the explicitly set user - if (!authenticatedIdentity.getUser().equals(identity.getUser())) { + if (!authenticatedIdentity.getUser().equals(originalIdentity.getUser())) { // load enabled roles for authenticated identity, so impersonation permissions can be assigned to roles authenticatedIdentity = Identity.from(authenticatedIdentity) .withEnabledRoles(metadata.listEnabledRoles(authenticatedIdentity)) .build(); - accessControl.checkCanImpersonateUser(authenticatedIdentity, identity.getUser()); + accessControl.checkCanImpersonateUser(authenticatedIdentity, originalIdentity.getUser()); } }); + if (!originalIdentity.getUser().equals(identity.getUser())) { + accessControl.checkCanSetUser(originalIdentity.getPrincipal(), identity.getUser()); + accessControl.checkCanImpersonateUser(originalIdentity, identity.getUser()); + } + return addEnabledRoles(identity, parseSystemRoleHeaders(protocolHeaders, headers), metadata); } @@ -265,6 +273,20 @@ private Identity buildSessionIdentity(Optional authenticatedIdentity, .build(); } + private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders protocolHeaders, MultivaluedMap headers) + { + // We derive original identity using this header, but older clients will not send it, so fall back to identity + Optional optionalOriginalUser = Optional + .ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser()))); + Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity) + .withUser(originalUser) + .withExtraCredentials(new HashMap<>()) + .withGroups(groupProvider.getGroups(originalUser)) + .build()) + .orElse(identity); + return originalIdentity; + } + private static List splitHttpHeader(MultivaluedMap headers, String name) { List values = firstNonNull(headers.get(name), ImmutableList.of()); diff --git a/core/trino-main/src/main/java/io/trino/server/QueryExecutionFactoryModule.java b/core/trino-main/src/main/java/io/trino/server/QueryExecutionFactoryModule.java index 54d77391ce648..e43f9c2c35b54 100644 --- a/core/trino-main/src/main/java/io/trino/server/QueryExecutionFactoryModule.java +++ b/core/trino-main/src/main/java/io/trino/server/QueryExecutionFactoryModule.java @@ -48,6 +48,7 @@ import io.trino.execution.RenameSchemaTask; import io.trino.execution.RenameTableTask; import io.trino.execution.RenameViewTask; +import io.trino.execution.ResetSessionAuthorizationTask; import io.trino.execution.ResetSessionTask; import io.trino.execution.RevokeRolesTask; import io.trino.execution.RevokeTask; @@ -57,6 +58,7 @@ import io.trino.execution.SetPropertiesTask; import io.trino.execution.SetRoleTask; import io.trino.execution.SetSchemaAuthorizationTask; +import io.trino.execution.SetSessionAuthorizationTask; import io.trino.execution.SetSessionTask; import io.trino.execution.SetTableAuthorizationTask; import io.trino.execution.SetTimeZoneTask; @@ -93,6 +95,7 @@ import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; import io.trino.sql.tree.ResetSession; +import io.trino.sql.tree.ResetSessionAuthorization; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; @@ -102,6 +105,7 @@ import io.trino.sql.tree.SetRole; import io.trino.sql.tree.SetSchemaAuthorization; import io.trino.sql.tree.SetSession; +import io.trino.sql.tree.SetSessionAuthorization; import io.trino.sql.tree.SetTableAuthorization; import io.trino.sql.tree.SetTimeZone; import io.trino.sql.tree.SetViewAuthorization; @@ -159,6 +163,7 @@ public void configure(Binder binder) bindDataDefinitionTask(binder, executionBinder, RenameTable.class, RenameTableTask.class); bindDataDefinitionTask(binder, executionBinder, RenameView.class, RenameViewTask.class); bindDataDefinitionTask(binder, executionBinder, ResetSession.class, ResetSessionTask.class); + bindDataDefinitionTask(binder, executionBinder, ResetSessionAuthorization.class, ResetSessionAuthorizationTask.class); bindDataDefinitionTask(binder, executionBinder, Revoke.class, RevokeTask.class); bindDataDefinitionTask(binder, executionBinder, RevokeRoles.class, RevokeRolesTask.class); bindDataDefinitionTask(binder, executionBinder, Rollback.class, RollbackTask.class); @@ -169,6 +174,7 @@ public void configure(Binder binder) bindDataDefinitionTask(binder, executionBinder, SetRole.class, SetRoleTask.class); bindDataDefinitionTask(binder, executionBinder, SetSchemaAuthorization.class, SetSchemaAuthorizationTask.class); bindDataDefinitionTask(binder, executionBinder, SetSession.class, SetSessionTask.class); + bindDataDefinitionTask(binder, executionBinder, SetSessionAuthorization.class, SetSessionAuthorizationTask.class); bindDataDefinitionTask(binder, executionBinder, SetTableAuthorization.class, SetTableAuthorizationTask.class); bindDataDefinitionTask(binder, executionBinder, SetViewAuthorization.class, SetViewAuthorizationTask.class); bindDataDefinitionTask(binder, executionBinder, StartTransaction.class, StartTransactionTask.class); diff --git a/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java b/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java index 21c063182c8e9..65902978425c9 100644 --- a/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java +++ b/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java @@ -71,20 +71,29 @@ public QuerySessionSupplier( @Override public Session createSession(QueryId queryId, Span querySpan, SessionContext context) { - Identity identity = context.getIdentity(); - accessControl.checkCanSetUser(identity.getPrincipal(), identity.getUser()); + Identity originalIdentity = context.getOriginalIdentity(); + accessControl.checkCanSetUser(originalIdentity.getPrincipal(), originalIdentity.getUser()); // authenticated identity is not present for HTTP or if authentication is not setup if (context.getAuthenticatedIdentity().isPresent()) { Identity authenticatedIdentity = context.getAuthenticatedIdentity().get(); // only check impersonation if authenticated user is not the same as the explicitly set user - if (!authenticatedIdentity.getUser().equals(identity.getUser())) { + if (!authenticatedIdentity.getUser().equals(originalIdentity.getUser())) { // add enabled roles for authenticated identity, so impersonation permissions can be assigned to roles authenticatedIdentity = addEnabledRoles(authenticatedIdentity, context.getSelectedRole(), metadata); - accessControl.checkCanImpersonateUser(authenticatedIdentity, identity.getUser()); + accessControl.checkCanImpersonateUser(authenticatedIdentity, originalIdentity.getUser()); } } + Identity identity = context.getIdentity(); + if (!originalIdentity.getUser().equals(identity.getUser())) { + // When the current user (user) and the original user are different, we check if the original user can impersonate current user. + // We preserve the information of original user in the originalIdentity, + // and it will be used for the impersonation checks and be used as the source of audit information. + accessControl.checkCanSetUser(originalIdentity.getPrincipal(), identity.getUser()); + accessControl.checkCanImpersonateUser(originalIdentity, identity.getUser()); + } + // add the enabled roles identity = addEnabledRoles(identity, context.getSelectedRole(), metadata); @@ -92,6 +101,7 @@ public Session createSession(QueryId queryId, Span querySpan, SessionContext con .setQueryId(queryId) .setQuerySpan(querySpan) .setIdentity(identity) + .setOriginalIdentity(originalIdentity) .setPath(context.getPath().or(() -> defaultPath).map(SqlPath::new)) .setSource(context.getSource()) .setRemoteUserAddress(context.getRemoteUserAddress()) diff --git a/core/trino-main/src/main/java/io/trino/server/SessionContext.java b/core/trino-main/src/main/java/io/trino/server/SessionContext.java index 318548129e654..6db911268af14 100644 --- a/core/trino-main/src/main/java/io/trino/server/SessionContext.java +++ b/core/trino-main/src/main/java/io/trino/server/SessionContext.java @@ -39,6 +39,7 @@ public class SessionContext private final Optional authenticatedIdentity; private final Identity identity; + private final Identity originalIdentity; private final SelectedRole selectedRole; private final Optional source; @@ -67,6 +68,7 @@ public SessionContext( Optional path, Optional authenticatedIdentity, Identity identity, + Identity originalIdentity, SelectedRole selectedRole, Optional source, Optional traceToken, @@ -90,6 +92,7 @@ public SessionContext( this.path = requireNonNull(path, "path is null"); this.authenticatedIdentity = requireNonNull(authenticatedIdentity, "authenticatedIdentity is null"); this.identity = requireNonNull(identity, "identity is null"); + this.originalIdentity = requireNonNull(originalIdentity, "originalIdentity is null"); this.selectedRole = requireNonNull(selectedRole, "selectedRole is null"); this.source = requireNonNull(source, "source is null"); this.traceToken = requireNonNull(traceToken, "traceToken is null"); @@ -125,6 +128,11 @@ public Identity getIdentity() return identity; } + public Identity getOriginalIdentity() + { + return originalIdentity; + } + public SelectedRole getSelectedRole() { return selectedRole; diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java index d80c9f57edbbd..2053385cd5386 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java @@ -234,6 +234,10 @@ private Response toResponse(QueryResultsResponse resultsResponse) resultsResponse.setCatalog().ifPresent(catalog -> response.header(protocolHeaders.responseSetCatalog(), catalog)); resultsResponse.setSchema().ifPresent(schema -> response.header(protocolHeaders.responseSetSchema(), schema)); resultsResponse.setPath().ifPresent(path -> response.header(protocolHeaders.responseSetPath(), path)); + resultsResponse.setAuthorizationUser().ifPresent(authorizationUser -> response.header(protocolHeaders.responseSetAuthorizationUser(), authorizationUser)); + if (resultsResponse.resetAuthorizationUser()) { + response.header(protocolHeaders.responseResetAuthorizationUser(), true); + } // add set session properties resultsResponse.setSessionProperties() diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index 5f1b04a7164a5..385cd0dc05c08 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -139,6 +139,12 @@ class Query @GuardedBy("this") private Optional setPath = Optional.empty(); + @GuardedBy("this") + private Optional setAuthorizationUser = Optional.empty(); + + @GuardedBy("this") + private boolean resetAuthorizationUser; + @GuardedBy("this") private Map setSessionProperties = ImmutableMap.of(); @@ -463,6 +469,10 @@ private synchronized QueryResultsResponse getNextResult(long token, UriInfo uriI setSchema = queryInfo.getSetSchema(); setPath = queryInfo.getSetPath(); + // update setAuthorizationUser + setAuthorizationUser = queryInfo.getSetAuthorizationUser(); + resetAuthorizationUser = queryInfo.isResetAuthorizationUser(); + // update setSessionProperties setSessionProperties = queryInfo.getSetSessionProperties(); resetSessionProperties = queryInfo.getResetSessionProperties(); @@ -505,6 +515,8 @@ private synchronized QueryResultsResponse toResultsResponse(QueryResults queryRe setCatalog, setSchema, setPath, + setAuthorizationUser, + resetAuthorizationUser, setSessionProperties, resetSessionProperties, setRoles, diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java index 0bce541cefe59..4387fd9f405b4 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java @@ -28,6 +28,8 @@ record QueryResultsResponse( Optional setCatalog, Optional setSchema, Optional setPath, + Optional setAuthorizationUser, + boolean resetAuthorizationUser, Map setSessionProperties, Set resetSessionProperties, Map setRoles, @@ -42,6 +44,7 @@ record QueryResultsResponse( requireNonNull(setCatalog, "setCatalog is null"); requireNonNull(setSchema, "setSchema is null"); requireNonNull(setPath, "setPath is null"); + requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); requireNonNull(setSessionProperties, "setSessionProperties is null"); requireNonNull(resetSessionProperties, "resetSessionProperties is null"); requireNonNull(setRoles, "setRoles is null"); diff --git a/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java index 4e76a2e6bbb04..9b10d40b02064 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java @@ -58,7 +58,10 @@ public Identity authenticate(ContainerRequestContext request) else { try { ProtocolHeaders protocolHeaders = detectProtocol(alternateHeaderName, request.getHeaders().keySet()); - user = emptyToNull(request.getHeaders().getFirst(protocolHeaders.requestUser())); + user = emptyToNull(request.getHeaders().getFirst(protocolHeaders.requestOriginalUser())); + if (user == null) { + user = emptyToNull(request.getHeaders().getFirst(protocolHeaders.requestUser())); + } } catch (ProtocolDetectionException e) { // ignored @@ -67,7 +70,7 @@ public Identity authenticate(ContainerRequestContext request) } if (user == null) { - throw new AuthenticationException("Basic authentication or " + TRINO_HEADERS.requestUser() + " must be sent", BasicAuthCredentials.AUTHENTICATE_HEADER); + throw new AuthenticationException("Basic authentication or " + TRINO_HEADERS.requestOriginalUser() + " or " + TRINO_HEADERS.requestUser() + " must be sent", BasicAuthCredentials.AUTHENTICATE_HEADER); } try { diff --git a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java index dd82da505e441..9af26fbbfef78 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/PasswordAuthenticator.java @@ -94,7 +94,7 @@ private void rewriteUserHeaderToMappedUser(BasicAuthCredentials basicAuthCredent { String userHeader; try { - userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestUser(); + userHeader = getUserHeader(headers); } catch (ProtocolDetectionException ignored) { // this shouldn't fail here, but ignore and it will be handled elsewhere @@ -105,6 +105,17 @@ private void rewriteUserHeaderToMappedUser(BasicAuthCredentials basicAuthCredent } } + // Extract this out in a method so that the logic of preferring originalUser and fallback on user remains in one place + private String getUserHeader(MultivaluedMap headers) + throws ProtocolDetectionException + { + String userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestOriginalUser(); + if (headers.getFirst(userHeader) == null || headers.getFirst(userHeader).isEmpty()) { + userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestUser(); + } + return userHeader; + } + private static AuthenticationException needAuthentication(String message) { return new AuthenticationException(message, BasicAuthCredentials.AUTHENTICATE_HEADER); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 7945af4ef9b04..d4487aa50e701 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -204,6 +204,7 @@ import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; import io.trino.sql.tree.ResetSession; +import io.trino.sql.tree.ResetSessionAuthorization; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.Rollback; import io.trino.sql.tree.Row; @@ -216,6 +217,7 @@ import io.trino.sql.tree.SetProperties; import io.trino.sql.tree.SetSchemaAuthorization; import io.trino.sql.tree.SetSession; +import io.trino.sql.tree.SetSessionAuthorization; import io.trino.sql.tree.SetTableAuthorization; import io.trino.sql.tree.SetTimeZone; import io.trino.sql.tree.SetViewAuthorization; @@ -1035,6 +1037,18 @@ protected Scope visitResetSession(ResetSession node, Optional scope) return createAndAssignScope(node, scope); } + @Override + protected Scope visitSetSessionAuthorization(SetSessionAuthorization node, Optional scope) + { + return createAndAssignScope(node, scope); + } + + @Override + protected Scope visitResetSessionAuthorization(ResetSessionAuthorization node, Optional scope) + { + return createAndAssignScope(node, scope); + } + @Override protected Scope visitAddColumn(AddColumn node, Optional scope) { @@ -5541,6 +5555,7 @@ private Session createViewSession(Optional catalog, Optional sch .setQueryId(session.getQueryId()) .setTransactionId(session.getTransactionId().orElse(null)) .setIdentity(identity) + .setOriginalIdentity(session.getOriginalIdentity()) .setSource(session.getSource().orElse(null)) .setCatalog(catalog) .setSchema(schema) diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 2aef65e109849..705a69e65f846 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -503,6 +503,7 @@ private LocalQueryRunner( transactionId, defaultSession.isClientTransactionSupport(), defaultSession.getIdentity(), + defaultSession.getOriginalIdentity(), defaultSession.getSource(), defaultSession.getCatalog(), defaultSession.getSchema(), diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java index 9d4f2637a1400..856b65b84948a 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java @@ -53,6 +53,7 @@ public static SessionBuilder testSessionBuilder(SessionPropertyManager sessionPr return Session.builder(sessionPropertyManager) .setQueryId(queryIdGenerator.createNextQueryId()) .setIdentity(Identity.ofUser("user")) + .setOriginalIdentity(Identity.ofUser("user")) .setSource("test") .setCatalog("catalog") .setSchema("schema") diff --git a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java index 8f01a60fe12ba..205fc01fb80da 100644 --- a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java @@ -43,6 +43,7 @@ import io.trino.execution.RenameSchemaTask; import io.trino.execution.RenameTableTask; import io.trino.execution.RenameViewTask; +import io.trino.execution.ResetSessionAuthorizationTask; import io.trino.execution.ResetSessionTask; import io.trino.execution.RevokeRolesTask; import io.trino.execution.RevokeTask; @@ -52,6 +53,7 @@ import io.trino.execution.SetPropertiesTask; import io.trino.execution.SetRoleTask; import io.trino.execution.SetSchemaAuthorizationTask; +import io.trino.execution.SetSessionAuthorizationTask; import io.trino.execution.SetSessionTask; import io.trino.execution.SetTableAuthorizationTask; import io.trino.execution.SetTimeZoneTask; @@ -99,6 +101,7 @@ import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; import io.trino.sql.tree.ResetSession; +import io.trino.sql.tree.ResetSessionAuthorization; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; @@ -108,6 +111,7 @@ import io.trino.sql.tree.SetRole; import io.trino.sql.tree.SetSchemaAuthorization; import io.trino.sql.tree.SetSession; +import io.trino.sql.tree.SetSessionAuthorization; import io.trino.sql.tree.SetTableAuthorization; import io.trino.sql.tree.SetTimeZone; import io.trino.sql.tree.SetViewAuthorization; @@ -216,6 +220,7 @@ private StatementUtils() {} .add(dataDefinitionStatement(RenameTable.class, RenameTableTask.class)) .add(dataDefinitionStatement(RenameView.class, RenameViewTask.class)) .add(dataDefinitionStatement(ResetSession.class, ResetSessionTask.class)) + .add(dataDefinitionStatement(ResetSessionAuthorization.class, ResetSessionAuthorizationTask.class)) .add(dataDefinitionStatement(Revoke.class, RevokeTask.class)) .add(dataDefinitionStatement(RevokeRoles.class, RevokeRolesTask.class)) .add(dataDefinitionStatement(Rollback.class, RollbackTask.class)) @@ -224,6 +229,7 @@ private StatementUtils() {} .add(dataDefinitionStatement(SetRole.class, SetRoleTask.class)) .add(dataDefinitionStatement(SetSchemaAuthorization.class, SetSchemaAuthorizationTask.class)) .add(dataDefinitionStatement(SetSession.class, SetSessionTask.class)) + .add(dataDefinitionStatement(SetSessionAuthorization.class, SetSessionAuthorizationTask.class)) .add(dataDefinitionStatement(SetProperties.class, SetPropertiesTask.class)) .add(dataDefinitionStatement(SetTableAuthorization.class, SetTableAuthorizationTask.class)) .add(dataDefinitionStatement(SetTimeZone.class, SetTimeZoneTask.class)) diff --git a/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java b/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java index d439aefc0d16f..8e4502fa0fe8a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java @@ -257,6 +257,8 @@ public QueryInfo getFullQueryInfo() Optional.empty(), Optional.empty(), Optional.empty(), + Optional.empty(), + false, ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java index fff0b491fdf71..2531537b2fb12 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java @@ -108,6 +108,8 @@ private static QueryInfo createQueryInfo() Optional.of("set_catalog"), Optional.of("set_schema"), Optional.of("set_path"), + Optional.of("set_authorization_user"), + false, ImmutableMap.of("set_property", "set_value"), ImmutableSet.of("reset_property"), ImmutableMap.of("set_roles", new SelectedRole(SelectedRole.Type.ROLE, Optional.of("role"))), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSetSessionAuthorizationTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSetSessionAuthorizationTask.java new file mode 100644 index 0000000000000..b2975fa015e5e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/TestSetSessionAuthorizationTask.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import io.trino.client.NodeVersion; +import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.Metadata; +import io.trino.security.AccessControl; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.TrinoException; +import io.trino.spi.resourcegroups.ResourceGroupId; +import io.trino.sql.parser.ParsingOptions; +import io.trino.sql.parser.SqlParser; +import io.trino.sql.tree.SetSessionAuthorization; +import io.trino.transaction.TransactionId; +import io.trino.transaction.TransactionManager; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; +import static io.trino.metadata.MetadataManager.testMetadataManagerBuilder; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; +import static java.util.Collections.emptyList; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; + +public class TestSetSessionAuthorizationTask +{ + private TransactionManager transactionManager; + private AccessControl accessControl; + private Metadata metadata; + private SqlParser parser; + private ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); + + @BeforeClass + public void setUp() + { + transactionManager = createTestTransactionManager(); + accessControl = new AllowAllAccessControl(); + metadata = testMetadataManagerBuilder() + .withTransactionManager(transactionManager) + .build(); + parser = new SqlParser(); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + executor = null; + transactionManager = null; + accessControl = null; + metadata = null; + } + + @Test + public void testSetSessionAuthorization() + { + assertSetSessionAuthorization("SET SESSION AUTHORIZATION otheruser", Optional.of("otheruser")); + assertSetSessionAuthorization("SET SESSION AUTHORIZATION 'otheruser'", Optional.of("otheruser")); + assertSetSessionAuthorization("SET SESSION AUTHORIZATION \"otheruser\"", Optional.of("otheruser")); + } + + @Test + public void testSetSessionAuthorizationInTransaction() + { + String query = "SET SESSION AUTHORIZATION user"; + SetSessionAuthorization statement = (SetSessionAuthorization) parser.createStatement(query, new ParsingOptions()); + TransactionId transactionId = transactionManager.beginTransaction(false); + QueryStateMachine stateMachine = createStateMachine(Optional.of(transactionId), query); + assertThatThrownBy(() -> new SetSessionAuthorizationTask(accessControl, transactionManager).execute(statement, stateMachine, emptyList(), WarningCollector.NOOP)) + .isInstanceOf(TrinoException.class) + .hasMessageContaining("Can't set authorization user in the middle of a transaction"); + } + + private void assertSetSessionAuthorization(String query, Optional expected) + { + SetSessionAuthorization statement = (SetSessionAuthorization) parser.createStatement(query, new ParsingOptions()); + QueryStateMachine stateMachine = createStateMachine(Optional.empty(), query); + new SetSessionAuthorizationTask(accessControl, transactionManager).execute(statement, stateMachine, emptyList(), WarningCollector.NOOP); + QueryInfo queryInfo = stateMachine.getQueryInfo(Optional.empty()); + assertEquals(queryInfo.getSetAuthorizationUser(), expected); + } + + private QueryStateMachine createStateMachine(Optional transactionId, String query) + { + QueryStateMachine stateMachine = QueryStateMachine.begin( + transactionId, + query, + Optional.empty(), + testSessionBuilder().build(), + URI.create("fake://uri"), + new ResourceGroupId("test"), + false, + transactionManager, + accessControl, + executor, + metadata, + WarningCollector.NOOP, + createPlanOptimizersStatsCollector(), + Optional.empty(), + true, + new NodeVersion("test")); + return stateMachine; + } +} diff --git a/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java b/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java index 6f5b543ddda3b..7b43d0cd01bb2 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java +++ b/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java @@ -142,6 +142,8 @@ public void testConstructor() Optional.empty(), Optional.empty(), Optional.empty(), + Optional.empty(), + false, ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java index 4ae7ad74c833a..f9ef58545158e 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java @@ -186,6 +186,8 @@ private QueryInfo createQueryInfo(String queryId, QueryState state, String query Optional.empty(), Optional.empty(), Optional.empty(), + Optional.empty(), + false, ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/server/TestSessionPropertyDefaults.java b/core/trino-main/src/test/java/io/trino/server/TestSessionPropertyDefaults.java index 5435d1ef4bab7..29b58f2a09cfc 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestSessionPropertyDefaults.java +++ b/core/trino-main/src/test/java/io/trino/server/TestSessionPropertyDefaults.java @@ -75,6 +75,7 @@ public void testApplyDefaultProperties() Session session = Session.builder(sessionPropertyManager) .setQueryId(new QueryId("test_query_id")) .setIdentity(Identity.ofUser("testUser")) + .setOriginalIdentity(Identity.ofUser("testUser")) .setSystemProperty(QUERY_MAX_MEMORY, "1GB") // Override this default system property .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "partitioned") .setSystemProperty(MAX_HASH_PARTITION_COUNT, "43") diff --git a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java index 60d9581b3de3a..715e33e5c8905 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java +++ b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java @@ -954,6 +954,38 @@ public void testJwtWithRefreshTokensForOAuth2Enabled() } } + @Test + public void testResourceSecurityImpersonation() + throws Exception + { + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(SECURE_PROPERTIES) + .put("password-authenticator.config-files", passwordConfigDummy.toString()) + .put("http-server.authentication.type", "password") + .put("http-server.authentication.password.user-mapping.pattern", ALLOWED_USER_MAPPING_PATTERN) + .buildOrThrow()) + .setAdditionalModule(binder -> jaxrsBinder(binder).bind(TestResource.class)) + .build()) { + server.getInstance(Key.get(PasswordAuthenticatorManager.class)).setAuthenticators(TestResourceSecurity::authenticate); + server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + + // Authenticated user TEST_USER_LOGIN impersonates impersonated-user by passing request header X-Trino-Authorization-User + Request request = new Request.Builder() + .url(getLocation(httpServerInfo.getHttpsUri(), "/protocol/identity")) + .addHeader("Authorization", Credentials.basic(TEST_USER_LOGIN, TEST_PASSWORD)) + .addHeader("X-Trino-Original-User", TEST_USER_LOGIN) + .addHeader("X-Trino-User", "impersonated-user") + .build(); + try (Response response = client.newCall(request).execute()) { + assertEquals(response.code(), SC_OK); + assertEquals(response.header("user"), "impersonated-user"); + assertEquals(response.header("principal"), TEST_USER_LOGIN); + } + } + } + private static Module oauth2Module(TokenServer tokenServer) { return binder -> { diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index e5f9e367c7f60..7672c9bcf9d6b 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -103,6 +103,7 @@ import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; import io.trino.sql.tree.ResetSession; +import io.trino.sql.tree.ResetSessionAuthorization; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; @@ -117,6 +118,7 @@ import io.trino.sql.tree.SetRole; import io.trino.sql.tree.SetSchemaAuthorization; import io.trino.sql.tree.SetSession; +import io.trino.sql.tree.SetSessionAuthorization; import io.trino.sql.tree.SetTableAuthorization; import io.trino.sql.tree.SetTimeZone; import io.trino.sql.tree.SetViewAuthorization; @@ -1889,6 +1891,21 @@ public Void visitResetSession(ResetSession node, Integer indent) return null; } + @Override + protected Void visitSetSessionAuthorization(SetSessionAuthorization node, Integer context) + { + builder.append("SET SESSION AUTHORIZATION "); + builder.append(formatExpression(node.getUser())); + return null; + } + + @Override + protected Void visitResetSessionAuthorization(ResetSessionAuthorization node, Integer context) + { + builder.append("RESET SESSION AUTHORIZATION"); + return null; + } + @Override protected Void visitCallArgument(CallArgument node, Integer indent) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 883ed1089ae6c..68772c9427a2c 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -200,6 +200,7 @@ import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; import io.trino.sql.tree.ResetSession; +import io.trino.sql.tree.ResetSessionAuthorization; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; @@ -217,6 +218,7 @@ import io.trino.sql.tree.SetRole; import io.trino.sql.tree.SetSchemaAuthorization; import io.trino.sql.tree.SetSession; +import io.trino.sql.tree.SetSessionAuthorization; import io.trino.sql.tree.SetTableAuthorization; import io.trino.sql.tree.SetTimeZone; import io.trino.sql.tree.SetViewAuthorization; @@ -1452,6 +1454,23 @@ public Node visitResetSession(SqlBaseParser.ResetSessionContext context) return new ResetSession(getLocation(context), getQualifiedName(context.qualifiedName())); } + @Override + public Node visitSetSessionAuthorization(SqlBaseParser.SetSessionAuthorizationContext context) + { + if (context.authorizationUser() instanceof SqlBaseParser.IdentifierUserContext || context.authorizationUser() instanceof SqlBaseParser.StringUserContext) { + return new SetSessionAuthorization(getLocation(context), (Expression) visit(context.authorizationUser())); + } + else { + throw new IllegalArgumentException("Unsupported Session Authorization User: " + context.authorizationUser()); + } + } + + @Override + public Node visitResetSessionAuthorization(SqlBaseParser.ResetSessionAuthorizationContext context) + { + return new ResetSessionAuthorization(getLocation(context)); + } + @Override public Node visitCreateRole(SqlBaseParser.CreateRoleContext context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index fb47eb92bd785..2f975f54ee161 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -187,6 +187,16 @@ protected R visitResetSession(ResetSession node, C context) return visitStatement(node, context); } + protected R visitSetSessionAuthorization(SetSessionAuthorization node, C context) + { + return visitStatement(node, context); + } + + protected R visitResetSessionAuthorization(ResetSessionAuthorization node, C context) + { + return visitStatement(node, context); + } + protected R visitGenericLiteral(GenericLiteral node, C context) { return visitLiteral(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ResetSessionAuthorization.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ResetSessionAuthorization.java new file mode 100644 index 0000000000000..aaca69e5e4e7f --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ResetSessionAuthorization.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +public class ResetSessionAuthorization + extends Statement +{ + public ResetSessionAuthorization() + { + this(Optional.empty()); + } + + public ResetSessionAuthorization(NodeLocation location) + { + this(Optional.of(location)); + } + + private ResetSessionAuthorization(Optional location) + { + super(location); + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitResetSessionAuthorization(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public int hashCode() + { + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + return true; + } + + @Override + public String toString() + { + return "RESET SESSION AUTHORIZATION"; + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/SetSessionAuthorization.java b/core/trino-parser/src/main/java/io/trino/sql/tree/SetSessionAuthorization.java new file mode 100644 index 0000000000000..a160758f52d60 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/SetSessionAuthorization.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class SetSessionAuthorization + extends Statement +{ + private final Expression user; + + public SetSessionAuthorization(Expression user) + { + this(Optional.empty(), user); + } + + public SetSessionAuthorization(NodeLocation location, Expression user) + { + this(Optional.of(location), user); + } + + private SetSessionAuthorization(Optional location, Expression user) + { + super(location); + this.user = requireNonNull(user, "user is null"); + } + + public Expression getUser() + { + return user; + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitSetSessionAuthorization(this, context); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SetSessionAuthorization setSessionAuthorization = (SetSessionAuthorization) o; + return Objects.equals(user, setSessionAuthorization.user); + } + + @Override + public int hashCode() + { + return Objects.hash(user); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("user", user) + .toString(); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index ecd45454eee01..5d57b56ddf08e 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -163,6 +163,7 @@ import io.trino.sql.tree.RenameTable; import io.trino.sql.tree.RenameView; import io.trino.sql.tree.ResetSession; +import io.trino.sql.tree.ResetSessionAuthorization; import io.trino.sql.tree.Revoke; import io.trino.sql.tree.RevokeRoles; import io.trino.sql.tree.Rollback; @@ -175,6 +176,7 @@ import io.trino.sql.tree.SetProperties; import io.trino.sql.tree.SetRole; import io.trino.sql.tree.SetSession; +import io.trino.sql.tree.SetSessionAuthorization; import io.trino.sql.tree.SetTableAuthorization; import io.trino.sql.tree.SetTimeZone; import io.trino.sql.tree.SetViewAuthorization; @@ -6001,6 +6003,28 @@ public void testJsonTableNestedColumns() Optional.of(JsonTable.ErrorBehavior.ERROR)))); } + @Test + public void testSetSessionAuthorization() + { + assertStatement("SET SESSION AUTHORIZATION user", new SetSessionAuthorization(identifier("user"))); + assertStatement("SET SESSION AUTHORIZATION \"user\"", new SetSessionAuthorization(identifier("user"))); + assertStatement("SET SESSION AUTHORIZATION 'user'", new SetSessionAuthorization(new StringLiteral("user"))); + + assertStatementIsInvalid("SET SESSION AUTHORIZATION user-a").withMessage("line 1:31: mismatched input '-'. Expecting: "); + assertStatement("SET SESSION AUTHORIZATION \"user-a\"", new SetSessionAuthorization(identifier("user-a"))); + assertStatement("SET SESSION AUTHORIZATION 'user-a'", new SetSessionAuthorization(new StringLiteral("user-a"))); + + assertStatementIsInvalid("SET SESSION AUTHORIZATION null").withMessage("line 1:27: mismatched input 'null'. Expecting: '.', '=', , "); + assertStatement("SET SESSION AUTHORIZATION \"null\"", new SetSessionAuthorization(identifier("null"))); + assertStatement("SET SESSION AUTHORIZATION 'null'", new SetSessionAuthorization(new StringLiteral("null"))); + } + + @Test + public void testResetSessionAuthorization() + { + assertStatement("RESET SESSION AUTHORIZATION", new ResetSessionAuthorization()); + } + private static QualifiedName makeQualifiedName(String tableName) { List parts = Splitter.on('.').splitToList(tableName).stream() diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestStatementBuilder.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestStatementBuilder.java index 60c86f8def38a..29bf92521018b 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestStatementBuilder.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestStatementBuilder.java @@ -329,6 +329,9 @@ public void testStatementBuilder() "when matched and c.action = 'del' then delete\n" + "when not matched and c.action = 'new' then\n" + "insert (part, qty) values (c.part, c.qty)"); + + printStatement("set session authorization user"); + printStatement("reset session authorization"); } @Test diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java index db235e24e70d9..6dc00edac896a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java @@ -32,6 +32,7 @@ public class QueryContext { private final String user; + private final String originalUser; private final Optional principal; private final Set groups; private final Optional traceToken; @@ -62,6 +63,7 @@ public class QueryContext @Unstable public QueryContext( String user, + String originalUser, Optional principal, Set groups, Optional traceToken, @@ -83,6 +85,7 @@ public QueryContext( String retryPolicy) { this.user = requireNonNull(user, "user is null"); + this.originalUser = requireNonNull(originalUser, "originalUser is null"); this.principal = requireNonNull(principal, "principal is null"); this.groups = requireNonNull(groups, "groups is null"); this.traceToken = requireNonNull(traceToken, "traceToken is null"); @@ -110,6 +113,12 @@ public String getUser() return user; } + @JsonProperty + public String getOriginalUser() + { + return originalUser; + } + @JsonProperty public Optional getPrincipal() { diff --git a/docs/src/main/sphinx/develop/client-protocol.md b/docs/src/main/sphinx/develop/client-protocol.md index d113d70dd6133..93fffc9dd9b9e 100644 --- a/docs/src/main/sphinx/develop/client-protocol.md +++ b/docs/src/main/sphinx/develop/client-protocol.md @@ -142,6 +142,8 @@ in subsequent requests, just like browser cookies. * - ``X-Trino-User`` - Specifies the session user. If not supplied, the session user is automatically determined via :doc:`/security/user-mapping`. + * - ``X-Trino-Original-User`` + - Specifies the session's original user. * - ``X-Trino-Source`` - For reporting purposes, this supplies the name of the software that submitted the query. @@ -229,6 +231,13 @@ subsequent requests to be consistent with the response headers received. * - ``X-Trino-Set-Schema`` - Instructs the client to set the schema in the ``X-Trino-Schema`` request header in subsequent client requests. + * - ``X-Trino-Set-Authorization-User`` + - Instructs the client to set the session authorization user in the + ``X-Trino-Authorization-User`` request header in subsequent client requests. + * - ``X-Trino-Reset-Authorization-User`` + - Instructs the client to remove ``X-Trino-Authorization-User`` request header + in subsequent client requests to reset the authorization user back to the + original user. * - ``X-Trino-Set-Session`` - The value of the ``X-Trino-Set-Session`` response header is a string of the form *property* = *value*. It diff --git a/docs/src/main/sphinx/sql.md b/docs/src/main/sphinx/sql.md index 589144205f376..687c5f4747304 100644 --- a/docs/src/main/sphinx/sql.md +++ b/docs/src/main/sphinx/sql.md @@ -47,6 +47,7 @@ sql/pattern-recognition-in-window sql/prepare sql/refresh-materialized-view sql/reset-session +sql/reset-session-authorization sql/revoke sql/revoke-roles sql/rollback @@ -54,6 +55,7 @@ sql/select sql/set-path sql/set-role sql/set-session +sql/set-session-authorization sql/set-time-zone sql/show-catalogs sql/show-columns diff --git a/docs/src/main/sphinx/sql/reset-session-authorization.rst b/docs/src/main/sphinx/sql/reset-session-authorization.rst new file mode 100644 index 0000000000000..b1b163a5c90a8 --- /dev/null +++ b/docs/src/main/sphinx/sql/reset-session-authorization.rst @@ -0,0 +1,22 @@ +=========================== +RESET SESSION AUTHORIZATION +=========================== + +Synopsis +-------- + +.. code-block:: none + + RESET SESSION AUTHORIZATION + +Description +----------- + +Resets the current authorization user back to the original user. +The original user is usually the authenticated user (principal), +or it can be the session user when the session user is provided by the client. + +See Also +-------- + +:doc:`set-session-authorization` diff --git a/docs/src/main/sphinx/sql/set-session-authorization.rst b/docs/src/main/sphinx/sql/set-session-authorization.rst new file mode 100644 index 0000000000000..98634dd112355 --- /dev/null +++ b/docs/src/main/sphinx/sql/set-session-authorization.rst @@ -0,0 +1,47 @@ +========================= +SET SESSION AUTHORIZATION +========================= + +Synopsis +-------- + +.. code-block:: none + + SET SESSION AUTHORIZATION username + +Description +----------- + +Changes the current user of the session. +For the ``SET SESSION AUTHORIZATION username`` statement to succeed, +the the original user (that the client connected with) must be able to impersonate the specified user. +User impersonation can be enabled in the system access control. + +Examples +-------- + +In the following example, the original user when the connection to Trino is made is Kevin. +The following sets the session authorization user to John:: + + SET SESSION AUTHORIZATION 'John'; + +Queries will now execute as John instead of Kevin. + +All supported syntax to change the session authorization users are shown below. + +Changing the session authorization with single quotes:: + + SET SESSION AUTHORIZATION 'John'; + +Changing the session authorization with double quotes:: + + SET SESSION AUTHORIZATION "John"; + +Changing the session authorization without quotes:: + + SET SESSION AUTHORIZATION John; + +See Also +-------- + +:doc:`reset-session-authorization` diff --git a/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java b/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java index 96455b34aa819..5df2effd266ed 100644 --- a/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java +++ b/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java @@ -105,6 +105,7 @@ public class TestHttpEventListener queryContext = new QueryContext( "user", + "originalUser", Optional.of("principal"), Set.of(), // groups Optional.empty(), // traceToken diff --git a/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java b/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java index 0bf80c6409e75..5b9f27f32f639 100644 --- a/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java +++ b/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java @@ -132,6 +132,7 @@ public class TestMysqlEventListener private static final QueryContext FULL_QUERY_CONTEXT = new QueryContext( "user", + "originalUser", Optional.of("principal"), Set.of("group1", "group2"), Optional.of("traceToken"), @@ -284,6 +285,7 @@ public class TestMysqlEventListener private static final QueryContext MINIMAL_QUERY_CONTEXT = new QueryContext( "user", + "originalUser", Optional.empty(), Set.of(), Optional.empty(), diff --git a/plugin/trino-session-property-managers/src/test/java/io/trino/plugin/session/db/TestDbSessionPropertyManagerIntegration.java b/plugin/trino-session-property-managers/src/test/java/io/trino/plugin/session/db/TestDbSessionPropertyManagerIntegration.java index cd597095a8b43..bf1286de9fef1 100644 --- a/plugin/trino-session-property-managers/src/test/java/io/trino/plugin/session/db/TestDbSessionPropertyManagerIntegration.java +++ b/plugin/trino-session-property-managers/src/test/java/io/trino/plugin/session/db/TestDbSessionPropertyManagerIntegration.java @@ -177,6 +177,7 @@ private static Session.SessionBuilder testSessionBuilder() return Session.builder(new SessionPropertyManager()) .setQueryId(new QueryIdGenerator().createNextQueryId()) .setIdentity(Identity.ofUser("user")) + .setOriginalIdentity(Identity.ofUser("user")) .setSource("test") .setCatalog("catalog") .setSchema("schema") diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index a256e5d208392..4c7cddeed405f 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -5354,6 +5354,7 @@ public void testShowSession() Optional.empty(), getSession().isClientTransactionSupport(), getSession().getIdentity(), + getSession().getOriginalIdentity(), getSession().getSource(), getSession().getCatalog(), getSession().getSchema(), diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingSessionContext.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingSessionContext.java index 71df76db4051e..c7bd50642b8a8 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingSessionContext.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingSessionContext.java @@ -50,6 +50,7 @@ else if (enabledRoles.size() == 1) { session.getPath().getRawPath(), Optional.empty(), session.getIdentity(), + session.getOriginalIdentity(), selectedRole, session.getSource(), session.getTraceToken(), diff --git a/testing/trino-testing/src/test/java/io/trino/testing/TestTestingTrinoClient.java b/testing/trino-testing/src/test/java/io/trino/testing/TestTestingTrinoClient.java index 326e519b3e9f3..e356e43fd205f 100644 --- a/testing/trino-testing/src/test/java/io/trino/testing/TestTestingTrinoClient.java +++ b/testing/trino-testing/src/test/java/io/trino/testing/TestTestingTrinoClient.java @@ -47,6 +47,7 @@ public class TestTestingTrinoClient private static final SessionPropertyManager sessionManager = new SessionPropertyManager(); private static final Session session = Session.builder(sessionManager) .setIdentity(Identity.forUser(TEST_USER).build()) + .setOriginalIdentity(Identity.forUser(TEST_USER).build()) .setQueryId(queryIdGenerator.createNextQueryId()) .build(); diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestSetSessionAuthorization.java b/testing/trino-tests/src/test/java/io/trino/execution/TestSetSessionAuthorization.java new file mode 100644 index 0000000000000..4ca18ba55d687 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestSetSessionAuthorization.java @@ -0,0 +1,300 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import io.trino.client.ClientSession; +import io.trino.client.QueryData; +import io.trino.client.StatementClient; +import io.trino.plugin.base.security.FileBasedSystemAccessControl; +import io.trino.spi.ErrorCode; +import io.trino.spi.security.SystemAccessControl; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.QueryRunner; +import okhttp3.OkHttpClient; +import org.testng.annotations.Test; + +import java.time.ZoneId; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.client.StatementClientFactory.newStatementClient; +import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.testng.Assert.assertEquals; + +public class TestSetSessionAuthorization + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(TEST_SESSION) + .setSystemAccessControl(newFileBasedSystemAccessControl("set_session_authorization_permissions.json")) + .build(); + return queryRunner; + } + + private SystemAccessControl newFileBasedSystemAccessControl(String resourceName) + { + return newFileBasedSystemAccessControl(ImmutableMap.of("security.config-file", getResourcePath(resourceName))); + } + + private SystemAccessControl newFileBasedSystemAccessControl(ImmutableMap config) + { + return new FileBasedSystemAccessControl.Factory().create(config); + } + + private String getResourcePath(String resourceName) + { + return this.getClass().getClassLoader().getResource(resourceName).getPath(); + } + + @Test + public void testSetSessionAuthorizationToSelf() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .build(); + assertEquals(submitQuery("SET SESSION AUTHORIZATION user", clientSession).getSetAuthorizationUser().get(), + "user"); + assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), + "alice"); + assertEquals(submitQuery("SET SESSION AUTHORIZATION user", clientSession).getSetAuthorizationUser().get(), + "user"); + } + + @Test + public void testValidSetSessionAuthorization() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .build(); + assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), + "alice"); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user2")) + .user(Optional.of("user2")) + .build(); + assertEquals(submitQuery("SET SESSION AUTHORIZATION bob", clientSession).getSetAuthorizationUser().get(), + "bob"); + } + + @Test + public void testInvalidSetSessionAuthorization() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .build(); + assertError(submitQuery("SET SESSION AUTHORIZATION user2", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user user2"); + assertError(submitQuery("SET SESSION AUTHORIZATION bob", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user bob"); + assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), "alice"); + assertError(submitQuery("SET SESSION AUTHORIZATION charlie", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user charlie"); + StatementClient client = submitQuery("START TRANSACTION", clientSession); + clientSession = ClientSession.builder(clientSession).transactionId(client.getStartedTransactionId()).build(); + assertError(submitQuery("SET SESSION AUTHORIZATION alice", clientSession), + GENERIC_USER_ERROR.toErrorCode(), "Can't set authorization user in the middle of a transaction"); + } + + // If user A can impersonate user B, and B can impersonate C - but A cannot go to C, + // then we can only go from A->B or B->C, but not A->B->C + @Test + public void testInvalidTransitiveSetSessionAuthorization() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .build(); + assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), "alice"); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("alice")) + .user(Optional.of("alice")) + .build(); + assertEquals(submitQuery("SET SESSION AUTHORIZATION charlie", clientSession).getSetAuthorizationUser().get(), "charlie"); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .build(); + assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), "alice"); + assertError(submitQuery("SET SESSION AUTHORIZATION charlie", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user charlie"); + } + + @Test + public void testValidSessionAuthorizationExecution() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .authorizationUser(Optional.of("alice")) + .build(); + assertEquals(submitQuery("SELECT 1+1", clientSession).currentStatusInfo().getError(), null); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .authorizationUser(Optional.of("user")) + .build(); + assertEquals(submitQuery("SELECT 1+1", clientSession).currentStatusInfo().getError(), null); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .authorizationUser(Optional.of("alice")) + .build(); + assertEquals(submitQuery("SELECT 1+1", clientSession).currentStatusInfo().getError(), null); + } + + @Test + public void testInvalidSessionAuthorizationExecution() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .authorizationUser(Optional.of("user2")) + .build(); + assertError(submitQuery("SELECT 1+1", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user user2"); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .authorizationUser(Optional.of("user3")) + .build(); + assertError(submitQuery("SELECT 1+1", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user user3"); + + clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .authorizationUser(Optional.of("charlie")) + .build(); + assertError(submitQuery("SELECT 1+1", clientSession), + PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user charlie"); + } + + @Test + public void testSelectCurrentUser() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .authorizationUser(Optional.of("alice")) + .build(); + + ImmutableList.Builder> data = ImmutableList.builder(); + submitQuery("SELECT CURRENT_USER", clientSession, data); + List> rows = data.build(); + assertEquals((String) rows.get(0).get(0), "alice"); + } + + @Test + public void testResetSessionAuthorization() + { + ClientSession clientSession = defaultClientSessionBuilder() + .principal(Optional.of("user")) + .user(Optional.of("user")) + .build(); + assertResetAuthorizationUser(submitQuery("RESET SESSION AUTHORIZATION", clientSession)); + assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), "alice"); + assertResetAuthorizationUser(submitQuery("RESET SESSION AUTHORIZATION", clientSession)); + StatementClient client = submitQuery("START TRANSACTION", clientSession); + clientSession = ClientSession.builder(clientSession).transactionId(client.getStartedTransactionId()).build(); + assertError(submitQuery("RESET SESSION AUTHORIZATION", clientSession), + GENERIC_USER_ERROR.toErrorCode(), "Can't reset authorization user in the middle of a transaction"); + } + + private void assertError(StatementClient client, ErrorCode errorCode, String errorMessage) + { + assertEquals(client.getSetAuthorizationUser(), Optional.empty()); + assertEquals(client.currentStatusInfo().getError().getErrorName(), errorCode.getName()); + assertEquals(client.currentStatusInfo().getError().getMessage(), errorMessage); + } + + private void assertResetAuthorizationUser(StatementClient client) + { + assertEquals(client.isResetAuthorizationUser(), true); + assertEquals(client.getSetAuthorizationUser().isEmpty(), true); + } + + private ClientSession.Builder defaultClientSessionBuilder() + { + return ClientSession.builder() + .server(getDistributedQueryRunner().getCoordinator().getBaseUrl()) + .source("source") + .timeZone(ZoneId.of("America/Los_Angeles")) + .locale(Locale.ENGLISH) + .clientRequestTimeout(new Duration(2, MINUTES)); + } + + private StatementClient submitQuery(String query, ClientSession clientSession) + { + OkHttpClient httpClient = new OkHttpClient(); + try { + try (StatementClient client = newStatementClient(httpClient, clientSession, query)) { + // wait for query to be fully scheduled + while (client.isRunning() && !client.currentStatusInfo().getStats().isScheduled()) { + client.advance(); + } + return client; + } + } + finally { + // close the client since, query is not managed by the client protocol + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); + } + } + + private StatementClient submitQuery(String query, ClientSession clientSession, ImmutableList.Builder> data) + { + OkHttpClient httpClient = new OkHttpClient(); + try { + try (StatementClient client = newStatementClient(httpClient, clientSession, query)) { + while (client.isRunning() && !Thread.currentThread().isInterrupted()) { + QueryData results = client.currentData(); + if (results.getData() != null) { + data.addAll(results.getData()); + } + client.advance(); + } + // wait for query to be fully scheduled + while (client.isRunning() && !client.currentStatusInfo().getStats().isScheduled()) { + client.advance(); + } + return client; + } + } + finally { + // close the client since, query is not managed by the client protocol + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); + } + } +} diff --git a/testing/trino-tests/src/test/resources/set_session_authorization_permissions.json b/testing/trino-tests/src/test/resources/set_session_authorization_permissions.json new file mode 100644 index 0000000000000..ae8f035cb54da --- /dev/null +++ b/testing/trino-tests/src/test/resources/set_session_authorization_permissions.json @@ -0,0 +1,16 @@ +{ + "impersonation": [ + { + "originalUser": "user", + "newUser": "alice" + }, + { + "originalUser": "user2", + "newUser": "bob" + }, + { + "originalUser": "alice", + "newUser": "charlie" + } + ] +}