14
14
from typing import (
15
15
Optional ,
16
16
Union ,
17
+ Tuple ,
17
18
TYPE_CHECKING
18
19
)
19
20
20
21
import scipy .linalg as _la
21
22
22
- from pygsti .baseobjs import StateSpace as _StateSpace , statespace as _statespace , ExplicitStateSpace as _ExplicitStateSpace
23
+ from pygsti .baseobjs import (
24
+ StateSpace as _StateSpace ,
25
+ statespace as _statespace ,
26
+ ExplicitStateSpace as _ExplicitStateSpace
27
+ )
23
28
from pygsti .modelmembers import operations as _op
24
29
from pygsti .baseobjs .basis import Basis as _Basis , BuiltinBasis as _BuiltinBasis
25
30
from pygsti .baseobjs .nicelyserializable import NicelySerializable as _NicelySerializable
26
31
from pygsti .evotypes .evotype import Evotype as _Evotype
27
32
from pygsti .tools .optools import superop_to_unitary , unitary_to_superop
28
- from pygsti import SpaceT
29
33
30
34
if TYPE_CHECKING :
31
35
from pygsti .modelmembers .operations import LinearOperator as _LinearOperator
@@ -98,6 +102,7 @@ def initial_params(self):
98
102
"""
99
103
return _np .zeros (self .num_params , dtype = 'd' )
100
104
105
+
101
106
class GaugeGroupElement (_NicelySerializable ):
102
107
"""
103
108
The element of a :class:`GaugeGroup`, which represents a single gauge transformation.
@@ -129,7 +134,7 @@ def transform_matrix_inverse(self) -> Optional[_np.ndarray]:
129
134
"""
130
135
return None
131
136
132
- def deriv_wrt_params (self , wrt_filter : Optional [Union [_np .ndarray , list ]]= None ) -> Optional [_np .ndarray ]:
137
+ def deriv_wrt_params (self , wrt_filter : Optional [Union [_np .ndarray , list ]] = None ) -> Optional [_np .ndarray ]:
133
138
"""
134
139
Computes the derivative of the gauge group at this element.
135
140
@@ -479,7 +484,7 @@ def transform_matrix_inverse(self) -> _np.ndarray:
479
484
self ._inv_matrix = _np .linalg .inv (self ._operation .to_dense ("minimal" ))
480
485
return self ._inv_matrix
481
486
482
- def deriv_wrt_params (self , wrt_filter : Optional [list , _np .ndarray ]= None ) -> _np .ndarray :
487
+ def deriv_wrt_params (self , wrt_filter : Optional [Union [ list , _np .ndarray ]] = None ) -> _np .ndarray :
483
488
"""
484
489
Computes the derivative of the gauge group at this element.
485
490
@@ -570,7 +575,8 @@ class FullGaugeGroup(OpGaugeGroupWithBasis):
570
575
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
571
576
"""
572
577
573
- def __init__ (self , state_space : _StateSpace , model_basis : Optional [Union [_Basis , str ]]= 'pp' , evotype : Optional [Union [_Evotype , str ]] = 'default' ):
578
+ def __init__ (self , state_space : _StateSpace , model_basis : Optional [Union [_Basis , str ]] = 'pp' ,
579
+ evotype : Optional [Union [_Evotype , str ]] = 'default' ):
574
580
state_space = _StateSpace .cast (state_space )
575
581
operation = _op .FullArbitraryOp (_np .identity (state_space .dim , 'd' ), model_basis , evotype , state_space )
576
582
OpGaugeGroupWithBasis .__init__ (self , operation , FullGaugeGroupElement , "Full" , model_basis )
@@ -618,7 +624,8 @@ class TPGaugeGroup(OpGaugeGroupWithBasis):
618
624
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
619
625
"""
620
626
621
- def __init__ (self , state_space : _StateSpace , model_basis : Optional [Union [_Basis , str ]]= 'pp' , evotype : Optional [Union [_Evotype , str ]]= 'default' ):
627
+ def __init__ (self , state_space : _StateSpace , model_basis : Optional [Union [_Basis , str ]] = 'pp' ,
628
+ evotype : Optional [Union [_Evotype , str ]] = 'default' ):
622
629
state_space = _StateSpace .cast (state_space )
623
630
operation = _op .FullTPOp (_np .identity (state_space .dim , 'd' ), model_basis , evotype , state_space )
624
631
OpGaugeGroupWithBasis .__init__ (self , operation , TPGaugeGroupElement , "TP" , model_basis )
@@ -676,7 +683,7 @@ class DiagGaugeGroup(OpGaugeGroup):
676
683
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
677
684
"""
678
685
679
- def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]]= 'default' ):
686
+ def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]] = 'default' ):
680
687
state_space = _StateSpace .cast (state_space )
681
688
dim = state_space .dim
682
689
ltrans = _np .identity (dim , 'd' )
@@ -727,7 +734,7 @@ class TPDiagGaugeGroup(TPGaugeGroup):
727
734
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
728
735
"""
729
736
730
- def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]]= 'default' ):
737
+ def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]] = 'default' ):
731
738
"""
732
739
Create a new gauge group with gauge-transform dimension `dim`, which
733
740
should be the same as `mdl.dim` where `mdl` is a :class:`Model` you
@@ -784,7 +791,7 @@ def __init__(self, operation: _LinearOperator):
784
791
785
792
@property
786
793
def operation (self ) -> _LinearOperator :
787
- return self ._operation
794
+ return self ._operation
788
795
789
796
790
797
class UnitaryGaugeGroup (OpGaugeGroupWithBasis ):
@@ -810,7 +817,8 @@ class UnitaryGaugeGroup(OpGaugeGroupWithBasis):
810
817
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
811
818
"""
812
819
813
- def __init__ (self , state_space : _StateSpace , basis : Optional [Union [_Basis , str ]], evotype : Optional [Union [_Evotype , str ]]= 'default' ):
820
+ def __init__ (self , state_space : _StateSpace , basis : Optional [Union [_Basis , str ]],
821
+ evotype : Optional [Union [_Evotype , str ]] = 'default' ):
814
822
state_space = _StateSpace .cast (state_space )
815
823
evotype = _Evotype .cast (str (evotype ), default_prefer_dense_reps = True ) # since we use deriv_wrt_params
816
824
errgen = _op .LindbladErrorgen .from_operation_matrix (
@@ -846,7 +854,7 @@ class SpamGaugeGroup(OpGaugeGroup):
846
854
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
847
855
"""
848
856
849
- def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]]= 'default' ):
857
+ def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]] = 'default' ):
850
858
"""
851
859
Create a new gauge group with gauge-transform dimension `dim`, which
852
860
should be the same as `mdl.dim` where `mdl` is a :class:`Model` you
@@ -904,7 +912,7 @@ class TPSpamGaugeGroup(OpGaugeGroup):
904
912
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
905
913
"""
906
914
907
- def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]]= 'default' ):
915
+ def __init__ (self , state_space : _StateSpace , evotype : Optional [Union [_Evotype , str ]] = 'default' ):
908
916
"""
909
917
Create a new gauge group with gauge-transform dimension `dim`, which
910
918
should be the same as `mdl.dim` where `mdl` is a :class:`Model` you
@@ -988,7 +996,7 @@ def compute_element(self, param_vec: _np.ndarray):
988
996
-------
989
997
TrivialGaugeGroupElement
990
998
"""
991
- assert (len (param_vec ) == 0 )
999
+ assert (len (param_vec ) == 0 )
992
1000
return TrivialGaugeGroupElement (self .state_space .dim )
993
1001
994
1002
@property
@@ -1096,7 +1104,7 @@ def from_vector(self, v: _np.ndarray) -> None:
1096
1104
-------
1097
1105
None
1098
1106
"""
1099
- assert (len (v ) == 0 )
1107
+ assert (len (v ) == 0 )
1100
1108
1101
1109
@property
1102
1110
def num_params (self ) -> int :
@@ -1124,12 +1132,13 @@ class DirectSumUnitaryGroup(GaugeGroup):
1124
1132
A subgroup of the unitary group, where the unitary operators in the group all have a
1125
1133
shared block-diagonal structure.
1126
1134
1127
- Example setting where this is useful:
1135
+ Example setting where this is useful:
1128
1136
The system's Hilbert space is naturally expressed as a direct sum, H = U ⨁ V,
1129
1137
and we want gauge optimization to preserve the natural separation between U and V.
1130
1138
"""
1131
1139
1132
- def __init__ (self , subgroups : Tuple [Union [UnitaryGaugeGroup , TrivialGaugeGroup ], ...], basis , name = "Direct sum gauge group" ):
1140
+ def __init__ (self , subgroups : Tuple [Union [UnitaryGaugeGroup , TrivialGaugeGroup ], ...],
1141
+ basis , name = "Direct sum gauge group" ):
1133
1142
self .subgroups = subgroups
1134
1143
if isinstance (basis , _Basis ):
1135
1144
self .basis = basis
@@ -1223,7 +1232,7 @@ def from_vector(self, v):
1223
1232
offset += se .num_params
1224
1233
self ._update_matrices ()
1225
1234
return
1226
-
1235
+
1227
1236
def _update_matrices (self ):
1228
1237
u_blocks , num_params = [], []
1229
1238
for se in self .subelements :
0 commit comments