Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,10 @@ public SolverReturnStatus update(BoundVariables bindings)

TypeSignature constraintTypeSignature = applyBoundVariables(formalTypeSignature, bindings);

if (actualType == UNKNOWN) {
return SolverReturnStatus.UNCHANGED_SATISFIED;
}

return satisfiesCoercion(relationshipType, actualType, constraintTypeSignature) ? SolverReturnStatus.UNCHANGED_SATISFIED : SolverReturnStatus.UNCHANGED_NOT_SATISFIED;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
import static io.trino.operator.scalar.Re2JCastToRegexpFunction.castVarcharToRe2JRegexp;
import static io.trino.operator.scalar.RowToJsonCast.ROW_TO_JSON;
import static io.trino.operator.scalar.RowToRowCast.ROW_TO_ROW_CAST;
import static io.trino.operator.scalar.RowTransformFunction.ROW_TRANSFORM_FUNCTION;
import static io.trino.operator.scalar.TryCastFunction.TRY_CAST;
import static io.trino.operator.scalar.ZipFunction.ZIP_FUNCTIONS;
import static io.trino.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION;
Expand Down Expand Up @@ -577,7 +578,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.aggregates(DecimalAverageAggregation.class)
.aggregates(DecimalSumAggregation.class)
.function(DECIMAL_MOD_FUNCTION)
.functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
.functions(ROW_TRANSFORM_FUNCTION, ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
.functions(MAP_FILTER_FUNCTION, new MapTransformKeysFunction(blockTypeOperators), MAP_TRANSFORM_VALUES_FUNCTION)
.function(FORMAT_FUNCTION)
.function(TRY_CAST)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.metadata.SqlScalarFunction;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlRow;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.RowType;
import io.trino.spi.type.RowType.Field;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.lambda.UnaryFunctionInterface;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.type.TypeSignature.functionType;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.spi.type.TypeUtils.writeNativeValue;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.util.Reflection.methodHandle;

public final class RowTransformFunction
extends SqlScalarFunction
{
public static final RowTransformFunction ROW_TRANSFORM_FUNCTION = new RowTransformFunction();
private static final String ROW_TRANSFORM_NAME = "transform";
private static final MethodHandle METHOD_HANDLE = methodHandle(RowTransformFunction.class, "transform", RowType.class, Type.class, SqlRow.class, Slice.class, Object.class, UnaryFunctionInterface.class);

private RowTransformFunction()
{
super(FunctionMetadata.scalarBuilder(ROW_TRANSFORM_NAME)
.signature(Signature.builder()
.variadicTypeParameter("T", "row")
.typeVariable("V")
.returnType(new TypeSignature("T"))
.argumentType(new TypeSignature("T"))
.argumentType(VARCHAR.getTypeSignature())
.argumentType(new TypeSignature("V"))
.argumentType(functionType(new TypeSignature("V"), new TypeSignature("V")))
.build())
.description("Apply lambda to the value of a field, returning the transformed row")
.build());
}

@Override
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
{
RowType rowType = (RowType) boundSignature.getArgumentType(0);
Type valueType = boundSignature.getArgumentType(2);

return new ChoicesSpecializedSqlScalarFunction(
boundSignature,
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, NEVER_NULL, NEVER_NULL, FUNCTION),
ImmutableList.of(UnaryFunctionInterface.class),
METHOD_HANDLE.asType(
METHOD_HANDLE.type()
.changeParameterType(4, valueType.getJavaType())
).bindTo(rowType).bindTo(valueType),
Optional.empty());
}

@UsedByGeneratedCode
public static SqlRow transform(RowType rowType, Type valueType, SqlRow sqlRow, Slice fieldNameSlice, Object dummyValue, UnaryFunctionInterface function)
{
int fieldIndex = -1;
Field match = null;
String fieldName = fieldNameSlice.toStringUtf8();
List<Field> fields = rowType.getFields();
for (int i = 0; i < fields.size(); i++) {
Field field = fields.get(i);
if (field.getName().orElse("").equals(fieldName)) {
match = field;
fieldIndex = i;
break;
}
}

if (match == null) {
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, String.format("Field with name %s not found in row", fieldName));
}
if (match.getType().getClass() != valueType.getClass()) {
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, String.format("Incompatible function types: field is of type %s but lambda returns %s", match.getType(), valueType));
}

