Skip to content

Commit 5d42f00

Browse files
committed
Implement AGGREATE pushdown in Exasol Connector
1 parent 2c8a365 commit 5d42f00

File tree

5 files changed

+178
-30
lines changed

5 files changed

+178
-30
lines changed

plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/StandardColumnMappings.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ public static ColumnMapping decimalColumnMapping(DecimalType decimalType, Roundi
177177
checkArgument(roundingMode == UNNECESSARY, "Round mode is not supported for short decimal, map the type to long decimal instead");
178178
return ColumnMapping.longMapping(
179179
decimalType,
180-
shortDecimalReadFunction(decimalType),
180+
shortDecimalReadFunction(decimalType, roundingMode),
181181
shortDecimalWriteFunction(decimalType));
182182
}
183183
return ColumnMapping.objectMapping(
184184
decimalType,
185-
longDecimalReadFunction(decimalType, roundingMode),
185+
longDecimalReadFunction(decimalType),
186186
longDecimalWriteFunction(decimalType));
187187
}
188188

plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,12 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
150150
return switch (connectorBehavior) {
151151
case SUPPORTS_UPDATE -> true;
152152
case SUPPORTS_ADD_COLUMN_WITH_POSITION,
153-
SUPPORTS_CREATE_MATERIALIZED_VIEW,
154-
SUPPORTS_CREATE_VIEW,
155-
SUPPORTS_DEFAULT_COLUMN_VALUE,
156-
SUPPORTS_MERGE,
157-
SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN,
158-
SUPPORTS_ROW_LEVEL_UPDATE -> false;
153+
SUPPORTS_CREATE_MATERIALIZED_VIEW,
154+
SUPPORTS_CREATE_VIEW,
155+
SUPPORTS_DEFAULT_COLUMN_VALUE,
156+
SUPPORTS_MERGE,
157+
SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN,
158+
SUPPORTS_ROW_LEVEL_UPDATE -> false;
159159
// Dynamic filters can be pushed down only if predicate push down is supported.
160160
// It is possible for a connector to have predicate push down support but not push down dynamic filters.
161161
// TODO default SUPPORTS_DYNAMIC_FILTER_PUSHDOWN to SUPPORTS_PREDICATE_PUSHDOWN
@@ -621,15 +621,15 @@ public void testNumericAggregationPushdown()
621621
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
622622
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
623623
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
624-
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
624+
assertNumericAveragePushdown(emptyTable);
625625
}
626626

627627
try (TestTable testTable = createAggregationTestTable(schemaName + ".test_num_agg_pd",
628628
ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) {
629629
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTable.getName())).isFullyPushedDown();
630630
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTable.getName())).isFullyPushedDown();
631631
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTable.getName())).isFullyPushedDown();
632-
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown();
632+
assertNumericAveragePushdown(testTable);
633633

634634
// smoke testing of more complex cases
635635
// WHERE on aggregation column
@@ -647,6 +647,11 @@ public void testNumericAggregationPushdown()
647647
}
648648
}
649649

650+
protected void assertNumericAveragePushdown(TestTable testTable)
651+
{
652+
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown();
653+
}
654+
650655
@Test
651656
public void testCountDistinctWithStringTypes()
652657
{
@@ -1153,12 +1158,12 @@ public void testArithmeticPredicatePushdown()
11531158

11541159
assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % nationkey = 2"))
11551160
.isFullyPushedDown()
1156-
.matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')");
1161+
.matches(getArithmeticPredicatePushdownExpectedValues());
11571162

11581163
// some databases calculate remainder instead of modulus when one of the values is negative
11591164
assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % -nationkey = 2"))
11601165
.isFullyPushedDown()
1161-
.matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')");
1166+
.matches(getArithmeticPredicatePushdownExpectedValues());
11621167

11631168
assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % 0 = 2"))
11641169
.failure().hasMessageContaining("by zero");
@@ -1170,6 +1175,11 @@ public void testArithmeticPredicatePushdown()
11701175
// TODO add coverage for other arithmetic pushdowns https://github.com/trinodb/trino/issues/14808
11711176
}
11721177

