Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions src/iceberg/expression/binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#include "iceberg/expression/binder.h"

#include "iceberg/util/macros.h"

namespace iceberg {

Binder::Binder(const Schema& schema, bool case_sensitive)
Expand Down Expand Up @@ -54,30 +56,30 @@ Result<std::shared_ptr<Expression>> Binder::Or(

Result<std::shared_ptr<Expression>> Binder::Predicate(
const std::shared_ptr<UnboundPredicate>& pred) {
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
return pred->Bind(schema_, case_sensitive_);
}

Result<std::shared_ptr<Expression>> Binder::Predicate(
const std::shared_ptr<BoundPredicate>& pred) {
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
return InvalidExpression("Found already bound predicate: {}", pred->ToString());
}

Result<std::shared_ptr<Expression>> Binder::Aggregate(
const std::shared_ptr<BoundAggregate>& aggregate) {
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString());
}

Result<std::shared_ptr<Expression>> Binder::Aggregate(
const std::shared_ptr<UnboundAggregate>& aggregate) {
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
return aggregate->Bind(schema_, case_sensitive_);
}

Result<bool> IsBoundVisitor::IsBound(const std::shared_ptr<Expression>& expr) {
ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null");
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
IsBoundVisitor visitor;
return Visit<bool, IsBoundVisitor>(expr, visitor);
}
Expand Down Expand Up @@ -113,4 +115,54 @@ Result<bool> IsBoundVisitor::Aggregate(
return false;
}

Result<std::unordered_set<int32_t>> ReferenceVisitor::GetReferencedFieldIds(
const std::shared_ptr<Expression>& expr) {
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
ReferenceVisitor visitor;
return Visit<FieldIdsSetRef, ReferenceVisitor>(expr, visitor);
}

Result<FieldIdsSetRef> ReferenceVisitor::AlwaysTrue() { return referenced_field_ids_; }

Result<FieldIdsSetRef> ReferenceVisitor::AlwaysFalse() { return referenced_field_ids_; }

Result<FieldIdsSetRef> ReferenceVisitor::Not(
[[maybe_unused]] const FieldIdsSetRef& child_result) {
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::And(
[[maybe_unused]] const FieldIdsSetRef& left_result,
[[maybe_unused]] const FieldIdsSetRef& right_result) {
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Or(
[[maybe_unused]] const FieldIdsSetRef& left_result,
[[maybe_unused]] const FieldIdsSetRef& right_result) {
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
const std::shared_ptr<BoundPredicate>& pred) {
referenced_field_ids_.insert(pred->reference()->field_id());
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
[[maybe_unused]] const std::shared_ptr<UnboundPredicate>& pred) {
return referenced_field_ids_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we error on unbound predicates and unbound aggregates? ISTM that an expression should be either fully bound or fully unbound, and that mixing the two is not possible, and visiting an unbound expression would be meaningless.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just took a look at the Java's impl of ReferenceVisitor and IsBoundVisitor, it appears that mixed bound and unbound predicate should error, I created PR #503 to fix our IsBoundVisitor logic, please correct me if f I'm wrong.

}

Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
const std::shared_ptr<BoundAggregate>& aggregate) {
referenced_field_ids_.insert(aggregate->reference()->field_id());
return referenced_field_ids_;
}

Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
[[maybe_unused]] const std::shared_ptr<UnboundAggregate>& aggregate) {
return referenced_field_ids_;
}

} // namespace iceberg
30 changes: 29 additions & 1 deletion src/iceberg/expression/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
/// \file iceberg/expression/binder.h
/// Bind an expression to a schema.

#include <functional>
#include <unordered_set>

#include "iceberg/expression/expression_visitor.h"

namespace iceberg {
Expand Down Expand Up @@ -73,6 +76,31 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor<bool> {
Result<bool> Aggregate(const std::shared_ptr<UnboundAggregate>& aggregate) override;
};

// TODO(gangwu): add the Java parity `ReferenceVisitor`
using FieldIdsSetRef = std::reference_wrapper<std::unordered_set<int32_t>>;

/// \brief Visitor to collect referenced field IDs from an expression.
class ICEBERG_EXPORT ReferenceVisitor : public ExpressionVisitor<FieldIdsSetRef> {
public:
static Result<std::unordered_set<int32_t>> GetReferencedFieldIds(
const std::shared_ptr<Expression>& expr);

Result<FieldIdsSetRef> AlwaysTrue() override;
Result<FieldIdsSetRef> AlwaysFalse() override;
Result<FieldIdsSetRef> Not(const FieldIdsSetRef& child_result) override;
Result<FieldIdsSetRef> And(const FieldIdsSetRef& left_result,
const FieldIdsSetRef& right_result) override;
Result<FieldIdsSetRef> Or(const FieldIdsSetRef& left_result,
const FieldIdsSetRef& right_result) override;
Result<FieldIdsSetRef> Predicate(const std::shared_ptr<BoundPredicate>& pred) override;
Result<FieldIdsSetRef> Predicate(
const std::shared_ptr<UnboundPredicate>& pred) override;
Result<FieldIdsSetRef> Aggregate(
const std::shared_ptr<BoundAggregate>& aggregate) override;
Result<FieldIdsSetRef> Aggregate(
const std::shared_ptr<UnboundAggregate>& aggregate) override;

private:
std::unordered_set<int32_t> referenced_field_ids_;
};

} // namespace iceberg
206 changes: 206 additions & 0 deletions src/iceberg/test/expression_visitor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,4 +505,210 @@ TEST_F(RewriteNotTest, ComplexExpression) {
EXPECT_EQ(rewritten->op(), Expression::Operation::kOr);
}

