Skip to content

Commit e2e8def

Browse files
committed
Apply complete PR opensearch-project#4612 changes: multi-field binning, AVG rounding, timestamp support
Signed-off-by: Kai Huang <[email protected]>
1 parent 6471779 commit e2e8def

File tree

1 file changed

+199
-23
lines changed

1 file changed

+199
-23
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 199 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,22 @@ private void validateWildcardPatterns(
462462
}
463463
}
464464

465+
/** Extract field name from UnresolvedExpression. Handles Field and Alias expressions. */
466+
private String extractFieldName(UnresolvedExpression expr) {
467+
if (expr instanceof org.opensearch.sql.ast.expression.Field) {
468+
return ((org.opensearch.sql.ast.expression.Field) expr).getField().toString();
469+
} else if (expr instanceof org.opensearch.sql.ast.expression.Alias) {
470+
org.opensearch.sql.ast.expression.Alias alias =
471+
(org.opensearch.sql.ast.expression.Alias) expr;
472+
if (alias.getDelegated() instanceof org.opensearch.sql.ast.expression.Field) {
473+
return ((org.opensearch.sql.ast.expression.Field) alias.getDelegated())
474+
.getField()
475+
.toString();
476+
}
477+
}
478+
return null;
479+
}
480+
465481
private boolean isMetadataField(String fieldName) {
466482
return OpenSearchConstants.METADATAFIELD_TYPE_MAP.containsKey(fieldName);
467483
}
@@ -671,7 +687,26 @@ public RelNode visitBin(Bin node, CalcitePlanContext context) {
671687
RexNode binExpression = BinUtils.createBinExpression(node, fieldExpr, context, rexVisitor);
672688

673689
String alias = node.getAlias() != null ? node.getAlias() : fieldName;
674-
projectPlusOverriding(List.of(binExpression), List.of(alias), context);
690+
691+
// Check if this field is used in aggregation grouping with multiple fields
692+
if (context.getAggregationGroupByFields().contains(fieldName)
693+
&& context.getAggregationGroupByCount() > 1) {
694+
// For multi-field aggregation: preserve BOTH original field and binned field
695+
// The binned field (fieldName_bin) is used for grouping
696+
// The original field is used for MIN aggregation to show actual timestamps per group
697+
List<RexNode> allFields = new ArrayList<>(context.relBuilder.fields());
698+
List<String> allFieldNames =
699+
new ArrayList<>(context.relBuilder.peek().getRowType().getFieldNames());
700+
701+
// Add the binned field with _bin suffix
702+
allFields.add(binExpression);
703+
allFieldNames.add(fieldName + "_bin");
704+
705+
context.relBuilder.project(allFields, allFieldNames);
706+
} else {
707+
// For non-aggregation queries OR single-field binning: replace field with binned value
708+
projectPlusOverriding(List.of(binExpression), List.of(alias), context);
709+
}
675710

676711
return context.relBuilder.peek();
677712
}
@@ -1084,20 +1119,98 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
10841119
CalcitePlanContext context) {
10851120
List<AggCall> aggCallList =
10861121
aggExprList.stream().map(expr -> aggVisitor.analyze(expr, context)).toList();
1087-
List<RexNode> groupByList =
1088-
groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
1089-
return Pair.of(groupByList, aggCallList);
1122+
1123+
// Get available field names in the current relation
1124+
List<String> availableFields = context.relBuilder.peek().getRowType().getFieldNames();
1125+
1126+
// Build group-by list, replacing fields with their _bin columns if they exist
1127+
List<RexNode> groupByList = new ArrayList<>();
1128+
List<AggCall> additionalAggCalls = new ArrayList<>();
1129+
1130+
// Track if we have bin columns - we'll need this to decide whether to add MIN aggregations
1131+
boolean hasBinColumns = false;
1132+
int nonBinGroupByCount = 0;
1133+
1134+
for (UnresolvedExpression groupExpr : groupExprList) {
1135+
RexNode resolvedExpr = rexVisitor.analyze(groupExpr, context);
1136+
1137+
// Extract field name from UnresolvedExpression
1138+
String fieldName = extractFieldName(groupExpr);
1139+
1140+
// Check if this field has a corresponding _bin column
1141+
if (fieldName != null) {
1142+
String binColumnName = fieldName + "_bin";
1143+
if (availableFields.contains(binColumnName)) {
1144+
// Use the _bin column for grouping
1145+
groupByList.add(context.relBuilder.field(binColumnName));
1146+
hasBinColumns = true;
1147+
continue;
1148+
}
1149+
}
1150+
1151+
// Regular group-by field
1152+
groupByList.add(resolvedExpr);
1153+
nonBinGroupByCount++;
1154+
}
1155+
1156+
// Only add MIN aggregations for bin columns if there are OTHER group-by fields
1157+
// This matches OpenSearch behavior:
1158+
// - With multi-field grouping (e.g., by region, timestamp): Show MIN(timestamp) per group
1159+
// - With single-field grouping (e.g., by timestamp only): Show bin start time
1160+
if (hasBinColumns && nonBinGroupByCount > 0) {
1161+
for (UnresolvedExpression groupExpr : groupExprList) {
1162+
String fieldName = extractFieldName(groupExpr);
1163+
if (fieldName != null) {
1164+
String binColumnName = fieldName + "_bin";
1165+
if (availableFields.contains(binColumnName)) {
1166+
// Add MIN(original_field) to show minimum timestamp per bin
1167+
additionalAggCalls.add(
1168+
context.relBuilder.min(context.relBuilder.field(fieldName)).as(fieldName));
1169+
}
1170+
}
1171+
}
1172+
}
1173+
1174+
// Combine original aggregations with additional MIN aggregations for binned fields
1175+
List<AggCall> combinedAggCalls = new ArrayList<>(aggCallList);
1176+
combinedAggCalls.addAll(additionalAggCalls);
1177+
1178+
return Pair.of(groupByList, combinedAggCalls);
10901179
}
10911180