Block[] blocks = new Block[fields.size()];
for (int i = 0; i < fields.size(); i++) {
if (i != fieldIndex) {
blocks[i] = sqlRow.getRawFieldBlock(i).getSingleValueBlock(sqlRow.getRawIndex());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to give the right result, but is it correct to get the blocks this way?

}
else {
Object value = readNativeValue(valueType, sqlRow.getRawFieldBlock(i), sqlRow.getRawIndex());
blocks[i] = writeNativeValue(valueType, function.apply(value));
}
}
return new SqlRow(0, blocks);
}
}
5 changes: 4 additions & 1 deletion core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.type.FunctionType;
import io.trino.type.UnknownType;

import java.util.Arrays;
import java.util.Collection;
Expand All @@ -37,6 +39,7 @@
import static com.google.common.collect.Streams.stream;
import static io.trino.sql.ir.Booleans.FALSE;
import static io.trino.sql.ir.Booleans.TRUE;
import static io.trino.type.UnknownType.UNKNOWN;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

Expand All @@ -46,7 +49,7 @@ private IrUtils() {}

static void validateType(Type expected, Expression expression)
{
checkArgument(expected.equals(expression.type()), "Expected '%s' type but found '%s' for expression: %s", expected, expression.type(), expression);
checkArgument(expected.equals(expression.type()) || (expression.type() instanceof FunctionType f && f.getReturnType() == UNKNOWN), "Expected '%s' type but found '%s' for expression: %s", expected, expression.type(), expression);
}

public static List<Expression> extractConjuncts(Expression expression)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.trino.spi.type.ArrayType;
import io.trino.sql.query.QueryAssertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static io.trino.spi.type.ArrayType.arrayType;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RowType.field;
import static io.trino.spi.type.RowType.rowType;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static org.apache.commons.io.IOUtils.closeQuietly;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;

@TestInstance(PER_CLASS)
@Execution(CONCURRENT)
public class TestRowTransformFunction
{
private QueryAssertions assertions;

@BeforeAll
public void init()
{
assertions = new QueryAssertions();
}

@AfterAll
public void teardown()
{
closeQuietly(assertions);
assertions = null;
}

@Test
public void testInteger()
{
assertThat(assertions.expression("transform(a, 'greeting', 1337, greeting -> greeting * 2)")
.binding("a", "CAST(ROW(2, 3) as ROW(greeting integer, planet integer))"))
.hasType(rowType(field("greeting", INTEGER), field("planet", INTEGER)))
.isEqualTo(ImmutableList.of(4, 3));
}

@Test
public void testVarchar()
{
assertThat(assertions.expression("transform(a, 'greeting', '', greeting -> concat(greeting, ' or goodbye'))")
.binding("a", "CAST(ROW('hello', 'world') as ROW(greeting varchar, planet varchar))"))
.hasType(rowType(field("greeting", VARCHAR), field("planet", VARCHAR)))
.isEqualTo(ImmutableList.of("hello or goodbye", "world"));
}


@Test
public void testNullReturn()
{
List<String> expected = new ArrayList<>();
expected.add(null);
expected.add("world");

assertThat(assertions.expression("transform(a, 'greeting', '', greeting -> NULL)")
.binding("a", "CAST(ROW('hello', 'world') as ROW(greeting varchar, planet varchar))"))
.hasType(rowType(field("greeting", VARCHAR), field("planet", VARCHAR)))
.isEqualTo(expected);
}

@Test
public void testIntegerArray()
{
assertThat(assertions.expression("transform(a, 'greeting', ARRAY[0], greeting -> greeting || 2)")
.binding("a", "CAST(ROW(ARRAY[1], 'world') as ROW(greeting array(integer), planet varchar))"))
.hasType(rowType(field("greeting", new ArrayType(INTEGER)), field("planet", VARCHAR)))
.isEqualTo(ImmutableList.of(ImmutableList.of(1, 2), "world"));
}

@Test
public void testVarcharArray()
{
assertThat(assertions.expression("transform(a, 'greeting', ARRAY[''], greeting -> greeting || 'or' || 'goodbye')")
.binding("a", "CAST(ROW(ARRAY['hello'], 'world') AS ROW(greeting array(varchar), planet varchar))"))
.hasType(rowType(field("greeting", new ArrayType(VARCHAR)), field("planet", VARCHAR)))
.isEqualTo(ImmutableList.of(ImmutableList.of("hello", "or", "goodbye"), "world"));
}

@Test
public void testVarcharRowType()
{
assertThat(assertions.expression("transform(a, 'greeting', CAST(ROW('') as ROW(text varchar)), greeting -> transform(greeting, 'text', '', text -> concat(text, ' or goodbye')))")
.binding("a", "CAST(ROW(ROW('hello'), 'world') as ROW(greeting ROW(text varchar), planet varchar))"))
.hasType(rowType(field("greeting", rowType(field("text", VARCHAR))), field("planet", VARCHAR)))
.isEqualTo(ImmutableList.of(ImmutableList.of("hello or goodbye"), "world"));
}

@Test
public void testVarcharRowTypeArrayType()
{
assertThat(assertions.expression("""
transform(a, data ->
transform(data, 'greeting', '', greeting -> concat(greeting, ' or goodbye')))
""")
.binding("a", """
ARRAY[CAST(ROW('hello', 'world') as ROW(greeting varchar, planet varchar)),
CAST(ROW('hi', 'mars') as ROW(greeting varchar, planet varchar)),
CAST(ROW('hey', 'jupiter') as ROW(greeting varchar, planet varchar))]
"""))
.hasType(arrayType(rowType(field("greeting", VARCHAR), field("planet", VARCHAR))))
.isEqualTo(ImmutableList.of(
ImmutableList.of("hello or goodbye", "world"),
ImmutableList.of("hi or goodbye", "mars"),
ImmutableList.of("hey or goodbye", "jupiter")));
}
}
5 changes: 5 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper
return operatorDeclaration;
}

