Skip to content

Commit ac630c7

Browse files
committed
Bug fix
1 parent 3b39d80 commit ac630c7

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

pygsti/models/gaugegroup.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,22 @@
1414
from typing import (
1515
Optional,
1616
Union,
17+
Tuple,
1718
TYPE_CHECKING
1819
)
1920

2021
import scipy.linalg as _la
2122

22-
from pygsti.baseobjs import StateSpace as _StateSpace, statespace as _statespace, ExplicitStateSpace as _ExplicitStateSpace
23+
from pygsti.baseobjs import (
24+
StateSpace as _StateSpace,
25+
statespace as _statespace,
26+
ExplicitStateSpace as _ExplicitStateSpace
27+
)
2328
from pygsti.modelmembers import operations as _op
2429
from pygsti.baseobjs.basis import Basis as _Basis, BuiltinBasis as _BuiltinBasis
2530
from pygsti.baseobjs.nicelyserializable import NicelySerializable as _NicelySerializable
2631
from pygsti.evotypes.evotype import Evotype as _Evotype
2732
from pygsti.tools.optools import superop_to_unitary, unitary_to_superop
28-
from pygsti import SpaceT
2933

3034
if TYPE_CHECKING:
3135
from pygsti.modelmembers.operations import LinearOperator as _LinearOperator
@@ -98,6 +102,7 @@ def initial_params(self):
98102
"""
99103
return _np.zeros(self.num_params, dtype='d')
100104

105+
101106
class GaugeGroupElement(_NicelySerializable):
102107
"""
103108
The element of a :class:`GaugeGroup`, which represents a single gauge transformation.
@@ -129,7 +134,7 @@ def transform_matrix_inverse(self) -> Optional[_np.ndarray]:
129134
"""
130135
return None
131136

132-
def deriv_wrt_params(self, wrt_filter: Optional[Union[_np.ndarray, list]]=None) -> Optional[_np.ndarray]:
137+
def deriv_wrt_params(self, wrt_filter: Optional[Union[_np.ndarray, list]] = None) -> Optional[_np.ndarray]:
133138
"""
134139
Computes the derivative of the gauge group at this element.
135140
@@ -479,7 +484,7 @@ def transform_matrix_inverse(self) -> _np.ndarray:
479484
self._inv_matrix = _np.linalg.inv(self._operation.to_dense("minimal"))
480485
return self._inv_matrix
481486

482-
def deriv_wrt_params(self, wrt_filter: Optional[list, _np.ndarray]=None) -> _np.ndarray:
487+
def deriv_wrt_params(self, wrt_filter: Optional[Union[list, _np.ndarray]] = None) -> _np.ndarray:
483488
"""
484489
Computes the derivative of the gauge group at this element.
485490
@@ -570,7 +575,8 @@ class FullGaugeGroup(OpGaugeGroupWithBasis):
570575
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
571576
"""
572577

573-
def __init__(self, state_space: _StateSpace, model_basis: Optional[Union[_Basis, str]]='pp', evotype: Optional[Union [_Evotype, str]] ='default'):
578+
def __init__(self, state_space: _StateSpace, model_basis: Optional[Union[_Basis, str]] = 'pp',
579+
evotype: Optional[Union[_Evotype, str]] = 'default'):
574580
state_space = _StateSpace.cast(state_space)
575581
operation = _op.FullArbitraryOp(_np.identity(state_space.dim, 'd'), model_basis, evotype, state_space)
576582
OpGaugeGroupWithBasis.__init__(self, operation, FullGaugeGroupElement, "Full", model_basis)
@@ -618,7 +624,8 @@ class TPGaugeGroup(OpGaugeGroupWithBasis):
618624
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
619625
"""
620626

621-
def __init__(self, state_space: _StateSpace, model_basis: Optional[Union[_Basis, str]]='pp', evotype: Optional[Union[_Evotype, str]]='default'):
627+
def __init__(self, state_space: _StateSpace, model_basis: Optional[Union[_Basis, str]] = 'pp',
628+
evotype: Optional[Union[_Evotype, str]] = 'default'):
622629
state_space = _StateSpace.cast(state_space)
623630
operation = _op.FullTPOp(_np.identity(state_space.dim, 'd'), model_basis, evotype, state_space)
624631
OpGaugeGroupWithBasis.__init__(self, operation, TPGaugeGroupElement, "TP", model_basis)
@@ -676,7 +683,7 @@ class DiagGaugeGroup(OpGaugeGroup):
676683
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
677684
"""
678685

