Skip to content

Commit cc227c2

Browse files
authored
Fix expect and correlation_matrix for complex operator and real states (#163)
1 parent f23415f commit cc227c2

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorMPS"
22
uuid = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2"
33
authors = ["Matthew Fishman <[email protected]>", "Miles Stoudenmire <[email protected]>"]
4-
version = "0.3.22"
4+
version = "0.3.23"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/mps.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Adapt: adapt
22
using NDTensors: using_auto_fermion
3+
using NDTensors.TypeParameterAccessors: unspecify_type_parameters
34
using Random: Random
45
using ITensors.SiteTypes: SiteTypes, siteind, siteinds, state
56

@@ -737,7 +738,8 @@ Cuu = correlation_matrix(psi, "Cdagup", "Cup"; sites=2:8)
737738
```
738739
"""
739740
function correlation_matrix(
740-
psi::MPS, _Op1, _Op2; sites = 1:length(psi), site_range = nothing, ishermitian = nothing
741+
psi::MPS, _Op1, _Op2; sites = 1:length(psi), site_range = nothing,
742+
ishermitian = nothing
741743
)
742744
if !isnothing(site_range)
743745
@warn "The `site_range` keyword arg. to `correlation_matrix` is deprecated: use the keyword `sites` instead"
@@ -796,7 +798,12 @@ function correlation_matrix(
796798
# Nb = size of block of correlation matrix
797799
Nb = length(sites)
798800

799-
C = zeros(ElT, Nb, Nb)
801+
op1_start = op(_Op1, s[start_site])
802+
op2_start = op(_Op2, s[start_site])
803+
ElT1 = eltype(op1_start)
804+
ElT2 = eltype(op2_start)
805+
ElT′ = promote_type(ElT1, ElT2, ElT)
806+
C = zeros(ElT′, Nb, Nb)
800807

801808
if start_site == 1
802809
L = ITensor(1.0)
@@ -817,15 +824,15 @@ function correlation_matrix(
817824

818825
# Get j == i diagonal correlations
819826
rind = commonind(psi[i], psi[i + 1])
820-
oᵢ = adapt(datatype(Li), op(onsiteOp, s, i))
827+
oᵢ = adapt(unspecify_type_parameters(datatype(Li)), op(onsiteOp, s, i))
821828
C[ni, ni] = ((Li * oᵢ) * prime(dag(psi[i]), !rind))[] / norm2_psi
822829

823830
# Get j > i correlations
824831
if !using_auto_fermion() && fermionic2
825832
Op1 = "$Op1 * F"
826833
end
827834

828-
oᵢ = adapt(datatype(Li), op(Op1, s, i))
835+
oᵢ = adapt(unspecify_type_parameters(datatype(Li)), op(Op1, s, i))
829836

830837
Li12 = (dag(psi[i])' * oᵢ) * Li
831838
pL12 = i
@@ -836,7 +843,8 @@ function correlation_matrix(
836843
while pL12 < j - 1
837844
pL12 += 1
838845
if !using_auto_fermion() && fermionic2
839-
oᵢ = adapt(datatype(psi[pL12]), op("F", s[pL12]))
846+
dtype = unspecify_type_parameters(datatype(psi[pL12]))
847+
oᵢ = adapt(dtype, op("F", s[pL12]))
840848
Li12 *= (oᵢ * dag(psi[pL12])')
841849
else
842850
sᵢ = siteind(psi, pL12)
@@ -848,7 +856,7 @@ function correlation_matrix(
848856
lind = commonind(psi[j], Li12)
849857
Li12 *= psi[j]
850858

851-
oⱼ = adapt(datatype(Li12), op(Op2, s, j))
859+
oⱼ = adapt(unspecify_type_parameters(datatype(Li12)), op(Op2, s, j))
852860
sⱼ = siteind(psi, j)
853861
val = (Li12 * oⱼ) * prime(dag(psi[j]), (sⱼ, lind))
854862

@@ -863,7 +871,8 @@ function correlation_matrix(
863871

864872
pL12 += 1
865873
if !using_auto_fermion() && fermionic2
866-
oᵢ = adapt(datatype(psi[pL12]), op("F", s[pL12]))
874+
dtype = unspecify_type_parameters(datatype(psi[pL12]))
875+
oᵢ = adapt(dtype, op("F", s[pL12]))
867876
Li12 *= (oᵢ * dag(psi[pL12])')
868877
else
869878
sᵢ = siteind(psi, pL12)
@@ -879,7 +888,7 @@ function correlation_matrix(
879888
if !using_auto_fermion() && fermionic1
880889
Op2 = "$Op2 * F"
881890
end
882-
oᵢ = adapt(datatype(psi[i]), op(Op2, s, i))
891+
oᵢ = adapt(unspecify_type_parameters(datatype(psi[i])), op(Op2, s, i))
883892
Li21 = (Li * oᵢ) * dag(psi[i])'
884893
pL21 = i
885894
if !using_auto_fermion() && fermionic1
@@ -892,7 +901,8 @@ function correlation_matrix(
892901
while pL21 < j - 1
893902
pL21 += 1
894903
if !using_auto_fermion() && fermionic1
895-
oᵢ = adapt(datatype(psi[pL21]), op("F", s[pL21]))
904+
dtype = unspecify_type_parameters(datatype(psi[pL21]))
905+
oᵢ = adapt(dtype, op("F", s[pL21]))
896906
Li21 *= oᵢ * dag(psi[pL21])'
897907
else
898908
sᵢ = siteind(psi, pL21)
@@ -904,14 +914,15 @@ function correlation_matrix(
904914
lind = commonind(psi[j], Li21)
905915
Li21 *= psi[j]
906916

907-
oⱼ = adapt(datatype(psi[j]), op(Op1, s, j))
917+
oⱼ = adapt(unspecify_type_parameters(datatype(psi[j])), op(Op1, s, j))
908918
sⱼ = siteind(psi, j)
909919
val = (prime(dag(psi[j]), (sⱼ, lind)) * (oⱼ * Li21))[]
910920
C[nj, ni] = val / norm2_psi
911921

912922
pL21 += 1
913923
if !using_auto_fermion() && fermionic1
914-
oᵢ = adapt(datatype(psi[pL21]), op("F", s[pL21]))
924+
dtype = unspecify_type_parameters(datatype(psi[pL21]))
925+
oᵢ = adapt(dtype, op("F", s[pL21]))
915926
Li21 *= (oᵢ * dag(psi[pL21])')
916927
else
917928
sᵢ = siteind(psi, pL21)
@@ -935,7 +946,7 @@ function correlation_matrix(
935946
L = L * psi[pL] * prime(dag(psi[pL]), !sᵢ)
936947
end
937948
lind = commonind(psi[i], psi[i - 1])
938-
oᵢ = adapt(datatype(psi[i]), op(onsiteOp, s, i))
949+
oᵢ = adapt(unspecify_type_parameters(datatype(psi[i])), op(onsiteOp, s, i))
939950
sᵢ = siteind(psi, i)
940951
val = (L * (oᵢ * psi[i]) * prime(dag(psi[i]), (sᵢ, lind)))[]
941952
C[Nb, Nb] = val / norm2_psi
@@ -1007,7 +1018,7 @@ function expect(psi::MPS, ops; sites = 1:length(psi), site_range = nothing)
10071018
for (entry, j) in enumerate(site_range)
10081019
psi = orthogonalize(psi, j)
10091020
for (n, opname) in enumerate(ops)
1010-
oⱼ = adapt(datatype(psi[j]), op(opname, s[j]))
1021+
oⱼ = adapt(unspecify_type_parameters(datatype(psi[j])), op(opname, s[j]))
10111022
val = inner(psi[j], apply(oⱼ, psi[j])) / norm2_psi
10121023
ex[n][entry] = (el_types[n] <: Real) ? real(val) : val
10131024
end

test/base/test_mps.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,21 @@ end
913913
@test_throws ErrorException expect(psi0, "Sz")
914914
end
915915

916+
@testset "expect real wavefunction complex operator" for elt in (Float32, Float64)
917+
N = 8
918+
s = siteinds("S=1/2", N)
919+
using StableRNGs: StableRNG
920+
rng = StableRNG(123)
921+
psi = random_mps(rng, elt, s; linkdims = 2)
922+
eSy = zeros(complex(elt), N)
923+
for j in 1:N
924+
psi = orthogonalize(psi, j)
925+
eSy[j] = (dag(psi[j]) * apply(op("Sy", s[j]), psi[j]))[]
926+
end
927+
res = expect(psi, "Sy")
928+
@test res eSy atol = eps(elt)
929+
end
930+
916931
@testset "Expected value and Correlations" begin
917932
m = 2
918933

@@ -935,6 +950,7 @@ end
935950
("Sz", "Sz"),
936951
("iSy", "iSy"),
937952
("Sx", "Sx"),
953+
("Sy", "Sy"),
938954
("Sz", "Sx"),
939955
("S+", "S+"),
940956
("S-", "S+"),

0 commit comments

Comments
 (0)