Skip to content

Commit c3cf480

Browse files
committed
Resource Group pattern matching on client tags
1 parent a0992f1 commit c3cf480

File tree

5 files changed

+84
-28
lines changed

5 files changed

+84
-28
lines changed

plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/SelectorSpec.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class SelectorSpec
3232
private final Optional<Pattern> authenticatedUserRegex;
3333
private final Optional<Pattern> sourceRegex;
3434
private final Optional<String> queryType;
35-
private final Optional<List<String>> clientTags;
35+
private final Optional<List<Pattern>> clientTags;
3636
private final Optional<SelectorResourceEstimate> selectorResourceEstimate;
3737
private final ResourceGroupIdTemplate group;
3838

@@ -44,7 +44,7 @@ public SelectorSpec(
4444
@JsonProperty("authenticatedUser") Optional<Pattern> authenticatedUserRegex,
4545
@JsonProperty("source") Optional<Pattern> sourceRegex,
4646
@JsonProperty("queryType") Optional<String> queryType,
47-
@JsonProperty("clientTags") Optional<List<String>> clientTags,
47+
@JsonProperty("clientTags") Optional<List<Pattern>> clientTags,
4848
@JsonProperty("selectorResourceEstimate") Optional<SelectorResourceEstimate> selectorResourceEstimate,
4949
@JsonProperty("group") ResourceGroupIdTemplate group)
5050
{
@@ -89,7 +89,7 @@ public Optional<String> getQueryType()
8989
return queryType;
9090
}
9191

92-
public Optional<List<String>> getClientTags()
92+
public Optional<List<Pattern>> getClientTags()
9393
{
9494
return clientTags;
9595
}

plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/StaticSelector.java

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.trino.spi.resourcegroups.SelectionContext;
2121
import io.trino.spi.resourcegroups.SelectionCriteria;
2222

23+
import java.util.Collection;
2324
import java.util.HashMap;
2425
import java.util.HashSet;
2526
import java.util.List;
@@ -51,7 +52,7 @@ public StaticSelector(
5152
Optional<Pattern> originalUserRegex,
5253
Optional<Pattern> authenticatedUserRegex,
5354
Optional<Pattern> sourceRegex,
54-
Optional<List<String>> clientTags,
55+
Optional<List<Pattern>> clientTags,
5556
Optional<SelectorResourceEstimate> selectorResourceEstimate,
5657
Optional<String> queryType,
5758
ResourceGroupIdTemplate group)
@@ -90,8 +91,12 @@ public StaticSelector(
9091
new BasicMatcher(criteria -> queryTypeValue.equalsIgnoreCase(criteria.getQueryType().orElse("")))))
9192
.add(selectorResourceEstimate.map(selectorResourceEstimateValue ->
9293
new BasicMatcher(criteria -> selectorResourceEstimateValue.match(criteria.getResourceEstimates()))))
93-
.add(clientTags.map(clientTagsValue ->
94-
new BasicMatcher(criteria -> criteria.getTags().containsAll(clientTagsValue))))
94+
.add(clientTags.map(tags -> {
95+
for (Pattern tag : tags) {
96+
addNamedGroups(tag, variableNames);
97+
}
98+
return new PatternMatcher(variableNames, tags, SelectionCriteria::getTags);
99+
}))
95100
.build()
96101
.stream()
97102
.flatMap(Optional::stream) // remove any empty optionals
@@ -154,29 +159,40 @@ public void populateVariables(SelectionCriteria criteria, Map<String, String> va
154159
}
155160
}
156161

