Skip to content

Commit 93e3d03

Browse files
committed
Swap the order of chart output to ensure metrics come last
Signed-off-by: Yuanchun Shen <[email protected]>
1 parent 1a6c5c5 commit 93e3d03

File tree

3 files changed

+93
-73
lines changed

3 files changed

+93
-73
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,19 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
11081108

11091109
@Override
11101110
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
1111+
visitAggregation(node, context, true);
1112+
return context.relBuilder.peek();
1113+
}
1114+
1115+
/**
1116+
* Visits an aggregation node and builds the corresponding Calcite RelNode.
1117+
*
1118+
* @param node the aggregation node containing group expressions and aggregation functions
1119+
* @param context the Calcite plan context for building RelNodes
1120+
* @param aggFirst if true, aggregation results (metrics) appear first in output schema (agg,
1121+
* group-by fields); if false, group expressions appear first (group-by fields, agg).
1122+
*/
1123+
private void visitAggregation(Aggregation node, CalcitePlanContext context, boolean aggFirst) {
11111124
visitChildren(node, context);
11121125

11131126
List<UnresolvedExpression> aggExprList = node.getAggExprList();
@@ -1152,17 +1165,13 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
11521165
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
11531166

11541167
// schema reordering
1155-
// As an example, in command `stats count() by colA, colB`,
1156-
// the sequence of output schema is "count, colA, colB".
11571168
List<RexNode> outputFields = context.relBuilder.fields();
11581169
int numOfOutputFields = outputFields.size();
11591170
int numOfAggList = aggExprList.size();
11601171
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
11611172
// Add aggregation results first
11621173
List<RexNode> aggRexList =
11631174
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
1164-
reordered.addAll(aggRexList);
1165-
// Add group by columns
11661175
List<RexNode> aliasedGroupByList =
11671176
aggregationAttributes.getLeft().stream()
11681177
.map(this::extractAliasLiteral)
@@ -1171,10 +1180,17 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
11711180
.map(context.relBuilder::field)
11721181
.map(f -> (RexNode) f)
11731182
.toList();
1174-
reordered.addAll(aliasedGroupByList);
1183+
if (aggFirst) {
1184+
// As an example, in command `stats count() by colA, colB`,
1185+
// the sequence of output schema is "count, colA, colB".
1186+
reordered.addAll(aggRexList);
1187+
// Add group by columns
1188+
reordered.addAll(aliasedGroupByList);
1189+
} else {
1190+
reordered.addAll(aliasedGroupByList);
1191+
reordered.addAll(aggRexList);
1192+
}
11751193
context.relBuilder.project(reordered);
1176-
1177-
return context.relBuilder.peek();
11781194
}
11791195

11801196
private Optional<UnresolvedExpression> getTimeSpanField(UnresolvedExpression expr) {
@@ -2038,7 +2054,13 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20382054
groupExprList,
20392055
null,
20402056
List.of(new Argument(Argument.BUCKET_NULLABLE, AstDSL.booleanLiteral(config.useNull))));
2041-
RelNode aggregated = visitAggregation(aggregation, context);
2057+
visitAggregation(aggregation, context, false);
2058+
RelBuilder relBuilder = context.relBuilder;
2059+
String columnSplitName =
2060+
relBuilder.peek().getRowType().getFieldNames().size() > 2
2061+
? relBuilder.peek().getRowType().getFieldNames().get(1)
2062+
: null;
2063+
RelNode aggregated = context.relBuilder.peek();
20422064

20432065
// If row or column split does not present or limit equals 0, this is the same as `stats agg
20442066
// [group by col]` because all truncating is performed on the column split
@@ -2058,25 +2080,23 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20582080

20592081
// Convert the column split to string if necessary: column split was supposed to be pivoted to
20602082
// column names. This guarantees that its type compatibility with useother and usenull
2061-
RelBuilder relBuilder = context.relBuilder;
2062-
RexNode colSplit = relBuilder.field(2);
2063-
String columSplitName = relBuilder.peek().getRowType().getFieldNames().getLast();
2083+
RexNode colSplit = relBuilder.field(1);
2084+
String columSplitName = relBuilder.peek().getRowType().getFieldNames().get(1);
20642085
if (!SqlTypeUtil.isCharacter(colSplit.getType())) {
20652086
colSplit =
20662087
relBuilder.alias(
20672088
context.rexBuilder.makeCast(
20682089
UserDefinedFunctionUtils.NULLABLE_STRING, colSplit, true, true),
20692090
columSplitName);
20702091
}
2071-
relBuilder.project(relBuilder.field(0), relBuilder.field(1), colSplit);
2092+
relBuilder.project(relBuilder.field(0), colSplit, relBuilder.field(2));
20722093
aggregated = relBuilder.peek();
20732094

