diff --git a/client/trino-cli/src/main/java/io/trino/cli/Console.java b/client/trino-cli/src/main/java/io/trino/cli/Console.java index bb76e04556bc..c6ba90b6d46b 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Console.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Console.java @@ -14,6 +14,7 @@ package io.trino.cli; import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.ByteStreams; import io.airlift.units.Duration; @@ -391,10 +392,12 @@ private static boolean process( // update authorization user if present if (query.getSetAuthorizationUser().isPresent()) { builder = builder.authorizationUser(query.getSetAuthorizationUser()); + builder = builder.roles(ImmutableMap.of()); } if (query.isResetAuthorizationUser()) { builder = builder.authorizationUser(Optional.empty()); + builder = builder.roles(ImmutableMap.of()); } // update session properties if present diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index b11b51f6aab3..5babd5b2525c 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -783,9 +783,13 @@ void updateSession(StatementClient client) client.getSetSchema().ifPresent(schema::set); client.getSetPath().ifPresent(path::set); - client.getSetAuthorizationUser().ifPresent(authorizationUser::set); + if (client.getSetAuthorizationUser().isPresent()) { + authorizationUser.set(client.getSetAuthorizationUser().get()); + roles.clear(); + } if (client.isResetAuthorizationUser()) { authorizationUser.set(null); + roles.clear(); } if (client.getStartedTransactionId() != null) { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java index 7ba65bea697d..e3128179da30 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java @@ -1120,6 +1120,27 @@ public void testResetSessionAuthorization() } } + @Test(timeOut = 10000) + public void testSetRoleAfterSetSessionAuthorization() + throws Exception + { + try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + statement.execute("SET SESSION AUTHORIZATION john"); + assertEquals(connection.getAuthorizationUser(), "john"); + statement.execute("SET ROLE ALL"); + assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.ALL, Optional.empty()))); + statement.execute("SET SESSION AUTHORIZATION bob"); + assertEquals(connection.getAuthorizationUser(), "bob"); + assertEquals(connection.getRoles(), ImmutableMap.of()); + statement.execute("SET ROLE NONE"); + assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.NONE, Optional.empty()))); + statement.execute("RESET SESSION AUTHORIZATION"); + assertEquals(connection.getAuthorizationUser(), null); + assertEquals(connection.getRoles(), ImmutableMap.of()); + } + } + private QueryState getQueryState(String queryId) throws SQLException {