Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ public Builder staticArguments(StaticArgument... staticArguments) {
return this;
}

public Builder disableSystemArguments(boolean disableSystemArguments) {
this.typeInferenceBuilder.disableSystemArguments(disableSystemArguments);
return this;
}

/**
* @deprecated Use {@link #staticArguments(StaticArgument...)} instead.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.apache.flink.table.types.inference.ArgumentTypeStrategy;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.inference.strategies.ArrayOfStringArgumentTypeStrategy;
import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies;
Expand All @@ -46,12 +48,14 @@
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.table.utils.EncodingUtils;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
Expand All @@ -68,6 +72,7 @@
import static org.apache.flink.table.api.DataTypes.TIMESTAMP_LTZ;
import static org.apache.flink.table.functions.FunctionKind.AGGREGATE;
import static org.apache.flink.table.functions.FunctionKind.OTHER;
import static org.apache.flink.table.functions.FunctionKind.PROCESS_TABLE;
import static org.apache.flink.table.functions.FunctionKind.SCALAR;
import static org.apache.flink.table.functions.FunctionKind.TABLE;
import static org.apache.flink.table.types.inference.InputTypeStrategies.ANY;
Expand Down Expand Up @@ -103,11 +108,13 @@
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.ARRAY_FULLY_COMPARABLE;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.INDEX;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.JSON_ARGUMENT;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.ML_PREDICT_INPUT_TYPE_STRATEGY;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.TWO_EQUALS_COMPARABLE;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.TWO_FULLY_COMPARABLE;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.percentage;
import static org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies.percentageArray;
import static org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies.ARRAY_APPEND_PREPEND;
import static org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies.ML_PREDICT_OUTPUT_TYPE_STRATEGY;

/** Dictionary of function definitions for all built-in functions. */
@PublicEvolving
Expand Down Expand Up @@ -725,6 +732,29 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
.outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN())))
.build();

public static final BuiltInFunctionDefinition ML_PREDICT =
BuiltInFunctionDefinition.newBuilder()
.name("ML_PREDICT")
.kind(PROCESS_TABLE)
.disableSystemArguments(true)
.staticArguments(
StaticArgument.table(
"INPUT",
Row.class,
false,
EnumSet.of(StaticArgumentTrait.TABLE)),
StaticArgument.model(
"MODEL", false, EnumSet.of(StaticArgumentTrait.MODEL)),
StaticArgument.scalar("ARGS", DataTypes.DESCRIPTOR(), false),
StaticArgument.scalar(
"CONFIG",
DataTypes.MAP(DataTypes.STRING(), DataTypes.STRING()),
true))
.inputTypeStrategy(ML_PREDICT_INPUT_TYPE_STRATEGY)
.outputTypeStrategy(ML_PREDICT_OUTPUT_TYPE_STRATEGY)
.runtimeProvided()
.build();

public static final BuiltInFunctionDefinition GREATEST =
BuiltInFunctionDefinition.newBuilder()
.name("GREATEST")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.types.DataType;