679-
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]]='default'):
686+
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]] = 'default'):
680687
state_space = _StateSpace.cast(state_space)
681688
dim = state_space.dim
682689
ltrans = _np.identity(dim, 'd')
@@ -727,7 +734,7 @@ class TPDiagGaugeGroup(TPGaugeGroup):
727734
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
728735
"""
729736

730-
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]]='default'):
737+
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]] = 'default'):
731738
"""
732739
Create a new gauge group with gauge-transform dimension `dim`, which
733740
should be the same as `mdl.dim` where `mdl` is a :class:`Model` you
@@ -784,7 +791,7 @@ def __init__(self, operation: _LinearOperator):
784791

785792
@property
786793
def operation(self) -> _LinearOperator:
787-
return self._operation
794+
return self._operation
788795

789796

790797
class UnitaryGaugeGroup(OpGaugeGroupWithBasis):
@@ -810,7 +817,8 @@ class UnitaryGaugeGroup(OpGaugeGroupWithBasis):
810817
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
811818
"""
812819

813-
def __init__(self, state_space: _StateSpace, basis: Optional[Union[_Basis, str]], evotype: Optional[Union[_Evotype, str]]='default'):
820+
def __init__(self, state_space: _StateSpace, basis: Optional[Union[_Basis, str]],
821+
evotype: Optional[Union[_Evotype, str]] = 'default'):
814822
state_space = _StateSpace.cast(state_space)
815823
evotype = _Evotype.cast(str(evotype), default_prefer_dense_reps=True) # since we use deriv_wrt_params
816824
errgen = _op.LindbladErrorgen.from_operation_matrix(
@@ -846,7 +854,7 @@ class SpamGaugeGroup(OpGaugeGroup):
846854
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
847855
"""
848856

849-
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]]='default'):
857+
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]] = 'default'):
850858
"""
851859
Create a new gauge group with gauge-transform dimension `dim`, which
852860
should be the same as `mdl.dim` where `mdl` is a :class:`Model` you
@@ -904,7 +912,7 @@ class TPSpamGaugeGroup(OpGaugeGroup):
904912
to specifying the value of `pygsti.evotypes.Evotype.default_evotype`.
905913
"""
906914

907-
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]]='default'):
915+
def __init__(self, state_space: _StateSpace, evotype: Optional[Union[_Evotype, str]] = 'default'):
908916
"""
909917
Create a new gauge group with gauge-transform dimension `dim`, which
910918
should be the same as `mdl.dim` where `mdl` is a :class:`Model` you
@@ -988,7 +996,7 @@ def compute_element(self, param_vec: _np.ndarray):
988996
-------
989997
TrivialGaugeGroupElement
990998
"""
991-
assert(len(param_vec) == 0)
999+
assert (len(param_vec) == 0)
9921000
return TrivialGaugeGroupElement(self.state_space.dim)
9931001

9941002
@property
@@ -1096,7 +1104,7 @@ def from_vector(self, v: _np.ndarray) -> None:
10961104
-------
10971105
None
10981106
"""
1099-
assert(len(v) == 0)
1107+
assert (len(v) == 0)
11001108

11011109
@property
11021110
def num_params(self) -> int:
@@ -1124,12 +1132,13 @@ class DirectSumUnitaryGroup(GaugeGroup):
11241132
A subgroup of the unitary group, where the unitary operators in the group all have a
11251133
shared block-diagonal structure.
11261134
1127-
Example setting where this is useful:
1135+
Example setting where this is useful:
11281136
The system's Hilbert space is naturally expressed as a direct sum, H = U ⨁ V,
11291137
and we want gauge optimization to preserve the natural separation between U and V.
11301138
"""
11311139

1132-
def __init__(self, subgroups: Tuple[Union[UnitaryGaugeGroup, TrivialGaugeGroup], ...], basis, name="Direct sum gauge group"):
1140+
def __init__(self, subgroups: Tuple[Union[UnitaryGaugeGroup, TrivialGaugeGroup], ...],
1141+
basis, name="Direct sum gauge group"):
11331142
self.subgroups = subgroups
11341143
if isinstance(basis, _Basis):
11351144
self.basis = basis
@@ -1223,7 +1232,7 @@ def from_vector(self, v):
12231232
offset += se.num_params
12241233
self._update_matrices()
12251234
return
1226-
1235+
12271236
def _update_matrices(self):
12281237
u_blocks, num_params = [], []
12291238
for se in self.subelements:

0 commit comments

Comments
 (0)