Skip to content

Commit 561d175

Browse files
MBkktmeta-codesync[bot]
authored andcommitted
fix: MetadataFilter construction from flat (multi) And/Or (facebookincubator#14981)
Summary: Pull Request resolved: facebookincubator#14981 Reviewed By: kagamiori Differential Revision: D83756179 Pulled By: bikramSingh91 fbshipit-source-id: 68b6532464650213cd7df01c28ca035497c5dabf
1 parent ed9efc7 commit 561d175

File tree

4 files changed

+173
-69
lines changed

4 files changed

+173
-69
lines changed

velox/dwio/common/MetadataFilter.cpp

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -68,86 +68,118 @@ class MetadataFilter::LeafNode : public Node {
6868
std::unique_ptr<Filter> filter_;
6969
};
7070

71-
struct MetadataFilter::AndNode : Node {
71+
struct MetadataFilter::ConditionNode : Node {
7272
static std::unique_ptr<Node> create(
73-
std::unique_ptr<Node> lhs,
74-
std::unique_ptr<Node> rhs) {
75-
if (!lhs) {
76-
return rhs;
77-
}
78-
if (!rhs) {
79-
return lhs;
73+
bool conjuction,
74+
std::vector<std::unique_ptr<Node>> args);
75+
76+
static std::unique_ptr<Node> fromExpression(
77+
const std::vector<core::TypedExprPtr>& inputs,
78+
core::ExpressionEvaluator* evaluator,
79+
bool conjunction,
80+
bool negated) {
81+
conjunction = negated ? !conjunction : conjunction;
82+
std::vector<std::unique_ptr<Node>> args;
83+
args.reserve(inputs.size());
84+
for (const auto& input : inputs) {
85+
auto node = Node::fromExpression(*input, evaluator, negated);
86+
if (node) {
87+
args.push_back(std::move(node));
88+
} else if (!conjunction) {
89+
return nullptr;
90+
}
8091
}
81-
return std::make_unique<AndNode>(std::move(lhs), std::move(rhs));
92+
return create(conjunction, std::move(args));
8293
}
8394

84-
AndNode(std::unique_ptr<Node> lhs, std::unique_ptr<Node> rhs)
85-
: lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}
95+
explicit ConditionNode(std::vector<std::unique_ptr<Node>> args)
96+
: args_{std::move(args)} {}
8697

87-
void addToScanSpec(ScanSpec& scanSpec) const override {
88-
lhs_->addToScanSpec(scanSpec);
89-
rhs_->addToScanSpec(scanSpec);
90-
}
91-
92-
uint64_t* eval(LeafResults& leafResults, int size) const override {
93-
auto* l = lhs_->eval(leafResults, size);
94-
auto* r = rhs_->eval(leafResults, size);
95-
if (!l) {
96-
return r;
97-
}
98-
if (!r) {
99-
return l;
98+
void addToScanSpec(ScanSpec& scanSpec) const final {
99+
for (const auto& arg : args_) {
100+
arg->addToScanSpec(scanSpec);
100101
}
101-
bits::orBits(l, r, 0, size);
102-
return l;
103102
}
104103

105-
std::string toString() const override {
106-
return "and(" + lhs_->toString() + "," + rhs_->toString() + ")";
104+
protected:
105+
std::string ToStringImpl(std::string_view prefix) const {
106+
std::string result{prefix};
107+
for (size_t i = 0; i < args_.size(); ++i) {
108+
if (i != 0) {
109+
result += ",";
110+
}
111+
result += args_[i]->toString();
112+
}
113+
result += ")";
114+
return result;
107115
}
108116

109-
private:
110-
std::unique_ptr<Node> lhs_;
111-
std::unique_ptr<Node> rhs_;
117+
std::vector<std::unique_ptr<Node>> args_;
112118
};
113119

114-
struct MetadataFilter::OrNode : Node {
115-
static std::unique_ptr<Node> create(
116-
std::unique_ptr<Node> lhs,
117-
std::unique_ptr<Node> rhs) {
118-
if (!lhs || !rhs) {
119-
return nullptr;
120+
struct MetadataFilter::AndNode final : ConditionNode {
121+
using ConditionNode::ConditionNode;
122+
123+
uint64_t* eval(LeafResults& leafResults, int size) const final {
124+
uint64_t* result = nullptr;
125+
for (const auto& arg : args_) {
126+
auto* a = arg->eval(leafResults, size);
127+
if (!a) {
128+
continue;
129+
}
130+
if (!result) {
131+
result = a;
132+
} else {
133+
bits::orBits(result, a, 0, size);
134+
}
120135
}
121-
return std::make_unique<OrNode>(std::move(lhs), std::move(rhs));
136+
return result;
122137
}
123138

124-
OrNode(std::unique_ptr<Node> lhs, std::unique_ptr<Node> rhs)
125-
: lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}
126-
127-
void addToScanSpec(ScanSpec& scanSpec) const override {
128-
lhs_->addToScanSpec(scanSpec);
129-
rhs_->addToScanSpec(scanSpec);
139+
std::string toString() const final {
140+
return ToStringImpl("and(");
130141
}
142+
};
131143

132-
uint64_t* eval(LeafResults& leafResults, int size) const override {
133-
auto* l = lhs_->eval(leafResults, size);
134-
auto* r = rhs_->eval(leafResults, size);
135-
if (!l || !r) {
136-
return nullptr;
144+
struct MetadataFilter::OrNode final : ConditionNode {
145+
using ConditionNode::ConditionNode;
146+
147+
uint64_t* eval(LeafResults& leafResults, int size) const final {
148+
uint64_t* result = nullptr;
149+
for (const auto& arg : args_) {
150+
auto* a = arg->eval(leafResults, size);
151+
if (!a) {
152+
return nullptr;
153+
}
154+
if (!result) {
155+
result = a;
156+
} else {
157+
bits::andBits(result, a, 0, size);
158+
}
137159
}
138-
bits::andBits(l, r, 0, size);
139-
return l;
160+
return result;
140161
}
141162

142-
std::string toString() const override {
143-
return "or(" + lhs_->toString() + "," + rhs_->toString() + ")";
163+
std::string toString() const final {
164+
return ToStringImpl("or(");
144165
}
145-
146-
private:
147-
std::unique_ptr<Node> lhs_;
148-
std::unique_ptr<Node> rhs_;
149166
};
150167

168+
std::unique_ptr<MetadataFilter::Node> MetadataFilter::ConditionNode::create(
169+
bool conjunction,
170+
std::vector<std::unique_ptr<Node>> args) {
171+
if (args.empty()) {
172+
return nullptr;
173+
}
174+
if (args.size() == 1) {
175+
return std::move(args[0]);
176+
}
177+
if (conjunction) {
178+
return std::make_unique<AndNode>(std::move(args));
179+
}
180+
return std::make_unique<OrNode>(std::move(args));
181+
}
182+
151183
namespace {
152184

153185
const core::CallTypedExpr* asCall(const core::ITypedExpr* expr) {
@@ -165,16 +197,12 @@ std::unique_ptr<MetadataFilter::Node> MetadataFilter::Node::fromExpression(
165197
return nullptr;
166198
}
167199
if (call->name() == expression::kAnd) {
168-
auto lhs = fromExpression(*call->inputs()[0], evaluator, negated);
169-
auto rhs = fromExpression(*call->inputs()[1], evaluator, negated);
170-
return negated ? OrNode::create(std::move(lhs), std::move(rhs))
171-
: AndNode::create(std::move(lhs), std::move(rhs));
200+
return ConditionNode::fromExpression(
201+
call->inputs(), evaluator, true, negated);
172202
}
173203
if (call->name() == expression::kOr) {
174-
auto lhs = fromExpression(*call->inputs()[0], evaluator, negated);
175-
auto rhs = fromExpression(*call->inputs()[1], evaluator, negated);
176-
return negated ? AndNode::create(std::move(lhs), std::move(rhs))
177-
: OrNode::create(std::move(lhs), std::move(rhs));
204+
return ConditionNode::fromExpression(
205+
call->inputs(), evaluator, false, negated);
178206
}
179207
if (call->name() == "not") {
180208
return fromExpression(*call->inputs()[0], evaluator, !negated);

velox/dwio/common/MetadataFilter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class MetadataFilter {
5050

5151
private:
5252
struct Node;
53+
struct ConditionNode;
5354
struct AndNode;
5455
struct OrNode;
5556

velox/dwio/common/tests/utils/E2EFilterTestBase.cpp

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "velox/dwio/common/tests/utils/DataSetBuilder.h"
2020
#include "velox/expression/Expr.h"
21+
#include "velox/expression/ExprConstants.h"
2122
#include "velox/expression/ExprToSubfieldFilter.h"
2223
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
2324
#include "velox/parse/Expressions.h"
@@ -491,15 +492,31 @@ void E2EFilterTestBase::testMetadataFilterImpl(
491492
core::ExpressionEvaluator* evaluator,
492493
const std::string& remainingFilter,
493494
std::function<bool(int64_t, int64_t)> validationFilter) {
494-
SCOPED_TRACE(fmt::format("remainingFilter={}", remainingFilter));
495+
SCOPED_TRACE(fmt::format("remainingFilter='{}'", remainingFilter));
496+
auto untypedExpr = parse::parseExpr(remainingFilter, {});
497+
auto typedExpr = core::Expressions::inferTypes(
498+
untypedExpr, batches[0]->type(), leafPool_.get());
499+
testMetadataFilterImpl(
500+
batches,
501+
std::move(filterField),
502+
std::move(filter),
503+
evaluator,
504+
std::move(typedExpr),
505+
std::move(validationFilter));
506+
}
507+
508+
void E2EFilterTestBase::testMetadataFilterImpl(
509+
const std::vector<RowVectorPtr>& batches,
510+
common::Subfield filterField,
511+
std::unique_ptr<common::Filter> filter,
512+
core::ExpressionEvaluator* evaluator,
513+
core::TypedExprPtr typedExpr,
514+
std::function<bool(int64_t, int64_t)> validationFilter) {
495515
auto spec = std::make_shared<common::ScanSpec>("<root>");
496516
if (filter) {
497517
spec->getOrCreateChild(std::move(filterField))
498518
->setFilter(std::move(filter));
499519
}
500-
auto untypedExpr = parse::parseExpr(remainingFilter, {});
501-
auto typedExpr = core::Expressions::inferTypes(
502-
untypedExpr, batches[0]->type(), leafPool_.get());
503520
auto metadataFilter =
504521
std::make_shared<MetadataFilter>(*spec, *typedExpr, evaluator);
505522
auto specA = spec->getOrCreateChild(common::Subfield("a"));
@@ -621,6 +638,56 @@ void E2EFilterTestBase::testMetadataFilter() {
621638
[](int64_t a, int64_t) {
622639
return !!(a == 2 || a == 3 || a == 5 || a == 7);
623640
});
641+
{
642+
SCOPED_TRACE("remainingFilter='a == 1 or a == 3 or a == 8'");
643+
auto typedExpr1 = core::Expressions::inferTypes(
644+
parse::parseExpr("a == 1", {}), batches[0]->type(), leafPool_.get());
645+
auto typedExpr2 = core::Expressions::inferTypes(
646+
parse::parseExpr("a == 3", {}), batches[0]->type(), leafPool_.get());
647+
auto typedExpr3 = core::Expressions::inferTypes(
648+
parse::parseExpr("a == 8", {}), batches[0]->type(), leafPool_.get());
649+
650+
auto typedExpr = std::make_shared<core::CallTypedExpr>(
651+
velox::BOOLEAN(),
652+
std::vector{
653+
std::move(typedExpr1),
654+
std::move(typedExpr2),
655+
std::move(typedExpr3),
656+
},
657+
expression::kOr);
658+
testMetadataFilterImpl(
659+
batches,
660+
common::Subfield("a"),
661+
nullptr,
662+
&evaluator,
663+
std::move(typedExpr),
664+
[](int64_t a, int64_t) { return a == 1 || a == 3 || a == 8; });
665+
}
666+
{
667+
SCOPED_TRACE("remainingFilter='a >= 1 and a <= 100 and a == 8'");
668+
auto typedExpr1 = core::Expressions::inferTypes(
669+
parse::parseExpr("a >= 1", {}), batches[0]->type(), leafPool_.get());
670+
auto typedExpr2 = core::Expressions::inferTypes(
671+
parse::parseExpr("a <= 100", {}), batches[0]->type(), leafPool_.get());
672+
auto typedExpr3 = core::Expressions::inferTypes(
673+
parse::parseExpr("b.c != 8", {}), batches[0]->type(), leafPool_.get());
674+
675+
auto typedExpr = std::make_shared<core::CallTypedExpr>(
676+
velox::BOOLEAN(),
677+
std::vector{
678+
std::move(typedExpr1),
679+
std::move(typedExpr2),
680+
std::move(typedExpr3),
681+
},
682+
expression::kAnd);
683+
testMetadataFilterImpl(
684+
batches,
685+
common::Subfield("a"),
686+
nullptr,
687+
&evaluator,
688+
std::move(typedExpr),
689+
[](int64_t a, int64_t c) { return a >= 1 && a <= 100 && c != 8; });
690+
}
624691

625692
{
626693
SCOPED_TRACE("Values not unique in row group");

velox/dwio/common/tests/utils/E2EFilterTestBase.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ class E2EFilterTestBase : public testing::Test {
336336
const std::string& remainingFilter,
337337
std::function<bool(int64_t a, int64_t c)> validationFilter);
338338

339+
void testMetadataFilterImpl(
340+
const std::vector<RowVectorPtr>& batches,
341+
common::Subfield filterField,
342+
std::unique_ptr<common::Filter> filter,
343+
core::ExpressionEvaluator* evaluator,
344+
core::TypedExprPtr typedExpr,
345+
std::function<bool(int64_t, int64_t)> validationFilter);
346+
339347
protected:
340348
void testMetadataFilter();
341349

0 commit comments

Comments
 (0)