Skip to content

Commit

Permalink
Support SET SESSION AUTHORIZATION and RESET SESSION AUTHORIZATION
Browse files Browse the repository at this point in the history
  • Loading branch information
baohe-zhang authored and Baohe Zhang committed Aug 29, 2023
1 parent 5a0e489 commit 4a2a489
Show file tree
Hide file tree
Showing 59 changed files with 1,360 additions and 11 deletions.
12 changes: 12 additions & 0 deletions client/trino-cli/src/main/java/io/trino/cli/Console.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.cli;

import com.google.common.base.CharMatcher;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.ByteStreams;
import io.airlift.units.Duration;
Expand Down Expand Up @@ -388,6 +389,17 @@ private static boolean process(
builder = builder.path(query.getSetPath().get());
}

// update authorization user if present
if (query.getSetAuthorizationUser().isPresent()) {
builder = builder.authorizationUser(query.getSetAuthorizationUser());
builder = builder.roles(ImmutableMap.of());
}

if (query.isResetAuthorizationUser()) {
builder = builder.authorizationUser(Optional.empty());
builder = builder.roles(ImmutableMap.of());
}

// update session properties if present
if (!query.getSetSessionProperties().isEmpty() || !query.getResetSessionProperties().isEmpty()) {
Map<String, String> sessionProperties = new HashMap<>(session.getProperties());
Expand Down
10 changes: 10 additions & 0 deletions client/trino-cli/src/main/java/io/trino/cli/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ public Optional<String> getSetPath()
return client.getSetPath();
}

public Optional<String> getSetAuthorizationUser()
{
return client.getSetAuthorizationUser();
}

public boolean isResetAuthorizationUser()
{
return client.isResetAuthorizationUser();
}

