Skip to content

Commit 0fa04ac

Browse files
committed
Simplify compiled page projection
Avoid constructing new PageProjectionWork object per page
1 parent 9ca4c3c commit 0fa04ac

File tree

3 files changed

+66
-64
lines changed

3 files changed

+66
-64
lines changed

core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
import io.trino.sql.gen.PageProjectionWork;
2121
import io.trino.sql.relational.RowExpression;
2222

23-
import java.lang.invoke.MethodHandle;
24-
2523
import static com.google.common.base.MoreObjects.toStringHelper;
26-
import static com.google.common.base.Throwables.throwIfUnchecked;
2724
import static java.util.Objects.requireNonNull;
2825

2926
public class GeneratedPageProjection
@@ -32,16 +29,16 @@ public class GeneratedPageProjection
3229
private final RowExpression projection;
3330
private final boolean isDeterministic;
3431
private final InputChannels inputChannels;
35-
private final MethodHandle pageProjectionWorkFactory;
32+
private final PageProjectionWork pageProjectionWork;
3633

3734
private BlockBuilder blockBuilder;
3835

39-
public GeneratedPageProjection(RowExpression projection, boolean isDeterministic, InputChannels inputChannels, MethodHandle pageProjectionWorkFactory)
36+
public GeneratedPageProjection(RowExpression projection, boolean isDeterministic, InputChannels inputChannels, PageProjectionWork pageProjectionWork)
4037
{
4138
this.projection = requireNonNull(projection, "projection is null");
4239
this.isDeterministic = isDeterministic;
4340
this.inputChannels = requireNonNull(inputChannels, "inputChannels is null");
44-
this.pageProjectionWorkFactory = requireNonNull(pageProjectionWorkFactory, "pageProjectionWorkFactory is null");
41+
this.pageProjectionWork = requireNonNull(pageProjectionWork, "pageProjectionWork is null");
4542
this.blockBuilder = projection.type().createBlockBuilder(null, 1);
4643
}
4744

@@ -61,12 +58,7 @@ public InputChannels getInputChannels()
6158
public Block project(ConnectorSession session, SourcePage page, SelectedPositions selectedPositions)
6259
{
6360
blockBuilder = blockBuilder.newBlockBuilderLike(selectedPositions.size(), null);
64-
try {
65-
return ((PageProjectionWork) pageProjectionWorkFactory.invoke(blockBuilder, session, page, selectedPositions)).process();
66-
}
67-
catch (Throwable throwable) {
68-
throw propagate(throwable);
69-
}
61+
return pageProjectionWork.process(session, page, selectedPositions, blockBuilder);
7062
}
7163

7264
@Override
@@ -76,13 +68,4 @@ public String toString()
7668
.add("projection", projection)
7769
.toString();
7870
}
79-
80-
private static RuntimeException propagate(Throwable throwable)
81-
{
82-
if (throwable instanceof InterruptedException) {
83-
Thread.currentThread().interrupt();
84-
}
85-
throwIfUnchecked(throwable);
86-
throw new RuntimeException(throwable);
87-
}
8871
}

core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
import org.weakref.jmx.Managed;
6262
import org.weakref.jmx.Nested;
6363

64-
import java.lang.invoke.MethodHandle;
64+
import java.lang.reflect.Constructor;
6565
import java.util.List;
6666
import java.util.Map;
6767
import java.util.Optional;
@@ -72,6 +72,7 @@
7272

7373
import static com.google.common.base.MoreObjects.toStringHelper;
7474
import static com.google.common.base.Throwables.throwIfInstanceOf;
75+
import static com.google.common.collect.ImmutableList.toImmutableList;
7576
import static io.airlift.bytecode.Access.FINAL;
7677
import static io.airlift.bytecode.Access.PRIVATE;
7778
import static io.airlift.bytecode.Access.PUBLIC;
@@ -97,7 +98,6 @@
9798
import static io.trino.sql.relational.DeterminismEvaluator.isDeterministic;
9899
import static io.trino.util.CompilerUtils.defineClass;
99100
import static io.trino.util.CompilerUtils.makeClassName;
100-
import static io.trino.util.Reflection.constructorMethodHandle;
101101
import static java.util.Objects.requireNonNull;
102102

