Skip to content

Commit ef9fc02

Browse files
andythsukokosing
authored andcommitted
Replace originalIdentity's groups with identity's groups in set session authorization
1 parent 7549ad5 commit ef9fc02

File tree

2 files changed

+158
-1
lines changed

2 files changed

+158
-1
lines changed

core/trino-main/src/main/java/io/trino/Session.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ public SessionRepresentation toSessionRepresentation()
548548
identity.getUser(),
549549
originalIdentity.getUser(),
550550
originalIdentity.getEnabledRoles(),
551-
originalIdentity.getGroups(),
551+
identity.getGroups(),
552552
originalIdentity.getGroups(),
553553
identity.getPrincipal().map(Principal::toString),
554554
identity.getEnabledRoles(),
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.security;
15+
16+
import com.google.common.collect.ImmutableMap;
17+
import com.google.common.collect.ImmutableSet;
18+
import io.airlift.log.Logging;
19+
import io.trino.Session;
20+
import io.trino.jdbc.TrinoConnection;
21+
import io.trino.plugin.base.security.AllowAllSystemAccessControl;
22+
import io.trino.plugin.memory.MemoryPlugin;
23+
import io.trino.server.BasicQueryInfo;
24+
import io.trino.server.testing.TestingTrinoServer;
25+
import io.trino.spi.QueryId;
26+
import io.trino.spi.security.Identity;
27+
import io.trino.testing.TestingGroupProvider;
28+
import org.junit.jupiter.api.BeforeAll;
29+
import org.junit.jupiter.api.Test;
30+
import org.junit.jupiter.api.TestInstance;
31+
import org.junit.jupiter.api.Timeout;
32+
import org.junit.jupiter.api.parallel.Execution;
33+
34+
import java.sql.Connection;
35+
import java.sql.DriverManager;
36+
import java.sql.SQLException;
37+
import java.sql.Statement;
38+
import java.util.Set;
39+
40+
import static io.trino.jdbc.BaseTrinoDriverTest.getCurrentUser;
41+
import static io.trino.testing.TestingSession.testSessionBuilder;
42+
import static java.lang.String.format;
43+
import static org.assertj.core.api.Assertions.assertThat;
44+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
45+
import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD;
46+
47+
@TestInstance(PER_CLASS)
48+
@Execution(SAME_THREAD)
49+
public class TestSessionImpersonation
50+
{
51+
private TestingTrinoServer server;
52+
53+
@BeforeAll
54+
public void setup()
55+
throws Exception
56+
{
57+
Logging.initialize();
58+
server = TestingTrinoServer.builder()
59+
.setSystemAccessControl(new AllowAllSystemAccessControl())
60+
.build();
61+
server.installPlugin(new MemoryPlugin());
62+
server.createCatalog("memory", "memory");
63+
}
64+
65+
@Test
66+
@Timeout(10)
67+
public void testSessionRepresentationReturnsCorrectGroupsDuringImpersonation()
68+
{
69+
Set<String> aliceGroups = ImmutableSet.of("alice_group");
70+
Set<String> johnGroups = ImmutableSet.of("john_group");
71+
Identity alice = Identity.forUser("alice").withGroups(aliceGroups).build();
72+
Identity john = Identity.forUser("john").withGroups(johnGroups).build();
73+
74+
Session aliceImpersonationSession = testSessionBuilder()
75+
.setOriginalIdentity(alice)
76+
.setIdentity(john)
77+
.build();
78+
79+
Set<String> originalUserGroups = aliceImpersonationSession.toSessionRepresentation()
80+
.getOriginalUserGroups();
81+
Set<String> userGroups = aliceImpersonationSession.toSessionRepresentation()
82+
.getGroups();
83+
assertThat(originalUserGroups).isEqualTo(aliceGroups);
84+
assertThat(userGroups).isEqualTo(johnGroups);
85+
}
86+
87+
@Test
88+
@Timeout(60)
89+
public void testSessionReturnsCorrectGroupsForImpersonatedQueries()
90+
throws Exception
91+
{
92+
Set<String> johnGroups = ImmutableSet.of("john_group");
93+
Set<String> aliceGroups = ImmutableSet.of("alice_group");
94+
String alice = "alice";
95+
String john = "john";
96+
97+
TestingGroupProvider testingGroupProvider = new TestingGroupProvider();
98+
testingGroupProvider.setUserGroups(ImmutableMap.of(
99+
john, johnGroups,
100+
alice, aliceGroups));
101+
server.getGroupProvider().setConfiguredGroupProvider(testingGroupProvider);
102+
103+
try (TrinoConnection connection = createConnection("memory", "default", "alice").unwrap(TrinoConnection.class);
104+
Statement statement = connection.createStatement()) {
105+
assertThat(getCurrentUser(connection)).isEqualTo("alice");
106+
107+
statement.execute("SET SESSION AUTHORIZATION john");
108+
109+
String showCatalogsQuery = "SHOW CATALOGS";
110+
String showSchemasQuery = "SHOW SCHEMAS FROM memory";
111+
String showTablesQuery = "SHOW TABLES FROM memory.default";
112+
113+
statement.execute(showCatalogsQuery);
114+
statement.execute(showSchemasQuery);
115+
statement.execute(showTablesQuery);
116+
117+
BasicQueryInfo showCatalogsQueryInfo = getQueryInfo(showCatalogsQuery);
118+
BasicQueryInfo showSchemasQueryInfo = getQueryInfo(showSchemasQuery);
119+
BasicQueryInfo showTablesQueryInfo = getQueryInfo(showTablesQuery);
120+
121+
assertSessionUsersAndGroups(showCatalogsQueryInfo, alice, aliceGroups, john, johnGroups);
122+
assertSessionUsersAndGroups(showSchemasQueryInfo, alice, aliceGroups, john, johnGroups);
123+
assertSessionUsersAndGroups(showTablesQueryInfo, alice, aliceGroups, john, johnGroups);
124+
}
125+
}
126+
127+
private void assertSessionUsersAndGroups(
128+
BasicQueryInfo queryInfo,
129+
String expectedOriginalUser,
130+
Set<String> expectedOriginalUserGroups,
131+
String expectedUser,
132+
Set<String> expectedUserGroups)
133+
{
134+
assertThat(queryInfo.getSession().getOriginalUser()).isEqualTo(expectedOriginalUser);
135+
assertThat(queryInfo.getSession().getOriginalUserGroups()).isEqualTo(expectedOriginalUserGroups);
136+
assertThat(queryInfo.getSession().getUser()).isEqualTo(expectedUser);
137+
assertThat(queryInfo.getSession().getGroups()).isEqualTo(expectedUserGroups);
138+
}
139+
140+
private BasicQueryInfo getQueryInfo(String query)
141+
{
142+
QueryId queryId = null;
143+
for (BasicQueryInfo basicQueryInfo : server.getDispatchManager().getQueries()) {
144+
if (basicQueryInfo.getQuery().equals(query)) {
145+
queryId = basicQueryInfo.getQueryId();
146+
}
147+
}
148+
return server.getDispatchManager().getQueryInfo(queryId);
149+
}
150+
151+
private Connection createConnection(String catalog, String schema, String user)
152+
throws SQLException
153+
{
154+
String url = format("jdbc:trino://%s/%s/%s", server.getAddress(), catalog, schema);
155+
return DriverManager.getConnection(url, user, null);
156+
}
157+
}

0 commit comments

Comments
 (0)