public Map<String, String> getSetSessionProperties()
{
return client.getSetSessionProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class ClientSession
private final URI server;
private final Optional<String> principal;
private final Optional<String> user;
private final Optional<String> authorizationUser;
private final String source;
private final Optional<String> traceToken;
private final Set<String> clientTags;
Expand Down Expand Up @@ -75,6 +76,7 @@ private ClientSession(
URI server,
Optional<String> principal,
Optional<String> user,
Optional<String> authorizationUser,
String source,
Optional<String> traceToken,
Set<String> clientTags,
Expand All @@ -96,6 +98,7 @@ private ClientSession(
this.server = requireNonNull(server, "server is null");
this.principal = requireNonNull(principal, "principal is null");
this.user = requireNonNull(user, "user is null");
this.authorizationUser = requireNonNull(authorizationUser, "authorizationUser is null");
this.source = source;
this.traceToken = requireNonNull(traceToken, "traceToken is null");
this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null"));
Expand Down Expand Up @@ -158,6 +161,11 @@ public Optional<String> getUser()
return user;
}

public Optional<String> getAuthorizationUser()
{
return authorizationUser;
}

public String getSource()
{
return source;
Expand Down Expand Up @@ -258,6 +266,7 @@ public String toString()
.add("server", server)
.add("principal", principal)
.add("user", user)
.add("authorizationUser", authorizationUser)
.add("clientTags", clientTags)
.add("clientInfo", clientInfo)
.add("catalog", catalog)
Expand All @@ -277,6 +286,7 @@ public static final class Builder
private URI server;
private Optional<String> principal = Optional.empty();
private Optional<String> user = Optional.empty();
private Optional<String> authorizationUser = Optional.empty();
private String source;
private Optional<String> traceToken = Optional.empty();
private Set<String> clientTags = ImmutableSet.of();
Expand All @@ -303,6 +313,7 @@ private Builder(ClientSession clientSession)
server = clientSession.getServer();
principal = clientSession.getPrincipal();
user = clientSession.getUser();
authorizationUser = clientSession.getAuthorizationUser();
source = clientSession.getSource();
traceToken = clientSession.getTraceToken();
clientTags = clientSession.getClientTags();
Expand Down Expand Up @@ -334,6 +345,12 @@ public Builder user(Optional<String> user)
return this;
}

public Builder authorizationUser(Optional<String> authorizationUser)
{
this.authorizationUser = authorizationUser;
return this;
}

public Builder principal(Optional<String> principal)
{
this.principal = principal;
Expand Down Expand Up @@ -448,6 +465,7 @@ public ClientSession build()
server,
principal,
user,
authorizationUser,
source,
traceToken,
clientTags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public final class ProtocolHeaders

private final String name;
private final String requestUser;
private final String requestOriginalUser;
private final String requestSource;
private final String requestCatalog;
private final String requestSchema;
Expand All @@ -52,6 +53,8 @@ public final class ProtocolHeaders
private final String responseDeallocatedPrepare;
private final String responseStartedTransactionId;
private final String responseClearTransactionId;
private final String responseSetAuthorizationUser;
private final String responseResetAuthorizationUser;

public static ProtocolHeaders createProtocolHeaders(String name)
{
Expand All @@ -69,6 +72,7 @@ private ProtocolHeaders(String name)
this.name = name;
String prefix = "X-" + name + "-";
requestUser = prefix + "User";
requestOriginalUser = prefix + "Original-User";
requestSource = prefix + "Source";
requestCatalog = prefix + "Catalog";
requestSchema = prefix + "Schema";
Expand All @@ -95,6 +99,8 @@ private ProtocolHeaders(String name)
responseDeallocatedPrepare = prefix + "Deallocated-Prepare";
responseStartedTransactionId = prefix + "Started-Transaction-Id";
responseClearTransactionId = prefix + "Clear-Transaction-Id";
responseSetAuthorizationUser = prefix + "Set-Authorization-User";
responseResetAuthorizationUser = prefix + "Reset-Authorization-User";
}

public String getProtocolName()
Expand All @@ -107,6 +113,11 @@ public String requestUser()
return requestUser;
}

public String requestOriginalUser()
{
return requestOriginalUser;
}

public String requestSource()
{
return requestSource;
Expand Down Expand Up @@ -237,6 +248,16 @@ public String responseClearTransactionId()
return responseClearTransactionId;
}

public String responseSetAuthorizationUser()
{
return responseSetAuthorizationUser;
}

public String responseResetAuthorizationUser()
{
return responseResetAuthorizationUser;
}

public static ProtocolHeaders detectProtocol(Optional<String> alternateHeaderName, Set<String> headerNames)
throws ProtocolDetectionException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public interface StatementClient

Optional<String> getSetPath();

Optional<String> getSetAuthorizationUser();

boolean isResetAuthorizationUser();

Map<String, String> getSetSessionProperties();

Set<String> getResetSessionProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class StatementClientV1
private final AtomicReference<String> setCatalog = new AtomicReference<>();
private final AtomicReference<String> setSchema = new AtomicReference<>();
private final AtomicReference<String> setPath = new AtomicReference<>();
private final AtomicReference<String> setAuthorizationUser = new AtomicReference<>();
private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean();
private final Map<String, String> setSessionProperties = new ConcurrentHashMap<>();
private final Set<String> resetSessionProperties = Sets.newConcurrentHashSet();
private final Map<String, ClientSelectedRole> setRoles = new ConcurrentHashMap<>();
Expand All @@ -89,6 +91,7 @@ class StatementClientV1
private final ZoneId timeZone;
private final Duration requestTimeoutNanos;
private final Optional<String> user;
private final Optional<String> originalUser;
private final String clientCapabilities;
private final boolean compressionDisabled;

Expand All @@ -104,7 +107,11 @@ public StatementClientV1(Call.Factory httpCallFactory, ClientSession session, St
this.timeZone = session.getTimeZone();
this.query = query;
this.requestTimeoutNanos = session.getClientRequestTimeout();
this.user = Stream.of(session.getUser(), session.getPrincipal())
this.user = Stream.of(session.getAuthorizationUser(), session.getUser(), session.getPrincipal())
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
this.originalUser = Stream.of(session.getUser(), session.getPrincipal())
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
Expand Down Expand Up @@ -270,6 +277,18 @@ public Optional<String> getSetPath()
return Optional.ofNullable(setPath.get());
}

@Override
public Optional<String> getSetAuthorizationUser()
{
return Optional.ofNullable(setAuthorizationUser.get());
}

@Override
public boolean isResetAuthorizationUser()
{
return resetAuthorizationUser.get();
}

@Override
public Map<String, String> getSetSessionProperties()
{
Expand Down Expand Up @@ -319,6 +338,7 @@ private Request.Builder prepareRequest(HttpUrl url)
.addHeader(USER_AGENT, USER_AGENT_VALUE)
.url(url);
user.ifPresent(requestUser -> builder.addHeader(TRINO_HEADERS.requestUser(), requestUser));
originalUser.ifPresent(originalUser -> builder.addHeader(TRINO_HEADERS.requestOriginalUser(), originalUser));
if (compressionDisabled) {
builder.header(ACCEPT_ENCODING, "identity");
}
Expand Down Expand Up @@ -399,6 +419,16 @@ private void processResponse(Headers headers, QueryResults results)
setSchema.set(headers.get(TRINO_HEADERS.responseSetSchema()));
setPath.set(headers.get(TRINO_HEADERS.responseSetPath()));

String setAuthorizationUser = headers.get(TRINO_HEADERS.responseSetAuthorizationUser());
if (setAuthorizationUser != null) {
this.setAuthorizationUser.set(setAuthorizationUser);
}

String resetAuthorizationUser = headers.get(TRINO_HEADERS.responseResetAuthorizationUser());
if (resetAuthorizationUser != null) {
this.resetAuthorizationUser.set(Boolean.parseBoolean(resetAuthorizationUser));
}

for (String setSession : headers.values(TRINO_HEADERS.responseSetSession())) {
List<String> keyValue = COLLECTION_HEADER_SPLITTER.splitToList(setSession);
if (keyValue.size() != 2) {
Expand Down
17 changes: 17 additions & 0 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class TrinoConnection
private final AtomicReference<String> catalog = new AtomicReference<>();
private final AtomicReference<String> schema = new AtomicReference<>();
private final AtomicReference<String> path = new AtomicReference<>();
private final AtomicReference<String> authorizationUser = new AtomicReference<>();
private final AtomicReference<ZoneId> timeZoneId = new AtomicReference<>();
private final AtomicReference<Locale> locale = new AtomicReference<>();
private final AtomicReference<Integer> networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2)));
Expand Down Expand Up @@ -746,6 +747,7 @@ StatementClient startQuery(String sql, Map<String, String> sessionPropertiesOver
.server(httpUri)
.principal(user)
.user(sessionUser.get())
.authorizationUser(Optional.ofNullable(authorizationUser.get()))
.source(source)
.traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN)))
.clientTags(ImmutableSet.copyOf(clientTags))
Expand Down Expand Up @@ -781,6 +783,15 @@ void updateSession(StatementClient client)
client.getSetSchema().ifPresent(schema::set);
client.getSetPath().ifPresent(path::set);