class ReferenceVisitorTest : public ExpressionVisitorTest {};

TEST_F(ReferenceVisitorTest, Constants) {
// Constants should have no referenced fields
auto true_expr = Expressions::AlwaysTrue();
ICEBERG_UNWRAP_OR_FAIL(auto refs_true,
ReferenceVisitor::GetReferencedFieldIds(true_expr));
EXPECT_TRUE(refs_true.empty());

auto false_expr = Expressions::AlwaysFalse();
ICEBERG_UNWRAP_OR_FAIL(auto refs_false,
ReferenceVisitor::GetReferencedFieldIds(false_expr));
EXPECT_TRUE(refs_false.empty());
}

TEST_F(ReferenceVisitorTest, UnboundPredicate) {
// Unbound predicates should have no referenced field IDs (not yet bound to schema)
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
ICEBERG_UNWRAP_OR_FAIL(auto refs,
ReferenceVisitor::GetReferencedFieldIds(unbound_pred));
EXPECT_TRUE(refs.empty());
}

TEST_F(ReferenceVisitorTest, BoundPredicate) {
// Bound predicate should return the field ID
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
ICEBERG_UNWRAP_OR_FAIL(auto bound_pred, Bind(unbound_pred));

ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_pred));
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1); // name field has id=2
}

TEST_F(ReferenceVisitorTest, MultiplePredicates) {
// Test various predicates with different fields
auto pred_age = Expressions::GreaterThan("age", Literal::Int(25));
ICEBERG_UNWRAP_OR_FAIL(auto bound_age, Bind(pred_age));
ICEBERG_UNWRAP_OR_FAIL(auto refs_age,
ReferenceVisitor::GetReferencedFieldIds(bound_age));
EXPECT_EQ(refs_age.size(), 1);
EXPECT_EQ(refs_age.count(3), 1); // age field has id=3

auto pred_salary = Expressions::LessThan("salary", Literal::Double(50000.0));
ICEBERG_UNWRAP_OR_FAIL(auto bound_salary, Bind(pred_salary));
ICEBERG_UNWRAP_OR_FAIL(auto refs_salary,
ReferenceVisitor::GetReferencedFieldIds(bound_salary));
EXPECT_EQ(refs_salary.size(), 1);
EXPECT_EQ(refs_salary.count(4), 1); // salary field has id=4
}

TEST_F(ReferenceVisitorTest, UnaryPredicates) {
// Test unary predicates
auto pred_is_null = Expressions::IsNull("name");
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_null, Bind(pred_is_null));
ICEBERG_UNWRAP_OR_FAIL(auto refs,
ReferenceVisitor::GetReferencedFieldIds(bound_is_null));
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1);

auto pred_is_nan = Expressions::IsNaN("salary");
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_nan, Bind(pred_is_nan));
ICEBERG_UNWRAP_OR_FAIL(auto refs_nan,
ReferenceVisitor::GetReferencedFieldIds(bound_is_nan));
EXPECT_EQ(refs_nan.size(), 1);
EXPECT_EQ(refs_nan.count(4), 1);
}

TEST_F(ReferenceVisitorTest, AndExpression) {
// AND expression should return union of field IDs from both sides
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
auto and_expr = Expressions::And(pred1, pred2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));

EXPECT_EQ(refs.size(), 2);
EXPECT_EQ(refs.count(2), 1); // name field
EXPECT_EQ(refs.count(3), 1); // age field
}

TEST_F(ReferenceVisitorTest, OrExpression) {
// OR expression should return union of field IDs from both sides
auto pred1 = Expressions::IsNull("salary");
auto pred2 = Expressions::Equal("active", Literal::Boolean(true));
auto or_expr = Expressions::Or(pred1, pred2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_or, Bind(or_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_or));

EXPECT_EQ(refs.size(), 2);
EXPECT_EQ(refs.count(4), 1); // salary field
EXPECT_EQ(refs.count(5), 1); // active field
}