public static ArrayType arrayType(Type elementType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to be consistent with RowType.rowType

{
return new ArrayType(elementType);
}

private synchronized void generateTypeOperators(TypeOperators typeOperators)
{
if (operatorDeclaration != null) {
Expand Down
1 change: 1 addition & 0 deletions docs/src/main/sphinx/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Map <functions/map>
Math <functions/math>
Quantile digest <functions/qdigest>
Regular expression <functions/regexp>
Row <functions/row>
Session <functions/session>
Set Digest <functions/setdigest>
String <functions/string>
Expand Down
55 changes: 55 additions & 0 deletions docs/src/main/sphinx/functions/row.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Row functions

Row functions use the [ROW type](row-type).
Create a row by explicitly casting the field names and types:

```sql
SELECT CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar))
-- ROW('hello', 'world')
```

Fields can be accessed via the dot fieldname:
```sql
SELECT CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar)).greeting
-- 'hello'
```

## Row functions

:::{function} transform(T, varchar, V, function(V, V)) -> T
With this function, a field in the row can be updated with the lambda function.
The returned value is the original value with the updated field. The second
argument is the name of the field to update. The third argument, `V` is a dummy
so the type of the function can be resolved. It can be any value, as long as the
type of the value is equal to the type of the argument and return type of the
lambda function.

```sql
SELECT transform(
CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar))",
'greeting', '', greeting -> concat(greeting, ' or goodbye'));
-- ROW('hello or goodbye', 'world')
```

The transform can be used to reach fields in nested rows or fields in rows
in arrays:

```sql
SELECT transform(
CAST(ROW(ROW('hello'), 'world') as ROW(greeting ROW(text varchar), planet varchar)),
'greeting',
CAST(ROW('') as ROW(text varchar)),
greeting -> transform(
greeting, 'text', '', text -> concat(text, ' or goodbye')));
-- ROW(ROW('hello or goodbye'), 'world')
```


```sql
SELECT transform(ARRAY[
CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar)),
CAST(ROW('hi', 'mars') AS ROW(greeting varchar, planet varchar))],
data -> transform(data, 'greeting', '', greeting -> concat(greeting, ' or goodbye')));
-- ARRAY[ROW('hello or goodbye', 'world'), ROW('hi or goodbye', 'mars')]
```
:::
Loading