Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,4 @@ Paket.Restore.targets
.paket
/docs/output
docs/output/**/*.*
*.orig
1 change: 1 addition & 0 deletions src/SqlClient.Tests/Lib/Lib.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
<Compile Include="Library1.fs" />
<None Include="Script.fsx" />
<Content Include="App.config" />
<Content Include="packages.config" />
</ItemGroup>
<ItemGroup>
<Reference Include="FSharp.Data.SqlClient">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
<ItemGroup>
<Compile Include="Program.fs" />
<Content Include="Uncomment.App.config" />
<Content Include="packages.config" />
</ItemGroup>
<ItemGroup>
<Reference Include="FSharp.Data.SqlClient">
Expand Down
1 change: 1 addition & 0 deletions src/SqlClient.Tests/SqlClient.Tests.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
<Compile Include="SynonymsTests.fs" />
<Compile Include="CreateCommand.fs" />
<Compile Include="UnitsOfMeasure.fs" />
<Compile Include="TempTableTests.fs" />
<None Include="sampleCommand.sql" />
<None Include="extensions.sql" />
<None Include="MySqlFolder\sampleCommand.sql" />
Expand Down
20 changes: 20 additions & 0 deletions src/SqlClient.Tests/TVPTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,23 @@ let UsingTVPInQuery() =
|> Seq.toList

Assert.Equal<_ list>(expected, actual)

type MappedTVP =
SqlCommandProvider<"
SELECT myId, myName from @input
", ConnectionStrings.AdventureWorksLiteral, TableVarMapping = "@input=dbo.MyTableType">
[<Fact>]
let UsingMappedTVPInQuery() =
printfn "%s" ConnectionStrings.AdventureWorksLiteral
use cmd = new MappedTVP(ConnectionStrings.AdventureWorksLiteral)
let expected = [
1, Some "monkey"
2, Some "donkey"
]

let actual =
cmd.Execute(input = [ for id, name in expected -> MappedTVP.MyTableType(id, name) ])
|> Seq.map(fun x -> x.myId, x.myName)
|> Seq.toList

Assert.Equal<_ list>(expected, actual)
106 changes: 106 additions & 0 deletions src/SqlClient.Tests/TempTableTests.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
module FSharp.Data.TempTableTests

open FSharp.Data
open Xunit
open System.Data.SqlClient

type TempTable =
SqlCommandProvider<
TempTableDefinitions = "
CREATE TABLE #Temp (
Id INT NOT NULL,
Name NVARCHAR(100) NULL)",
CommandText = "
SELECT Id, Name FROM #Temp",
ConnectionStringOrName =
ConnectionStrings.AdventureWorksLiteral>

[<Fact>]
let usingTempTable() =
use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral)
conn.Open()

use cmd = new TempTable(conn)

cmd.LoadTempTables(
Temp =
[ TempTable.Temp(Id = 1, Name = Some "monkey")
TempTable.Temp(Id = 2, Name = Some "donkey") ])

let actual =
cmd.Execute()
|> Seq.map(fun x -> x.Id, x.Name)
|> Seq.toList

let expected = [
1, Some "monkey"
2, Some "donkey"
]

Assert.Equal<_ list>(expected, actual)

[<Fact>]
let queryWithHash() =
// We shouldn't mangle the statement when it's run
use cmd =
new SqlCommandProvider<
CommandText = "
SELECT Id, Name
FROM
(
SELECT 1 AS Id, '#name' AS Name UNION
SELECT 2, 'some other value'
) AS a
WHERE Name = '#name'",
ConnectionStringOrName =
ConnectionStrings.AdventureWorksLiteral>(ConnectionStrings.AdventureWorksLiteral)

let actual =
cmd.Execute()
|> Seq.map(fun x -> x.Id, x.Name)
|> Seq.toList

let expected = [
1, "#name"
]

Assert.Equal<_ list>(expected, actual)

type TempTableHash =
SqlCommandProvider<
TempTableDefinitions = "
CREATE TABLE #Temp (
Id INT NOT NULL)",
CommandText = "
SELECT a.Id, a.Name
FROM
(
SELECT 1 AS Id, '#Temp' AS Name UNION
SELECT 2, 'some other value'
) AS a
INNER JOIN #Temp t ON t.Id = a.Id",
ConnectionStringOrName =
ConnectionStrings.AdventureWorksLiteral>

[<Fact>]
let queryWithHashAndTempTable() =
// We shouldn't mangle the statement when it's run
use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral)
conn.Open()

use cmd = new TempTableHash(conn)

cmd.LoadTempTables(
Temp =
[ TempTableHash.Temp(Id = 1) ])

let actual =
cmd.Execute()
|> Seq.map(fun x -> x.Id, x.Name)
|> Seq.toList

let expected = [
1, "#Temp"
]

Assert.Equal<_ list>(expected, actual)
10 changes: 8 additions & 2 deletions src/SqlClient/AssemblyInfo.fs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
namespace System
// Auto-Generated by FAKE; do not edit
namespace System
open System.Reflection
open System.Runtime.CompilerServices