1178+
protected String getArithmeticPredicatePushdownExpectedValues()
1179+
{
1180+
return "VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')";
1181+
}
1182+
11731183
@Test
11741184
public void testCaseSensitiveTopNPushdown()
11751185
{
@@ -1307,7 +1317,8 @@ public void testJoinPushdown()
13071317
assertThat(query(session, format("SELECT n.name FROM nation n %s orders o ON DATE '2025-03-19' = o.orderdate", joinOperator))).joinIsNotFullyPushedDown();
13081318

13091319
// no projection on the probe side, only filter
1310-
assertJoinConditionallyPushedDown(session, format("SELECT n.name FROM nation n %s orders o ON n.regionkey = 1", joinOperator),
1320+
// reduced the size of the join table to make the test faster: instead of joining on large orders table join only on one record
1321+
assertJoinConditionallyPushedDown(session, format("SELECT n.name FROM nation n %s (SELECT * FROM orders WHERE orderkey = 1) o ON n.regionkey = 1", joinOperator),
13111322
expectJoinPushdownOnEmptyProjection(joinOperator));
13121323

13131324
// pushdown when using USING

plugin/trino-exasol/src/main/java/io/trino/plugin/exasol/ExasolClient.java

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import com.google.common.collect.ImmutableSet;
1717
import com.google.inject.Inject;
1818
import io.airlift.slice.Slices;
19+
import io.trino.plugin.base.aggregation.AggregateFunctionRewriter;
20+
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
1921
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
2022
import io.trino.plugin.base.mapping.IdentifierMapping;
2123
import io.trino.plugin.jdbc.BaseJdbcClient;
@@ -36,6 +38,22 @@
3638
import io.trino.plugin.jdbc.SliceWriteFunction;
3739
import io.trino.plugin.jdbc.WriteFunction;
3840
import io.trino.plugin.jdbc.WriteMapping;
41+
import io.trino.plugin.jdbc.aggregation.ImplementAvgDecimal;
42+
import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint;
43+
import io.trino.plugin.jdbc.aggregation.ImplementCorr;
44+
import io.trino.plugin.jdbc.aggregation.ImplementCount;
45+
import io.trino.plugin.jdbc.aggregation.ImplementCountAll;
46+
import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct;
47+
import io.trino.plugin.jdbc.aggregation.ImplementCovariancePop;
48+
import io.trino.plugin.jdbc.aggregation.ImplementCovarianceSamp;
49+
import io.trino.plugin.jdbc.aggregation.ImplementMinMax;
50+
import io.trino.plugin.jdbc.aggregation.ImplementRegrIntercept;
51+
import io.trino.plugin.jdbc.aggregation.ImplementRegrSlope;
52+
import io.trino.plugin.jdbc.aggregation.ImplementStddevPop;
53+
import io.trino.plugin.jdbc.aggregation.ImplementStddevSamp;
54+
import io.trino.plugin.jdbc.aggregation.ImplementSum;
55+
import io.trino.plugin.jdbc.aggregation.ImplementVariancePop;
56+
import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp;
3957
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
4058
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
4159
import io.trino.plugin.jdbc.expression.RewriteIn;
@@ -50,6 +68,7 @@
5068
import io.trino.spi.expression.ConnectorExpression;
5169
import io.trino.spi.type.DecimalType;
5270
import io.trino.spi.type.Type;
71+
import io.trino.spi.type.VarcharType;
5372

5473
import java.sql.Connection;
5574
import java.sql.Date;
@@ -73,6 +92,7 @@
7392
import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction;
7493
import static io.trino.plugin.jdbc.StandardColumnMappings.shortDecimalWriteFunction;
7594
import static io.trino.plugin.jdbc.StandardColumnMappings.smallintColumnMapping;
95+
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction;
7696
import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling;
7797
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
7898
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
@@ -91,7 +111,9 @@ public class ExasolClient
91111
.add("EXA_STATISTICS")
92112
.add("SYS")
93113
.build();
114+
public static final int MAX_EXASOL_DECIMAL_PRECISION = 36;
94115
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
116+
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;
95117

96118
@Inject
97119
public ExasolClient(
@@ -102,12 +124,13 @@ public ExasolClient(
102124
RemoteQueryModifier queryModifier)
103125
{
104126
super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false);
105-
// Basic implementation required to enable JOIN pushdown support
106-
// It is covered by "testJoinpushdown" integration tests.
107-
// More detailed test case scenarios are covered by Unit tests in "TestConvertPredicate"
127+
// Basic implementation required to enable JOIN and AGGREGATION pushdown support
128+
// It is covered by "testJoinPushdown" and "testAggregationPushdown" integration tests.
129+
// More detailed test case scenarios are covered by Unit tests in "TestExasolClient"
108130
this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
109131
.addStandardRules(this::quoted)
110132
.add(new RewriteIn())
133+
.withTypeClass("numeric_type", ImmutableSet.of("decimal", "double"))
111134
.map("$equal(left, right)").to("left = right")
112135
.map("$not_equal(left, right)").to("left <> right")
113136
// Exasol doesn't support "IS NOT DISTINCT FROM" expression,
@@ -120,7 +143,38 @@ public ExasolClient(
120143
.map("$not($is_null(value))").to("value IS NOT NULL")
121144
.map("$not(value: boolean)").to("NOT value")
122145
.map("$is_null(value)").to("value IS NULL")
146+
.map("$add(left: numeric_type, right: numeric_type)").to("left + right")
147+
.map("$subtract(left: numeric_type, right: numeric_type)").to("left - right")
148+
.map("$multiply(left: numeric_type, right: numeric_type)").to("left * right")
149+
.map("$divide(left: numeric_type, right: numeric_type)").to("left / right")
150+
.map("$modulus(left: numeric_type, right: numeric_type)").to("mod(left, right)")
151+
.map("$negate(value: numeric_type)").to("-value")
152+
.map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern")
153+
.map("$like(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape")
154+
.map("$nullif(first, second)").to("NULLIF(first, second)")
123155
.build();
156+
JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
157+
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
158+
this.connectorExpressionRewriter,
159+
ImmutableSet.<AggregateFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
160+
.add(new ImplementCountAll(bigintTypeHandle))
161+
.add(new ImplementMinMax(true))
162+
.add(new ImplementCount(bigintTypeHandle))
163+
.add(new ImplementCountDistinct(bigintTypeHandle, true))
164+
.add(new ImplementSum(ExasolClient::toSumTypeHandle))
165+
.add(new ImplementAvgFloatingPoint())
166+
.add(new ImplementAvgDecimal())
167+
.add(new ImplementExasolAvgBigInt())
168+
.add(new ImplementStddevSamp())
169+
.add(new ImplementStddevPop())
170+
.add(new ImplementVarianceSamp())
171+
.add(new ImplementVariancePop())
172+
.add(new ImplementCovarianceSamp())
173+
.add(new ImplementCovariancePop())
174+
.add(new ImplementCorr())
175+
.add(new ImplementRegrIntercept())
176+
.add(new ImplementRegrSlope())
177+
.build());
124178
}
125179

126180
@Override
@@ -234,11 +288,22 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon
234288
return true;
235289
}
236290

291+
@Override
292+
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
293+
{
294+
return true;
295+
}
296+
237297
@Override
238298
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
239299
{
240-
// Deactivated because test 'testCaseSensitiveAggregationPushdown()' requires write access which is not implemented for Exasol
241-
return Optional.empty();
300+
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
301+
}
302+
303+
@Override
304+
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
305+
{
306+
return true;
242307
}
243308

244309
@Override
@@ -263,9 +328,9 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
263328
case Types.DOUBLE:
264329
return Optional.of(doubleColumnMapping());
265330
case Types.DECIMAL:
266-
int decimalDigits = typeHandle.requiredDecimalDigits();
267-
int columnSize = typeHandle.requiredColumnSize();
268-
return Optional.of(decimalColumnMapping(createDecimalType(columnSize, decimalDigits)));
331+
int precision = typeHandle.requiredColumnSize();
332+
int scale = typeHandle.requiredDecimalDigits();
333+
return Optional.of(decimalColumnMapping(createDecimalType(precision, scale)));
269334
case Types.CHAR:
270335
return Optional.of(defaultCharColumnMapping(typeHandle.requiredColumnSize(), true));
271336
case Types.VARCHAR:
@@ -289,6 +354,12 @@ private boolean isHashType(JdbcTypeHandle typeHandle)
289354
&& typeHandle.jdbcTypeName().get().equalsIgnoreCase("HASHTYPE");
290355
}
291356