if (client.getSetAuthorizationUser().isPresent()) {
authorizationUser.set(client.getSetAuthorizationUser().get());
roles.clear();
}
if (client.isResetAuthorizationUser()) {
authorizationUser.set(null);
roles.clear();
}

if (client.getStartedTransactionId() != null) {
transactionId.set(client.getStartedTransactionId());
}
Expand Down Expand Up @@ -810,6 +821,12 @@ int activeStatements()
return statements.size();
}

@VisibleForTesting
String getAuthorizationUser()
{
return authorizationUser.get();
}

private void checkOpen()
throws SQLException
{
Expand Down
54 changes: 54 additions & 0 deletions client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,47 @@ public void testCustomDnsResolver()
}
}

@Test(timeOut = 10000)
public void testResetSessionAuthorization()
throws Exception
{
try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class);
Statement statement = connection.createStatement()) {
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(getCurrentUser(connection), "test");
statement.execute("SET SESSION AUTHORIZATION john");
assertEquals(connection.getAuthorizationUser(), "john");
assertEquals(getCurrentUser(connection), "john");
statement.execute("SET SESSION AUTHORIZATION bob");
assertEquals(connection.getAuthorizationUser(), "bob");
assertEquals(getCurrentUser(connection), "bob");
statement.execute("RESET SESSION AUTHORIZATION");
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(getCurrentUser(connection), "test");
}
}

@Test(timeOut = 10000)
public void testSetRoleAfterSetSessionAuthorization()
throws Exception
{
try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class);
Statement statement = connection.createStatement()) {
statement.execute("SET SESSION AUTHORIZATION john");
assertEquals(connection.getAuthorizationUser(), "john");
statement.execute("SET ROLE ALL");
assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.ALL, Optional.empty())));
statement.execute("SET SESSION AUTHORIZATION bob");
assertEquals(connection.getAuthorizationUser(), "bob");
assertEquals(connection.getRoles(), ImmutableMap.of());
statement.execute("SET ROLE NONE");
assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.NONE, Optional.empty())));
statement.execute("RESET SESSION AUTHORIZATION");
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(connection.getRoles(), ImmutableMap.of());
}
}

private QueryState getQueryState(String queryId)
throws SQLException
{
Expand Down Expand Up @@ -1166,6 +1207,19 @@ private static Properties toProperties(Map<String, String> map)
return properties;
}

private static String getCurrentUser(Connection connection)
throws SQLException
{
try (Statement statement = connection.createStatement();
ResultSet rs = statement.executeQuery("SELECT current_user")) {
while (rs.next()) {
return rs.getString(1);
}
}

throw new RuntimeException("Failed to get CURRENT_USER");
}

public static class TestingDnsResolver
implements DnsResolver
{
Expand Down
Loading

0 comments on commit 4a2a489

Please sign in to comment.