/**
* Provides call information about the model that has been passed to a model argument.
*
* <p>This class is only available for model arguments (i.e. arguments of a {@link
* ProcessTableFunction} that are annotated with {@code @ArgumentHint(MODEL)}).
*/
@PublicEvolving
public interface ModelSemantics {

/** Input data type expected by the passed model. */
DataType inputDataType();

/** Output data type produced by the passed model. */
DataType outputDataType();
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.functions.ChangelogFunction;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.ModelSemantics;
import org.apache.flink.table.functions.ProcessTableFunction;
import org.apache.flink.table.functions.TableSemantics;
import org.apache.flink.table.types.DataType;
Expand Down Expand Up @@ -78,6 +79,15 @@ default Optional<TableSemantics> getTableSemantics(int pos) {
return Optional.empty();
}

/**
* Returns information about the model that has been passed to a model argument.
*
* <p>This method applies only to {@link ProcessTableFunction}s.
*/
default Optional<ModelSemantics> getModelSemantics(int pos) {
return Optional.empty();
}

/**
* Returns the {@link ChangelogMode} that the framework requires from the function.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ private StaticArgument(
checkTraits(traits);
checkOptionalType();
checkTableType();
checkModelType();
}

/**
Expand Down Expand Up @@ -144,6 +145,23 @@ public static StaticArgument table(
return new StaticArgument(name, dataType, null, isOptional, enrichTableTraits(traits));
}

/**
* Declares a model argument such as {@code f(m => myModel)} or {@code f(m => MODEL myModel))}.
*
* <p>By using this method, the argument supports a "polymorphic" behavior. In other words: it
* accepts models with arbitrary schemas or types.
*
* @param name name for the assignment operator e.g. {@code f(myArg => myModel)}
* @param isOptional whether the argument is optional
* @param traits set of {@link StaticArgumentTrait} requiring {@link StaticArgumentTrait#MODEL}
*/
public static StaticArgument model(
String name, boolean isOptional, EnumSet<StaticArgumentTrait> traits) {
final EnumSet<StaticArgumentTrait> enrichedTraits = EnumSet.copyOf(traits);
enrichedTraits.add(StaticArgumentTrait.MODEL);
return new StaticArgument(name, null, null, isOptional, enrichedTraits);
}

private static EnumSet<StaticArgumentTrait> enrichTableTraits(
EnumSet<StaticArgumentTrait> traits) {
final EnumSet<StaticArgumentTrait> enrichedTraits = EnumSet.copyOf(traits);
Expand Down Expand Up @@ -283,6 +301,13 @@ private void checkTableType() {
checkTypedTableType();
}

private void checkModelType() {
if (!traits.contains(StaticArgumentTrait.MODEL)) {
return;
}
checkModelNotOptional();
}

private void checkTableNotOptional() {
if (isOptional) {
throw new ValidationException("Table arguments must not be optional.");
Expand Down Expand Up @@ -323,4 +348,10 @@ private void checkTypedTableType() {
type, name));
}
}

private void checkModelNotOptional() {
if (isOptional) {
throw new ValidationException("Model arguments must not be optional.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,27 @@ public static TypeInference of(FunctionKind functionKind, TypeInference origin)
final TypeInference.Builder builder = TypeInference.newBuilder();

final List<StaticArgument> systemArgs =
deriveSystemArgs(functionKind, origin.getStaticArguments().orElse(null));
deriveSystemArgs(
functionKind,
origin.getStaticArguments().orElse(null),
origin.disableSystemArguments());
if (systemArgs != null) {
builder.staticArguments(systemArgs);
}
builder.inputTypeStrategy(
deriveSystemInputStrategy(functionKind, systemArgs, origin.getInputTypeStrategy()));
deriveSystemInputStrategy(
functionKind,
systemArgs,
origin.getInputTypeStrategy(),
origin.disableSystemArguments()));
builder.stateTypeStrategies(origin.getStateTypeStrategies());
builder.outputTypeStrategy(
deriveSystemOutputStrategy(
functionKind, systemArgs, origin.getOutputTypeStrategy()));
functionKind,
systemArgs,
origin.getOutputTypeStrategy(),
origin.disableSystemArguments()));
builder.disableSystemArguments(origin.disableSystemArguments());
return builder.build();
}

Expand All @@ -130,7 +141,9 @@ private static void checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
}

private static @Nullable List<StaticArgument> deriveSystemArgs(
FunctionKind functionKind, @Nullable List<StaticArgument> declaredArgs) {
FunctionKind functionKind,
@Nullable List<StaticArgument> declaredArgs,
boolean disableSystemArgs) {
if (functionKind != FunctionKind.PROCESS_TABLE) {
if (declaredArgs != null) {
checkScalarArgsOnly(declaredArgs);
Expand All @@ -147,7 +160,9 @@ private static void checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
checkPassThroughColumns(declaredArgs);

final List<StaticArgument> newStaticArgs = new ArrayList<>(declaredArgs);
newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS);
Copy link
Contributor

@davidradl davidradl Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious what the system arguments mean. Is this something that the user needs to be aware of? I do not see this phrase in the Flip and there is no more information in the Jira. I suggest including a description and motivation behind this piece. It appears to be a type of static arg that will be added if the boolean flag is on, but I am not sure when this would/should be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to control whether uid, ontime field will be added to ptf input. This is currently used by ml_predict because it doesn't need uid and ontime field. It's not exposed to PTF function user can define. Yes. I can add more description if this approach makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This in not entirely correct. A user-defined PTF can implement a TypeInference and avoid system args, but this is kind of second-level API.

if (!disableSystemArgs) {
newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS);
}
return newStaticArgs;
}

