Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baohe-zhang committed Aug 15, 2023
1 parent dcee8b4 commit ce689dc
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ public String toString()
.add("server", server)
.add("principal", principal)
.add("user", user)
.add("authorizationUser", authorizationUser.orElse(null))
.add("authorizationUser", authorizationUser)
.add("clientTags", clientTags)
.add("clientInfo", clientInfo)
.add("catalog", catalog)
Expand Down
18 changes: 18 additions & 0 deletions client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -1111,12 +1111,17 @@ public void testResetSessionAuthorization()
{
try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class);
Statement statement = connection.createStatement()) {
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(getCurrentUser(connection), "test");
statement.execute("SET SESSION AUTHORIZATION john");
assertEquals(connection.getAuthorizationUser(), "john");
assertEquals(getCurrentUser(connection), "john");
statement.execute("SET SESSION AUTHORIZATION bob");
assertEquals(connection.getAuthorizationUser(), "bob");
assertEquals(getCurrentUser(connection), "bob");
statement.execute("RESET SESSION AUTHORIZATION");
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(getCurrentUser(connection), "test");
}
}

Expand Down Expand Up @@ -1202,6 +1207,19 @@ private static Properties toProperties(Map<String, String> map)
return properties;
}

private static String getCurrentUser(Connection connection)
throws SQLException
{
try (Statement statement = connection.createStatement();
ResultSet rs = statement.executeQuery("SELECT current_user")) {
while (rs.next()) {
return rs.getString(1);
}
}

throw new RuntimeException("Failed to get CURRENT_USER");
}

public static class TestingDnsResolver
implements DnsResolver
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,6 @@ public Identity toIdentity(Map<String, String> extraCredentials)
.build();
}

public Identity toOriginalIdentity()
{
return toOriginalIdentity(emptyMap());
}

public Identity toOriginalIdentity(Map<String, String> extraCredentials)
{
return Identity.forUser(originalUser)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public SessionContext createSessionContext(

requireNonNull(authenticatedIdentity, "authenticatedIdentity is null");
Identity identity = buildSessionIdentity(authenticatedIdentity, protocolHeaders, headers);
Identity originalIdentity = buildSessionOriginalIdentity(identity, protocolHeaders, headers);
SelectedRole selectedRole = parseSystemRoleHeaders(protocolHeaders, headers);

Optional<String> source = Optional.ofNullable(headers.getFirst(protocolHeaders.requestSource()));
Expand Down Expand Up @@ -153,15 +154,6 @@ else if (nameParts.size() == 2) {
catalogSessionProperties = catalogSessionProperties.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> ImmutableMap.copyOf(entry.getValue())));

Optional<String> optionalOriginalUser = Optional
.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser())));
Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity)
.withUser(originalUser)
.withExtraCredentials(new HashMap<>())
.withGroups(groupProvider.getGroups(originalUser))
.build())
.orElse(identity);

Map<String, String> preparedStatements = parsePreparedStatementsHeaders(protocolHeaders, headers);

String transactionIdHeader = headers.getFirst(protocolHeaders.requestTransactionId());
Expand Down Expand Up @@ -219,15 +211,7 @@ public Identity extractAuthorizedIdentity(
}

Identity identity = buildSessionIdentity(optionalAuthenticatedIdentity, protocolHeaders, headers);

Optional<String> optionalOriginalUser = Optional
.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser())));
Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity)
.withUser(originalUser)
.withExtraCredentials(new HashMap<>())
.withGroups(groupProvider.getGroups(originalUser))
.build())
.orElse(identity);
Identity originalIdentity = buildSessionOriginalIdentity(identity, protocolHeaders, headers);

accessControl.checkCanSetUser(originalIdentity.getPrincipal(), originalIdentity.getUser());

Expand Down Expand Up @@ -289,6 +273,20 @@ private Identity buildSessionIdentity(Optional<Identity> authenticatedIdentity,
.build();
}

private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers)
{
// We derive original identity using this header, but older clients will not send it, so fall back to identity
Optional<String> optionalOriginalUser = Optional
.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser())));
Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity)
.withUser(originalUser)
.withExtraCredentials(new HashMap<>())
.withGroups(groupProvider.getGroups(originalUser))
.build())
.orElse(identity);
return originalIdentity;
}