TEST_F(ReferenceVisitorTest, NotExpression) {
// NOT expression should return field IDs from its child
auto pred = Expressions::Equal("name", Literal::String("Alice"));
auto not_expr = Expressions::Not(pred);

ICEBERG_UNWRAP_OR_FAIL(auto bound_not, Bind(not_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_not));

EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1); // name field
}

TEST_F(ReferenceVisitorTest, ComplexNestedExpression) {
// (name = 'Alice' AND age > 25) OR (salary < 30000 AND active = true)
// Should reference fields: name(2), age(3), salary(4), active(5)
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
auto pred3 = Expressions::LessThan("salary", Literal::Double(30000.0));
auto pred4 = Expressions::Equal("active", Literal::Boolean(true));

auto and1 = Expressions::And(pred1, pred2);
auto and2 = Expressions::And(pred3, pred4);
auto complex_or = Expressions::Or(and1, and2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_complex, Bind(complex_or));
ICEBERG_UNWRAP_OR_FAIL(auto refs,
ReferenceVisitor::GetReferencedFieldIds(bound_complex));

EXPECT_EQ(refs.size(), 4);
EXPECT_EQ(refs.count(2), 1); // name field
EXPECT_EQ(refs.count(3), 1); // age field
EXPECT_EQ(refs.count(4), 1); // salary field
EXPECT_EQ(refs.count(5), 1); // active field
}

TEST_F(ReferenceVisitorTest, DuplicateFieldReferences) {
// Multiple predicates referencing the same field
// age > 25 AND age < 50
auto pred1 = Expressions::GreaterThan("age", Literal::Int(25));
auto pred2 = Expressions::LessThan("age", Literal::Int(50));
auto and_expr = Expressions::And(pred1, pred2);

ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));

// Should only contain the field ID once (set semantics)
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(3), 1); // age field
}

TEST_F(ReferenceVisitorTest, SetPredicates) {
// Test In predicate
auto pred_in =
Expressions::In("age", {Literal::Int(25), Literal::Int(30), Literal::Int(35)});
ICEBERG_UNWRAP_OR_FAIL(auto bound_in, Bind(pred_in));
ICEBERG_UNWRAP_OR_FAIL(auto refs_in, ReferenceVisitor::GetReferencedFieldIds(bound_in));

EXPECT_EQ(refs_in.size(), 1);
EXPECT_EQ(refs_in.count(3), 1); // age field

// Test NotIn predicate
auto pred_not_in =
Expressions::NotIn("name", {Literal::String("Alice"), Literal::String("Bob")});
ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, Bind(pred_not_in));
ICEBERG_UNWRAP_OR_FAIL(auto refs_not_in,
ReferenceVisitor::GetReferencedFieldIds(bound_not_in));

EXPECT_EQ(refs_not_in.size(), 1);
EXPECT_EQ(refs_not_in.count(2), 1); // name field
}

TEST_F(ReferenceVisitorTest, MixedBoundAndUnbound) {
// Expression with both bound and unbound predicates
auto bound_pred = Expressions::Equal("name", Literal::String("Alice"));
ICEBERG_UNWRAP_OR_FAIL(auto pred1, Bind(bound_pred));

auto unbound_pred = Expressions::GreaterThan("age", Literal::Int(25));

auto mixed_and = Expressions::And(pred1, unbound_pred);
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(mixed_and));

// Should only return field IDs from bound predicates
EXPECT_EQ(refs.size(), 1);
EXPECT_EQ(refs.count(2), 1); // name field only
}

TEST_F(ReferenceVisitorTest, AllFields) {
// Create expression referencing all fields in the schema
auto pred1 = Expressions::NotNull("id");
auto pred2 = Expressions::Equal("name", Literal::String("Test"));
auto pred3 = Expressions::GreaterThan("age", Literal::Int(0));
auto pred4 = Expressions::LessThan("salary", Literal::Double(100000.0));
auto pred5 = Expressions::Equal("active", Literal::Boolean(true));

auto and1 = Expressions::And(pred1, pred2);
auto and2 = Expressions::And(pred3, pred4);
auto and3 = Expressions::And(and1, and2);
auto all_fields = Expressions::And(and3, pred5);

ICEBERG_UNWRAP_OR_FAIL(auto bound_all, Bind(all_fields));
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_all));

// Should reference all 5 fields
EXPECT_EQ(refs.size(), 4);
EXPECT_EQ(refs.count(1), 0); // id field is optimized out
EXPECT_EQ(refs.count(2), 1); // name field
EXPECT_EQ(refs.count(3), 1); // age field
EXPECT_EQ(refs.count(4), 1); // salary field
EXPECT_EQ(refs.count(5), 1); // active field
}

} // namespace iceberg
Loading