diff --git a/beam-core/Database/Beam/Backend/SQL.hs b/beam-core/Database/Beam/Backend/SQL.hs index 017cfe6f..9d4be330 100644 --- a/beam-core/Database/Beam/Backend/SQL.hs +++ b/beam-core/Database/Beam/Backend/SQL.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE UndecidableInstances #-} module Database.Beam.Backend.SQL @@ -7,10 +8,12 @@ module Database.Beam.Backend.SQL , MonadBeam(..) - , BeamSqlBackend + , BeamSqlBackend(..) , BeamSqlBackendSyntax , MockSqlBackend + , beamSqlDefaultColumnNames + , BeamSqlBackendIsString , BeamSql99ExpressionBackend @@ -80,6 +83,7 @@ import qualified Control.Monad.State.Strict as Strict import qualified Control.Monad.Writer.Strict as Strict import Data.Kind (Type) +import Data.String (fromString) import Data.Tagged (Tagged) import Data.Text (Text) @@ -230,7 +234,17 @@ class ( -- Every SQL backend must be a beam backend -- Needed for the Eq instance on QGenExpr , Eq (BeamSqlBackendExpressionSyntax be) - ) => BeamSqlBackend be + + , KnownBool (BeamSqlBackendSupportsColumnAliases be) + ) => BeamSqlBackend be where + type BeamSqlBackendSupportsColumnAliases be :: Bool + + beamSqlBackendDefaultColumnNames :: [Text] + beamSqlBackendDefaultColumnNames = beamSqlDefaultColumnNames + +-- | Infinite list of column names that we use for projections, by default +beamSqlDefaultColumnNames :: [Text] +beamSqlDefaultColumnNames = map (\n -> "res" <> fromString (show n)) [0..] type family BeamSqlBackendSyntax be :: Type @@ -252,7 +266,8 @@ instance ( IsSql92Syntax syntax -- Needed for the Eq instance on QGenExpr , Eq (Sql92ExpressionSyntax syntax) - ) => BeamSqlBackend (MockSqlBackend syntax) + ) => BeamSqlBackend (MockSqlBackend syntax) where + type BeamSqlBackendSupportsColumnAliases (MockSqlBackend syntax) = True type instance BeamSqlBackendSyntax (MockSqlBackend syntax) = syntax -- | Type class for things which are text-like in this backend diff --git a/beam-core/Database/Beam/Backend/SQL/AST.hs b/beam-core/Database/Beam/Backend/SQL/AST.hs index 4a52e7d0..bda85434 100644 --- a/beam-core/Database/Beam/Backend/SQL/AST.hs +++ b/beam-core/Database/Beam/Backend/SQL/AST.hs @@ -479,7 +479,7 @@ instance IsSql92TableNameSyntax TableName where data TableSource = TableNamed TableName | TableFromSubSelect Select - | TableFromValues [ [ Expression ] ] + | TableFromValues Int [ [ Expression ] ] deriving (Show, Eq) instance IsSql92TableSourceSyntax TableSource where diff --git a/beam-core/Database/Beam/Backend/SQL/Builder.hs b/beam-core/Database/Beam/Backend/SQL/Builder.hs index 86b123b9..801f86c6 100644 --- a/beam-core/Database/Beam/Backend/SQL/Builder.hs +++ b/beam-core/Database/Beam/Backend/SQL/Builder.hs @@ -397,7 +397,7 @@ instance IsSql92TableSourceSyntax SqlSyntaxBuilder where tableNamed = id tableFromSubSelect query = SqlSyntaxBuilder (byteString "(" <> buildSql query <> byteString ")") - tableFromValues vss = + tableFromValues _ vss = SqlSyntaxBuilder $ byteString "VALUES " <> buildSepBy (byteString ", ") diff --git a/beam-core/Database/Beam/Backend/SQL/SQL92.hs b/beam-core/Database/Beam/Backend/SQL/SQL92.hs index fb122334..23d28fd8 100644 --- a/beam-core/Database/Beam/Backend/SQL/SQL92.hs +++ b/beam-core/Database/Beam/Backend/SQL/SQL92.hs @@ -362,7 +362,8 @@ class IsSql92TableNameSyntax (Sql92TableSourceTableNameSyntax tblSource) => tableNamed :: Sql92TableSourceTableNameSyntax tblSource -> tblSource tableFromSubSelect :: Sql92TableSourceSelectSyntax tblSource -> tblSource - tableFromValues :: [ [ Sql92TableSourceExpressionSyntax tblSource ] ] -> tblSource + -- | First argument is the number of columns to return + tableFromValues :: Int -> [ [ Sql92TableSourceExpressionSyntax tblSource ] ] -> tblSource class IsSql92GroupingSyntax grouping where type Sql92GroupingExpressionSyntax grouping :: Type diff --git a/beam-core/Database/Beam/Backend/Types.hs b/beam-core/Database/Beam/Backend/Types.hs index 4738e7ba..42114f28 100644 --- a/beam-core/Database/Beam/Backend/Types.hs +++ b/beam-core/Database/Beam/Backend/Types.hs @@ -1,7 +1,8 @@ -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} module Database.Beam.Backend.Types ( BeamBackend(..) + , KnownBool(..) , Exposed, Nullable @@ -28,3 +29,11 @@ data Exposed x -- -- See 'Columnar' for more information. data Nullable (c :: Type -> Type) x + +class KnownBool (x :: Bool) where + knownBool :: Bool + +instance KnownBool 'True where + knownBool = True +instance KnownBool 'False where + knownBool = False diff --git a/beam-core/Database/Beam/Query/Combinators.hs b/beam-core/Database/Beam/Query/Combinators.hs index 07119295..850391e3 100644 --- a/beam-core/Database/Beam/Query/Combinators.hs +++ b/beam-core/Database/Beam/Query/Combinators.hs @@ -112,11 +112,17 @@ values_ :: forall be db s a , BeamSqlBackend be ) => [ a ] -> Q be db s a values_ rows = - Q $ liftF (QAll (\tblPfx -> fromTable (tableFromValues (map (\row -> project (Proxy @be) row tblPfx) rows)) . Just . (,Just fieldNames)) - (\tblNm' -> fst $ mkFieldNames (qualifiedField tblNm')) + Q $ liftF (QAll (\tblPfx -> fromTable (tableFromValues colCount (map (\row -> project (Proxy @be) row tblPfx) rows)) . Just . (,colAliases)) + (\tblNm' -> if useAliases + then fst $ mkFieldNames (qualifiedField tblNm') + else fst $ mkDefaultFieldNames (qualifiedField tblNm')) (\_ -> Nothing) snd) where + useAliases = knownBool @(BeamSqlBackendSupportsColumnAliases be) + colAliases | useAliases = Just fieldNames + | otherwise = Nothing fieldNames = snd $ mkFieldNames @be @a unqualifiedField + colCount = length fieldNames -- | Introduce all entries of a table into the 'Q' monad based on the -- given QExpr. The join condition is expected to return a diff --git a/beam-core/Database/Beam/Query/Internal.hs b/beam-core/Database/Beam/Query/Internal.hs index 631dca68..0dd3e713 100644 --- a/beam-core/Database/Beam/Query/Internal.hs +++ b/beam-core/Database/Beam/Query/Internal.hs @@ -663,6 +663,20 @@ mkFieldNames mkField = tell [ fieldName' ] pure (\_ -> BeamSqlBackendExpressionSyntax' (fieldE (mkField fieldName'))) +mkDefaultFieldNames :: forall be res + . ( BeamSqlBackend be, Projectible be res ) + => (T.Text -> BeamSqlBackendFieldNameSyntax be) -> (res, [T.Text]) +mkDefaultFieldNames mkField = + runWriter . flip evalStateT (beamSqlBackendDefaultColumnNames @be) . flip evalStateT 0 $ + mkFieldsSkeleton @be @res $ \_ -> do + cols <- lift get + (x, xs) <- case cols of + [] -> error "Not enough default column names" + x:xs -> pure (x, xs) + tell [x] + lift (put xs) + pure (\_ -> BeamSqlBackendExpressionSyntax' (fieldE (mkField x))) + tableNameFromEntity :: IsSql92TableNameSyntax name => DatabaseEntityDescriptor be (TableEntity tbl) -> name diff --git a/beam-postgres/Database/Beam/Postgres/Syntax.hs b/beam-postgres/Database/Beam/Postgres/Syntax.hs index 59925d04..da2f7085 100644 --- a/beam-postgres/Database/Beam/Postgres/Syntax.hs +++ b/beam-postgres/Database/Beam/Postgres/Syntax.hs @@ -1033,11 +1033,12 @@ instance IsSql92TableSourceSyntax PgTableSourceSyntax where tableNamed = PgTableSourceSyntax . fromPgTableName tableFromSubSelect s = PgTableSourceSyntax $ emit "(" <> fromPgSelect s <> emit ")" - tableFromValues vss = PgTableSourceSyntax . pgParens $ - emit "VALUES " <> - pgSepBy (emit ", ") - (map (\vs -> pgParens (pgSepBy (emit ", ") - (map fromPgExpression vs))) vss) + tableFromValues _cnt vss = + PgTableSourceSyntax . pgParens $ + emit "VALUES " <> + pgSepBy (emit ", ") + (map (\vs -> pgParens (pgSepBy (emit ", ") + (map fromPgExpression vs))) vss) instance IsSql92ProjectionSyntax PgProjectionSyntax where type Sql92ProjectionExpressionSyntax PgProjectionSyntax = PgExpressionSyntax diff --git a/beam-postgres/Database/Beam/Postgres/Types.hs b/beam-postgres/Database/Beam/Postgres/Types.hs index efc93b5d..91ae3a1a 100644 --- a/beam-postgres/Database/Beam/Postgres/Types.hs +++ b/beam-postgres/Database/Beam/Postgres/Types.hs @@ -1,5 +1,6 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -162,7 +163,8 @@ instance FromBackendRow Postgres (Pg.Binary BL.ByteString) instance (Pg.FromField a, Typeable a) => FromBackendRow Postgres (Pg.PGRange a) instance (Pg.FromField a, Pg.FromField b, Typeable a, Typeable b) => FromBackendRow Postgres (Either a b) -instance BeamSqlBackend Postgres +instance BeamSqlBackend Postgres where + type BeamSqlBackendSupportsColumnAliases Postgres = 'True instance BeamMigrateOnlySqlBackend Postgres type instance BeamSqlBackendSyntax Postgres = PgCommandSyntax diff --git a/beam-sqlite/Database/Beam/Sqlite/Connection.hs b/beam-sqlite/Database/Beam/Sqlite/Connection.hs index 2307fc5a..30d31d4a 100644 --- a/beam-sqlite/Database/Beam/Sqlite/Connection.hs +++ b/beam-sqlite/Database/Beam/Sqlite/Connection.hs @@ -167,7 +167,8 @@ instance FromField SqliteScientific where "No conversion to Scientific for '" <> s <> "'" Just s' -> pure s' -instance BeamSqlBackend Sqlite +instance BeamSqlBackend Sqlite where + type BeamSqlBackendSupportsColumnAliases Sqlite = 'False instance BeamMigrateOnlySqlBackend Sqlite type instance BeamSqlBackendSyntax Sqlite = SqliteCommandSyntax @@ -380,9 +381,8 @@ runInsertReturningList SqlInsertNoRows = pure [] runInsertReturningList (SqlInsert tblSettings insertStmt_@(SqliteInsertSyntax nm _ _ _)) = do (logger, conn) <- SqliteM ask SqliteM . liftIO $ do - -- We create a pseudo-random savepoint identification that can be referenced - -- throughout this operation. -- This used to be based on the process ID + -- throughout this operation. -- This used to be based on the process ID -- (e.g. `System.Posix.Process.getProcessID` for UNIX), -- but using timestamps is more portable; see #738 -- diff --git a/beam-sqlite/Database/Beam/Sqlite/Syntax.hs b/beam-sqlite/Database/Beam/Sqlite/Syntax.hs index fddde6a6..de9d149e 100644 --- a/beam-sqlite/Database/Beam/Sqlite/Syntax.hs +++ b/beam-sqlite/Database/Beam/Sqlite/Syntax.hs @@ -92,10 +92,17 @@ import GHC.Generics -- example), the data builder attempts to properly format and escape the data. -- This returns syntax suitable for inclusion in scripts. In this case, the -- value list is ignored. -data SqliteSyntax = SqliteSyntax ((SQLData -> Builder) -> Builder) (DL.DList SQLData) +data SqliteSyntax = SqliteSyntax ((SQLData -> Builder) -> AnonTable -> (Builder, AnonTable)) (DL.DList SQLData) newtype SqliteData = SqliteData SQLData -- newtype for Hashable deriving Eq +newtype AnonTable = AnonTable Int deriving (Show, Eq, Ord) +nextTable :: AnonTable -> AnonTable +nextTable (AnonTable x) = AnonTable (succ x) + +anonTableSyntax :: AnonTable -> Builder +anonTableSyntax (AnonTable n) = "tbl_" <> fromString (show n) + instance Show SqliteSyntax where show (SqliteSyntax s d) = "SqliteSyntax (" <> show (toLazyByteString (withPlaceholders s)) <> ") " <> show d @@ -105,10 +112,14 @@ instance Sql92DisplaySyntax SqliteSyntax where instance Semigroup SqliteSyntax where (<>) (SqliteSyntax ab av) (SqliteSyntax bb bv) = - SqliteSyntax (\v -> ab v <> bb v) (av <> bv) + SqliteSyntax (\v tbl -> + let (a, tbl') = ab v tbl + (b, tbl'') = bb v tbl' + in (a <> b, tbl'')) + (av <> bv) instance Monoid SqliteSyntax where - mempty = SqliteSyntax (\_ -> mempty) mempty + mempty = SqliteSyntax (\_ x -> (mempty, x)) mempty instance Eq SqliteSyntax where SqliteSyntax ab av == SqliteSyntax bb bv = @@ -129,18 +140,21 @@ instance Hashable SqliteData where -- | Convert the first argument of 'SQLiteSyntax' to a 'ByteString' 'Builder', -- where all the data has been replaced by @"?"@ placeholders. -withPlaceholders :: ((SQLData -> Builder) -> Builder) -> Builder -withPlaceholders build = build (\_ -> "?") +withPlaceholders :: ((SQLData -> Builder) -> AnonTable -> (Builder, AnonTable)) -> Builder +withPlaceholders build = fst $ build (\_ -> "?") (AnonTable 0) -- | Embed a 'ByteString' directly in the syntax emit :: ByteString -> SqliteSyntax -emit b = SqliteSyntax (\_ -> byteString b) mempty +emit b = SqliteSyntax (\_ t -> (byteString b, t)) mempty emit' :: Show a => a -> SqliteSyntax -emit' x = SqliteSyntax (\_ -> byteString (fromString (show x))) mempty +emit' x = SqliteSyntax (\_ t -> (byteString (fromString (show x)), t)) mempty + +tableRef :: SqliteSyntax +tableRef = SqliteSyntax (\_ n -> (anonTableSyntax n, nextTable n)) mempty quotedIdentifier :: T.Text -> SqliteSyntax -quotedIdentifier txt = emit "\"" <> SqliteSyntax (\_ -> stringUtf8 (T.unpack (sqliteEscape txt))) mempty <> emit "\"" +quotedIdentifier txt = emit "\"" <> SqliteSyntax (\_ t -> (stringUtf8 (T.unpack (sqliteEscape txt)), t)) mempty <> emit "\"" -- | A best effort attempt to implement the escaping rules of SQLite. This is -- never used to escape data sent to the database; only for emitting scripts or @@ -152,14 +166,14 @@ sqliteEscape = T.concatMap (\c -> if c == '"' then "\"\"" else T.singleton c) -- -- This causes a literal @?@ 3 emitValue :: SQLData -> SqliteSyntax -emitValue v = SqliteSyntax ($ v) (DL.singleton v) +emitValue v = SqliteSyntax (\emitValue t -> (emitValue v, t)) (DL.singleton v) -- | Render a 'SqliteSyntax' as a lazy 'BL.ByteString', for purposes of -- displaying to a user. Embedded 'SQLData' is directly embedded into the -- concrete syntax, with a best effort made to escape strings. sqliteRenderSyntaxScript :: SqliteSyntax -> BL.ByteString sqliteRenderSyntaxScript (SqliteSyntax s _) = - toLazyByteString . s $ \case + toLazyByteString . fst . flip s (AnonTable 0) $ \case SQLInteger i -> int64Dec i SQLFloat d -> doubleDec d SQLText t -> TE.encodeUtf8Builder (sqliteEscape t) @@ -174,8 +188,8 @@ sqliteRenderSyntaxScript (SqliteSyntax s _) = -- columns. The 'fromSqliteCommand' function will take an 'SqliteCommandSyntax' -- and convert it into the correct 'SqliteSyntax'. data SqliteCommandSyntax - = SqliteCommandSyntax SqliteSyntax - | SqliteCommandInsert SqliteInsertSyntax + = SqliteCommandSyntax !SqliteSyntax + | SqliteCommandInsert !SqliteInsertSyntax -- | Convert a 'SqliteCommandSyntax' into a renderable 'SqliteSyntax' fromSqliteCommand :: SqliteCommandSyntax -> SqliteSyntax @@ -206,7 +220,8 @@ newtype SqliteUpdateSyntax = SqliteUpdateSyntax { fromSqliteUpdate :: SqliteSynt -- | SQLite @DELETE@ syntax newtype SqliteDeleteSyntax = SqliteDeleteSyntax { fromSqliteDelete :: SqliteSyntax } -newtype SqliteSelectTableSyntax = SqliteSelectTableSyntax { fromSqliteSelectTable :: SqliteSyntax } +data SqliteSelectTableSyntax = SqliteSelectTableSyntax { sqliteSelectCTEs :: [SqliteCTE] + , fromSqliteSelectTable :: SqliteSyntax } -- | Implements beam SQL expression syntaxes data SqliteExpressionSyntax @@ -214,7 +229,7 @@ data SqliteExpressionSyntax | SqliteExpressionDefault deriving (Show, Eq, Generic) instance Hashable SqliteExpressionSyntax -newtype SqliteFromSyntax = SqliteFromSyntax { fromSqliteFromSyntax :: SqliteSyntax } +data SqliteFromSyntax = SqliteFromSyntax { sqliteFromCTEs :: [SqliteCTE], fromSqliteFromSyntax :: SqliteSyntax } newtype SqliteComparisonQuantifierSyntax = SqliteComparisonQuantifierSyntax { fromSqliteComparisonQuantifier :: SqliteSyntax } newtype SqliteAggregationSetQuantifierSyntax = SqliteAggregationSetQuantifierSyntax { fromSqliteAggregationSetQuantifier :: SqliteSyntax } newtype SqliteProjectionSyntax = SqliteProjectionSyntax { fromSqliteProjection :: SqliteSyntax } @@ -222,9 +237,14 @@ newtype SqliteGroupingSyntax = SqliteGroupingSyntax { fromSqliteGrouping :: Sqli newtype SqliteOrderingSyntax = SqliteOrderingSyntax { fromSqliteOrdering :: SqliteSyntax } -- | SQLite syntax for values that can be embedded in 'SqliteSyntax' newtype SqliteValueSyntax = SqliteValueSyntax { fromSqliteValue :: SqliteSyntax } -newtype SqliteTableSourceSyntax = SqliteTableSourceSyntax { fromSqliteTableSource :: SqliteSyntax } +data SqliteTableSourceSyntax = SqliteTableSourceSyntax + { sqliteTableSourceCTEs :: [SqliteCTE] + , fromSqliteTableSource :: SqliteSyntax } newtype SqliteFieldNameSyntax = SqliteFieldNameSyntax { fromSqliteFieldNameSyntax :: SqliteSyntax } +data SqliteCTE = SqliteCTE { sqliteCteColumnNames :: Maybe [T.Text] + , sqliteCteSelect :: SqliteSelectSyntax } + -- | SQLite @VALUES@ clause in @INSERT@. Expressions need to be handled -- explicitly in order to deal with @DEFAULT@ values and @AUTO INCREMENT@ -- columns. @@ -603,6 +623,7 @@ instance IsSql92SelectSyntax SqliteSelectSyntax where selectStmt tbl ordering limit offset = SqliteSelectSyntax $ + withClause <> fromSqliteSelectTable tbl <> (case ordering of [] -> mempty @@ -613,6 +634,16 @@ instance IsSql92SelectSyntax SqliteSelectSyntax where (Nothing, Just offset) -> emit " LIMIT -1 OFFSET " <> emit' offset (Just limit, Just offset) -> emit " LIMIT " <> emit' limit <> emit " OFFSET " <> emit' offset + where + withClause = case sqliteSelectCTEs tbl of + [] -> mempty + _ -> emit "WITH " <> commas (zipWith buildCte [0..] (sqliteSelectCTEs tbl)) + buildCte :: Int -> SqliteCTE -> SqliteSyntax + buildCte n (SqliteCTE mColNames cte) = emit "tbl_" <> emit' n <> colNames <> emit " AS " <> parens (fromSqliteSelect cte) + where + colNames = case mColNames of + Nothing -> mempty + Just nms -> parens (commas (map quotedIdentifier nms)) instance IsSql92SelectTableSyntax SqliteSelectTableSyntax where type Sql92SelectTableSelectSyntax SqliteSelectTableSyntax = SqliteSelectSyntax @@ -623,7 +654,7 @@ instance IsSql92SelectTableSyntax SqliteSelectTableSyntax where type Sql92SelectTableSetQuantifierSyntax SqliteSelectTableSyntax = SqliteAggregationSetQuantifierSyntax selectTableStmt setQuantifier proj from where_ grouping having = - SqliteSelectTableSyntax $ + SqliteSelectTableSyntax (fromMaybe [] (sqliteFromCTEs <$> from)) $ emit "SELECT " <> maybe mempty (<> emit " ") (fromSqliteAggregationSetQuantifier <$> setQuantifier) <> fromSqliteProjection proj <> @@ -638,17 +669,18 @@ instance IsSql92SelectTableSyntax SqliteSelectTableSyntax where tableOp :: ByteString -> SqliteSelectTableSyntax -> SqliteSelectTableSyntax -> SqliteSelectTableSyntax tableOp op a b = - SqliteSelectTableSyntax $ + SqliteSelectTableSyntax (sqliteSelectCTEs a <> sqliteSelectCTEs b) $ fromSqliteSelectTable a <> spaces (emit op) <> fromSqliteSelectTable b instance IsSql92FromSyntax SqliteFromSyntax where type Sql92FromExpressionSyntax SqliteFromSyntax = SqliteExpressionSyntax type Sql92FromTableSourceSyntax SqliteFromSyntax = SqliteTableSourceSyntax - fromTable tableSrc Nothing = SqliteFromSyntax (fromSqliteTableSource tableSrc) - fromTable tableSrc (Just (nm, colNms)) = - SqliteFromSyntax (fromSqliteTableSource tableSrc <> emit " AS " <> quotedIdentifier nm <> - maybe mempty (\colNms' -> parens (commas (map quotedIdentifier colNms'))) colNms) + fromTable tableSrc Nothing = SqliteFromSyntax (sqliteTableSourceCTEs tableSrc) (fromSqliteTableSource tableSrc) + fromTable tableSrc (Just (nm, Nothing)) = + SqliteFromSyntax (sqliteTableSourceCTEs tableSrc) + (fromSqliteTableSource tableSrc <> emit " AS " <> quotedIdentifier nm) + fromTable _ (Just (_, Just _)) = error "beam-sqlite cannot support table names with column aliases" innerJoin = _join "INNER JOIN" leftJoin = _join "LEFT JOIN" @@ -656,9 +688,11 @@ instance IsSql92FromSyntax SqliteFromSyntax where _join :: ByteString -> SqliteFromSyntax -> SqliteFromSyntax -> Maybe SqliteExpressionSyntax -> SqliteFromSyntax _join joinType a b Nothing = - SqliteFromSyntax (fromSqliteFromSyntax a <> spaces (emit joinType) <> fromSqliteFromSyntax b) + SqliteFromSyntax (sqliteFromCTEs a <> sqliteFromCTEs b) + (fromSqliteFromSyntax a <> spaces (emit joinType) <> fromSqliteFromSyntax b) _join joinType a b (Just on) = - SqliteFromSyntax (fromSqliteFromSyntax a <> spaces (emit joinType) <> fromSqliteFromSyntax b <> emit " ON " <> fromSqliteExpression on) + SqliteFromSyntax (sqliteFromCTEs a <> sqliteFromCTEs b) + (fromSqliteFromSyntax a <> spaces (emit joinType) <> fromSqliteFromSyntax b <> emit " ON " <> fromSqliteExpression on) instance IsSql92ProjectionSyntax SqliteProjectionSyntax where type Sql92ProjectionExpressionSyntax SqliteProjectionSyntax = SqliteExpressionSyntax @@ -681,12 +715,14 @@ instance IsSql92TableSourceSyntax SqliteTableSourceSyntax where type Sql92TableSourceSelectSyntax SqliteTableSourceSyntax = SqliteSelectSyntax type Sql92TableSourceExpressionSyntax SqliteTableSourceSyntax = SqliteExpressionSyntax - tableNamed = SqliteTableSourceSyntax . fromSqliteTableName + tableNamed = SqliteTableSourceSyntax [] . fromSqliteTableName tableFromSubSelect s = - SqliteTableSourceSyntax (parens (fromSqliteSelect s)) - tableFromValues vss = SqliteTableSourceSyntax . parens $ - emit "VALUES " <> - commas (map (\vs -> parens (commas (map fromSqliteExpression vs))) vss) + SqliteTableSourceSyntax [] (parens (fromSqliteSelect s)) + tableFromValues colCount vss = SqliteTableSourceSyntax [SqliteCTE (Just (take colCount beamSqlDefaultColumnNames)) valuesTable] tableRef + where + valuesTable = SqliteSelectSyntax $ + emit "VALUES " <> + commas (map (\vs -> parens (commas (map fromSqliteExpression vs))) vss) instance IsSql92GroupingSyntax SqliteGroupingSyntax where type Sql92GroupingExpressionSyntax SqliteGroupingSyntax = SqliteExpressionSyntax diff --git a/beam-sqlite/test/Database/Beam/Sqlite/Test/Select.hs b/beam-sqlite/test/Database/Beam/Sqlite/Test/Select.hs index 511de1da..2829a1e1 100644 --- a/beam-sqlite/test/Database/Beam/Sqlite/Test/Select.hs +++ b/beam-sqlite/test/Database/Beam/Sqlite/Test/Select.hs @@ -14,9 +14,10 @@ import Database.Beam.Sqlite.Test tests :: TestTree tests = testGroup "Selection tests" - [ expectFail testExceptValues + [ testExceptValues , testInRowValues , testInSelect + , testSelectFromValues ] data Pair f = Pair @@ -46,10 +47,19 @@ testInSelect = testCase "IN (SELECT ...) works" $ return $ x `inQuery_` (pure (as_ @Int32 $ val_ 2)) assertEqual "result" [False] result --- | Regression test for testExceptValues :: TestTree testExceptValues = testCase "EXCEPT with VALUES works" $ withTestDb $ \conn -> do result <- runBeamSqlite conn $ runSelectReturningList $ select $ values_ [as_ @Bool $ val_ True, val_ False] `except_` values_ [val_ False] assertEqual "result" [True] result + +testSelectFromValues :: TestTree +testSelectFromValues = testCase "SELECT * FROM (VALUES ...) works by factoring out a common CTE" $ + withTestDb $ \conn -> do + xs <- runBeamSqlite conn $ runSelectReturningList $ select $ do + (a, b) <- values_ [(val_ 1, val_ 2), (val_ 2, val_ 3)] + (c, d) <- values_ [(val_ 2, val_ 4), (val_ 2, val_ 3)] + guard_ (as_ @Int32 b ==. as_ @Int32 c) + pure (as_ @Int32 a, as_ @Int32 d) + assertEqual "result" [(1,4), (1,3)] xs