Skip to content

Commit 8a1788d

Browse files
Merge pull request #302 from vasily-kirichenko/feature/temp-tvp-types
temp tvp types
2 parents b872536 + 6ad7dd7 commit 8a1788d

File tree

9 files changed

+346
-8
lines changed

9 files changed

+346
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,4 @@ Paket.Restore.targets
235235
.paket
236236
/docs/output
237237
docs/output/**/*.*
238+
*.orig

src/SqlClient.Tests/Lib/Lib.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
<Compile Include="Library1.fs" />
5454
<None Include="Script.fsx" />
5555
<Content Include="App.config" />
56+
<Content Include="packages.config" />
5657
</ItemGroup>
5758
<ItemGroup>
5859
<Reference Include="FSharp.Data.SqlClient">

src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
<ItemGroup>
5959
<Compile Include="Program.fs" />
6060
<Content Include="Uncomment.App.config" />
61+
<Content Include="packages.config" />
6162
</ItemGroup>
6263
<ItemGroup>
6364
<Reference Include="FSharp.Data.SqlClient">

src/SqlClient.Tests/SqlClient.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
<Compile Include="SynonymsTests.fs" />
8888
<Compile Include="CreateCommand.fs" />
8989
<Compile Include="UnitsOfMeasure.fs" />
90+
<Compile Include="TempTableTests.fs" />
9091
<None Include="sampleCommand.sql" />
9192
<None Include="extensions.sql" />
9293
<None Include="MySqlFolder\sampleCommand.sql" />

src/SqlClient.Tests/TVPTests.fs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,23 @@ let UsingTVPInQuery() =
145145
|> Seq.toList
146146

147147
Assert.Equal<_ list>(expected, actual)
148+
149+
type MappedTVP =
150+
SqlCommandProvider<"
151+
SELECT myId, myName from @input
152+
", ConnectionStrings.AdventureWorksLiteral, TableVarMapping = "@input=dbo.MyTableType">
153+
[<Fact>]
154+
let UsingMappedTVPInQuery() =
155+
printfn "%s" ConnectionStrings.AdventureWorksLiteral
156+
use cmd = new MappedTVP(ConnectionStrings.AdventureWorksLiteral)
157+
let expected = [
158+
1, Some "monkey"
159+
2, Some "donkey"
160+
]
161+
162+
let actual =
163+
cmd.Execute(input = [ for id, name in expected -> MappedTVP.MyTableType(id, name) ])
164+
|> Seq.map(fun x -> x.myId, x.myName)
165+
|> Seq.toList
166+
167+
Assert.Equal<_ list>(expected, actual)

src/SqlClient.Tests/TempTableTests.fs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
module FSharp.Data.TempTableTests
2+
3+
open FSharp.Data
4+
open Xunit
5+
open System.Data.SqlClient
6+
7+
type TempTable =
8+
SqlCommandProvider<
9+
TempTableDefinitions = "
10+
CREATE TABLE #Temp (
11+
Id INT NOT NULL,
12+
Name NVARCHAR(100) NULL)",
13+
CommandText = "
14+
SELECT Id, Name FROM #Temp",
15+
ConnectionStringOrName =
16+
ConnectionStrings.AdventureWorksLiteral>
17+
18+
[<Fact>]
19+
let usingTempTable() =
20+
use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral)
21+
conn.Open()
22+
23+
use cmd = new TempTable(conn)
24+
25+
cmd.LoadTempTables(
26+
Temp =
27+
[ TempTable.Temp(Id = 1, Name = Some "monkey")
28+
TempTable.Temp(Id = 2, Name = Some "donkey") ])
29+
30+
let actual =
31+
cmd.Execute()
32+
|> Seq.map(fun x -> x.Id, x.Name)
33+
|> Seq.toList
34+
35+
let expected = [
36+
1, Some "monkey"
37+
2, Some "donkey"
38+
]
39+
40+
Assert.Equal<_ list>(expected, actual)
41+
42+
[<Fact>]
43+
let queryWithHash() =
44+
// We shouldn't mangle the statement when it's run
45+
use cmd =
46+
new SqlCommandProvider<
47+
CommandText = "
48+
SELECT Id, Name
49+
FROM
50+
(
51+
SELECT 1 AS Id, '#name' AS Name UNION
52+
SELECT 2, 'some other value'
53+
) AS a
54+
WHERE Name = '#name'",
55+
ConnectionStringOrName =
56+
ConnectionStrings.AdventureWorksLiteral>(ConnectionStrings.AdventureWorksLiteral)
57+
58+
let actual =
59+
cmd.Execute()
60+
|> Seq.map(fun x -> x.Id, x.Name)
61+
|> Seq.toList
62+
63+
let expected = [
64+
1, "#name"
65+
]
66+
67+
Assert.Equal<_ list>(expected, actual)
68+
69+
type TempTableHash =
70+
SqlCommandProvider<
71+
TempTableDefinitions = "
72+
CREATE TABLE #Temp (
73+
Id INT NOT NULL)",
74+
CommandText = "
75+
SELECT a.Id, a.Name
76+
FROM
77+
(
78+
SELECT 1 AS Id, '#Temp' AS Name UNION
79+
SELECT 2, 'some other value'
80+
) AS a
81+
INNER JOIN #Temp t ON t.Id = a.Id",
82+
ConnectionStringOrName =
83+
ConnectionStrings.AdventureWorksLiteral>
84+
85+
[<Fact>]
86+
let queryWithHashAndTempTable() =
87+
// We shouldn't mangle the statement when it's run
88+
use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral)
89+
conn.Open()
90+
91+
use cmd = new TempTableHash(conn)
92+
93+
cmd.LoadTempTables(
94+
Temp =
95+
[ TempTableHash.Temp(Id = 1) ])
96+
97+
let actual =
98+
cmd.Execute()
99+
|> Seq.map(fun x -> x.Id, x.Name)
100+
|> Seq.toList
101+
102+
let expected = [
103+
1, "#Temp"
104+
]
105+
106+
Assert.Equal<_ list>(expected, actual)