Expand Down Expand Up @@ -197,36 +212,44 @@ private static void checkPassThroughColumns(List<StaticArgument> staticArgs) {
private static InputTypeStrategy deriveSystemInputStrategy(
FunctionKind functionKind,
@Nullable List<StaticArgument> staticArgs,
InputTypeStrategy inputStrategy) {
InputTypeStrategy inputStrategy,
boolean disableSystemArgs) {
if (functionKind != FunctionKind.PROCESS_TABLE) {
return inputStrategy;
}
return new SystemInputStrategy(staticArgs, inputStrategy);
return new SystemInputStrategy(staticArgs, inputStrategy, disableSystemArgs);
}

private static TypeStrategy deriveSystemOutputStrategy(
FunctionKind functionKind,
@Nullable List<StaticArgument> staticArgs,
TypeStrategy outputStrategy) {
TypeStrategy outputStrategy,
boolean disableSystemArgs) {
if (functionKind != FunctionKind.TABLE
&& functionKind != FunctionKind.PROCESS_TABLE
&& functionKind != FunctionKind.ASYNC_TABLE) {
return outputStrategy;
}
return new SystemOutputStrategy(functionKind, staticArgs, outputStrategy);
return new SystemOutputStrategy(
functionKind, staticArgs, outputStrategy, disableSystemArgs);
}

private static class SystemOutputStrategy implements TypeStrategy {

private final FunctionKind functionKind;
private final List<StaticArgument> staticArgs;
private final TypeStrategy origin;
private final boolean disableSystemArgs;

private SystemOutputStrategy(
FunctionKind functionKind, List<StaticArgument> staticArgs, TypeStrategy origin) {
FunctionKind functionKind,
List<StaticArgument> staticArgs,
TypeStrategy origin,
boolean disableSystemArgs) {
this.functionKind = functionKind;
this.staticArgs = staticArgs;
this.origin = origin;
this.disableSystemArgs = disableSystemArgs;
}

@Override
Expand All @@ -247,7 +270,10 @@ public Optional<DataType> inferType(CallContext callContext) {
// this whole topic is kind of vendor specific already
fields.addAll(derivePassThroughFields(callContext));
fields.addAll(deriveFunctionOutputFields(functionDataType));
fields.addAll(deriveRowtimeField(callContext));

if (!disableSystemArgs) {
fields.addAll(deriveRowtimeField(callContext));
}

final List<Field> uniqueFields = makeFieldNamesUnique(fields);

Expand Down Expand Up @@ -480,10 +506,15 @@ private static class SystemInputStrategy implements InputTypeStrategy {

private final List<StaticArgument> staticArgs;
private final InputTypeStrategy origin;
private final boolean disableSystemArgs;

private SystemInputStrategy(List<StaticArgument> staticArgs, InputTypeStrategy origin) {
private SystemInputStrategy(
List<StaticArgument> staticArgs,
InputTypeStrategy origin,
boolean disableSystemArgs) {
this.staticArgs = staticArgs;
this.origin = origin;
this.disableSystemArgs = disableSystemArgs;
}

@Override
Expand Down Expand Up @@ -514,7 +545,9 @@ public Optional<List<DataType>> inferInputTypes(

try {
checkTableArgs(staticArgs, callContext);
checkUidArg(callContext);
if (!disableSystemArgs) {
checkUidArg(callContext);
}
} catch (ValidationException e) {
return callContext.fail(throwOnFailure, e.getMessage());
}
Expand Down
Loading