33import time
44from qutip_qoc .result import Result
55
6- class _GENETIC :
6+
7+ class _Genetic :
78 """
89 Genetic Algorithm-based optimizer for quantum control problems.
910
1011 This class implements a global optimization routine using a Genetic Algorithm
11- to optimize control pulses that drive a quantum system from an initial state
12- to a target state (or unitary).
12+ to optimize control pulses that drive a quantum system from an initial state
13+ to a target state (or unitary).
14+
15+ Based on the code from Jonathan Brown
1316 """
1417
1518 def __init__ (
@@ -26,7 +29,9 @@ def __init__(
2629 ):
2730 self ._objective = objectives [0 ]
2831 self ._Hd = self ._objective .H [0 ]
29- self ._Hc_lst = [H [0 ] if isinstance (H , list ) else H for H in self ._objective .H [1 :]]
32+ self ._Hc_lst = [
33+ H [0 ] if isinstance (H , list ) else H for H in self ._objective .H [1 :]
34+ ]
3035 self ._initial = self ._objective .initial
3136 self ._target = self ._objective .target
3237 self ._norm_fac = 1 / self ._target .norm ()
@@ -41,14 +46,17 @@ def __init__(
4146 self .generations = alg_kwargs .get ("generations" , 100 )
4247 self .mutation_rate = alg_kwargs .get ("mutation_rate" , 0.3 )
4348 self .fid_err_targ = alg_kwargs .get ("fid_err_targ" , 1e-4 )
44- self ._stagnation_patience = 20 # Internally fixed
49+ self ._stagnation_patience = 50 # Internally fixed
4550
4651 self ._integrator_kwargs = integrator_kwargs
4752 self ._fid_type = alg_kwargs .get ("fid_type" , "PSU" )
4853
4954 self ._generator = self ._prepare_generator ()
50- self ._solver = qt .MESolver (H = self ._generator , options = self ._integrator_kwargs ) \
51- if self ._Hd .issuper else qt .SESolver (H = self ._generator , options = self ._integrator_kwargs )
55+ self ._solver = (
56+ qt .MESolver (H = self ._generator , options = self ._integrator_kwargs )
57+ if self ._Hd .issuper
58+ else qt .SESolver (H = self ._generator , options = self ._integrator_kwargs )
59+ )
5260
5361 self ._result = Result (
5462 objectives = [self ._objective ],
@@ -64,10 +72,18 @@ def __init__(
6472 self ._result ._final_states = []
6573
6674 def _prepare_generator (self ):
67- args = {f"p{ i + 1 } _{ j } " : 0.0 for i in range (self .N_controls ) for j in range (self .N_steps )}
75+ args = {
76+ f"p{ i + 1 } _{ j } " : 0.0
77+ for i in range (self .N_controls )
78+ for j in range (self .N_steps )
79+ }
6880
6981 def make_coeff (i , j ):
70- return lambda t , args : args [f"p{ i + 1 } _{ j } " ] if int (t / (self ._evo_time / self .N_steps )) == j else 0
82+ return lambda t , args : (
83+ args [f"p{ i + 1 } _{ j } " ]
84+ if int (t / (self ._evo_time / self .N_steps )) == j
85+ else 0
86+ )
7187
7288 H_qev = [self ._Hd ]
7389 for i , Hc in enumerate (self ._Hc_lst ):
@@ -77,7 +93,11 @@ def make_coeff(i, j):
7793 return qt .QobjEvo (H_qev , args = args )
7894
7995 def _infid (self , params ):
80- args = {f"p{ i + 1 } _{ j } " : params [i * self .N_steps + j ] for i in range (self .N_controls ) for j in range (self .N_steps )}
96+ args = {
97+ f"p{ i + 1 } _{ j } " : params [i * self .N_steps + j ]
98+ for i in range (self .N_controls )
99+ for j in range (self .N_steps )
100+ }
81101 result = self ._solver .run (self ._initial , [0.0 , self ._evo_time ], args = args )
82102 final_state = result .final_state
83103 self ._result ._final_states .append (final_state )
@@ -87,7 +107,9 @@ def _infid(self, params):
87107 fid = 0.5 * np .real ((diff .dag () * diff ).tr ())
88108 else :
89109 overlap = self ._norm_fac * self ._target .overlap (final_state )
90- fid = 1 - np .abs (overlap ) if self ._fid_type == "PSU" else 1 - np .real (overlap )
110+ fid = (
111+ 1 - np .abs (overlap ) if self ._fid_type == "PSU" else 1 - np .real (overlap )
112+ )
91113
92114 return fid
93115
@@ -101,7 +123,7 @@ def initial_population(self):
101123 return np .random .uniform (- 1 , 1 , (self .N_pop , self .N_var ))
102124
103125 def darwin (self , population , fitness ):
104- indices = np .argsort (- fitness )[:self .N_pop // 2 ]
126+ indices = np .argsort (- fitness )[: self .N_pop // 2 ]
105127 return population [indices ], fitness [indices ]
106128
107129 def pairing (self , survivors , survivor_fitness ):
@@ -130,7 +152,9 @@ def build_next_gen(self, survivors, offspring):
130152 return np .vstack ((survivors , offspring ))
131153
132154 def mutate (self , population ):
133- n_mut = int ((population .shape [0 ] - 1 ) * population .shape [1 ] * self .mutation_rate )
155+ n_mut = int (
156+ (population .shape [0 ] - 1 ) * population .shape [1 ] * self .mutation_rate
157+ )
134158 row = np .random .randint (1 , population .shape [0 ], size = n_mut )
135159 col = np .random .randint (0 , population .shape [1 ], size = n_mut )
136160 population [row , col ] += np .random .normal (0 , 0.3 , size = n_mut )
@@ -183,13 +207,15 @@ def optimize(self):
183207
184208 self ._result .message = (
185209 f"Stopped early: reached infidelity target { self .fid_err_targ } "
186- if - best_fit <= self .fid_err_targ else
187- f"Stopped due to stagnation after { self ._stagnation_patience } generations"
188- if no_improvement_counter >= self ._stagnation_patience else
189- "Optimization completed successfully"
210+ if - best_fit <= self .fid_err_targ
211+ else (
212+ f"Stopped due to stagnation after { self ._stagnation_patience } generations"
213+ if no_improvement_counter >= self ._stagnation_patience
214+ else "Optimization completed successfully"
215+ )
190216 )
191217 return self ._result
192-
218+
193219 def result (self ):
194220 self ._result .start_local_time = time .strftime (
195221 "%Y-%m-%d %H:%M:%S" , time .localtime (self ._result .start_local_time )
0 commit comments