src/SqlClient/AssemblyInfo.fs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
namespace System
1+
// Auto-Generated by FAKE; do not edit
2+
namespace System
23
open System.Reflection
34
open System.Runtime.CompilerServices
45

@@ -11,4 +12,9 @@ open System.Runtime.CompilerServices
1112
do ()
1213

1314
module internal AssemblyVersionInformation =
14-
let [<Literal>] Version = "1.8.4"
15+
let [<Literal>] AssemblyTitle = "SqlClient"
16+
let [<Literal>] AssemblyProduct = "FSharp.Data.SqlClient"
17+
let [<Literal>] AssemblyDescription = "SqlClient F# type providers"
18+
let [<Literal>] AssemblyVersion = "1.8.4"
19+
let [<Literal>] AssemblyFileVersion = "1.8.4"
20+
let [<Literal>] InternalsVisibleTo = "SqlClient.Tests"

src/SqlClient/DesignTime.fs

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ open System.Diagnostics
1010
open Microsoft.FSharp.Quotations
1111
open ProviderImplementation.ProvidedTypes
1212
open FSharp.Data
13+
open System.Text.RegularExpressions
1314

1415
type internal RowType = {
1516
Provided: Type
@@ -40,7 +41,52 @@ module internal SharedLogic =
4041
// add .Table
4142
returnType.Single |> cmdProvidedType.AddMember
4243

43-
type DesignTime private() =
44+
module Prefixes =
45+
let tempTable = "##SQLCOMMANDPROVIDER_"
46+
let tableVar = "@SQLCOMMANDPROVIDER_"
47+
48+
type TempTableLoader(fieldCount, items: obj seq) =
49+
let enumerator = items.GetEnumerator()
50+
51+
interface IDataReader with
52+
member this.FieldCount: int = fieldCount
53+
member this.Read(): bool = enumerator.MoveNext()
54+
member this.GetValue(i: int): obj =
55+
let row : obj[] = unbox enumerator.Current
56+
row.[i]
57+
member this.Dispose(): unit = ()
58+
59+
member __.Close(): unit = invalidOp "NotImplementedException"
60+
member __.Depth: int = invalidOp "NotImplementedException"
61+
member __.GetBoolean(_: int): bool = invalidOp "NotImplementedException"
62+
member __.GetByte(_ : int): byte = invalidOp "NotImplementedException"
63+
member __.GetBytes(_ : int, _ : int64, _ : byte [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
64+
member __.GetChar(_ : int): char = invalidOp "NotImplementedException"
65+
member __.GetChars(_ : int, _ : int64, _ : char [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
66+
member __.GetData(_ : int): IDataReader = invalidOp "NotImplementedException"
67+
member __.GetDataTypeName(_ : int): string = invalidOp "NotImplementedException"
68+
member __.GetDateTime(_ : int): System.DateTime = invalidOp "NotImplementedException"
69+
member __.GetDecimal(_ : int): decimal = invalidOp "NotImplementedException"
70+
member __.GetDouble(_ : int): float = invalidOp "NotImplementedException"
71+
member __.GetFieldType(_ : int): System.Type = invalidOp "NotImplementedException"
72+
member __.GetFloat(_ : int): float32 = invalidOp "NotImplementedException"
73+
member __.GetGuid(_ : int): System.Guid = invalidOp "NotImplementedException"
74+
member __.GetInt16(_ : int): int16 = invalidOp "NotImplementedException"
75+
member __.GetInt32(_ : int): int = invalidOp "NotImplementedException"
76+
member __.GetInt64(_ : int): int64 = invalidOp "NotImplementedException"
77+
member __.GetName(_ : int): string = invalidOp "NotImplementedException"
78+
member __.GetOrdinal(_ : string): int = invalidOp "NotImplementedException"
79+
member __.GetSchemaTable(): DataTable = invalidOp "NotImplementedException"
80+
member __.GetString(_ : int): string = invalidOp "NotImplementedException"
81+
member __.GetValues(_ : obj []): int = invalidOp "NotImplementedException"
82+
member __.IsClosed: bool = invalidOp "NotImplementedException"
83+
member __.IsDBNull(_ : int): bool = invalidOp "NotImplementedException"
84+
member __.Item with get (_ : int): obj = invalidOp "NotImplementedException"
85+
member __.Item with get (_ : string): obj = invalidOp "NotImplementedException"
86+
member __.NextResult(): bool = invalidOp "NotImplementedException"
87+
member __.RecordsAffected: int = invalidOp "NotImplementedException"
88+
89+
type DesignTime private() =
4490
static member internal AddGeneratedMethod
4591
(sqlParameters: Parameter list, hasOutputParameters, executeArgs: ProvidedParameter list, erasedType, providedOutputType, name) =
4692

@@ -632,3 +678,133 @@ type DesignTime private() =
632678
then
633679
yield upcast ProvidedMethod(factoryMethodName.Value, parameters2, returnType = cmdProvidedType, IsStaticMethod = true, InvokeCode = body2)
634680
]
681+
682+
static member private CreateTempTableRecord(name, cols) =
683+
let rowType = ProvidedTypeDefinition(name, Some typeof<obj>, HideObjectMethods = true)
684+
685+
let parameters =
686+
[
687+
for (p : Column) in cols do
688+
let name = p.Name
689+
let param = ProvidedParameter( name, p.GetProvidedType(), ?optionalValue = if p.Nullable then Some null else None)
690+
yield param
691+
]
692+
693+
let ctor = ProvidedConstructor( parameters)
694+
ctor.InvokeCode <- fun args ->
695+
let optionsToNulls = QuotationsFactory.MapArrayNullableItems(cols, "MapArrayOptionItemToObj")
696+
697+
<@@ let values: obj[] = %%Expr.NewArray(typeof<obj>, [ for a in args -> Expr.Coerce(a, typeof<obj>) ])
698+
(%%optionsToNulls) values
699+
values @@>
700+
701+
rowType.AddMember ctor
702+
rowType.AddXmlDoc "Type Table Type"
703+
704+
rowType
705+
706+
// Changes any temp tables in to a global temp table (##name) then creates them on the open connection.
707+
static member internal SubstituteTempTables(connection, commandText: string, tempTableDefinitions : string, connectionId) =
708+
// Extract and temp tables
709+
let tempTableRegex = Regex("#([a-z0-9\-_]+)", RegexOptions.IgnoreCase)
710+
let tempTableNames =
711+
tempTableRegex.Matches(tempTableDefinitions)
712+
|> Seq.cast<Match>
713+
|> Seq.map (fun m -> m.Groups.[1].Value)
714+
|> Seq.toList
715+
716+
match tempTableNames with
717+
| [] -> commandText, None
718+
| _ ->
719+
// Create temp table(s), extracts the columns then drop it.
720+
let tableTypes =
721+
use create = new SqlCommand(tempTableDefinitions, connection)
722+
create.ExecuteScalar() |> ignore
723+
724+
tempTableNames
725+
|> List.map(fun name ->
726+
let cols = DesignTime.GetOutputColumns(connection, "SELECT * FROM #"+name, [], isStoredProcedure = false)
727+
use drop = new SqlCommand("DROP TABLE #"+name, connection)
728+
drop.ExecuteScalar() |> ignore
729+
DesignTime.CreateTempTableRecord(name, cols), cols)
730+
731+
let parameters =
732+
tableTypes
733+
|> List.map (fun (typ, _) ->
734+
ProvidedParameter(typ.Name, parameterType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ seq>, [ typ ])))
735+
736+
// Build the values load method.
737+
let loadValues (exprArgs: Expr list) (connection) =
738+
(exprArgs.Tail, tableTypes)
739+
||> List.map2 (fun expr (typ, cols) ->
740+
let destinationTableName = typ.Name
741+
let colsLength = cols.Length
742+
743+
<@@
744+
let items = (%%expr : obj seq)
745+
use reader = new TempTableLoader(colsLength, items)
746+
747+
use bulkCopy = new SqlBulkCopy((%%connection : SqlConnection))
748+
bulkCopy.BulkCopyTimeout <- 0
749+
bulkCopy.BatchSize <- 5000
750+
bulkCopy.DestinationTableName <- "#" + destinationTableName
751+
bulkCopy.WriteToServer(reader)
752+
753+
@@>
754+
)
755+
|> List.fold (fun acc x -> Expr.Sequential(acc, x)) <@@ () @@>
756+
757+
let loadTempTablesMethod = ProvidedMethod("LoadTempTables", parameters, typeof<unit>)
758+
759+
loadTempTablesMethod.InvokeCode <- fun exprArgs ->
760+
761+
let command = Expr.Coerce(exprArgs.[0], typedefof<ISqlCommand>)
762+
763+
let connection =
764+
<@@ let cmd = (%%command : ISqlCommand)
765+
cmd.Raw.Connection @@>
766+
767+
<@@ do
768+
use create = new SqlCommand(tempTableDefinitions, (%%connection : SqlConnection))
769+
create.ExecuteNonQuery() |> ignore
770+
771+
(%%loadValues exprArgs connection)
772+
ignore() @@>
773+
774+
// Create the temp table(s) but as a global temp table with a unique name. This can be used later down stream on the open connection.
775+
use cmd = new SqlCommand(tempTableRegex.Replace(tempTableDefinitions, Prefixes.tempTable+connectionId+"$1"), connection)
776+
cmd.ExecuteScalar() |> ignore
777+
778+
// Only replace temp tables we find in our list.
779+
tempTableRegex.Replace(commandText, MatchEvaluator(fun m ->
780+
match tempTableNames |> List.tryFind((=) m.Groups.[1].Value) with
781+
| Some name -> Prefixes.tempTable + connectionId + name
782+
| None -> m.Groups.[0].Value)),
783+
784+
Some(loadTempTablesMethod, tableTypes |> List.unzip |> fst)
785+
786+
static member internal RemoveSubstitutedTempTables(connection, tempTables : ProvidedTypeDefinition list, connectionId) =
787+
if not tempTables.IsEmpty then
788+
use cmd = new SqlCommand(tempTables |> List.map(fun tempTable -> sprintf "DROP TABLE [%s%s%s]" Prefixes.tempTable connectionId tempTable.Name) |> String.concat ";", connection)
789+
cmd.ExecuteScalar() |> ignore
790+
791+
// tableVarMapping(s) is converted into DECLARE statements then prepended to the command text.
792+
static member internal SubstituteTableVar(commandText: string, tableVarMapping : string) =
793+
let varRegex = Regex("@([a-z0-9_]+)", RegexOptions.IgnoreCase)
794+
795+
let vars =
796+
tableVarMapping.Split([|';'|], System.StringSplitOptions.RemoveEmptyEntries)
797+
|> Array.choose(fun (x : string) ->
798+
match x.Split([|'='|]) with
799+
| [|name;typ|] -> Some(name.TrimStart('@'), typ)
800+
| _ -> None)
801+
802+
// Only replace table vars we find in our list.
803+
let commandText =
804+
varRegex.Replace(commandText, MatchEvaluator(fun m ->
805+
match vars |> Array.tryFind(fun (n,_) -> n = m.Groups.[1].Value) with
806+
| Some (name, _) -> Prefixes.tableVar + name
807+
| None -> m.Groups.[0].Value))
808+
809+
(vars |> Array.map(fun (name,typ) -> sprintf "DECLARE %s%s %s = @%s" Prefixes.tableVar name typ name) |> String.concat "; ") + "; " + commandText
810+

0 commit comments

Comments
 (0)