Skip to content

Commit ae4b60c

Browse files
committed
Introduce Pre Post Score Collection Hooks to QueryPhase
Signed-off-by: Atri Sharma <[email protected]>
1 parent 89edd4c commit ae4b60c

File tree

5 files changed

+363
-2
lines changed

5 files changed

+363
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2222
- [Workload Management] Modify logging message in WorkloadGroupService ([#18712](https://github.com/opensearch-project/OpenSearch/pull/18712))
2323
- Add BooleanQuery rewrite moving constant-scoring must clauses to filter clauses ([#18510](https://github.com/opensearch-project/OpenSearch/issues/18510))
2424
- Add functionality for plugins to inject QueryCollectorContext during QueryPhase ([#18637](https://github.com/opensearch-project/OpenSearch/pull/18637))
25+
- Add QueryPhaseExtension interface for pre/post score collection hooks ([#17593](https://github.com/opensearch-project/OpenSearch/issues/17593))
2526
- Add support for non-timing info in profiler ([#18460](https://github.com/opensearch-project/OpenSearch/issues/18460))
2627
- [Rule-based auto tagging] Bug fix and improvements ([#18726](https://github.com/opensearch-project/OpenSearch/pull/18726))
2728
- Extend Approximation Framework to other numeric types ([#18530](https://github.com/opensearch-project/OpenSearch/issues/18530))

server/src/main/java/org/opensearch/search/query/QueryPhase.java

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import org.apache.logging.log4j.LogManager;
3636
import org.apache.logging.log4j.Logger;
37+
import org.apache.logging.log4j.message.ParameterizedMessage;
3738
import org.apache.lucene.index.IndexReader;
3839
import org.apache.lucene.index.LeafReaderContext;
3940
import org.apache.lucene.search.BooleanClause;
@@ -430,7 +431,56 @@ public boolean searchWith(
430431
boolean hasFilterCollector,
431432
boolean hasTimeout
432433
) throws IOException {
433-
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
434+
// Fast path - skip extension logic entirely if no extensions are registered
435+
List<QueryPhaseExtension> extensions = queryPhaseExtensions();
436+
if (extensions == null || extensions.isEmpty()) {
437+
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
438+
}
439+
440+
// Execute beforeScoreCollection extensions
441+
for (QueryPhaseExtension extension : extensions) {
442+
try {
443+
extension.beforeScoreCollection(searchContext);
444+
} catch (Exception e) {
445+
if (extension.failOnError()) {
446+
throw new QueryPhaseExecutionException(
447+
searchContext.shardTarget(),
448+
"Failed to execute beforeScoreCollection extension [" + extension.getClass().getName() + "]",
449+
e
450+
);
451+
}
452+
LOGGER.warn(
453+
new ParameterizedMessage("Failed to execute beforeScoreCollection extension [{}]", extension.getClass().getName()),
454+
e
455+
);
456+
}
457+
}
458+
459+
try {
460+
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
461+
} finally {
462+
// Execute afterScoreCollection extensions
463+
for (QueryPhaseExtension extension : extensions) {
464+
try {
465+
extension.afterScoreCollection(searchContext);
466+
} catch (Exception e) {
467+
if (extension.failOnError()) {
468+
throw new QueryPhaseExecutionException(
469+
searchContext.shardTarget(),
470+
"Failed to execute afterScoreCollection extension [" + extension.getClass().getName() + "]",
471+
e
472+
);
473+
}
474+
LOGGER.warn(
475+
new ParameterizedMessage(
476+
"Failed to execute afterScoreCollection extension [{}]",
477+
extension.getClass().getName()
478+
),
479+
e
480+
);
481+
}
482+
}
483+
}
434484
}
435485

436486
@Override
@@ -447,7 +497,15 @@ protected boolean searchWithCollector(
447497
boolean hasTimeout
448498
) throws IOException {
449499
QueryCollectorContext queryCollectorContext = getQueryCollectorContext(searchContext, hasFilterCollector);
450-
return searchWithCollector(searchContext, searcher, query, collectors, queryCollectorContext, hasFilterCollector, hasTimeout);
500+
return QueryPhase.searchWithCollector(
501+
searchContext,
502+
searcher,
503+
query,
504+
collectors,
505+
queryCollectorContext,
506+
hasFilterCollector,
507+
hasTimeout
508+
);
451509
}
452510

453511
private QueryCollectorContext getQueryCollectorContext(SearchContext searchContext, boolean hasFilterCollector) throws IOException {
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.search.query;
10+
11+
import org.opensearch.common.annotation.PublicApi;
12+
import org.opensearch.search.internal.SearchContext;
13+
14+
/**
15+
* Extension point interface that allows plugins to hook into the query phase
16+
* before and after score collection. This enables custom CollectorManager
17+
* implementations and score processing for advanced search features like
18+
* hybrid queries and neural search.
19+
*
20+
* @opensearch.api
21+
*/
22+
@PublicApi(since = "3.2.0")
23+
public interface QueryPhaseExtension {
24+
25+
/**
26+
* Called before score collection begins in the query phase.
27+
* This allows extensions to set up custom state or modify the search context
28+
* before the main query execution.
29+
*
30+
* @param searchContext the current search context
31+
*/
32+
void beforeScoreCollection(SearchContext searchContext);
33+
34+
/**
35+
* Called after score collection completes in the query phase.
36+
* This allows extensions to process collected scores or perform
37+
* post-collection operations.
38+
*
39+
* @param searchContext the current search context
40+
*/
41+
void afterScoreCollection(SearchContext searchContext);
42+
43+
/**
44+
* Determines whether failures in this extension should fail the entire query.
45+
* When true, exceptions thrown by this extension will propagate and fail the search.
46+
* When false (default), exceptions are logged and the search continues.
47+
*
48+
* @return true if extension failures should fail the query, false otherwise
49+
*/
50+
default boolean failOnError() {
51+
return false;
52+
}
53+
}

server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import org.opensearch.search.internal.SearchContext;
1818

1919
import java.io.IOException;
20+
import java.util.Collections;
2021
import java.util.LinkedList;
22+
import java.util.List;
2123

2224
/**
2325
* The extension point which allows to plug in custom search implementation to be
@@ -53,4 +55,12 @@ boolean searchWith(
5355
default AggregationProcessor aggregationProcessor(SearchContext searchContext) {
5456
return new DefaultAggregationProcessor();
5557
}
58+
59+
/**
60+
* Get the list of query phase extensions that should be executed before and after score collection.
61+
* @return list of query phase extensions, empty list if none
62+
*/
63+
default List<QueryPhaseExtension> queryPhaseExtensions() {
64+
return Collections.emptyList();
65+
}
5666
}

0 commit comments

Comments
 (0)