357+
private static Optional<JdbcTypeHandle> toSumTypeHandle(DecimalType decimalType)
358+
{
359+
return Optional.of(new JdbcTypeHandle(Types.DECIMAL, Optional.of("decimal"),
360+
Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
361+
}
362+
292363
private static ColumnMapping dateColumnMapping()
293364
{
294365
// Exasol driver does not support LocalDate
@@ -350,6 +421,16 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
350421
}
351422
return WriteMapping.objectMapping(dataType, longDecimalWriteFunction(decimalType));
352423
}
424+
if (type instanceof VarcharType varcharType) {
425+
String dataType;
426+
if (varcharType.isUnbounded()) {
427+
dataType = "varchar";
428+
}
429+
else {
430+
dataType = "varchar(" + varcharType.getBoundedLength() + ")";
431+
}
432+
return WriteMapping.sliceMapping(dataType, varcharWriteFunction());
433+
}
353434

354435
throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName());
355436
}
@@ -398,10 +479,4 @@ public boolean isLimitGuaranteed(ConnectorSession session)
398479
{
399480
return true;
400481
}
401-
402-
@Override
403-
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
404-
{
405-
return true;
406-
}
407482
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.plugin.exasol;
15+
16+
import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint;
17+
18+
public class ImplementExasolAvgBigInt
19+
extends BaseImplementAvgBigint
20+
{
21+
@Override
22+
protected String getRewriteFormatExpression()
23+
{
24+
return "avg(CAST(%s AS double))";
25+
}
26+
}

0 commit comments

Comments
 (0)