157-
private record PatternMatcher(Set<String> variableNames, Pattern pattern,
158-
Function<SelectionCriteria, String> valueExtractor)
162+
private record PatternMatcher(Set<String> variableNames, Collection<Pattern> patterns,
163+
Function<SelectionCriteria, Collection<String>> valuesExtractor)
159164
implements SelectionMatcher
160165
{
166+
public PatternMatcher(Set<String> variableNames, Pattern pattern,
167+
Function<SelectionCriteria, String> valuesExtractor)
168+
{
169+
this(variableNames, ImmutableList.of(pattern), valuesExtractor.andThen(ImmutableList::of));
170+
}
171+
161172
@Override
162173
public boolean matches(SelectionCriteria criteria)
163174
{
164-
return pattern.matcher(valueExtractor.apply(criteria)).matches();
175+
return patterns.stream().allMatch(pattern ->
176+
valuesExtractor.apply(criteria).stream().anyMatch(v -> pattern.matcher(v).matches()));
165177
}
166178

167179
@Override
168180
public void populateVariables(SelectionCriteria criteria, Map<String, String> variables)
169181
{
170-
Matcher matcher = pattern.matcher(valueExtractor.apply(criteria));
171-
if (!matcher.matches()) {
172-
return;
173-
}
174-
Map<String, Integer> namedGroups = matcher.namedGroups();
175-
for (String key : variableNames) {
176-
if (namedGroups.containsKey(key)) {
177-
String value = matcher.group(namedGroups.get(key));
178-
if (value != null) {
179-
variables.put(key, value);
182+
for (Pattern pattern : patterns) {
183+
for (String tagValue : valuesExtractor.apply(criteria)) {
184+
Matcher matcher = pattern.matcher(tagValue);
185+
if (!matcher.matches()) {
186+
continue;
187+
}
188+
Map<String, Integer> namedGroups = matcher.namedGroups();
189+
for (String key : variableNames) {
190+
if (namedGroups.containsKey(key)) {
191+
String value = matcher.group(namedGroups.get(key));
192+
if (value != null) {
193+
variables.put(key, value);
194+
}
195+
}
180196
}
181197
}
182198
}

plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/SelectorRecord.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class SelectorRecord
3939
private final Optional<Pattern> authenticatedUserRegex;
4040
private final Optional<Pattern> sourceRegex;
4141
private final Optional<String> queryType;
42-
private final Optional<List<String>> clientTags;
42+
private final Optional<List<Pattern>> clientTags;
4343
private final Optional<SelectorResourceEstimate> selectorResourceEstimate;
4444

4545
public SelectorRecord(
@@ -51,7 +51,7 @@ public SelectorRecord(
5151
Optional<Pattern> authenticatedUserRegex,
5252
Optional<Pattern> sourceRegex,
5353
Optional<String> queryType,
54-
Optional<List<String>> clientTags,
54+
Optional<List<Pattern>> clientTags,
5555
Optional<SelectorResourceEstimate> selectorResourceEstimate)
5656
{
5757
this.resourceGroupId = resourceGroupId;
@@ -106,7 +106,7 @@ public Optional<String> getQueryType()
106106
return queryType;
107107
}
108108

109-
public Optional<List<String>> getClientTags()
109+
public Optional<List<Pattern>> getClientTags()
110110
{
111111
return clientTags;
112112
}
@@ -119,7 +119,7 @@ public Optional<SelectorResourceEstimate> getSelectorResourceEstimate()
119119
public static class Mapper
120120
implements RowMapper<SelectorRecord>
121121
{
122-
private static final JsonCodec<List<String>> LIST_STRING_CODEC = listJsonCodec(String.class);
122+
private static final JsonCodec<List<Pattern>> LIST_PATTERN_CODEC = listJsonCodec(Pattern.class);
123123
private static final JsonCodec<SelectorResourceEstimate> SELECTOR_RESOURCE_ESTIMATE_JSON_CODEC = jsonCodec(SelectorResourceEstimate.class);
124124

125125
@Override
@@ -135,7 +135,7 @@ public SelectorRecord map(ResultSet resultSet, StatementContext context)
135135
Optional.ofNullable(resultSet.getString("authenticated_user_regex")).map(Pattern::compile),
136136
Optional.ofNullable(resultSet.getString("source_regex")).map(Pattern::compile),
137137
Optional.ofNullable(resultSet.getString("query_type")),
138-
Optional.ofNullable(resultSet.getString("client_tags")).map(LIST_STRING_CODEC::fromJson),
138+
Optional.ofNullable(resultSet.getString("client_tags")).map(LIST_PATTERN_CODEC::fromJson),
139139
Optional.ofNullable(resultSet.getString("selector_resource_estimate")).map(SELECTOR_RESOURCE_ESTIMATE_JSON_CODEC::fromJson));
140140
}
141141
}

plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/TestStaticSelector.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ public void testClientTags()
202202
Optional.empty(),
203203
Optional.empty(),
204204
Optional.empty(),
205-
Optional.of(ImmutableList.of("tag1", "tag2")),
205+
Optional.of(ImmutableList.of(Pattern.compile("tag1"), Pattern.compile("tag2"))),
206206
Optional.empty(),
207207
Optional.empty(),
208208
new ResourceGroupIdTemplate("global.foo"));
@@ -212,6 +212,46 @@ public void testClientTags()
212212
assertThat(selector.match(newSelectionCriteria("A.user", "a source b", ImmutableSet.of("tag1", "tag2", "tag3"), EMPTY_RESOURCE_ESTIMATES)).map(SelectionContext::getResourceGroupId)).isEqualTo(Optional.of(resourceGroupId));
213213
}
214214

215+
@Test
216+
public void testClientTagsRegex()
217+
{
218+
ResourceGroupId resourceGroupId = new ResourceGroupId(new ResourceGroupId("global"), "foo");
219+
StaticSelector selector = new StaticSelector(
220+
Optional.empty(),
221+
Optional.empty(),
222+
Optional.empty(),
223+
Optional.empty(),
224+
Optional.empty(),
225+
Optional.of(ImmutableList.of(Pattern.compile("tag1"), Pattern.compile("tagPattern.*"))),
226+
Optional.empty(),
227+
Optional.empty(),
228+
new ResourceGroupIdTemplate("global.foo"));
229+
assertThat(selector.match(newSelectionCriteria("userA", null, ImmutableSet.of("tag1", "tagPattern2"), EMPTY_RESOURCE_ESTIMATES)).map(SelectionContext::getResourceGroupId)).isEqualTo(Optional.of(resourceGroupId));
230+
assertThat(selector.match(newSelectionCriteria("userB", "source", ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES))).isEqualTo(Optional.empty());
231+
assertThat(selector.match(newSelectionCriteria("A.user", "a source b", ImmutableSet.of("tag1"), EMPTY_RESOURCE_ESTIMATES))).isEqualTo(Optional.empty());
232+
assertThat(selector.match(newSelectionCriteria("A.user", "a source b", ImmutableSet.of("tag1", "tagPattern222", "tag3"), EMPTY_RESOURCE_ESTIMATES)).map(SelectionContext::getResourceGroupId)).isEqualTo(Optional.of(resourceGroupId));
233+
}
234+
235+
@Test
236+
void testClientTagsRegexCustomGroup()
237+
{
238+
ResourceGroupId resourceGroupId = new ResourceGroupId(new ResourceGroupId("global"), "foo_userA_my_job");
239+
StaticSelector selector = new StaticSelector(
240+
Optional.empty(),
241+
Optional.empty(),
242+
Optional.empty(),
243+
Optional.empty(),
244+
Optional.empty(),
245+
Optional.of(ImmutableList.of(Pattern.compile("job_id=job_(?<jobid>.*)"))),
246+
Optional.empty(),
247+
Optional.empty(),
248+
new ResourceGroupIdTemplate("global.foo_${USER}_${jobid}"));
249+
assertThat(selector.match(newSelectionCriteria("userA", null, ImmutableSet.of("job_id=job_my_job"), EMPTY_RESOURCE_ESTIMATES))).hasValueSatisfying(context -> {
250+
assertThat(context.getResourceGroupId()).isEqualTo(resourceGroupId);
251+
assertThat(context.getContext().getVariableNames()).containsExactlyInAnyOrder("jobid", "USER");
252+
});
253+
}
254+
215255
@Test
216256
public void testSelectorResourceEstimate()
217257
{

plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestResourceGroupsDao.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ private static void testSelectorInsert(H2ResourceGroupsDao dao, Map<Long, Select
142142
Optional.of(Pattern.compile("admin_auth_user")),
143143
Optional.of(Pattern.compile(".*")),
144144
Optional.of(EXPLAIN.name()),
145-
Optional.of(ImmutableList.of("tag1", "tag2")),
145+
Optional.of(ImmutableList.of(Pattern.compile("tag1"), Pattern.compile("tag2"))),
146146
Optional.empty()));
147147
map.put(4L,
148148
new SelectorRecord(
@@ -181,7 +181,7 @@ private static void testSelectorUpdate(H2ResourceGroupsDao dao, Map<Long, Select
181181
Optional.of(Pattern.compile("ping_auth.*")),
182182
Optional.of(Pattern.compile("ping_source")),
183183
Optional.empty(),
184-
Optional.of(ImmutableList.of("tag1")),
184+
Optional.of(ImmutableList.of(Pattern.compile("tag1"))),
185185
Optional.empty());
186186
map.put(2L, updated);
187187
compareSelectors(map, dao.getSelectors(ENVIRONMENT));
@@ -202,7 +202,7 @@ private static void testSelectorUpdateNull(H2ResourceGroupsDao dao, Map<Long, Se
202202
Optional.of(Pattern.compile("ping_auth.*")),
203203
Optional.of(Pattern.compile("ping_source")),
204204
Optional.of(EXPLAIN.name()),
205-
Optional.of(ImmutableList.of("tag1", "tag2")),
205+
Optional.of(ImmutableList.of(Pattern.compile("tag1"), Pattern.compile("tag2"))),
206206
Optional.empty());
207207
map.put(2L, updated);
208208
dao.updateSelector(2, "ping.*", "ping_gr.*", "ping_original.*", "ping_auth.*", "ping_source", LIST_STRING_CODEC.toJson(ImmutableList.of("tag1", "tag2")), null, null, null, null, null, null);

0 commit comments

Comments
 (0)