Skip to content

Commit 906f184

Browse files
authored
Refactor and generalize tensor network constructors (#155)
1 parent ae4ad2c commit 906f184

30 files changed

+657
-691
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]> and contributors"]
4-
version = "0.6"
4+
version = "0.7"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -56,14 +56,14 @@ GraphsFlows = "0.1.1"
5656
ITensors = "0.3.58"
5757
IsApprox = "0.1"
5858
IterTools = "1.4.0"
59-
KrylovKit = "0.6.0"
59+
KrylovKit = "0.6, 0.7"
6060
NamedGraphs = "0.1.23"
6161
Observers = "0.2"
6262
PackageExtensionCompat = "1"
6363
Requires = "1.3"
6464
SerializedElementArrays = "0.1"
6565
SimpleTraits = "0.9"
66-
SparseArrayKit = "0.2.1"
66+
SparseArrayKit = "0.2.1, 0.3"
6767
SplitApplyCombine = "1.2"
6868
StaticArrays = "1.5.12"
6969
StructWalk = "0.2"

README.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ and 4 edge(s):
105105

106106
with vertex data:
107107
4-element Dictionaries.Dictionary{Tuple{Int64, Int64}, Any}
108-
(1, 1) │ ((dim=2|id=712|"1×1,2×1"), (dim=2|id=598|"1×1,1×2"))
109-
(2, 1) │ ((dim=2|id=712|"1×1,2×1"), (dim=2|id=457|"2×1,2×2"))
110-
(1, 2) │ ((dim=2|id=598|"1×1,1×2"), (dim=2|id=683|"1×2,2×2"))
111-
(2, 2) │ ((dim=2|id=457|"2×1,2×2"), (dim=2|id=683|"1×2,2×2"))
108+
(1, 1) │ ((dim=2|id=74|"1×1,2×1"), (dim=2|id=723|"1×1,1×2"))
109+
(2, 1) │ ((dim=2|id=74|"1×1,2×1"), (dim=2|id=823|"2×1,2×2"))
110+
(1, 2) │ ((dim=2|id=723|"1×1,1×2"), (dim=2|id=712|"1×2,2×2"))
111+
(2, 2) │ ((dim=2|id=823|"2×1,2×2"), (dim=2|id=712|"1×2,2×2"))
112112

113113
julia> tn[1, 1]
114-
ITensor ord=2 (dim=2|id=712|"1×1,2×1") (dim=2|id=598|"1×1,1×2")
114+
ITensor ord=2 (dim=2|id=74|"1×1,2×1") (dim=2|id=723|"1×1,1×2")
115115
NDTensors.EmptyStorage{NDTensors.EmptyNumber, NDTensors.Dense{NDTensors.EmptyNumber, Vector{NDTensors.EmptyNumber}}}
116116

117117
julia> neighbors(tn, (1, 1))
@@ -135,8 +135,8 @@ and 1 edge(s):
135135

136136
with vertex data:
137137
2-element Dictionaries.Dictionary{Tuple{Int64, Int64}, Any}
138-
(1, 1) │ ((dim=2|id=712|"1×1,2×1"), (dim=2|id=598|"1×1,1×2"))
139-
(1, 2) │ ((dim=2|id=598|"1×1,1×2"), (dim=2|id=683|"1×2,2×2"))
138+
(1, 1) │ ((dim=2|id=74|"1×1,2×1"), (dim=2|id=723|"1×1,1×2"))
139+
(1, 2) │ ((dim=2|id=723|"1×1,1×2"), (dim=2|id=712|"1×2,2×2"))
140140

141141
julia> tn_2 = subgraph(v -> v[1] == 2, tn)
142142
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
@@ -149,8 +149,8 @@ and 1 edge(s):
149149

