|
| 1 | +using Graphs: vertices |
| 2 | +using NamedGraphs: AbstractNamedGraph, NamedEdge |
| 3 | +using NamedGraphs.PartitionedGraphs: partitionedges |
| 4 | +using Printf: @printf |
| 5 | +using ConstructionBase: setproperties |
| 6 | + |
| 7 | +@kwdef mutable struct FittingProblem{State<:AbstractBeliefPropagationCache} |
| 8 | + state::State |
| 9 | + ket_graph::AbstractNamedGraph |
| 10 | + overlap::Number = 0 |
| 11 | + gauge_region |
| 12 | +end |
| 13 | + |
| 14 | +overlap(F::FittingProblem) = F.overlap |
| 15 | +ITensorNetworks.state(F::FittingProblem) = F.state |
| 16 | +ket_graph(F::FittingProblem) = F.ket_graph |
| 17 | +gauge_region(F::FittingProblem) = F.gauge_region |
| 18 | + |
| 19 | +function set_state(F::FittingProblem, state) |
| 20 | + FittingProblem(state, F.ket_graph, F.overlap, F.gauge_region) |
| 21 | +end |
| 22 | +function set_overlap(F::FittingProblem, overlap) |
| 23 | + FittingProblem(F.state, F.ket_graph, overlap, F.gauge_region) |
| 24 | +end |
| 25 | + |
| 26 | +function ket(F::FittingProblem) |
| 27 | + ket_vertices = vertices(ket_graph(F)) |
| 28 | + return first(induced_subgraph(tensornetwork(state(F)), ket_vertices)) |
| 29 | +end |
| 30 | + |
| 31 | +function extract(problem::FittingProblem, region_iterator; sweep, kws...) |
| 32 | + region = current_region(region_iterator) |
| 33 | + prev_region = gauge_region(problem) |
| 34 | + tn = state(problem) |
| 35 | + path = edge_sequence_between_regions(ket_graph(problem), prev_region, region) |
| 36 | + tn = gauge_walk(Algorithm("orthogonalize"), tn, path) |
| 37 | + pe_path = partitionedges(partitioned_tensornetwork(tn), path) |
| 38 | + tn = update( |
| 39 | + Algorithm("bp"), tn, pe_path; message_update_function_kwargs=(; normalize=false) |
| 40 | + ) |
| 41 | + local_tensor = environment(tn, region) |
| 42 | + sequence = contraction_sequence(local_tensor; alg="optimal") |
| 43 | + local_tensor = dag(contract(local_tensor; sequence)) |
| 44 | + #problem, local_tensor = subspace_expand(problem, local_tensor, region; sweep, kws...) |
| 45 | + return setproperties(problem; state=tn, gauge_region=region), local_tensor |
| 46 | +end |
| 47 | + |
| 48 | +function update(F::FittingProblem, local_tensor, region; outputlevel, kws...) |
| 49 | + n = (local_tensor * dag(local_tensor))[] |
| 50 | + F = set_overlap(F, n / sqrt(n)) |
| 51 | + if outputlevel >= 2 |
| 52 | + @printf(" Region %s: squared overlap = %.12f\n", region, overlap(F)) |
| 53 | + end |
| 54 | + return F, local_tensor |
| 55 | +end |
| 56 | + |
| 57 | +function region_plan(F::FittingProblem; nsites, sweep_kwargs...) |
| 58 | + return euler_sweep(ket_graph(F); nsites, sweep_kwargs...) |
| 59 | +end |
| 60 | + |
| 61 | +function fit_tensornetwork( |
| 62 | + overlap_network, |
| 63 | + args...; |
| 64 | + nsweeps=25, |
| 65 | + nsites=1, |
| 66 | + outputlevel=0, |
| 67 | + extract_kwargs=(;), |
| 68 | + update_kwargs=(;), |
| 69 | + insert_kwargs=(;), |
| 70 | + normalize=true, |
| 71 | + kws..., |
| 72 | +) |
| 73 | + bpc = BeliefPropagationCache(overlap_network, args...) |
| 74 | + ket_graph = first( |
| 75 | + induced_subgraph(underlying_graph(overlap_network), ket_vertices(overlap_network)) |
| 76 | + ) |
| 77 | + init_prob = FittingProblem(; |
| 78 | + ket_graph, state=bpc, gauge_region=collect(vertices(ket_graph)) |
| 79 | + ) |
| 80 | + |
| 81 | + insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false) |
| 82 | + common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs) |
| 83 | + kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps] |
| 84 | + sweep_iter = sweep_iterator(init_prob, kwargs_array) |
| 85 | + converged_prob = sweep_solve(sweep_iter; outputlevel, kws...) |
| 86 | + return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob)) |
| 87 | +end |
| 88 | + |
| 89 | +function fit_tensornetwork(tn, init_state, args...; kwargs...) |
| 90 | + return fit_tensornetwork(inner_network(tn, init_state), args; kwargs...) |
| 91 | +end |
| 92 | + |
| 93 | +#function truncate(tn; maxdim=default_maxdim(), cutoff=default_cutoff(), kwargs...) |
| 94 | +# init_state = ITensorNetwork( |
| 95 | +# v -> inds -> delta(inds), siteinds(tn); link_space=maxdim |
| 96 | +# ) |
| 97 | +# overlap_network = inner_network(tn, init_state) |
| 98 | +# insert_kwargs = (; trunc=(; cutoff, maxdim)) |
| 99 | +# return fit_tensornetwork(overlap_network; insert_kwargs, kwargs...) |
| 100 | +#end |
| 101 | + |
| 102 | +function ITensors.apply( |
| 103 | + A::ITensorNetwork, |
| 104 | + x::ITensorNetwork; |
| 105 | + maxdim=default_maxdim(), |
| 106 | + cutoff=default_cutoff(), |
| 107 | + kwargs..., |
| 108 | +) |
| 109 | + init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim) |
| 110 | + overlap_network = inner_network(x, A, init_state) |
| 111 | + insert_kwargs = (; trunc=(; cutoff, maxdim)) |
| 112 | + return fit_tensornetwork(overlap_network; insert_kwargs, kwargs...) |
| 113 | +end |
0 commit comments