Skip to content

Commit c16dd5b

Browse files
andythsuasu80
authored andcommitted
Add tests to verify the groups for impersonation
1 parent 3fcc526 commit c16dd5b

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package io.trino.security;
2+
3+
import com.google.common.collect.ImmutableMap;
4+
import com.google.common.collect.ImmutableSet;
5+
import io.airlift.log.Logging;
6+
import io.trino.Session;
7+
import io.trino.jdbc.TrinoConnection;
8+
import io.trino.plugin.base.security.AllowAllSystemAccessControl;
9+
import io.trino.plugin.memory.MemoryPlugin;
10+
import io.trino.server.BasicQueryInfo;
11+
import io.trino.server.testing.TestingTrinoServer;
12+
import io.trino.spi.QueryId;
13+
import io.trino.spi.security.Identity;
14+
import io.trino.testing.TestingGroupProvider;
15+
import org.junit.jupiter.api.BeforeAll;
16+
import org.junit.jupiter.api.Test;
17+
import org.junit.jupiter.api.TestInstance;
18+
import org.junit.jupiter.api.Timeout;
19+
import org.junit.jupiter.api.parallel.Execution;
20+
21+
import java.sql.Connection;
22+
import java.sql.DriverManager;
23+
import java.sql.SQLException;
24+
import java.sql.Statement;
25+
import java.util.Set;
26+
27+
import static io.trino.jdbc.BaseTrinoDriverTest.getCurrentUser;
28+
import static io.trino.testing.TestingSession.testSessionBuilder;
29+
import static java.lang.String.format;
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
32+
import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD;
33+
34+
@TestInstance(PER_CLASS)
35+
@Execution(SAME_THREAD)
36+
public class TestSessionImpersonation
37+
{
38+
private TestingTrinoServer server;
39+
40+
@BeforeAll
41+
public void setup()
42+
throws Exception
43+
{
44+
Logging.initialize();
45+
server = TestingTrinoServer.builder()
46+
.setSystemAccessControl(new AllowAllSystemAccessControl())
47+
.build();
48+
server.installPlugin(new MemoryPlugin());
49+
server.createCatalog("memory", "memory");
50+
}
51+
52+
@Test
53+
@Timeout(10)
54+
public void testSessionRepresentationReturnsCorrectGroupsDuringImpersonation()
55+
{
56+
Set<String> aliceGroups = ImmutableSet.of("alice_group");
57+
Set<String> johnGroups = ImmutableSet.of("john_group");
58+
Identity alice = Identity.forUser("alice").withGroups(aliceGroups).build();
59+
Identity john = Identity.forUser("john").withGroups(johnGroups).build();
60+
61+
Session aliceImpersonationSession = testSessionBuilder()
62+
.setOriginalIdentity(alice)
63+
.setIdentity(john)
64+
.build();
65+
66+
Set<String> originalUserGroups = aliceImpersonationSession.toSessionRepresentation()
67+
.getOriginalUserGroups();
68+
Set<String> userGroups = aliceImpersonationSession.toSessionRepresentation()
69+
.getGroups();
70+
assertThat(originalUserGroups).isEqualTo(aliceGroups);
71+
assertThat(userGroups).isEqualTo(johnGroups);
72+
}
73+
74+
@Test
75+
@Timeout(60)
76+
public void testSessionReturnsCorrectGroupsForImpersonatedQueries()
77+
throws Exception
78+
{
79+
Set<String> johnGroups = ImmutableSet.of("john_group");
80+
Set<String> aliceGroups = ImmutableSet.of("alice_group");
81+
String alice = "alice";
82+
String john = "john";
83+
84+
TestingGroupProvider testingGroupProvider = new TestingGroupProvider();
85+
testingGroupProvider.setUserGroups(ImmutableMap.of(
86+
john, johnGroups,
87+
alice, aliceGroups));
88+
server.getGroupProvider().setConfiguredGroupProvider(testingGroupProvider);
89+
90+
try (TrinoConnection connection = createConnection("memory", "default", "alice").unwrap(TrinoConnection.class);
91+
Statement statement = connection.createStatement()) {
92+
assertThat(getCurrentUser(connection)).isEqualTo("alice");
93+
94+
statement.execute("SET SESSION AUTHORIZATION john");
95+
96+
String SHOW_CATALOGS = "SHOW CATALOGS";
97+
String SHOW_SCHEMAS = "SHOW SCHEMAS FROM memory";
98+
String SHOW_TABLES = "SHOW TABLES FROM memory.default";
99+
100+
statement.execute(SHOW_CATALOGS);
101+
statement.execute(SHOW_SCHEMAS);
102+
statement.execute(SHOW_TABLES);
103+
104+
BasicQueryInfo showCatalogsQueryInfo = getQueryInfo(SHOW_CATALOGS);
105+
BasicQueryInfo showSchemasQueryInfo = getQueryInfo(SHOW_SCHEMAS);
106+
BasicQueryInfo showTablesQueryInfo = getQueryInfo(SHOW_TABLES);
107+
108+
assertSessionUsersAndGroups(showCatalogsQueryInfo, alice, aliceGroups, john, johnGroups);
109+
assertSessionUsersAndGroups(showSchemasQueryInfo, alice, aliceGroups, john, johnGroups);
110+
assertSessionUsersAndGroups(showTablesQueryInfo, alice, aliceGroups, john, johnGroups);
111+
}
112+
}
113+
114+
private void assertSessionUsersAndGroups(
115+
BasicQueryInfo queryInfo,
116+
String expectedOriginalUser,
117+
Set<String> expectedOriginalUserGroups,
118+
String expectedUser,
119+
Set<String> expectedUserGroups)
120+
{
121+
assertThat(queryInfo.getSession().getOriginalUser()).isEqualTo(expectedOriginalUser);
122+
assertThat(queryInfo.getSession().getOriginalUserGroups()).isEqualTo(expectedOriginalUserGroups);
123+
assertThat(queryInfo.getSession().getUser()).isEqualTo(expectedUser);
124+
assertThat(queryInfo.getSession().getGroups()).isEqualTo(expectedUserGroups);
125+
}
126+
127+
private BasicQueryInfo getQueryInfo(String query)
128+
{
129+
QueryId queryId = null;
130+
for (BasicQueryInfo basicQueryInfo : server.getDispatchManager().getQueries()) {
131+
if (basicQueryInfo.getQuery().equals(query)) {
132+
queryId = basicQueryInfo.getQueryId();
133+
}
134+
}
135+
return server.getDispatchManager().getQueryInfo(queryId);
136+
}
137+
138+
private Connection createConnection(String catalog, String schema, String user)
139+
throws SQLException
140+
{
141+
String url = format("jdbc:trino://%s/%s/%s", server.getAddress(), catalog, schema);
142+
return DriverManager.getConnection(url, user, null);
143+
}
144+
}

0 commit comments

Comments
 (0)