150150
with vertex data:
151151
2-element Dictionaries.Dictionary{Tuple{Int64, Int64}, Any}
152-
(2, 1) │ ((dim=2|id=712|"1×1,2×1"), (dim=2|id=457|"2×1,2×2"))
153-
(2, 2) │ ((dim=2|id=457|"2×1,2×2"), (dim=2|id=683|"1×2,2×2"))
152+
(2, 1) │ ((dim=2|id=74|"1×1,2×1"), (dim=2|id=823|"2×1,2×2"))
153+
(2, 2) │ ((dim=2|id=823|"2×1,2×2"), (dim=2|id=712|"1×2,2×2"))
154154
```
155155

156156

@@ -176,9 +176,9 @@ and 2 edge(s):
176176

177177
with vertex data:
178178
3-element Dictionaries.Dictionary{Int64, Vector{ITensors.Index}}
179-
1 │ ITensors.Index[(dim=2|id=830|"S=1/2,Site,n=1")]
180-
2 │ ITensors.Index[(dim=2|id=369|"S=1/2,Site,n=2")]
181-
3 │ ITensors.Index[(dim=2|id=558|"S=1/2,Site,n=3")]
179+
1 │ ITensors.Index[(dim=2|id=683|"S=1/2,Site,n=1")]
180+
2 │ ITensors.Index[(dim=2|id=123|"S=1/2,Site,n=2")]
181+
3 │ ITensors.Index[(dim=2|id=656|"S=1/2,Site,n=3")]
182182

183183
and edge data:
184184
0-element Dictionaries.Dictionary{NamedGraphs.NamedEdge{Int64}, Vector{ITensors.Index}}
@@ -196,9 +196,9 @@ and 2 edge(s):
196196

197197
with vertex data:
198198
3-element Dictionaries.Dictionary{Int64, Any}
199-
1 │ ((dim=2|id=830|"S=1/2,Site,n=1"), (dim=2|id=186|"1,2"))
200-
2 │ ((dim=2|id=369|"S=1/2,Site,n=2"), (dim=2|id=186|"1,2"), (dim=2|id=430|"2,3…
201-
3 │ ((dim=2|id=558|"S=1/2,Site,n=3"), (dim=2|id=430|"2,3"))
199+
1 │ ((dim=2|id=683|"S=1/2,Site,n=1"), (dim=2|id=382|"1,2"))
200+
2 │ ((dim=2|id=123|"S=1/2,Site,n=2"), (dim=2|id=382|"1,2"), (dim=2|id=190|"2,3…
201+
3 │ ((dim=2|id=656|"S=1/2,Site,n=3"), (dim=2|id=190|"2,3"))
202202
203203
julia> tn2 = ITensorNetwork(s; link_space=2)
204204
ITensorNetworks.ITensorNetwork{Int64} with 3 vertices:
@@ -213,9 +213,9 @@ and 2 edge(s):
213213
214214
with vertex data:
215215
3-element Dictionaries.Dictionary{Int64, Any}
216-
1 │ ((dim=2|id=830|"S=1/2,Site,n=1"), (dim=2|id=994|"1,2"))
217-
2 │ ((dim=2|id=369|"S=1/2,Site,n=2"), (dim=2|id=994|"1,2"), (dim=2|id=978|"2,3
218-
3 │ ((dim=2|id=558|"S=1/2,Site,n=3"), (dim=2|id=978|"2,3"))
216+
1 │ ((dim=2|id=683|"S=1/2,Site,n=1"), (dim=2|id=934|"1,2"))
217+
2 │ ((dim=2|id=123|"S=1/2,Site,n=2"), (dim=2|id=934|"1,2"), (dim=2|id=614|"2,3
218+
3 │ ((dim=2|id=656|"S=1/2,Site,n=3"), (dim=2|id=614|"2,3"))
219219

220220
julia> @visualize tn1;
221221
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀

src/ModelNetworks/ModelNetworks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ModelNetworks
22
using Graphs: degree, dst, edges, src
3-
using ..ITensorNetworks: IndsNetwork, delta_network, insert_missing_internal_inds, itensor
3+
using ..ITensorNetworks: IndsNetwork, delta_network, insert_linkinds, itensor
44
using ITensors: commoninds, diagITensor, inds, noprime
55
using LinearAlgebra: Diagonal, eigen
66
using NamedGraphs: NamedGraph
@@ -17,7 +17,7 @@ OPTIONAL ARGUMENT:
1717
function ising_network(
1818
eltype::Type, s::IndsNetwork, beta::Number; h::Number=0.0, szverts=nothing
1919
)
20-
s = insert_missing_internal_inds(s, edges(s); internal_inds_space=2)
20+
s = insert_linkinds(s; link_space=2)
2121
tn = delta_network(eltype, s)
2222
if (szverts != nothing)
2323
for v in szverts

src/abstractindsnetwork.jl

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,57 @@ end
2323
# TODO: Define a generic fallback for `AbstractDataGraph`?
2424
DataGraphs.edge_data_type(::Type{<:AbstractIndsNetwork{V,I}}) where {V,I} = Vector{I}
2525

26+
## TODO: Bring these back.
27+
## function indsnetwork_getindex(is::AbstractIndsNetwork, index)
28+
## return get(data_graph(is), index, indtype(is)[])
29+
## end
30+
##
31+
## function Base.getindex(is::AbstractIndsNetwork, index)
32+
## return indsnetwork_getindex(is, index)
33+
## end
34+
##
35+
## function Base.getindex(is::AbstractIndsNetwork, index::Pair)
36+
## return indsnetwork_getindex(is, index)
37+
## end
38+
##
39+
## function Base.getindex(is::AbstractIndsNetwork, index::AbstractEdge)
40+
## return indsnetwork_getindex(is, index)
41+
## end
42+
##
43+
## function indsnetwork_setindex!(is::AbstractIndsNetwork, value, index)
44+
## data_graph(is)[index] = value
45+
## return is
46+
## end
47+
##
48+
## function Base.setindex!(is::AbstractIndsNetwork, value, index)
49+
## indsnetwork_setindex!(is, value, index)
50+
## return is
51+
## end
52+
##
53+
## function Base.setindex!(is::AbstractIndsNetwork, value, index::Pair)
54+
## indsnetwork_setindex!(is, value, index)
55+
## return is
56+
## end
57+
##
58+
## function Base.setindex!(is::AbstractIndsNetwork, value, index::AbstractEdge)
59+
## indsnetwork_setindex!(is, value, index)
60+
## return is
61+
## end
62+
##
63+
## function Base.setindex!(is::AbstractIndsNetwork, value::Index, index)
64+
## indsnetwork_setindex!(is, value, index)
65+
## return is
66+
## end
67+
2668
#
2769
# Index access
2870
#
2971

3072
function ITensors.uniqueinds(is::AbstractIndsNetwork, edge::AbstractEdge)
73+
# TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter.
3174
inds = IndexSet(get(is, src(edge), Index[]))
3275
for ei in setdiff(incident_edges(is, src(edge)), [edge])
76+
# TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter.
3377
inds = unioninds(inds, get(is, ei, Index[]))
3478
end
3579
return inds
@@ -39,8 +83,8 @@ function ITensors.uniqueinds(is::AbstractIndsNetwork, edge::Pair)
3983
return uniqueinds(is, edgetype(is)(edge))
4084
end
4185

42-
function Base.union(tn1::AbstractIndsNetwork, tn2::AbstractIndsNetwork; kwargs...)
43-
return IndsNetwork(union(data_graph(tn1), data_graph(tn2); kwargs...))
86+
function Base.union(is1::AbstractIndsNetwork, is2::AbstractIndsNetwork; kwargs...)
87+
return IndsNetwork(union(data_graph(is1), data_graph(is2); kwargs...))
4488
end
4589

4690
function NamedGraphs.rename_vertices(f::Function, tn::AbstractIndsNetwork)
@@ -51,31 +95,49 @@ end
5195
# Convenience functions
5296
#
5397

98+
function promote_indtypeof(is::AbstractIndsNetwork)
99+
sitetype = mapreduce(promote_indtype, vertices(is); init=Index{Int}) do v
100+
# TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter.
101+
return mapreduce(typeof, promote_indtype, get(is, v, Index[]); init=Index{Int})
102+
end
103+
linktype = mapreduce(promote_indtype, edges(is); init=Index{Int}) do e
104+
# TODO: Replace with `is[e]` once `getindex(::IndsNetwork, ...)` is smarter.
105+
return mapreduce(typeof, promote_indtype, get(is, e, Index[]); init=Index{Int})
106+
end
107+
return promote_indtype(sitetype, linktype)
108+
end
109+
54110
function union_all_inds(is_in::AbstractIndsNetwork...)
55111
@assert all(map(ug -> ug == underlying_graph(is_in[1]), underlying_graph.(is_in)))
56112
is_out = IndsNetwork(underlying_graph(is_in[1]))
57113
for v in vertices(is_out)
114+
# TODO: Remove this check.
58115
if any(isassigned(is, v) for is in is_in)
116+
# TODO: Change `get` to `getindex`.
59117
is_out[v] = unioninds([get(is, v, Index[]) for is in is_in]...)
60118
end
61119
end
62120
for e in edges(is_out)
121+
# TODO: Remove this check.
63122
if any(isassigned(is, e) for is in is_in)
123+
# TODO: Change `get` to `getindex`.
64124
is_out[e] = unioninds([get(is, e, Index[]) for is in is_in]...)
65125
end
66126
end
67127
return is_out
68128
end
69129

70-
function insert_missing_internal_inds(
130+
function insert_linkinds(
71131
indsnetwork::AbstractIndsNetwork,
72132
edges=edges(indsnetwork);
73-
internal_inds_space=trivial_space(indsnetwork),
133+
link_space=trivial_space(indsnetwork),
74134
)
75135
indsnetwork = copy(indsnetwork)
76136
for e in edges
137+
# TODO: Change to check if it is empty.
77138
if !isassigned(indsnetwork, e)
78-
iₑ = Index(internal_inds_space, edge_tag(e))
139+
iₑ = Index(link_space, edge_tag(e))
140+
# TODO: Allow setting with just `Index`.
79141
indsnetwork[e] = [iₑ]
80142
end
81143
end

src/abstractitensornetwork.jl

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,12 @@ end
116116
# TODO: broadcasting
117117

118118
function Base.union(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork; kwargs...)
119-
tn = ITensorNetwork(union(data_graph(tn1), data_graph(tn2)); kwargs...)
119+
# TODO: Use a different constructor call here?
120+
tn = _ITensorNetwork(union(data_graph(tn1), data_graph(tn2)); kwargs...)
120121
# Add any new edges that are introduced during the union
121122
for v1 in vertices(tn1)
122123
for v2 in vertices(tn2)
123-
if hascommoninds(tn[v1], tn[v2])
124+
if hascommoninds(tn, v1 => v2)
124125
add_edge!(tn, v1 => v2)
125126
end
126127
end
@@ -129,7 +130,8 @@ function Base.union(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork; kw
129130
end
130131

131132
function NamedGraphs.rename_vertices(f::Function, tn::AbstractITensorNetwork)
132-
return ITensorNetwork(rename_vertices(f, data_graph(tn)))
133+
# TODO: Use a different constructor call here?
134+
return _ITensorNetwork(rename_vertices(f, data_graph(tn)))
133135
end
134136

135137
#
@@ -172,6 +174,8 @@ function Base.Vector{ITensor}(tn::AbstractITensorNetwork)
172174
end
173175

174176
# Convenience wrapper
177+
# TODO: Delete this and just use `Vector{ITensor}`, or maybe
178+
# it should output a dictionary or be called `eachtensor`?
175179
itensors(tn::AbstractITensorNetwork) = Vector{ITensor}(tn)
176180

177181
#
@@ -182,10 +186,13 @@ function LinearAlgebra.promote_leaf_eltypes(tn::AbstractITensorNetwork)
182186
return LinearAlgebra.promote_leaf_eltypes(itensors(tn))
183187
end
184188

185-
function trivial_space(tn::AbstractITensorNetwork)
186-
return trivial_space(tn[first(vertices(tn))])
189+
function promote_indtypeof(tn::AbstractITensorNetwork)
190+
return mapreduce(promote_indtype, vertices(tn)) do v
191+
return indtype(tn[v])
192+
end
187193
end
188194

195+
# TODO: Delete in favor of `scalartype`.
189196
function ITensors.promote_itensor_eltype(tn::AbstractITensorNetwork)
190197
return LinearAlgebra.promote_leaf_eltypes(tn)
191198
end
@@ -464,7 +471,6 @@ function NDTensors.contract(
464471
neighbors_src = setdiff(neighbors(tn, src(edge)), [dst(edge)])
465472
neighbors_dst = setdiff(neighbors(tn, dst(edge)), [src(edge)])
466473
new_itensor = tn[src(edge)] * tn[dst(edge)]
467-
468474
# The following is equivalent to:
469475
#
470476
# tn[dst(edge)] = new_itensor
@@ -480,6 +486,7 @@ function NDTensors.contract(
480486
for n_dst in neighbors_dst
481487
add_edge!(tn, merged_vertex => n_dst)
482488
end
489+
483490
setindex_preserve_graph!(tn, new_itensor, merged_vertex)
484491

485492
return tn
@@ -736,7 +743,8 @@ function norm_network(tn::AbstractITensorNetwork)
736743
setindex_preserve_graph!(tndag, dag(tndag[v]), v)
737744
end
738745
tnket = rename_vertices(v -> (v, 2), data_graph(prime(tndag; sites=[])))
739-
tntn = ITensorNetwork(union(tnbra, tnket))
746+
# TODO: Use a different constructor here?
747+
tntn = _ITensorNetwork(union(tnbra, tnket))
740748
for v in vertices(tn)
741749
if !isempty(commoninds(tntn[(v, 1)], tntn[(v, 2)]))
742750
add_edge!(tntn, (v, 1) => (v, 2))
@@ -809,6 +817,9 @@ end
809817

810818
Base.show(io::IO, graph::AbstractITensorNetwork) = show(io, MIME"text/plain"(), graph)
811819

820+
# TODO: Move to an `ITensorNetworksVisualizationInterfaceExt`
821+
# package extension (and define a `VisualizationInterface` package
822+
# based on `ITensorVisualizationCore`.).
812823
function ITensorVisualizationCore.visualize(
813824
tn::AbstractITensorNetwork,
814825
args...;
@@ -865,13 +876,13 @@ function site_combiners(tn::AbstractITensorNetwork{V}) where {V}
865876
return Cs
866877
end
867878

868-
function insert_missing_internal_inds(
869-
tn::AbstractITensorNetwork, edges; internal_inds_space=trivial_space(tn)
879+
function insert_linkinds(
880+
tn::AbstractITensorNetwork, edges=edges(tn); link_space=trivial_space(tn)
870881
)
871882
tn = copy(tn)
872883
for e in edges
873-
if !hascommoninds(tn[src(e)], tn[dst(e)])
874-
iₑ = Index(internal_inds_space, edge_tag(e))
884+
if !hascommoninds(tn, e)
885+
iₑ = Index(link_space, edge_tag(e))
875886
X = onehot(iₑ => 1)
876887
tn[src(e)] *= X
877888
tn[dst(e)] *= dag(X)
@@ -880,12 +891,10 @@ function insert_missing_internal_inds(
880891
return tn
881892
end
882893

883-
function insert_missing_internal_inds(
884-
tn::AbstractITensorNetwork; internal_inds_space=trivial_space(tn)
885-
)
886-
return insert_internal_inds(tn, edges(tn); internal_inds_space)
887-
end
888-
894+
# TODO: What to output? Could be an `IndsNetwork`. Or maybe
895+
# that would be a different function `commonindsnetwork`.
896+
# Even in that case, this could output a `Dictionary`
897+
# from the edges to the common inds on that edge.
889898
function ITensors.commoninds(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
890899
inds = Index[]
891900
for v1 in vertices(tn1)
@@ -911,8 +920,8 @@ function ITensorMPS.add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork
911920

912921
if !issetequal(edges_tn1, edges_tn2)
913922
new_edges = union(edges_tn1, edges_tn2)
914-
tn1 = insert_missing_internal_inds(tn1, new_edges)
915-
tn2 = insert_missing_internal_inds(tn2, new_edges)
923+
tn1 = insert_linkinds(tn1, new_edges)
924+
tn2 = insert_linkinds(tn2, new_edges)
916925
end
917926

918927
edges_tn1, edges_tn2 = edges(tn1), edges(tn2)

src/indsnetwork.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ using DataGraphs: DataGraphs, DataGraph, IsUnderlyingGraph, map_data, vertex_dat
22
using Dictionaries: AbstractDictionary, Indices
33
using Graphs: Graphs
44
using Graphs.SimpleGraphs: AbstractSimpleGraph
5-
# using LinearAlgebra: I # Not sure if this is needed
65
using ITensors: Index, dag
76
using ITensors.ITensorVisualizationCore: ITensorVisualizationCore, visualize
8-
using NamedGraphs: NamedGraphs, AbstractNamedGraph, NamedEdge, NamedGraph, vertextype
7+
using NamedGraphs:
8+
NamedGraphs, AbstractNamedGraph, NamedEdge, NamedGraph, named_path_graph, vertextype
99

1010
struct IndsNetwork{V,I} <: AbstractIndsNetwork{V,I}
1111
data_graph::DataGraph{V,Vector{I},Vector{I},NamedGraph{V},NamedEdge{V}}

0 commit comments

Comments
 (0)