Expand All @@ -11,4 +12,9 @@ open System.Runtime.CompilerServices
do ()

module internal AssemblyVersionInformation =
let [<Literal>] Version = "1.8.4"
let [<Literal>] AssemblyTitle = "SqlClient"
let [<Literal>] AssemblyProduct = "FSharp.Data.SqlClient"
let [<Literal>] AssemblyDescription = "SqlClient F# type providers"
let [<Literal>] AssemblyVersion = "1.8.4"
let [<Literal>] AssemblyFileVersion = "1.8.4"
let [<Literal>] InternalsVisibleTo = "SqlClient.Tests"
178 changes: 177 additions & 1 deletion src/SqlClient/DesignTime.fs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ open System.Diagnostics
open Microsoft.FSharp.Quotations
open ProviderImplementation.ProvidedTypes
open FSharp.Data
open System.Text.RegularExpressions

type internal RowType = {
Provided: Type
Expand Down Expand Up @@ -40,7 +41,52 @@ module internal SharedLogic =
// add .Table
returnType.Single |> cmdProvidedType.AddMember

type DesignTime private() =
module Prefixes =
let tempTable = "##SQLCOMMANDPROVIDER_"
let tableVar = "@SQLCOMMANDPROVIDER_"

type TempTableLoader(fieldCount, items: obj seq) =
let enumerator = items.GetEnumerator()

interface IDataReader with
member this.FieldCount: int = fieldCount
member this.Read(): bool = enumerator.MoveNext()
member this.GetValue(i: int): obj =
let row : obj[] = unbox enumerator.Current
row.[i]
member this.Dispose(): unit = ()

member __.Close(): unit = invalidOp "NotImplementedException"
member __.Depth: int = invalidOp "NotImplementedException"
member __.GetBoolean(_: int): bool = invalidOp "NotImplementedException"
member __.GetByte(_ : int): byte = invalidOp "NotImplementedException"
member __.GetBytes(_ : int, _ : int64, _ : byte [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
member __.GetChar(_ : int): char = invalidOp "NotImplementedException"
member __.GetChars(_ : int, _ : int64, _ : char [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
member __.GetData(_ : int): IDataReader = invalidOp "NotImplementedException"
member __.GetDataTypeName(_ : int): string = invalidOp "NotImplementedException"
member __.GetDateTime(_ : int): System.DateTime = invalidOp "NotImplementedException"
member __.GetDecimal(_ : int): decimal = invalidOp "NotImplementedException"
member __.GetDouble(_ : int): float = invalidOp "NotImplementedException"
member __.GetFieldType(_ : int): System.Type = invalidOp "NotImplementedException"
member __.GetFloat(_ : int): float32 = invalidOp "NotImplementedException"
member __.GetGuid(_ : int): System.Guid = invalidOp "NotImplementedException"
member __.GetInt16(_ : int): int16 = invalidOp "NotImplementedException"
member __.GetInt32(_ : int): int = invalidOp "NotImplementedException"
member __.GetInt64(_ : int): int64 = invalidOp "NotImplementedException"
member __.GetName(_ : int): string = invalidOp "NotImplementedException"
member __.GetOrdinal(_ : string): int = invalidOp "NotImplementedException"
member __.GetSchemaTable(): DataTable = invalidOp "NotImplementedException"
member __.GetString(_ : int): string = invalidOp "NotImplementedException"
member __.GetValues(_ : obj []): int = invalidOp "NotImplementedException"
member __.IsClosed: bool = invalidOp "NotImplementedException"
member __.IsDBNull(_ : int): bool = invalidOp "NotImplementedException"
member __.Item with get (_ : int): obj = invalidOp "NotImplementedException"
member __.Item with get (_ : string): obj = invalidOp "NotImplementedException"
member __.NextResult(): bool = invalidOp "NotImplementedException"
member __.RecordsAffected: int = invalidOp "NotImplementedException"

type DesignTime private() =
static member internal AddGeneratedMethod
(sqlParameters: Parameter list, hasOutputParameters, executeArgs: ProvidedParameter list, erasedType, providedOutputType, name) =

Expand Down Expand Up @@ -632,3 +678,133 @@ type DesignTime private() =
then
yield upcast ProvidedMethod(factoryMethodName.Value, parameters2, returnType = cmdProvidedType, IsStaticMethod = true, InvokeCode = body2)
]

static member private CreateTempTableRecord(name, cols) =
let rowType = ProvidedTypeDefinition(name, Some typeof<obj>, HideObjectMethods = true)

let parameters =
[
for (p : Column) in cols do
let name = p.Name
let param = ProvidedParameter( name, p.GetProvidedType(), ?optionalValue = if p.Nullable then Some null else None)
yield param
]

let ctor = ProvidedConstructor( parameters)
ctor.InvokeCode <- fun args ->
let optionsToNulls = QuotationsFactory.MapArrayNullableItems(cols, "MapArrayOptionItemToObj")

<@@ let values: obj[] = %%Expr.NewArray(typeof<obj>, [ for a in args -> Expr.Coerce(a, typeof<obj>) ])
(%%optionsToNulls) values
values @@>

rowType.AddMember ctor
rowType.AddXmlDoc "Type Table Type"

rowType

// Changes any temp tables in to a global temp table (##name) then creates them on the open connection.
static member internal SubstituteTempTables(connection, commandText: string, tempTableDefinitions : string, connectionId) =
// Extract and temp tables
let tempTableRegex = Regex("#([a-z0-9\-_]+)", RegexOptions.IgnoreCase)
let tempTableNames =
tempTableRegex.Matches(tempTableDefinitions)
|> Seq.cast<Match>
|> Seq.map (fun m -> m.Groups.[1].Value)
|> Seq.toList

match tempTableNames with
| [] -> commandText, None
| _ ->
// Create temp table(s), extracts the columns then drop it.
let tableTypes =
use create = new SqlCommand(tempTableDefinitions, connection)
create.ExecuteScalar() |> ignore

tempTableNames
|> List.map(fun name ->
let cols = DesignTime.GetOutputColumns(connection, "SELECT * FROM #"+name, [], isStoredProcedure = false)
use drop = new SqlCommand("DROP TABLE #"+name, connection)
drop.ExecuteScalar() |> ignore
DesignTime.CreateTempTableRecord(name, cols), cols)

let parameters =
tableTypes
|> List.map (fun (typ, _) ->
ProvidedParameter(typ.Name, parameterType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ seq>, [ typ ])))

// Build the values load method.
let loadValues (exprArgs: Expr list) (connection) =
(exprArgs.Tail, tableTypes)
||> List.map2 (fun expr (typ, cols) ->
let destinationTableName = typ.Name
let colsLength = cols.Length

<@@
let items = (%%expr : obj seq)
use reader = new TempTableLoader(colsLength, items)

use bulkCopy = new SqlBulkCopy((%%connection : SqlConnection))
bulkCopy.BulkCopyTimeout <- 0
bulkCopy.BatchSize <- 5000
bulkCopy.DestinationTableName <- "#" + destinationTableName
bulkCopy.WriteToServer(reader)

@@>
)
|> List.fold (fun acc x -> Expr.Sequential(acc, x)) <@@ () @@>

let loadTempTablesMethod = ProvidedMethod("LoadTempTables", parameters, typeof<unit>)

loadTempTablesMethod.InvokeCode <- fun exprArgs ->

let command = Expr.Coerce(exprArgs.[0], typedefof<ISqlCommand>)

let connection =
<@@ let cmd = (%%command : ISqlCommand)
cmd.Raw.Connection @@>

<@@ do
use create = new SqlCommand(tempTableDefinitions, (%%connection : SqlConnection))
create.ExecuteNonQuery() |> ignore

(%%loadValues exprArgs connection)
ignore() @@>

// Create the temp table(s) but as a global temp table with a unique name. This can be used later down stream on the open connection.
use cmd = new SqlCommand(tempTableRegex.Replace(tempTableDefinitions, Prefixes.tempTable+connectionId+"$1"), connection)
cmd.ExecuteScalar() |> ignore

// Only replace temp tables we find in our list.
tempTableRegex.Replace(commandText, MatchEvaluator(fun m ->
match tempTableNames |> List.tryFind((=) m.Groups.[1].Value) with
| Some name -> Prefixes.tempTable + connectionId + name
| None -> m.Groups.[0].Value)),

Some(loadTempTablesMethod, tableTypes |> List.unzip |> fst)

static member internal RemoveSubstitutedTempTables(connection, tempTables : ProvidedTypeDefinition list, connectionId) =
if not tempTables.IsEmpty then
use cmd = new SqlCommand(tempTables |> List.map(fun tempTable -> sprintf "DROP TABLE [%s%s%s]" Prefixes.tempTable connectionId tempTable.Name) |> String.concat ";", connection)
cmd.ExecuteScalar() |> ignore

// tableVarMapping(s) is converted into DECLARE statements then prepended to the command text.
static member internal SubstituteTableVar(commandText: string, tableVarMapping : string) =
let varRegex = Regex("@([a-z0-9_]+)", RegexOptions.IgnoreCase)

let vars =
tableVarMapping.Split([|';'|], System.StringSplitOptions.RemoveEmptyEntries)
|> Array.choose(fun (x : string) ->
match x.Split([|'='|]) with
| [|name;typ|] -> Some(name.TrimStart('@'), typ)
| _ -> None)

// Only replace table vars we find in our list.
let commandText =
varRegex.Replace(commandText, MatchEvaluator(fun m ->
match vars |> Array.tryFind(fun (n,_) -> n = m.Groups.[1].Value) with
| Some (name, _) -> Prefixes.tableVar + name
| None -> m.Groups.[0].Value))

(vars |> Array.map(fun (name,typ) -> sprintf "DECLARE %s%s %s = @%s" Prefixes.tableVar name typ name) |> String.concat "; ") + "; " + commandText

Loading