private static List<String> splitHttpHeader(MultivaluedMap<String, String> headers, String name)
{
List<String> values = firstNonNull(headers.get(name), ImmutableList.of());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ public Session createSession(QueryId queryId, Span querySpan, SessionContext con

Identity identity = context.getIdentity();
if (!originalIdentity.getUser().equals(identity.getUser())) {
// When the current user (user) and the original user are different, we check if the original user can impersonate current user.
// We preserve the information of original user in the originalIdentity,
// and it will be used for the impersonation checks and be used as the source of audit information.
accessControl.checkCanSetUser(originalIdentity.getPrincipal(), identity.getUser());
accessControl.checkCanImpersonateUser(originalIdentity, identity.getUser());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ private void rewriteUserHeaderToMappedUser(BasicAuthCredentials basicAuthCredent
{
String userHeader;
try {
userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestOriginalUser();
if (headers.getFirst(userHeader) == null || headers.getFirst(userHeader).isEmpty()) {
userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestUser();
}
userHeader = getUserHeader(headers);
}
catch (ProtocolDetectionException ignored) {
// this shouldn't fail here, but ignore and it will be handled elsewhere
Expand All @@ -108,6 +105,17 @@ private void rewriteUserHeaderToMappedUser(BasicAuthCredentials basicAuthCredent
}
}

// Extract this out in a method so that the logic of preferring originalUser and fallback on user remains in one place
private String getUserHeader(MultivaluedMap<String, String> headers)
throws ProtocolDetectionException
{
String userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestOriginalUser();
if (headers.getFirst(userHeader) == null || headers.getFirst(userHeader).isEmpty()) {
userHeader = detectProtocol(alternateHeaderName, headers.keySet()).requestUser();
}
return userHeader;
}

private static AuthenticationException needAuthentication(String message)
{
return new AuthenticationException(message, BasicAuthCredentials.AUTHENTICATE_HEADER);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ public void tearDown()
@Test
public void testSetSessionAuthorization()
{
assertSetSessionAuthorization("SET SESSION AUTHORIZATION user", Optional.of("user"));
assertSetSessionAuthorization("SET SESSION AUTHORIZATION 'user'", Optional.of("user"));
assertSetSessionAuthorization("SET SESSION AUTHORIZATION \"user\"", Optional.of("user"));
assertSetSessionAuthorization("SET SESSION AUTHORIZATION otheruser", Optional.of("otheruser"));
assertSetSessionAuthorization("SET SESSION AUTHORIZATION 'otheruser'", Optional.of("otheruser"));
assertSetSessionAuthorization("SET SESSION AUTHORIZATION \"otheruser\"", Optional.of("otheruser"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ public void testJwtWithRefreshTokensForOAuth2Enabled()
}

@Test
public void testWebUiUserImpersonation()
public void testResourceSecurityImpersonation()
throws Exception
{
try (TestingTrinoServer server = TestingTrinoServer.builder()
Expand Down
2 changes: 1 addition & 1 deletion docs/src/main/sphinx/sql/set-session-authorization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Description

Changes the current user of the session.
For the ``SET SESSION AUTHORIZATION username`` statement to succeed,
the current user must be able to impersonate the specified user.
the the original user (that the client connected with) must be able to impersonate the specified user.
User impersonation can be enabled in the system access control.

Examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ public void testInvalidSetSessionAuthorization()
assertError(submitQuery("SET SESSION AUTHORIZATION bob", clientSession),
PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user bob");
assertEquals(submitQuery("SET SESSION AUTHORIZATION alice", clientSession).getSetAuthorizationUser().get(), "alice");
assertError(submitQuery("SET SESSION AUTHORIZATION user2", clientSession),
PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user user2");
assertError(submitQuery("SET SESSION AUTHORIZATION charlie", clientSession),
PERMISSION_DENIED.toErrorCode(), "Access Denied: User user cannot impersonate user charlie");
StatementClient client = submitQuery("START TRANSACTION", clientSession);
Expand Down

0 comments on commit ce689dc

Please sign in to comment.