2074-
// 0: agg; 2: column-split
2075-
relBuilder.project(relBuilder.field(0), relBuilder.field(2));
2076-
// 1: column split; 0: agg
2095+
// 1: column-split, 2: agg
2096+
relBuilder.project(relBuilder.field(1), relBuilder.field(2));
20772097
relBuilder.aggregate(
2078-
relBuilder.groupKey(relBuilder.field(1)),
2079-
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(0))
2098+
relBuilder.groupKey(relBuilder.field(0)),
2099+
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(1))
20802100
.as("__grand_total__")); // results: group key, agg calls
20812101
RexNode grandTotal = relBuilder.field("__grand_total__");
20822102
// Apply sorting: for MIN/EARLIEST, reverse the top/bottom logic
@@ -2105,9 +2125,9 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
21052125

21062126
// on column-split = group key
21072127
relBuilder.join(
2108-
JoinRelType.LEFT, relBuilder.equals(relBuilder.field(2, 0, 2), relBuilder.field(2, 1, 0)));
2128+
JoinRelType.LEFT, relBuilder.equals(relBuilder.field(2, 0, 1), relBuilder.field(2, 1, 0)));
21092129

2110-
RexNode colSplitPostJoin = relBuilder.field(2);
2130+
RexNode colSplitPostJoin = relBuilder.field(1);
21112131
RexNode lteCondition =
21122132
relBuilder.call(
21132133
SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
@@ -2126,25 +2146,25 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
21262146
nullCondition,
21272147
relBuilder.literal(config.nullStr),
21282148
lteCondition,
2129-
relBuilder.field(2),
2149+
relBuilder.field(1), // col split
21302150
relBuilder.literal(config.otherStr));
21312151
} else {
21322152
columnSplitExpr =
21332153
relBuilder.call(
21342154
SqlStdOperatorTable.CASE,
21352155
lteCondition,
2136-
relBuilder.field(2),
2156+
relBuilder.field(1),
21372157
relBuilder.literal(config.otherStr));
21382158
}
21392159

2140-
String aggFieldName = relBuilder.peek().getRowType().getFieldNames().getFirst();
2160+
String aggFieldName = relBuilder.peek().getRowType().getFieldNames().get(2);
21412161
relBuilder.project(
21422162
relBuilder.field(0),
2143-
relBuilder.field(1),
2144-
relBuilder.alias(columnSplitExpr, columSplitName));
2163+
relBuilder.alias(columnSplitExpr, columnSplitName),
2164+
relBuilder.field(2));
21452165
relBuilder.aggregate(
2146-
relBuilder.groupKey(relBuilder.field(1), relBuilder.field(2)),
2147-
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(0)).as(aggFieldName));
2166+
relBuilder.groupKey(relBuilder.field(0), relBuilder.field(1)),
2167+
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(2)).as(aggFieldName));
21482168
return relBuilder.peek();
21492169
}
21502170

docs/user/ppl/cmd/chart.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ PPL query::
109109

110110
os> source=accounts | chart count() by gender
111111
fetched rows / total rows = 2/2
112-
+---------+--------+
113-
| count() | gender |
114-
|---------+--------|
115-
| 1 | F |
116-
| 3 | M |
117-
+---------+--------+
112+
+--------+---------+
113+
| gender | count() |
114+
|--------+---------|
115+
| F | 1 |
116+
| M | 3 |
117+
+--------+---------+
118118

119119
Example 3: Using over and by for multiple field grouping
120120
--------------------------------------------------------

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteChartCommandIT.java

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ public void init() throws Exception {
3535
public void testChartWithSingleGroupKey() throws IOException {
3636
JSONObject result1 =
3737
executeQuery(String.format("source=%s | chart avg(balance) by gender", TEST_INDEX_BANK));
38-
verifySchema(result1, schema("avg(balance)", "double"), schema("gender", "string"));
39-
verifyDataRows(result1, rows(40488, "F"), rows(16377.25, "M"));
38+
verifySchema(result1, schema("gender", "string"), schema("avg(balance)", "double"));
39+
verifyDataRows(result1, rows("F", 40488), rows("M", 16377.25));
4040
JSONObject result2 =
4141
executeQuery(String.format("source=%s | chart avg(balance) over gender", TEST_INDEX_BANK));
4242
assertJsonEquals(result1.toString(), result2.toString());
@@ -74,27 +74,27 @@ public void testChartCombineOverByWithLimit0() throws IOException {
7474
"source=%s | chart limit=0 avg(balance) over state by gender", TEST_INDEX_BANK));
7575
verifySchema(
7676
result,
77-
schema("avg(balance)", "double"),
7877
schema("state", "string"),
79-
schema("gender", "string"));
78+
schema("gender", "string"),
79+
schema("avg(balance)", "double"));
8080
verifyDataRows(
8181
result,
82-
rows(39225.0, "IL", "M"),
83-
rows(48086.0, "IN", "F"),
84-
rows(4180.0, "MD", "M"),
85-
rows(40540.0, "PA", "F"),
86-
rows(5686.0, "TN", "M"),
87-
rows(32838.0, "VA", "F"),
88-
rows(16418.0, "WA", "M"));
82+
rows("IL", "M", 39225.0),
83+
rows("IN", "F", 48086.0),
84+
rows("MD", "M", 4180.0),
85+
rows("PA", "F", 40540.0),
86+
rows("TN", "M", 5686.0),
87+
rows("VA", "F", 32838.0),
88+
rows("WA", "M", 16418.0));
8989
}
9090

