Skip to content

Commit ba9c1ef

Browse files
committed
Commit fitting code and test
1 parent 4a5f686 commit ba9c1ef

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

src/solvers/fitting.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
using Graphs: vertices
2+
using NamedGraphs: AbstractNamedGraph, NamedEdge
3+
using NamedGraphs.PartitionedGraphs: partitionedges
4+
using Printf: @printf
5+
using ConstructionBase: setproperties
6+
7+
@kwdef mutable struct FittingProblem{State<:AbstractBeliefPropagationCache}
8+
state::State
9+
ket_graph::AbstractNamedGraph
10+
overlap::Number = 0
11+
gauge_region
12+
end
13+
14+
overlap(F::FittingProblem) = F.overlap
15+
ITensorNetworks.state(F::FittingProblem) = F.state
16+
ket_graph(F::FittingProblem) = F.ket_graph
17+
gauge_region(F::FittingProblem) = F.gauge_region
18+
19+
function set_state(F::FittingProblem, state)
20+
FittingProblem(state, F.ket_graph, F.overlap, F.gauge_region)
21+
end
22+
function set_overlap(F::FittingProblem, overlap)
23+
FittingProblem(F.state, F.ket_graph, overlap, F.gauge_region)
24+
end
25+
26+
function ket(F::FittingProblem)
27+
ket_vertices = vertices(ket_graph(F))
28+
return first(induced_subgraph(tensornetwork(state(F)), ket_vertices))
29+
end
30+
31+
function extract(problem::FittingProblem, region_iterator; sweep, kws...)
32+
region = current_region(region_iterator)
33+
prev_region = gauge_region(problem)
34+
tn = state(problem)
35+
path = edge_sequence_between_regions(ket_graph(problem), prev_region, region)
36+
tn = gauge_walk(Algorithm("orthogonalize"), tn, path)
37+
pe_path = partitionedges(partitioned_tensornetwork(tn), path)
38+
tn = update(
39+
Algorithm("bp"), tn, pe_path; message_update_function_kwargs=(; normalize=false)
40+
)
41+
local_tensor = environment(tn, region)
42+
sequence = contraction_sequence(local_tensor; alg="optimal")
43+
local_tensor = dag(contract(local_tensor; sequence))
44+
#problem, local_tensor = subspace_expand(problem, local_tensor, region; sweep, kws...)
45+
return setproperties(problem; state=tn, gauge_region=region), local_tensor
46+
end
47+
48+
function update(F::FittingProblem, local_tensor, region; outputlevel, kws...)
49+
n = (local_tensor * dag(local_tensor))[]
50+
F = set_overlap(F, n / sqrt(n))
51+
if outputlevel >= 2
52+
@printf(" Region %s: squared overlap = %.12f\n", region, overlap(F))
53+
end
54+
return F, local_tensor
55+
end
56+
57+
function region_plan(F::FittingProblem; nsites, sweep_kwargs...)
58+
return euler_sweep(ket_graph(F); nsites, sweep_kwargs...)
59+
end
60+
61+
function fit_tensornetwork(
62+
overlap_network,
63+
args...;
64+
nsweeps=25,
65+
nsites=1,
66+
outputlevel=0,
67+
extract_kwargs=(;),
68+
update_kwargs=(;),
69+
insert_kwargs=(;),
70+
normalize=true,
71+
kws...,
72+
)
73+
bpc = BeliefPropagationCache(overlap_network, args...)
74+
ket_graph = first(
75+
induced_subgraph(underlying_graph(overlap_network), ket_vertices(overlap_network))
76+
)
77+
init_prob = FittingProblem(;
78+
ket_graph, state=bpc, gauge_region=collect(vertices(ket_graph))
79+
)
80+
81+
insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false)
82+
common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs)
83+
kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps]
84+
sweep_iter = sweep_iterator(init_prob, kwargs_array)
85+
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
86+
return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob))
87+
end
88+
89+
function fit_tensornetwork(tn, init_state, args...; kwargs...)
90+
return fit_tensornetwork(inner_network(tn, init_state), args; kwargs...)
91+
end
92+
93+
#function truncate(tn; maxdim=default_maxdim(), cutoff=default_cutoff(), kwargs...)
94+
# init_state = ITensorNetwork(
95+
# v -> inds -> delta(inds), siteinds(tn); link_space=maxdim
96+
# )
97+
# overlap_network = inner_network(tn, init_state)
98+
# insert_kwargs = (; trunc=(; cutoff, maxdim))
99+
# return fit_tensornetwork(overlap_network; insert_kwargs, kwargs...)
100+
#end
101+
102+
function ITensors.apply(
103+
A::ITensorNetwork,
104+
x::ITensorNetwork;
105+
maxdim=default_maxdim(),
106+
cutoff=default_cutoff(),
107+
kwargs...,
108+
)
109+
init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim)
110+
overlap_network = inner_network(x, A, init_state)
111+
insert_kwargs = (; trunc=(; cutoff, maxdim))
112+
return fit_tensornetwork(overlap_network; insert_kwargs, kwargs...)
113+
end

