Skip to content

Commit 88f6b2c

Browse files
authored
Update gauge_walk functionality and fix bug in VidalITensorNetwork (#222)
1 parent f1cce56 commit 88f6b2c

File tree

7 files changed

+40
-21
lines changed

7 files changed

+40
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.12.0"
4+
version = "0.12.1"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ITensorNetworks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ include("specialitensornetworks.jl")
2525
include("boundarymps.jl")
2626
include("partitioneditensornetwork.jl")
2727
include("edge_sequences.jl")
28+
include("caches/abstractbeliefpropagationcache.jl")
29+
include("caches/beliefpropagationcache.jl")
2830
include("formnetworks/abstractformnetwork.jl")
2931
include("formnetworks/bilinearformnetwork.jl")
3032
include("formnetworks/quadraticformnetwork.jl")
31-
include("caches/abstractbeliefpropagationcache.jl")
32-
include("caches/beliefpropagationcache.jl")
3333
include("contraction_tree_to_graph.jl")
3434
include("gauging.jl")
3535
include("utils.jl")

src/abstractitensornetwork.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,18 +623,42 @@ function gauge_walk(
623623
return gauge_walk(alg, tn, edgetype(tn).(edges); kwargs...)
624624
end
625625

626+
function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region)
627+
return tree_gauge(alg, ψ, [region])
628+
end
629+
630+
#Get the path that moves the gauge from a to b on a tree
631+
#TODO: Move to NamedGraphs
632+
function edge_sequence_between_regions(g::AbstractGraph, region_a::Vector, region_b::Vector)
633+
issetequal(region_a, region_b) && return edgetype(g)[]
634+
st = steiner_tree(g, union(region_a, region_b))
635+
path = post_order_dfs_edges(st, first(region_b))
636+
path = filter(e -> !((src(e) region_b) && (dst(e) region_b)), path)
637+
return path
638+
end
639+
640+
# Gauge a ITensorNetwork from cur_region towards new_region, treating
641+
# the network as a tree spanned by a spanning tree.
642+
function tree_gauge(
643+
alg::Algorithm,
644+
ψ::AbstractITensorNetwork,
645+
cur_region::Vector,
646+
new_region::Vector;
647+
kwargs...,
648+
)
649+
es = edge_sequence_between_regions(ψ, cur_region, new_region)
650+
ψ = gauge_walk(alg, ψ, es; kwargs...)
651+
return ψ
652+
end
653+
626654
# Gauge a ITensorNetwork towards a region, treating
627655
# the network as a tree spanned by a spanning tree.
628656
function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region::Vector)
629-
region_center =
630-
length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region)
631-
path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center)
632-
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
633-
return gauge_walk(alg, ψ, path)
657+
return tree_gauge(alg, ψ, collect(vertices(ψ)), region)
634658
end
635659

636-
function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region)
637-
return tree_gauge(alg, ψ, [region])
660+
function tree_orthogonalize(ψ::AbstractITensorNetwork, cur_region, new_region; kwargs...)
661+
return tree_gauge(Algorithm("orthogonalize"), ψ, cur_region, new_region; kwargs...)
638662
end
639663

640664
function tree_orthogonalize::AbstractITensorNetwork, region; kwargs...)

src/apply.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ function ITensors.apply(o, ψ::VidalITensorNetwork; normalize=false, apply_kwarg
378378

379379
else
380380
updated_ψ = apply(o, updated_ψ; normalize)
381-
return VidalITensorNetwork(ψ, updated_bond_tensors)
381+
return VidalITensorNetwork(updated_ψ, updated_bond_tensors)
382382
end
383383
end
384384

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ default_messages(ptn::PartitionedGraph) = Dictionary()
4040
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
4141
end
4242
default_partitioned_vertices::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
43-
function default_partitioned_vertices(f::AbstractFormNetwork)
44-
return group(v -> original_state_vertex(f, v), vertices(f))
45-
end
4643

4744
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
4845
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()

src/formnetworks/abstractformnetwork.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,7 @@ operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v)
8080
bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v)
8181
ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v)
8282
original_state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v)
83+
84+
function default_partitioned_vertices(f::AbstractFormNetwork)
85+
return group(v -> original_state_vertex(f, v), vertices(f))
86+
end

src/treetensornetworks/abstracttreetensornetwork.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,7 @@ function set_ortho_region(tn::AbstractTTN, new_region)
3636
end
3737

3838
function gauge(alg::Algorithm, ttn::AbstractTTN, region::Vector; kwargs...)
39-
issetequal(region, ortho_region(ttn)) && return ttn
40-
st = steiner_tree(ttn, union(region, ortho_region(ttn)))
41-
path = post_order_dfs_edges(st, first(region))
42-
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
43-
if !isempty(path)
44-
ttn = typeof(ttn)(gauge_walk(alg, ITensorNetwork(ttn), path; kwargs...))
45-
end
39+
ttn = tree_gauge(alg, ttn, collect(ortho_region(ttn)), region; kwargs...)
4640
return set_ortho_region(ttn, region)
4741
end
4842

0 commit comments

Comments
 (0)