103103
public class PageFunctionCompiler
@@ -198,9 +198,10 @@ private Supplier<PageProjection> compileProjectionInternal(RowExpression project
198198

199199
ClassDefinition pageProjectionWorkDefinition = definePageProjectWorkClass(result.getRewrittenExpression(), callSiteBinder, classNameSuffix);
200200

201-
Class<?> pageProjectionWorkClass;
201+
Constructor<? extends PageProjectionWork> pageProjectionWorkConstructor;
202202
try {
203-
pageProjectionWorkClass = defineClass(pageProjectionWorkDefinition, PageProjectionWork.class, callSiteBinder.getBindings(), getClass().getClassLoader());
203+
Class<? extends PageProjectionWork> pageProjectionWorkClass = defineClass(pageProjectionWorkDefinition, PageProjectionWork.class, callSiteBinder.getBindings(), getClass().getClassLoader());
204+
pageProjectionWorkConstructor = pageProjectionWorkClass.getConstructor();
204205
}
205206
catch (Exception e) {
206207
if (Throwables.getRootCause(e) instanceof MethodTooLargeException) {
@@ -210,12 +211,19 @@ private Supplier<PageProjection> compileProjectionInternal(RowExpression project
210211
throw new TrinoException(COMPILER_ERROR, e);
211212
}
212213

213-
MethodHandle pageProjectionConstructor = constructorMethodHandle(pageProjectionWorkClass, BlockBuilder.class, ConnectorSession.class, SourcePage.class, SelectedPositions.class);
214-
return () -> new GeneratedPageProjection(
215-
result.getRewrittenExpression(),
216-
isExpressionDeterministic,
217-
result.getInputChannels(),
218-
pageProjectionConstructor);
214+
return () -> {
215+
try {
216+
PageProjectionWork pageProjectionWork = pageProjectionWorkConstructor.newInstance();
217+
return new GeneratedPageProjection(
218+
result.getRewrittenExpression(),
219+
isExpressionDeterministic,
220+
result.getInputChannels(),
221+
pageProjectionWork);
222+
}
223+
catch (ReflectiveOperationException e) {
224+
throw new TrinoException(COMPILER_ERROR, e);
225+
}
226+
};
219227
}
220228

221229
private static ParameterizedType generateProjectionWorkClassName(Optional<String> classNameSuffix)
@@ -232,89 +240,96 @@ private ClassDefinition definePageProjectWorkClass(RowExpression projection, Cal
232240
type(PageProjectionWork.class));
233241

234242
FieldDefinition blockBuilderField = classDefinition.declareField(a(PRIVATE), "blockBuilder", BlockBuilder.class);
235-
FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE), "session", ConnectorSession.class);
236-
FieldDefinition selectedPositionsField = classDefinition.declareField(a(PRIVATE), "selectedPositions", SelectedPositions.class);
237243

238244
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
239245

246+
List<Integer> inputChannels = getInputChannels(projection);
247+
List<FieldDefinition> blockFields = inputChannels.stream()
248+
.map(channel -> classDefinition.declareField(a(PRIVATE), "block_" + channel, Block.class))
249+
.collect(toImmutableList());
240250
// process
241-
generateProcessMethod(classDefinition, blockBuilderField, sessionField, selectedPositionsField);
251+
generateProcessMethod(classDefinition, blockBuilderField, blockFields, inputChannels);
242252

243253
// evaluate
244254
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, projection);
245255
generateEvaluateMethod(classDefinition, callSiteBinder, cachedInstanceBinder, compiledLambdaMap, projection, blockBuilderField);
246256

247257
// constructor
248-
Parameter blockBuilder = arg("blockBuilder", BlockBuilder.class);
249-
Parameter session = arg("session", ConnectorSession.class);
250-
Parameter page = arg("page", SourcePage.class);
251-
Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class);
252-
253-
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), blockBuilder, session, page, selectedPositions);
258+
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC));
254259

255260
BytecodeBlock body = constructorDefinition.getBody();
256261
Variable thisVariable = constructorDefinition.getThis();
257262

258263
body.comment("super();")
259264
.append(thisVariable)
260-
.invokeConstructor(Object.class)
261-
.append(thisVariable.setField(blockBuilderField, blockBuilder))
262-
.append(thisVariable.setField(sessionField, session))
263-
.append(thisVariable.setField(selectedPositionsField, selectedPositions));
264-
265-
for (int channel : getInputChannels(projection)) {
266-
FieldDefinition blockField = classDefinition.declareField(a(PRIVATE, FINAL), "block_" + channel, Block.class);
267-
body.append(thisVariable.setField(blockField, page.invoke("getBlock", Block.class, constantInt(channel))));
268-
}
265+
.invokeConstructor(Object.class);
269266

270267
cachedInstanceBinder.generateInitializations(thisVariable, body);
271268
body.ret();
272269

273270
return classDefinition;
274271
}
275272

