10
10
11
11
from qutip_qoc ._optimizer import _global_local_optimization
12
12
from qutip_qoc ._time import _TimeInterval
13
- from qutip_qoc ._rl import _RL
13
+
14
+ import qutip as qt
14
15
from qutip_qoc ._genetic import _GENETIC
15
16
17
+ try :
18
+ from qutip_qoc ._rl import _RL
19
+ _rl_available = True
20
+ except ImportError :
21
+ _rl_available = False
22
+
16
23
__all__ = ["optimize_pulses" ]
17
24
18
25
@@ -24,6 +31,7 @@ def optimize_pulses(
24
31
optimizer_kwargs = None ,
25
32
minimizer_kwargs = None ,
26
33
integrator_kwargs = None ,
34
+ optimization_type = None ,
27
35
):
28
36
"""
29
37
Run GOAT, JOPT, GRAPE, CRAB or RL optimization.
@@ -120,6 +128,11 @@ def optimize_pulses(
120
128
Options for the solver, see :obj:`MESolver.options` and
121
129
`Integrator <./classes.html#classes-ode>`_ for a list of all options.
122
130
131
+ optimization_type : str, optional
132
+ Type of optimization. By default, QuTiP-QOC will try to automatically determine
133
+ whether this is a *state transfer* or a *gate synthesis* problem. Set this
134
+ flag to ``"state_transfer"`` or ``"gate_synthesis"`` to set the mode manually.
135
+
123
136
Returns
124
137
-------
125
138
result : :class:`qutip_qoc.Result`
@@ -183,10 +196,43 @@ def optimize_pulses(
183
196
"maxiter" : algorithm_kwargs .get ("max_iter" , 1000 ),
184
197
"gtol" : algorithm_kwargs .get ("min_grad" , 0.0 if alg == "CRAB" else 1e-8 ),
185
198
}
199
+ # Iterate over objectives and convert initial and target states based on the optimization type
200
+ for objective in objectives :
201
+ H_list = objective .H if isinstance (objective .H , list ) else [objective .H ]
202
+ if any (qt .issuper (H_i ) for H_i in H_list ):
203
+ if isinstance (optimization_type , str ) and optimization_type .lower () == "state_transfer" :
204
+ if qt .isket (objective .initial ):
205
+ objective .initial = qt .operator_to_vector (qt .ket2dm (objective .initial ))
206
+ elif qt .isoper (objective .initial ):
207
+ objective .initial = qt .operator_to_vector (objective .initial )
208
+ if qt .isket (objective .target ):
209
+ objective .target = qt .operator_to_vector (qt .ket2dm (objective .target ))
210
+ elif qt .isoper (objective .target ):
211
+ objective .target = qt .operator_to_vector (objective .target )
212
+ elif isinstance (optimization_type , str ) and optimization_type .lower () == "gate_synthesis" :
213
+ objective .initial = qt .to_super (objective .initial )
214
+ objective .target = qt .to_super (objective .target )
215
+ elif optimization_type is None :
216
+ if qt .isoper (objective .initial ) and qt .isoper (objective .target ):
217
+ if np .isclose ((objective .initial ).tr (), 1 ) and np .isclose ((objective .target ).tr (), 1 ):
218
+ objective .initial = qt .operator_to_vector (objective .initial )
219
+ objective .target = qt .operator_to_vector (objective .target )
220
+ else :
221
+ objective .initial = qt .to_super (objective .initial )
222
+ objective .target = qt .to_super (objective .target )
223
+ if qt .isket (objective .initial ):
224
+ objective .initial = qt .operator_to_vector (qt .ket2dm (objective .initial ))
225
+ if qt .isket (objective .target ):
226
+ objective .target = qt .operator_to_vector (qt .ket2dm (objective .target ))
186
227
187
228
# prepare qtrl optimizers
188
229
qtrl_optimizers = []
189
230
if alg == "CRAB" or alg == "GRAPE" :
231
+ dyn_type = "GEN_MAT"
232
+ for objective in objectives :
233
+ if any (qt .isoper (H_i ) for H_i in (objective .H if isinstance (objective .H , list ) else [objective .H ])):
234
+ dyn_type = "UNIT"
235
+
190
236
if alg == "GRAPE" : # algorithm specific kwargs
191
237
use_as_amps = True
192
238
minimizer_kwargs .setdefault ("method" , "L-BFGS-B" ) # gradient
@@ -243,7 +289,7 @@ def optimize_pulses(
243
289
"accuracy_factor" : None , # deprecated
244
290
"alg_params" : alg_params ,
245
291
"optim_params" : algorithm_kwargs .get ("optim_params" , None ),
246
- "dyn_type" : algorithm_kwargs .get ("dyn_type" , "GEN_MAT" ),
292
+ "dyn_type" : algorithm_kwargs .get ("dyn_type" , dyn_type ),
247
293
"dyn_params" : algorithm_kwargs .get ("dyn_params" , None ),
248
294
"prop_type" : algorithm_kwargs .get (
249
295
"prop_type" , "DEF"
@@ -354,6 +400,12 @@ def optimize_pulses(
354
400
qtrl_optimizers .append (qtrl_optimizer )
355
401
356
402
elif alg == "RL" :
403
+ if not _rl_available :
404
+ raise ImportError (
405
+ "The required dependencies (gymnasium, stable-baselines3) for "
406
+ "the reinforcement learning algorithm are not available."
407
+ )
408
+
357
409
rl_env = _RL (
358
410
objectives ,
359
411
control_parameters ,
@@ -393,4 +445,4 @@ def optimize_pulses(
393
445
minimizer_kwargs ,
394
446
integrator_kwargs ,
395
447
qtrl_optimizers ,
396
- )
448
+ )
0 commit comments