diff --git a/pyproject.toml b/pyproject.toml index 7407070..8aa3246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20"] +dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0", "PyYAML"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/sql/__init__.py b/src/substrait/sql/__init__.py new file mode 100644 index 0000000..9dbd56a --- /dev/null +++ b/src/substrait/sql/__init__.py @@ -0,0 +1,2 @@ +from .extended_expression import parse_sql_extended_expression +from .functions_catalog import FunctionsCatalog diff --git a/src/substrait/sql/__main__.py b/src/substrait/sql/__main__.py new file mode 100644 index 0000000..f135e4a --- /dev/null +++ b/src/substrait/sql/__main__.py @@ -0,0 +1,67 @@ +import pathlib +import argparse + +from substrait import proto +from .functions_catalog import FunctionsCatalog +from .extended_expression import parse_sql_extended_expression + + +def main(): + """Commandline tool to test the SQL to ExtendedExpression parser. + + Run as python -m substrait.sql first_name=String,surname=String,age=I32 "SELECT surname, age + 1 as next_birthday, age + 2 WHERE age = 32" + """ + parser = argparse.ArgumentParser( + description="Convert a SQL SELECT statement to an ExtendedExpression" + ) + parser.add_argument("schema", type=str, help="Schema of the input data") + parser.add_argument("sql", type=str, help="SQL SELECT statement") + args = parser.parse_args() + + catalog = FunctionsCatalog() + catalog.load_standard_extensions( + pathlib.Path(__file__).parent.parent.parent.parent + / "third_party" + / "substrait" + / "extensions", + ) + schema = parse_schema(args.schema) + projection_expr, filter_expr = parse_sql_extended_expression( + catalog, schema, args.sql + ) + + print("---- SQL INPUT ----") + print(args.sql) + print("---- PROJECTION ----") + print(projection_expr) + print("---- FILTER ----") + print(filter_expr) + + +def parse_schema(schema_string): + """Parse Schema from a comma separated string of fieldname=fieldtype pairs. + + For example: "first_name=String,surname=String,age=I32" + """ + types = [] + names = [] + + fields = schema_string.split(",") + for field in fields: + fieldname, fieldtype = field.split("=") + proto_type = getattr(proto.Type, fieldtype) + names.append(fieldname) + types.append( + proto.Type( + **{ + fieldtype.lower(): proto_type( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + } + ) + ) + return proto.NamedStruct(names=names, struct=proto.Type.Struct(types=types)) + + +if __name__ == "__main__": + main() diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py new file mode 100644 index 0000000..7d41b3c --- /dev/null +++ b/src/substrait/sql/extended_expression.py @@ -0,0 +1,273 @@ +import itertools + +import sqlglot + +from substrait import proto +from .utils import DispatchRegistry + + +SQL_FUNCTIONS = { + # Arithmetic + sqlglot.expressions.Add: "add", + sqlglot.expressions.Div: "div", + sqlglot.expressions.Mul: "mul", + sqlglot.expressions.Sub: "sub", + sqlglot.expressions.Mod: "modulus", + sqlglot.expressions.BitwiseAnd: "bitwise_and", + sqlglot.expressions.BitwiseOr: "bitwise_or", + sqlglot.expressions.BitwiseXor: "bitwise_xor", + sqlglot.expressions.BitwiseNot: "bitwise_not", + # Comparisons + sqlglot.expressions.EQ: "equal", + sqlglot.expressions.NullSafeEQ: "is_not_distinct_from", + sqlglot.expressions.NEQ: "not_equal", + sqlglot.expressions.GT: "gt", + sqlglot.expressions.GTE: "gte", + sqlglot.expressions.LT: "lt", + sqlglot.expressions.LTE: "lte", + sqlglot.expressions.IsNan: "is_nan", + # logical + sqlglot.expressions.And: "and", + sqlglot.expressions.Or: "or", + sqlglot.expressions.Not: "not", +} + + +def parse_sql_extended_expression(catalog, schema, sql): + """Parse a SQL SELECT statement into an ExtendedExpression. + + Only supports SELECT statements with projections and WHERE clauses. + """ + select = sqlglot.parse_one(sql) + if not isinstance(select, sqlglot.expressions.Select): + raise ValueError("a SELECT statement was expected") + + sqlglot_parser = SQLGlotParser(catalog, schema) + + # Handle the projections in the SELECT statemenent. + project_expressions = [] + projection_invoked_functions = set() + for sqlexpr in select.expressions: + parsed_expr = sqlglot_parser.expression_from_sqlglot(sqlexpr) + projection_invoked_functions.update(parsed_expr.invoked_functions) + project_expressions.append( + proto.ExpressionReference( + expression=parsed_expr.expression, + output_names=[parsed_expr.output_name], + ) + ) + extension_uris, extensions = catalog.extensions_for_functions( + projection_invoked_functions + ) + projection_extended_expr = proto.ExtendedExpression( + extension_uris=extension_uris, + extensions=extensions, + base_schema=schema, + referred_expr=project_expressions, + ) + + # Handle WHERE clause in the SELECT statement. + filter_parsed_expr = sqlglot_parser.expression_from_sqlglot( + select.find(sqlglot.expressions.Where).this + ) + extension_uris, extensions = catalog.extensions_for_functions( + filter_parsed_expr.invoked_functions + ) + filter_extended_expr = proto.ExtendedExpression( + extension_uris=extension_uris, + extensions=extensions, + base_schema=schema, + referred_expr=[ + proto.ExpressionReference(expression=filter_parsed_expr.expression) + ], + ) + + return projection_extended_expr, filter_extended_expr + + +class SQLGlotParser: + DISPATCH_REGISTRY = DispatchRegistry() + + def __init__(self, functions_catalog, schema): + self._functions_catalog = functions_catalog + self._schema = schema + self._counter = itertools.count() + + self._parse_expression = self.DISPATCH_REGISTRY.bind(self) + + def expression_from_sqlglot(self, sqlglot_node): + """Parse a SQLGlot expression into a Substrait Expression.""" + return self._parse_expression(sqlglot_node) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Literal) + def _parse_Literal(self, expr): + if expr.is_string: + return ParsedSubstraitExpression( + f"literal${next(self._counter)}", + proto.Type(string=proto.Type.String()), + proto.Expression(literal=proto.Expression.Literal(string=expr.name)), + ) + elif expr.is_int: + return ParsedSubstraitExpression( + f"literal${next(self._counter)}", + proto.Type(i32=proto.Type.I32()), + proto.Expression(literal=proto.Expression.Literal(i32=int(expr.name))), + ) + elif sqlglot.helper.is_float(expr.name): + return ParsedSubstraitExpression( + f"literal${next(self._counter)}", + proto.Type(fp32=proto.Type.FP32()), + proto.Expression( + literal=proto.Expression.Literal(float=float(expr.name)) + ), + ) + else: + raise ValueError(f"Unsupporter literal: {expr.text}") + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Column) + def _parse_Column(self, expr): + column_name = expr.output_name + schema_field = list(self._schema.names).index(column_name) + schema_type = self._schema.struct.types[schema_field] + return ParsedSubstraitExpression( + column_name, + schema_type, + proto.Expression( + selection=proto.Expression.FieldReference( + direct_reference=proto.Expression.ReferenceSegment( + struct_field=proto.Expression.ReferenceSegment.StructField( + field=schema_field + ) + ) + ) + ), + ) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Alias) + def _parse_Alias(self, expr): + parsed_expression = self._parse_expression(expr.this) + return parsed_expression.duplicate(output_name=expr.output_name) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Is) + def _parse_IS(self, expr): + # IS NULL is a special case because in SQLGlot is a binary expression with argument + # while in Substrait there are only the is_null and is_not_null unary functions + argument_parsed_expr = self._parse_expression(expr.left) + if isinstance(expr.right, sqlglot.expressions.Null): + function_name = "is_null" + else: + raise ValueError(f"Unsupported IS expression: {expr}") + signature, result_type, function_expression = self._parse_function_invokation( + function_name, argument_parsed_expr + ) + result_name = ( + f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + ) + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + argument_parsed_expr.invoked_functions | {signature}, + ) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Binary) + def _parser_Binary(self, expr): + left_parsed_expr = self._parse_expression(expr.left) + right_parsed_expr = self._parse_expression(expr.right) + function_name = SQL_FUNCTIONS[type(expr)] + signature, result_type, function_expression = self._parse_function_invokation( + function_name, left_parsed_expr, right_parsed_expr + ) + result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + left_parsed_expr.invoked_functions + | right_parsed_expr.invoked_functions + | {signature}, + ) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Unary) + def _parse_Unary(self, expr): + argument_parsed_expr = self._parse_expression(expr.this) + function_name = SQL_FUNCTIONS[type(expr)] + signature, result_type, function_expression = self._parse_function_invokation( + function_name, argument_parsed_expr + ) + result_name = ( + f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + ) + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + argument_parsed_expr.invoked_functions | {signature}, + ) + + def _parse_function_invokation( + self, function_name, argument_parsed_expr, *additional_arguments + ): + """Generates a Substrait function invokation expression. + + The function invocation will be generated from the function name + and the arguments as ParsedSubstraitExpression. + + Returns the function signature, the return type and the + invokation expression itself. + """ + arguments = [argument_parsed_expr] + list(additional_arguments) + signature = self._functions_catalog.make_signature( + function_name, proto_argtypes=[arg.type for arg in arguments] + ) + + registered_function = self._functions_catalog.lookup_function(signature) + if registered_function is None: + raise KeyError(f"Function not found: {signature}") + + return ( + registered_function.signature, + registered_function.return_type, + proto.Expression( + scalar_function=proto.Expression.ScalarFunction( + function_reference=registered_function.function_anchor, + arguments=[ + proto.FunctionArgument(value=arg.expression) + for arg in arguments + ], + ) + ), + ) + + +class ParsedSubstraitExpression: + """A Substrait expression that was parsed from a SQLGlot node. + + This stores the expression itself, with an associated output name + in case it is required to emit projections. + + It also stores the type of the expression (i64, string, boolean, etc...) + and the functions that the expression in going to invoke. + """ + + def __init__(self, output_name, type, expression, invoked_functions=None): + self.expression = expression + self.output_name = output_name + self.type = type + + if invoked_functions is None: + invoked_functions = set() + self.invoked_functions = invoked_functions + + def duplicate( + self, output_name=None, type=None, expression=None, invoked_functions=None + ): + return ParsedSubstraitExpression( + output_name or self.output_name, + type or self.type, + expression or self.expression, + invoked_functions or self.invoked_functions, + ) + + def __repr__(self): + return f"" diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py new file mode 100644 index 0000000..4e026b7 --- /dev/null +++ b/src/substrait/sql/functions_catalog.py @@ -0,0 +1,289 @@ +import os +import pathlib +from collections.abc import Iterable + +import yaml + +from substrait.gen.proto.type_pb2 import Type as SubstraitType +from substrait.gen.proto.extensions.extensions_pb2 import ( + SimpleExtensionURI, + SimpleExtensionDeclaration, +) + + +class RegisteredSubstraitFunction: + """A Substrait function loaded from an extension file. + + The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction + and will use them to generate the necessary extension URIs and extensions. + """ + + def __init__(self, signature: str, function_anchor: int | None, impl: dict): + self.signature = signature + self.function_anchor = function_anchor + self.variadic = impl.get("variadic", False) + + if "return" in impl: + self.return_type = self._type_from_name(impl["return"]) + else: + # We do always need a return type + # to know which type to propagate up to the invoker + _, argtypes = FunctionsCatalog.parse_signature(signature) + # TODO: Is this the right way to handle this? + self.return_type = self._type_from_name(argtypes[0]) + + @property + def name(self) -> str: + name, _ = FunctionsCatalog.parse_signature(self.signature) + return name + + @property + def arguments(self) -> list[str]: + _, argtypes = FunctionsCatalog.parse_signature(self.signature) + return argtypes + + @property + def arguments_type(self) -> list[SubstraitType | None]: + return [self._type_from_name(arg) for arg in self.arguments] + + def _type_from_name(self, typename: str) -> SubstraitType | None: + # TODO: improve support complext type like LIST? + typename, *_ = typename.split("<", 1) + typename = typename.lower() + + nullable = False + if typename.endswith("?"): + nullable = True + + typename = typename.strip("?") + if typename in ("any", "any1"): + return None + + if typename == "boolean": + # For some reason boolean is an exception to the naming convention + typename = "bool" + + try: + type_descriptor = SubstraitType.DESCRIPTOR.fields_by_name[ + typename + ].message_type + except KeyError: + # TODO: improve resolution of complext type like LIST? + print("Unsupported type", typename) + return None + + type_class = getattr(SubstraitType, type_descriptor.name) + nullability = ( + SubstraitType.Nullability.NULLABILITY_REQUIRED + if not nullable + else SubstraitType.Nullability.NULLABILITY_NULLABLE + ) + return SubstraitType(**{typename: type_class(nullability=nullability)}) + + +class FunctionsCatalog: + """Catalog of Substrait functions and extensions. + + Loads extensions from YAML files and records the declared functions. + Given a set of functions it can generate the necessary extension URIs + and extensions to be included in an ExtendedExpression or Plan. + """ + + # TODO: Find a way to support standard extensions in released distribution. + # IE: Include the standard extension yaml files in the package data and + # update them when gen_proto is used.. + STANDARD_EXTENSIONS = ( + "/functions_aggregate_approx.yaml", + "/functions_aggregate_generic.yaml", + "/functions_arithmetic.yaml", + "/functions_arithmetic_decimal.yaml", + "/functions_boolean.yaml", + "/functions_comparison.yaml", + # "/functions_datetime.yaml", for now skip, it has duplicated functions + "/functions_geometry.yaml", + "/functions_logarithmic.yaml", + "/functions_rounding.yaml", + "/functions_set.yaml", + "/functions_string.yaml", + ) + + def __init__(self): + self._substrait_extension_uris = {} + self._substrait_extension_functions = {} + self._functions = {} + + def load_standard_extensions(self, dirpath: str | os.PathLike): + """Load all standard substrait extensions from the target directory.""" + for ext in self.STANDARD_EXTENSIONS: + self.load(dirpath, ext) + + def load(self, dirpath: str | os.PathLike, filename: str): + """Load an extension from a YAML file in a target directory.""" + with open(pathlib.Path(dirpath) / filename.strip("/")) as f: + sections = yaml.safe_load(f) + + loaded_functions = {} + for functions in sections.values(): + for function in functions: + function_name = function["name"] + for impl in function.get("impls", []): + # TODO: There seem to be some functions that have arguments without type. What to do? + # TODO: improve support complext type like LIST? + argtypes = [ + t.get("value", "unknown").strip("?") + for t in impl.get("args", []) + ] + if not argtypes: + signature = function_name + else: + signature = f"{function_name}:{'_'.join(argtypes)}" + loaded_functions[signature] = RegisteredSubstraitFunction( + signature, None, impl + ) + + self._register_extensions(filename, loaded_functions) + + def _register_extensions( + self, + extension_uri: str, + loaded_functions: dict[str, RegisteredSubstraitFunction], + ): + if extension_uri not in self._substrait_extension_uris: + ext_anchor_id = len(self._substrait_extension_uris) + 1 + self._substrait_extension_uris[extension_uri] = SimpleExtensionURI( + extension_uri_anchor=ext_anchor_id, uri=extension_uri + ) + + for signature, registered_function in loaded_functions.items(): + if signature in self._substrait_extension_functions: + extensions_by_anchor = self.extension_uris_by_anchor + existing_function = self._substrait_extension_functions[signature] + function_extension = extensions_by_anchor[ + existing_function.extension_uri_reference + ].uri + raise ValueError( + f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}" + ) + extension_anchor = self._substrait_extension_uris[ + extension_uri + ].extension_uri_anchor + function_anchor = len(self._substrait_extension_functions) + 1 + self._substrait_extension_functions[signature] = ( + SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=extension_anchor, + name=signature, + function_anchor=function_anchor, + ) + ) + registered_function.function_anchor = function_anchor + self._functions.setdefault(registered_function.name, []).append( + registered_function + ) + + @property + def extension_uris_by_anchor(self) -> dict[int, SimpleExtensionURI]: + return { + ext.extension_uri_anchor: ext + for ext in self._substrait_extension_uris.values() + } + + @property + def extension_uris(self) -> list[SimpleExtensionURI]: + return list(self._substrait_extension_uris.values()) + + @property + def extensions_functions( + self, + ) -> list[SimpleExtensionDeclaration.ExtensionFunction]: + return list(self._substrait_extension_functions.values()) + + @classmethod + def make_signature( + cls, function_name: str, proto_argtypes: Iterable[SubstraitType] + ): + """Create a function signature from a function name and substrait types. + + The signature is generated according to Function Signature Compound Names + as described in the Substrait documentation. + """ + + def _normalize_arg_types(argtypes): + for argtype in argtypes: + kind = argtype.WhichOneof("kind") + if kind == "bool": + yield "boolean" + else: + yield kind + + return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}" + + @classmethod + def parse_signature(cls, signature: str) -> tuple[str, list[str]]: + """Parse a function signature and returns name and type names""" + try: + function_name, signature_args = signature.split(":") + except ValueError: + function_name = signature + argtypes = [] + else: + argtypes = signature_args.split("_") + return function_name, argtypes + + def extensions_for_functions( + self, function_signatures: Iterable[str] + ) -> tuple[list[SimpleExtensionURI], list[SimpleExtensionDeclaration]]: + """Given a set of function signatures, return the necessary extensions. + + The function will return the URIs of the extensions and the extension + that have to be declared in the plan to use the functions. + """ + uris_anchors = set() + extensions = [] + for f in function_signatures: + ext = self._substrait_extension_functions[f] + uris_anchors.add(ext.extension_uri_reference) + extensions.append(SimpleExtensionDeclaration(extension_function=ext)) + + uris_by_anchor = self.extension_uris_by_anchor + extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] + return extension_uris, extensions + + def lookup_function(self, signature: str) -> RegisteredSubstraitFunction | None: + """Given the signature of a function invocation, return the matching function.""" + function_name, invocation_argtypes = self.parse_signature(signature) + + functions = self._functions.get(function_name) + if not functions: + # No function with such a name at all. + return None + + is_variadic = functions[0].variadic + if is_variadic: + # If it's variadic we care about only the first parameter. + invocation_argtypes = invocation_argtypes[:1] + + found_function = None + for function in functions: + accepted_function_arguments = function.arguments + for argidx, argtype in enumerate(invocation_argtypes): + try: + accepted_argument = accepted_function_arguments[argidx] + except IndexError: + # More arguments than available were provided + break + if accepted_argument != argtype and accepted_argument not in ( + "any", + "any1", + ): + break + else: + if argidx < len(accepted_function_arguments) - 1: + # Not enough arguments were provided + remainder = accepted_function_arguments[argidx + 1 :] + if all(arg.endswith("?") for arg in remainder): + # All remaining arguments are optional + found_function = function + else: + found_function = function + + return found_function diff --git a/src/substrait/sql/utils.py b/src/substrait/sql/utils.py new file mode 100644 index 0000000..9ffad36 --- /dev/null +++ b/src/substrait/sql/utils.py @@ -0,0 +1,40 @@ +import types + + +class DispatchRegistry: + """Dispatch a function based on the class of the argument. + + This class allows to register a function to execute for a specific class + and expose this as a method of an object which will be dispatched + based on the argument. + + It is similar to functools.singledispatch but it allows more + customization in case the dispatch rules grow in complexity + and works for class methods as well + (singledispatch supports methods only in more recent versions) + """ + + def __init__(self): + self._registry = {} + + def register(self, cls): + def decorator(func): + self._registry[cls] = func + return func + + return decorator + + def bind(self, obj): + return types.MethodType(self, obj) + + def __getitem__(self, argument): + for dispatch_cls, func in self._registry.items(): + if isinstance(argument, dispatch_cls): + return func + else: + raise ValueError( + f"Unsupported SQL Node type: {argument.__class__.__name__} -> {argument}" + ) + + def __call__(self, obj, dispatch_argument, *args, **kwargs): + return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)