test/solvers/test_fitting.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using ITensors: apply, inner
2+
using ITensorNetworks: ITensorNetwork, siteinds, ttn, random_tensornetwork
3+
using ITensorNetworks.ModelHamiltonians: heisenberg
4+
using NamedGraphs.NamedGraphGenerators: named_comb_tree
5+
using Test: @test, @testset
6+
using Printf
7+
using StableRNGs: StableRNG
8+
using TensorOperations: TensorOperations #For contraction order finding
9+
10+
@testset "Fitting Tests" begin
11+
outputlevel = 1
12+
for elt in (Float32, Float64, Complex{Float32}, Complex{Float64})
13+
(outputlevel >= 1) && println("\nFitting tests with elt = ", elt)
14+
g = named_comb_tree((3, 2))
15+
s = siteinds("S=1/2", g)
16+
17+
rng = StableRNG(1234)
18+
19+
##One-site truncation
20+
#a = random_tensornetwork(rng, elt, s; link_space=3)
21+
#b = truncate(a; maxdim=3)
22+
#f =
23+
# inner(a, b; alg="exact") /
24+
# sqrt(inner(a, a; alg="exact") * inner(b, b; alg="exact"))
25+
#(outputlevel >= 1) && @printf("One-site truncation. Fidelity = %s\n", f)
26+
#@test abs(abs(f) - 1.0) <= 10*eps(real(elt))
27+
28+
##Two-site truncation
29+
#a = random_tensornetwork(rng, elt, s; link_space=3)
30+
#b = truncate(a; maxdim=3, cutoff=1e-16, nsites=2)
31+
#f =
32+
# inner(a, b; alg="exact") /
33+
# sqrt(inner(a, a; alg="exact") * inner(b, b; alg="exact"))
34+
#(outputlevel >= 1) && @printf("Two-site truncation. Fidelity = %s\n", f)
35+
#@test abs(abs(f) - 1.0) <= 10*eps(real(elt))
36+
37+
# #One-site apply (no normalization)
38+
a = random_tensornetwork(rng, elt, s; link_space=2)
39+
H = ITensorNetwork(ttn(heisenberg(g), s))
40+
Ha = apply(H, a; maxdim=4, nsites=1, normalize=false)
41+
f = inner(Ha, a; alg="exact") / inner(a, H, a; alg="exact")
42+
(outputlevel >= 1) && @printf("One-site apply. Fidelity = %s\n", f)
43+
@test abs(f - 1.0) <= 500*eps(real(elt))
44+
45+
# #Two-site apply (no normalization)
46+
a = random_tensornetwork(rng, elt, s; link_space=2)
47+
H = ITensorNetwork(ttn(heisenberg(g), s))
48+
Ha = apply(H, a; maxdim=4, cutoff=1e-16, nsites=2, normalize=false)
49+
f = inner(Ha, a; alg="exact") / inner(a, H, a; alg="exact")
50+
(outputlevel >= 1) && @printf("Two-site apply. Fidelity = %s\n", f)
51+
@test abs(f - 1.0) <= 500*eps(real(elt))
52+
end
53+
end

0 commit comments

Comments
 (0)