Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ public class CalcitePlanContext {
@Getter @Setter private boolean isResolvingSubquery = false;
@Getter @Setter private boolean inCoalesceFunction = false;

/** Fields that are being grouped by in aggregation (for bin operations to preserve originals) */
@Getter @Setter
private java.util.Set<String> aggregationGroupByFields = new java.util.HashSet<>();

/** Total number of group-by fields in current aggregation */
@Getter @Setter private int aggregationGroupByCount = 0;

/**
* The flag used to determine whether we do metadata field projection for user 1. If a project is
* never visited, we will do metadata field projection for user 2. Else not because user may
Expand All @@ -59,6 +66,14 @@ public class CalcitePlanContext {
private final Stack<RexCorrelVariable> correlVar = new Stack<>();
private final Stack<List<RexNode>> windowPartitions = new Stack<>();

/**
* Partition columns for bin operations. Used to create partitioned window functions for MIN/MAX
* calculations in WIDTH_BUCKET, matching auto_date_histogram's per-group behavior. Stores
* UnresolvedExpression objects that will be analyzed by bin handlers.
*/
private final Stack<List<org.opensearch.sql.ast.expression.UnresolvedExpression>>
binPartitionExpressions = new Stack<>();

@Getter public Map<String, RexLambdaRef> rexLambdaRefMap;

private CalcitePlanContext(FrameworkConfig config, SysLimit sysLimit, QueryType queryType) {
Expand Down Expand Up @@ -134,4 +149,27 @@ public static boolean isLegacyPreferred() {
public void putRexLambdaRefMap(Map<String, RexLambdaRef> candidateMap) {
this.rexLambdaRefMap.putAll(candidateMap);
}

/**
* Push partition expressions for bin operations. These will be analyzed by bin handlers to create
* PARTITION BY clauses for window functions in WIDTH_BUCKET.
*/
public void pushBinPartitionExpressions(
List<org.opensearch.sql.ast.expression.UnresolvedExpression> partitionExpressions) {
binPartitionExpressions.push(partitionExpressions);
}

/** Pop partition expressions for bin operations. */
public void popBinPartitionExpressions() {
if (!binPartitionExpressions.empty()) {
binPartitionExpressions.pop();
}
}

/**
* Get current partition expressions for bin operations. Returns empty list if no partitions set.
*/
public List<org.opensearch.sql.ast.expression.UnresolvedExpression> getBinPartitionExpressions() {
return binPartitionExpressions.empty() ? List.of() : binPartitionExpressions.peek();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,22 @@ private void validateWildcardPatterns(
}
}

/** Extract field name from UnresolvedExpression. Handles Field and Alias expressions. */
private String extractFieldName(UnresolvedExpression expr) {
if (expr instanceof org.opensearch.sql.ast.expression.Field) {
return ((org.opensearch.sql.ast.expression.Field) expr).getField().toString();
} else if (expr instanceof org.opensearch.sql.ast.expression.Alias) {
org.opensearch.sql.ast.expression.Alias alias =
(org.opensearch.sql.ast.expression.Alias) expr;
if (alias.getDelegated() instanceof org.opensearch.sql.ast.expression.Field) {
return ((org.opensearch.sql.ast.expression.Field) alias.getDelegated())
.getField()
.toString();
}
}
return null;
}

private boolean isMetadataField(String fieldName) {
return OpenSearchConstants.METADATAFIELD_TYPE_MAP.containsKey(fieldName);
}
Expand Down Expand Up @@ -662,7 +678,26 @@ public RelNode visitBin(Bin node, CalcitePlanContext context) {
RexNode binExpression = BinUtils.createBinExpression(node, fieldExpr, context, rexVisitor);

String alias = node.getAlias() != null ? node.getAlias() : fieldName;
projectPlusOverriding(List.of(binExpression), List.of(alias), context);

// Check if this field is used in aggregation grouping with multiple fields
if (context.getAggregationGroupByFields().contains(fieldName)
&& context.getAggregationGroupByCount() > 1) {
// For multi-field aggregation: preserve BOTH original field and binned field
// The binned field (fieldName_bin) is used for grouping
// The original field is used for MIN aggregation to show actual timestamps per group
List<RexNode> allFields = new ArrayList<>(context.relBuilder.fields());
List<String> allFieldNames =
new ArrayList<>(context.relBuilder.peek().getRowType().getFieldNames());

// Add the binned field with _bin suffix
allFields.add(binExpression);
allFieldNames.add(fieldName + "_bin");

context.relBuilder.project(allFields, allFieldNames);
} else {
// For non-aggregation queries OR single-field binning: replace field with binned value
projectPlusOverriding(List.of(binExpression), List.of(alias), context);
}

return context.relBuilder.peek();
}
Expand Down Expand Up @@ -1020,20 +1055,98 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
CalcitePlanContext context) {
List<AggCall> aggCallList =
aggExprList.stream().map(expr -> aggVisitor.analyze(expr, context)).toList();
List<RexNode> groupByList =
groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
return Pair.of(groupByList, aggCallList);

// Get available field names in the current relation
List<String> availableFields = context.relBuilder.peek().getRowType().getFieldNames();

// Build group-by list, replacing fields with their _bin columns if they exist
List<RexNode> groupByList = new ArrayList<>();
List<AggCall> additionalAggCalls = new ArrayList<>();

// Track if we have bin columns - we'll need this to decide whether to add MIN aggregations
boolean hasBinColumns = false;
int nonBinGroupByCount = 0;

for (UnresolvedExpression groupExpr : groupExprList) {
RexNode resolvedExpr = rexVisitor.analyze(groupExpr, context);

// Extract field name from UnresolvedExpression
String fieldName = extractFieldName(groupExpr);

// Check if this field has a corresponding _bin column
if (fieldName != null) {
String binColumnName = fieldName + "_bin";
if (availableFields.contains(binColumnName)) {
// Use the _bin column for grouping
groupByList.add(context.relBuilder.field(binColumnName));
hasBinColumns = true;
continue;
}
}

// Regular group-by field
groupByList.add(resolvedExpr);
nonBinGroupByCount++;
}

// Only add MIN aggregations for bin columns if there are OTHER group-by fields
// This matches OpenSearch behavior:
// - With multi-field grouping (e.g., by region, timestamp): Show MIN(timestamp) per group
// - With single-field grouping (e.g., by timestamp only): Show bin start time
if (hasBinColumns && nonBinGroupByCount > 0) {
for (UnresolvedExpression groupExpr : groupExprList) {
String fieldName = extractFieldName(groupExpr);
if (fieldName != null) {
String binColumnName = fieldName + "_bin";
if (availableFields.contains(binColumnName)) {
// Add MIN(original_field) to show minimum timestamp per bin
additionalAggCalls.add(
context.relBuilder.min(context.relBuilder.field(fieldName)).as(fieldName));
}
}
}
}

// Combine original aggregations with additional MIN aggregations for binned fields
List<AggCall> combinedAggCalls = new ArrayList<>(aggCallList);
combinedAggCalls.addAll(additionalAggCalls);

return Pair.of(groupByList, combinedAggCalls);
}

@Override
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
visitChildren(node, context);

// Prepare partition columns for bin operations before visiting children
// This allows WIDTH_BUCKET to use per-group min/max (matching auto_date_histogram)
List<UnresolvedExpression> aggExprList = node.getAggExprList();
List<UnresolvedExpression> groupExprList = new ArrayList<>();
UnresolvedExpression span = node.getSpan();
if (Objects.nonNull(span)) {
groupExprList.add(span);
}
groupExprList.addAll(node.getGroupExprList());

// Store group-by field names and count so bin operations can preserve original fields
java.util.Set<String> savedGroupByFields = context.getAggregationGroupByFields();
int savedGroupByCount = context.getAggregationGroupByCount();
context.setAggregationGroupByFields(new java.util.HashSet<>());
context.setAggregationGroupByCount(groupExprList.size());
for (UnresolvedExpression groupExpr : groupExprList) {
String fieldName = extractFieldName(groupExpr);
if (fieldName != null) {
context.getAggregationGroupByFields().add(fieldName);
}
}

visitChildren(node, context);

// Restore previous group-by fields and count
context.setAggregationGroupByFields(savedGroupByFields);
context.setAggregationGroupByCount(savedGroupByCount);

groupExprList.clear();
// The span column is always the first column in result whatever
// the order of span in query is first or last one
UnresolvedExpression span = node.getSpan();
if (Objects.nonNull(span)) {
groupExprList.add(span);
List<RexNode> timeSpanFilters =
Expand Down Expand Up @@ -1093,23 +1206,86 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
// the sequence of output schema is "count, colA, colB".
List<RexNode> outputFields = context.relBuilder.fields();
int numOfOutputFields = outputFields.size();
int numOfAggList = aggExprList.size();
int numOfUserAggregations = aggExprList.size();
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
// Add aggregation results first
List<RexNode> aggRexList =
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
reordered.addAll(aggRexList);
// Add group by columns
List<RexNode> aliasedGroupByList =
aggregationAttributes.getLeft().stream()
.map(this::extractAliasLiteral)
.flatMap(Optional::stream)
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
.map(context.relBuilder::field)
.map(f -> (RexNode) f)
.toList();
reordered.addAll(aliasedGroupByList);
context.relBuilder.project(reordered);

// Add user-specified aggregation results first (exclude MIN aggregations for binned fields)
List<RexNode> userAggRexList =
outputFields.subList(
numOfOutputFields - aggregationAttributes.getRight().size(),
numOfOutputFields - aggregationAttributes.getRight().size() + numOfUserAggregations);

// Wrap AVG aggregations with ROUND to fix floating point precision
for (int i = 0; i < userAggRexList.size(); i++) {
RexNode aggRex = userAggRexList.get(i);
UnresolvedExpression aggExpr = aggExprList.get(i);

// Unwrap Alias to get to the actual aggregation function
UnresolvedExpression actualAggExpr = aggExpr;
if (aggExpr instanceof org.opensearch.sql.ast.expression.Alias) {
actualAggExpr = ((org.opensearch.sql.ast.expression.Alias) aggExpr).getDelegated();
}

// Check if this is an AVG aggregation
if (actualAggExpr instanceof org.opensearch.sql.ast.expression.AggregateFunction) {
org.opensearch.sql.ast.expression.AggregateFunction aggFunc =
(org.opensearch.sql.ast.expression.AggregateFunction) actualAggExpr;
if ("avg".equalsIgnoreCase(aggFunc.getFuncName())) {
// Wrap with ROUND(value, 2)
aggRex =
context.relBuilder.call(
org.apache.calcite.sql.fun.SqlStdOperatorTable.ROUND,
aggRex,
context.rexBuilder.makeLiteral(
2,
context
.relBuilder
.getTypeFactory()
.createSqlType(org.apache.calcite.sql.type.SqlTypeName.INTEGER),
false));
}
}
reordered.add(aggRex);
}

// Add group by columns, replacing _bin columns with their MIN aggregations
// Get field names from the aggregate output (group-by fields come first)
List<String> allFieldNames = context.relBuilder.peek().getRowType().getFieldNames();
int numGroupByFields = aggregationAttributes.getLeft().size();

List<String> outputFieldNames = new ArrayList<>();

for (int i = 0; i < numGroupByFields; i++) {
String fieldName = allFieldNames.get(i);
if (fieldName.endsWith("_bin")) {
// This is a bin column
String originalFieldName = fieldName.substring(0, fieldName.length() - 4); // Remove "_bin"
// Check if we have a MIN aggregation for this field (only present for multi-field grouping)
if (allFieldNames.contains(originalFieldName)) {
// Use the MIN aggregation
reordered.add(context.relBuilder.field(originalFieldName));
outputFieldNames.add(originalFieldName);
} else {
// Use the bin column directly (for single-field binning) - rename to original name
reordered.add(context.relBuilder.field(fieldName));
outputFieldNames.add(originalFieldName); // Rename _bin field to original name
}
} else {
// Regular group-by field
reordered.add(context.relBuilder.field(fieldName));
outputFieldNames.add(fieldName);
}
}

// Add aggregation field names (after group-by fields in the reordered list)
// The user aggregations are at the beginning of reordered list, so we add their names
int aggStartIndex = numOfOutputFields - aggregationAttributes.getRight().size();
for (int i = aggStartIndex; i < aggStartIndex + numOfUserAggregations; i++) {
outputFieldNames.add(
0, allFieldNames.get(i)); // Add at beginning to match reordered list order
}

context.relBuilder.project(reordered, outputFieldNames);

return context.relBuilder.peek();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.sql.calcite.utils.binning.handlers;

import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.tree.Bin;
import org.opensearch.sql.ast.tree.CountBin;
Expand All @@ -30,20 +29,21 @@ public RexNode createExpression(
requestedBins = BinConstants.DEFAULT_BINS;
}

// Calculate data range using window functions
// Calculate global MIN and MAX using window functions
RexNode minValue = context.relBuilder.min(fieldExpr).over().toRex();
RexNode maxValue = context.relBuilder.max(fieldExpr).over().toRex();
RexNode dataRange = context.relBuilder.call(SqlStdOperatorTable.MINUS, maxValue, minValue);

// Convert start/end parameters
RexNode startValue = convertParameter(countBin.getStart(), context);
RexNode endValue = convertParameter(countBin.getEnd(), context);

// WIDTH_BUCKET(field_value, num_bins, data_range, max_value)
// WIDTH_BUCKET(field_value, num_bins, min_value, max_value)
// Note: We pass minValue instead of dataRange - WIDTH_BUCKET will calculate the range
// internally
RexNode numBins = context.relBuilder.literal(requestedBins);

return context.rexBuilder.makeCall(
PPLBuiltinOperators.WIDTH_BUCKET, fieldExpr, numBins, dataRange, maxValue);
PPLBuiltinOperators.WIDTH_BUCKET, fieldExpr, numBins, minValue, maxValue);
}

private RexNode convertParameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public RexNode createExpression(
Number minspanNum = (Number) ((RexLiteral) minspanValue).getValue();
double minspan = minspanNum.doubleValue();

// Calculate data range using window functions
// Calculate global data range using window functions
RexNode minValue = context.relBuilder.min(fieldExpr).over().toRex();
RexNode maxValue = context.relBuilder.max(fieldExpr).over().toRex();
RexNode dataRange = context.relBuilder.call(SqlStdOperatorTable.MINUS, maxValue, minValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public RexNode createExpression(

RangeBin rangeBin = (RangeBin) node;

// Simple MIN/MAX calculation - cleaner than complex CASE expressions
// Simple global MIN/MAX calculation - cleaner than complex CASE expressions
RexNode dataMin = context.relBuilder.min(fieldExpr).over().toRex();
RexNode dataMax = context.relBuilder.max(fieldExpr).over().toRex();

Expand Down
Loading