@@ -323,9 +323,9 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
323
323
if self .params .get ('online' , 'path_to_model' ) is None or self .params .get ('online' , 'sniper_mode' ) is False :
324
324
loaded_model = None
325
325
self .params .set ('online' , {'sniper_mode' : False })
326
- # self.tf_in = None
327
- # self.tf_out = None
328
- self .use_torch = None #fix
326
+ self .tf_in = None
327
+ self .tf_out = None
328
+ # self.use_torch = None
329
329
else :
330
330
try :
331
331
from keras .models import load_model
@@ -340,12 +340,12 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
340
340
# uses online model -> be careful
341
341
model_path = "." .join (path + ["keras" ])
342
342
loaded_model = model_load (model_path )
343
- self .use_torch = False
343
+ # self.use_torch = False
344
344
else :
345
345
model_path = '.' .join (path + ['pt' ])
346
346
loaded_model = load_graph (model_path )
347
347
loaded_model = torch .load (model_file )
348
- self .use_torch = True
348
+ # self.use_torch = True
349
349
350
350
self .loaded_model = loaded_model
351
351
@@ -547,8 +547,8 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
547
547
sniper_mode = self .params .get ('online' , 'sniper_mode' ),
548
548
use_peak_max = self .params .get ('online' , 'use_peak_max' ),
549
549
mean_buff = self .estimates .mean_buff ,
550
- # tf_in=self.tf_in, tf_out=self.tf_out,
551
- use_torch = self .use_torch ,
550
+ tf_in = self .tf_in , tf_out = self .tf_out ,
551
+ # use_torch=self.use_torch,
552
552
ssub_B = ssub_B , W = self .estimates .W if self .is1p else None ,
553
553
b0 = self .estimates .b0 if self .is1p else None ,
554
554
corr_img = self .estimates .corr_img if use_corr else None ,
@@ -2003,8 +2003,8 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
2003
2003
patch_size = 50 , loaded_model = None , test_both = False ,
2004
2004
thresh_CNN_noisy = 0.5 , use_peak_max = False ,
2005
2005
thresh_std_peak_resid = 1 , mean_buff = None ,
2006
- # tf_in=None, tf_out=None):
2007
- use_torch = None ):
2006
+ tf_in = None , tf_out = None ):
2007
+ # use_torch=None):
2008
2008
"""
2009
2009
Extract new candidate components from the residual buffer and test them
2010
2010
using space correlation or the CNN classifier. The function runs the CNN
@@ -2146,8 +2146,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
2146
2146
corr_img = None , first_moment = None , second_moment = None ,
2147
2147
crosscorr = None , col_ind = None , row_ind = None , corr_img_mode = None ,
2148
2148
max_img = None , downscale_matrix = None , upscale_matrix = None ,
2149
- # tf_in=None, tf_out=None):
2150
- torch_in = None , torch_out = None ):
2149
+ tf_in = None , tf_out = None ):
2150
+ # torch_in=None, torch_out=None):
2151
2151
"""
2152
2152
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
2153
2153
"""
@@ -2177,8 +2177,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
2177
2177
sniper_mode = sniper_mode , rval_thr = rval_thr , patch_size = 50 ,
2178
2178
loaded_model = loaded_model , thresh_CNN_noisy = thresh_CNN_noisy ,
2179
2179
use_peak_max = use_peak_max , test_both = test_both , mean_buff = mean_buff ,
2180
- # tf_in=tf_in, tf_out=tf_out)
2181
- torch_in = torch_in , torch_out = torch_out )
2180
+ tf_in = tf_in , tf_out = tf_out )
2181
+ # torch_in=torch_in, torch_out=torch_out)
2182
2182
2183
2183
ind_new_all = ijsig_all
2184
2184
0 commit comments