10921181
@Override
10931182
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
1094-
visitChildren(node, context);
1095-
1183+
// Prepare partition columns for bin operations before visiting children
1184+
// This allows WIDTH_BUCKET to use per-group min/max (matching auto_date_histogram)
10961185
List<UnresolvedExpression> aggExprList = node.getAggExprList();
10971186
List<UnresolvedExpression> groupExprList = new ArrayList<>();
1187+
UnresolvedExpression span = node.getSpan();
1188+
if (Objects.nonNull(span)) {
1189+
groupExprList.add(span);
1190+
}
1191+
groupExprList.addAll(node.getGroupExprList());
1192+
1193+
// Store group-by field names and count so bin operations can preserve original fields
1194+
java.util.Set<String> savedGroupByFields = context.getAggregationGroupByFields();
1195+
int savedGroupByCount = context.getAggregationGroupByCount();
1196+
context.setAggregationGroupByFields(new java.util.HashSet<>());
1197+
context.setAggregationGroupByCount(groupExprList.size());
1198+
for (UnresolvedExpression groupExpr : groupExprList) {
1199+
String fieldName = extractFieldName(groupExpr);
1200+
if (fieldName != null) {
1201+
context.getAggregationGroupByFields().add(fieldName);
1202+
}
1203+
}
1204+
1205+
visitChildren(node, context);
1206+
1207+
// Restore previous group-by fields and count
1208+
context.setAggregationGroupByFields(savedGroupByFields);
1209+
context.setAggregationGroupByCount(savedGroupByCount);
1210+
1211+
groupExprList.clear();
10981212
// The span column is always the first column in result whatever
10991213
// the order of span in query is first or last one
1100-
UnresolvedExpression span = node.getSpan();
11011214
if (Objects.nonNull(span)) {
11021215
groupExprList.add(span);
11031216
List<RexNode> timeSpanFilters =
@@ -1142,23 +1255,86 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
11421255
// the sequence of output schema is "count, colA, colB".
11431256
List<RexNode> outputFields = context.relBuilder.fields();
11441257
int numOfOutputFields = outputFields.size();
1145-
int numOfAggList = aggExprList.size();
1258+
int numOfUserAggregations = aggExprList.size();
11461259
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
1147-
// Add aggregation results first
1148-
List<RexNode> aggRexList =
1149-
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
1150-
reordered.addAll(aggRexList);
1151-
// Add group by columns
1152-
List<RexNode> aliasedGroupByList =
1153-
aggregationAttributes.getLeft().stream()
1154-
.map(this::extractAliasLiteral)
1155-
.flatMap(Optional::stream)
1156-
.map(ref -> ref.getValueAs(String.class))
1157-
.map(context.relBuilder::field)
1158-
.map(f -> (RexNode) f)
1159-
.toList();
1160-
reordered.addAll(aliasedGroupByList);
1161-
context.relBuilder.project(reordered);
1260+
1261+
// Add user-specified aggregation results first (exclude MIN aggregations for binned fields)
1262+
List<RexNode> userAggRexList =
1263+
outputFields.subList(
1264+
numOfOutputFields - aggregationAttributes.getRight().size(),
1265+
numOfOutputFields - aggregationAttributes.getRight().size() + numOfUserAggregations);
1266+
1267+
// Wrap AVG aggregations with ROUND to fix floating point precision
1268+
for (int i = 0; i < userAggRexList.size(); i++) {
1269+
RexNode aggRex = userAggRexList.get(i);
1270+
UnresolvedExpression aggExpr = aggExprList.get(i);
1271+
1272+
// Unwrap Alias to get to the actual aggregation function
1273+
UnresolvedExpression actualAggExpr = aggExpr;
1274+
if (aggExpr instanceof org.opensearch.sql.ast.expression.Alias) {
1275+
actualAggExpr = ((org.opensearch.sql.ast.expression.Alias) aggExpr).getDelegated();
1276+
}
1277+
1278+
// Check if this is an AVG aggregation
1279+
if (actualAggExpr instanceof org.opensearch.sql.ast.expression.AggregateFunction) {
1280+
org.opensearch.sql.ast.expression.AggregateFunction aggFunc =
1281+
(org.opensearch.sql.ast.expression.AggregateFunction) actualAggExpr;
1282+
if ("avg".equalsIgnoreCase(aggFunc.getFuncName())) {
1283+
// Wrap with ROUND(value, 2)
1284+
aggRex =
1285+
context.relBuilder.call(
1286+
org.apache.calcite.sql.fun.SqlStdOperatorTable.ROUND,
1287+
aggRex,
1288+
context.rexBuilder.makeLiteral(
1289+
2,
1290+
context
1291+
.relBuilder
1292+
.getTypeFactory()
1293+
.createSqlType(org.apache.calcite.sql.type.SqlTypeName.INTEGER),
1294+
false));
1295+
}
1296+
}
1297+
reordered.add(aggRex);
1298+
}
1299+
1300+
// Add group by columns, replacing _bin columns with their MIN aggregations
1301+
// Get field names from the aggregate output (group-by fields come first)
1302+
List<String> allFieldNames = context.relBuilder.peek().getRowType().getFieldNames();
1303+
int numGroupByFields = aggregationAttributes.getLeft().size();
1304+
1305+
List<String> outputFieldNames = new ArrayList<>();
1306+
1307+
for (int i = 0; i < numGroupByFields; i++) {
1308+
String fieldName = allFieldNames.get(i);
1309+
if (fieldName.endsWith("_bin")) {
1310+
// This is a bin column
1311+
String originalFieldName = fieldName.substring(0, fieldName.length() - 4); // Remove "_bin"
1312+
// Check if we have a MIN aggregation for this field (only present for multi-field grouping)
1313+
if (allFieldNames.contains(originalFieldName)) {
1314+
// Use the MIN aggregation
1315+
reordered.add(context.relBuilder.field(originalFieldName));
1316+
outputFieldNames.add(originalFieldName);
1317+
} else {
1318+
// Use the bin column directly (for single-field binning) - rename to original name
1319+
reordered.add(context.relBuilder.field(fieldName));
1320+
outputFieldNames.add(originalFieldName); // Rename _bin field to original name
1321+
}
1322+
} else {
1323+
// Regular group-by field
1324+
reordered.add(context.relBuilder.field(fieldName));
1325+
outputFieldNames.add(fieldName);
1326+
}
1327+
}
1328+
1329+
// Add aggregation field names (after group-by fields in the reordered list)
1330+
// The user aggregations are at the beginning of reordered list, so we add their names
1331+
int aggStartIndex = numOfOutputFields - aggregationAttributes.getRight().size();
1332+
for (int i = aggStartIndex; i < aggStartIndex + numOfUserAggregations; i++) {
1333+
outputFieldNames.add(
1334+
0, allFieldNames.get(i)); // Add at beginning to match reordered list order
1335+
}
1336+
1337+
context.relBuilder.project(reordered, outputFieldNames);
11621338

11631339
return context.relBuilder.peek();
11641340
}

0 commit comments

Comments
 (0)