276-
private static MethodDefinition generateProcessMethod(
273+
private static void generateProcessMethod(
277274
ClassDefinition classDefinition,
278-
FieldDefinition blockBuilder,
279-
FieldDefinition session,
280-
FieldDefinition selectedPositions)
275+
FieldDefinition blockBuilderField,
276+
List<FieldDefinition> blockFields,
277+
List<Integer> inputChannels)
281278
{
282-
MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "process", type(Block.class), ImmutableList.of());
279+
Parameter session = arg("session", ConnectorSession.class);
280+
Parameter page = arg("page", SourcePage.class);
281+
Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class);
282+
Parameter blockBuilder = arg("blockBuilder", BlockBuilder.class);
283+
284+
MethodDefinition method = classDefinition.declareMethod(
285+
a(PUBLIC),
286+
"process",
287+
type(Block.class),
288+
ImmutableList.<Parameter>builder()
289+
.add(session)
290+
.add(page)
291+
.add(selectedPositions)
292+
.add(blockBuilder)
293+
.build());
283294

284295
Scope scope = method.getScope();
285296
Variable thisVariable = method.getThis();
286297
BytecodeBlock body = method.getBody();
287298

288-
Variable from = scope.declareVariable("from", body, thisVariable.getField(selectedPositions).invoke("getOffset", int.class));
289-
Variable to = scope.declareVariable("to", body, add(thisVariable.getField(selectedPositions).invoke("getOffset", int.class), thisVariable.getField(selectedPositions).invoke("size", int.class)));
299+
for (int i = 0; i < inputChannels.size(); i++) {
300+
int channel = inputChannels.get(i);
301+
body.append(thisVariable.setField(blockFields.get(i), page.invoke("getBlock", Block.class, constantInt(channel))));
302+
}
303+
body.append(thisVariable.setField(blockBuilderField, blockBuilder));
304+
305+
Variable from = scope.declareVariable("from", body, selectedPositions.invoke("getOffset", int.class));
306+
Variable to = scope.declareVariable("to", body, add(selectedPositions.invoke("getOffset", int.class), selectedPositions.invoke("size", int.class)));
290307
Variable positions = scope.declareVariable(int[].class, "positions");
291308
Variable index = scope.declareVariable(int.class, "index");
292309

293310
IfStatement ifStatement = new IfStatement()
294-
.condition(thisVariable.getField(selectedPositions).invoke("isList", boolean.class));
311+
.condition(selectedPositions.invoke("isList", boolean.class));
295312
body.append(ifStatement);
296313

297314
ifStatement.ifTrue(new BytecodeBlock()
298-
.append(positions.set(thisVariable.getField(selectedPositions).invoke("getPositions", int[].class)))
315+
.append(positions.set(selectedPositions.invoke("getPositions", int[].class)))
299316
.append(new ForLoop("positions loop")
300317
.initialize(index.set(from))
301318
.condition(lessThan(index, to))
302319
.update(index.increment())
303320
.body(new BytecodeBlock()
304-
.append(thisVariable.invoke("evaluate", void.class, thisVariable.getField(session), positions.getElement(index))))));
321+
.append(thisVariable.invoke("evaluate", void.class, session, positions.getElement(index))))));
305322

306323
ifStatement.ifFalse(new ForLoop("range based loop")
307324
.initialize(index.set(from))
308325
.condition(lessThan(index, to))
309326
.update(index.increment())
310327
.body(new BytecodeBlock()
311-
.append(thisVariable.invoke("evaluate", void.class, thisVariable.getField(session), index))));
328+
.append(thisVariable.invoke("evaluate", void.class, session, index))));
312329

313330
body.comment("return this.blockBuilder.build();")
314-
.append(thisVariable.getField(blockBuilder).invoke("build", Block.class))
331+
.append(thisVariable.getField(blockBuilderField).invoke("build", Block.class))
315332
.retObject();
316-
317-
return method;
318333
}
319334

320335
private MethodDefinition generateEvaluateMethod(

core/trino-main/src/main/java/io/trino/sql/gen/PageProjectionWork.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
*/
1414
package io.trino.sql.gen;
1515

16+
import io.trino.operator.project.SelectedPositions;
1617
import io.trino.spi.block.Block;
18+
import io.trino.spi.block.BlockBuilder;
19+
import io.trino.spi.connector.ConnectorSession;
20+
import io.trino.spi.connector.SourcePage;
1721

1822
public interface PageProjectionWork
1923
{
20-
Block process();
24+
Block process(ConnectorSession session, SourcePage page, SelectedPositions selectedPositions, BlockBuilder builder);
2125
}

0 commit comments

Comments
 (0)