9191
@Test
9292
public void testChartMaxBalanceByAgeSpan() throws IOException {
9393
JSONObject result =
9494
executeQuery(
9595
String.format("source=%s | chart max(balance) by age span=10", TEST_INDEX_BANK));
96-
verifySchema(result, schema("max(balance)", "bigint"), schema("age", "int"));
97-
verifyDataRows(result, rows(32838, 20), rows(48086, 30));
96+
verifySchema(result, schema("age", "int"), schema("max(balance)", "bigint"));
97+
verifyDataRows(result, rows(20, 32838), rows(30, 48086));
9898
}
9999

100100
@Test
@@ -172,37 +172,37 @@ public void testChartLimit0WithUseOther() throws IOException {
172172
TEST_INDEX_OTEL_LOGS));
173173
verifySchema(
174174
result,
175-
schema("max(severityNumber)", "bigint"),
176175
schema("flags", "bigint"),
177-
schema("severityText", "string"));
176+
schema("severityText", "string"),
177+
schema("max(severityNumber)", "bigint"));
178178
verifyDataRows(
179179
result,
180-
rows(5, 0, "DEBUG"),
181-
rows(6, 0, "DEBUG2"),
182-
rows(7, 0, "DEBUG3"),
183-
rows(8, 0, "DEBUG4"),
184-
rows(17, 0, "ERROR"),
185-
rows(18, 0, "ERROR2"),
186-
rows(19, 0, "ERROR3"),
187-
rows(20, 0, "ERROR4"),
188-
rows(21, 0, "FATAL"),
189-
rows(22, 0, "FATAL2"),
190-
rows(23, 0, "FATAL3"),
191-
rows(24, 0, "FATAL4"),
192-
rows(9, 0, "INFO"),
193-
rows(10, 0, "INFO2"),
194-
rows(11, 0, "INFO3"),
195-
rows(12, 0, "INFO4"),
196-
rows(2, 0, "TRACE2"),
197-
rows(3, 0, "TRACE3"),
198-
rows(4, 0, "TRACE4"),
199-
rows(13, 0, "WARN"),
200-
rows(14, 0, "WARN2"),
201-
rows(15, 0, "WARN3"),
202-
rows(16, 0, "WARN4"),
203-
rows(17, 1, "ERROR"),
204-
rows(9, 1, "INFO"),
205-
rows(1, 1, "TRACE"));
180+
rows(0, "DEBUG", 5),
181+
rows(0, "DEBUG2", 6),
182+
rows(0, "DEBUG3", 7),
183+
rows(0, "DEBUG4", 8),
184+
rows(0, "ERROR", 17),
185+
rows(0, "ERROR2", 18),
186+
rows(0, "ERROR3", 19),
187+
rows(0, "ERROR4", 20),
188+
rows(0, "FATAL", 21),
189+
rows(0, "FATAL2", 22),
190+
rows(0, "FATAL3", 23),
191+
rows(0, "FATAL4", 24),
192+
rows(0, "INFO", 9),
193+
rows(0, "INFO2", 10),
194+
rows(0, "INFO3", 11),
195+
rows(0, "INFO4", 12),
196+
rows(0, "TRACE2", 2),
197+
rows(0, "TRACE3", 3),
198+
rows(0, "TRACE4", 4),
199+
rows(0, "WARN", 13),
200+
rows(0, "WARN2", 14),
201+
rows(0, "WARN3", 15),
202+
rows(0, "WARN4", 16),
203+
rows(1, "ERROR", 17),
204+
rows(1, "INFO", 9),
205+
rows(1, "TRACE", 1));
206206
}
207207

208208
@Test

0 commit comments

Comments
 (0)