Skip to content

Commit 8d6b1ef

Browse files
authored
Linear form network (#228)
1 parent fa51083 commit 8d6b1ef

File tree

9 files changed

+78
-17
lines changed

9 files changed

+78
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.13.2"
4+
version = "0.13.3"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ITensorNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include("caches/abstractbeliefpropagationcache.jl")
2929
include("caches/beliefpropagationcache.jl")
3030
include("formnetworks/abstractformnetwork.jl")
3131
include("formnetworks/bilinearformnetwork.jl")
32+
include("formnetworks/linearformnetwork.jl")
3233
include("formnetworks/quadraticformnetwork.jl")
3334
include("contraction_tree_to_graph.jl")
3435
include("gauging.jl")

src/abstractitensornetwork.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ function split_index(
751751
end
752752

753753
function inner_network(x::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...)
754-
return BilinearFormNetwork(x, y; kwargs...)
754+
return LinearFormNetwork(x, y; kwargs...)
755755
end
756756

757757
function inner_network(
@@ -760,12 +760,7 @@ function inner_network(
760760
return BilinearFormNetwork(A, x, y; kwargs...)
761761
end
762762

763-
# TODO: We should make this use the QuadraticFormNetwork constructor here.
764-
# Parts of the code (tests relying on norm_sqr being two layer and the gauging code
765-
# which relies on specific message tensors) currently would break in that case so we need to resolve
766-
function norm_sqr_network::AbstractITensorNetwork)
767-
return disjoint_union("bra" => dag(prime(ψ; sites=[])), "ket" => ψ)
768-
end
763+
norm_sqr_network::AbstractITensorNetwork) = inner_network(ψ, ψ)
769764

770765
#
771766
# Printing

src/expect.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ function expect(
2626
(cache!)=nothing,
2727
update_cache=isnothing(cache!),
2828
cache_update_kwargs=default_cache_update_kwargs(alg),
29-
cache_construction_kwargs=default_cache_construction_kwargs(alg, inner_network(ψ, ψ)),
29+
cache_construction_kwargs=default_cache_construction_kwargs(alg, QuadraticFormNetwork(ψ)),
3030
kwargs...,
3131
)
32-
ψIψ = inner_network(ψ, ψ)
32+
ψIψ = QuadraticFormNetwork(ψ)
3333
if isnothing(cache!)
3434
cache! = Ref(cache(alg, ψIψ; cache_construction_kwargs...))
3535
end
@@ -42,7 +42,7 @@ function expect(
4242
end
4343

4444
function expect(alg::Algorithm"exact", ψ::AbstractITensorNetwork, ops; kwargs...)
45-
ψIψ = inner_network(ψ, ψ)
45+
ψIψ = QuadraticFormNetwork(ψ)
4646
return map(op -> expect(ψIψ, op; alg, kwargs...), ops)
4747
end
4848

src/formnetworks/abstractformnetwork.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ function SimilarType.similar_type(f::AbstractFormNetwork)
2020
return typeof(tensornetwork(f))
2121
end
2222

23+
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph_type`.
24+
function data_graph_type(f::AbstractFormNetwork)
25+
return data_graph_type(tensornetwork(f))
26+
end
27+
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
28+
data_graph(f::AbstractFormNetwork) = data_graph(tensornetwork(f))
29+
2330
function operator_vertices(f::AbstractFormNetwork)
2431
return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f))
2532
end

src/formnetworks/bilinearformnetwork.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ bra_vertex_suffix(blf::BilinearFormNetwork) = blf.bra_vertex_suffix
4242
ket_vertex_suffix(blf::BilinearFormNetwork) = blf.ket_vertex_suffix
4343
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
4444
tensornetwork(blf::BilinearFormNetwork) = blf.tensornetwork
45-
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph_type`.
46-
data_graph_type(::Type{<:BilinearFormNetwork}) = data_graph_type(tensornetwork(blf))
47-
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
48-
data_graph(blf::BilinearFormNetwork) = data_graph(tensornetwork(blf))
4945

5046
function Base.copy(blf::BilinearFormNetwork)
5147
return BilinearFormNetwork(
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using ITensors: ITensor, prime
2+
3+
default_dual_link_index_map = prime
4+
5+
struct LinearFormNetwork{
6+
V,TensorNetwork<:AbstractITensorNetwork{V},BraVertexSuffix,KetVertexSuffix
7+
} <: AbstractFormNetwork{V}
8+
tensornetwork::TensorNetwork
9+
bra_vertex_suffix::BraVertexSuffix
10+
ket_vertex_suffix::KetVertexSuffix
11+
end
12+
13+
function LinearFormNetwork(
14+
bra::AbstractITensorNetwork,
15+
ket::AbstractITensorNetwork;
16+
bra_vertex_suffix=default_bra_vertex_suffix(),
17+
ket_vertex_suffix=default_ket_vertex_suffix(),
18+
dual_link_index_map=default_dual_link_index_map,
19+
)
20+
bra_mapped = dual_link_index_map(bra; sites=[])
21+
tn = disjoint_union(bra_vertex_suffix => dag(bra_mapped), ket_vertex_suffix => ket)
22+
return LinearFormNetwork(tn, bra_vertex_suffix, ket_vertex_suffix)
23+
end
24+
25+
function LinearFormNetwork(blf::BilinearFormNetwork)
26+
bra, ket, operator = subgraph(blf, bra_vertices(blf)),
27+
subgraph(blf, ket_vertices(blf)),
28+
subgraph(blf, operator_vertices(blf))
29+
bra_suffix, ket_suffix = bra_vertex_suffix(blf), ket_vertex_suffix(blf)
30+
operator = rename_vertices(v -> bra_vertex_map(blf)(v), operator)
31+
tn = union(bra, ket, operator)
32+
return LinearFormNetwork(tn, bra_suffix, ket_suffix)
33+
end
34+
35+
bra_vertex_suffix(lf::LinearFormNetwork) = lf.bra_vertex_suffix
36+
ket_vertex_suffix(lf::LinearFormNetwork) = lf.ket_vertex_suffix
37+
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
38+
tensornetwork(lf::LinearFormNetwork) = lf.tensornetwork
39+
40+
function Base.copy(lf::LinearFormNetwork)
41+
return LinearFormNetwork(
42+
copy(tensornetwork(lf)), bra_vertex_suffix(lf), ket_vertex_suffix(lf)
43+
)
44+
end
45+
46+
function update(lf::LinearFormNetwork, original_ket_state_vertex, ket_state::ITensor)
47+
lf = copy(lf)
48+
# TODO: Maybe add a check that it really does preserve the graph.
49+
setindex_preserve_graph!(
50+
tensornetwork(lf), ket_state, ket_vertex(blf, original_ket_state_vertex)
51+
)
52+
return lf
53+
end

test/test_forms.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using NamedGraphs.NamedGraphGenerators: named_grid
55
using ITensorNetworks:
66
BeliefPropagationCache,
77
BilinearFormNetwork,
8+
LinearFormNetwork,
89
QuadraticFormNetwork,
910
bra_network,
1011
bra_vertex,
@@ -35,6 +36,10 @@ using Test: @test, @testset
3536
ψbra = random_tensornetwork(rng, s; link_space=χ)
3637
A = random_tensornetwork(rng, s_operator; link_space=D)
3738

39+
lf = LinearFormNetwork(ψbra, ψket)
40+
@test nv(lf) == nv(ψket) + nv(ψbra)
41+
@test isempty(flatten_siteinds(lf))
42+
3843
blf = BilinearFormNetwork(A, ψbra, ψket)
3944
@test nv(blf) == nv(ψket) + nv(ψbra) + nv(A)
4045
@test isempty(flatten_siteinds(blf))
@@ -43,6 +48,9 @@ using Test: @test, @testset
4348
@test underlying_graph(operator_network(blf)) == underlying_graph(A)
4449
@test underlying_graph(bra_network(blf)) == underlying_graph(ψbra)
4550

51+
lf = LinearFormNetwork(blf)
52+
@test underlying_graph(ket_network(lf)) == underlying_graph(ψket)
53+
4654
qf = QuadraticFormNetwork(ψket)
4755
@test nv(qf) == 3 * nv(ψket)
4856
@test isempty(flatten_siteinds(qf))

test/test_itensornetwork.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using ITensors:
3030
itensor,
3131
onehot,
3232
order,
33+
prime,
3334
random_itensor,
3435
scalartype,
3536
sim,
@@ -55,7 +56,7 @@ using ITensorNetworks:
5556
ttn
5657
using LinearAlgebra: factorize
5758
using NamedGraphs: NamedEdge
58-
using NamedGraphs.GraphsExtensions: incident_edges
59+
using NamedGraphs.GraphsExtensions: disjoint_union, incident_edges
5960
using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid
6061
using NDTensors: NDTensors, dim
6162
using Random: randn!
@@ -140,7 +141,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
140141
g = named_grid(dims)
141142
s = siteinds("S=1/2", g)
142143
ψ = ITensorNetwork(v -> "", s)
143-
tn = norm_sqr_network)
144+
tn = disjoint_union("bra" => ψ, "ket" => prime(dag(ψ); sites=[]))
144145
tn_2 = contract(tn, ((1, 2), "ket") => ((1, 2), "bra"))
145146
@test !has_vertex(tn_2, ((1, 2), "ket"))
146147
@test tn_2[((1, 2), "bra")] tn[((1, 2), "ket")] * tn[((1, 2), "bra")]

0 commit comments

Comments
 (0)