Skip to content

Commit 9532291

Browse files
author
mannypaeza
committed
torch dev for ring_cnn + 2p spatial
1 parent e33642e commit 9532291

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

caiman/source_extraction/cnmf/online_cnmf.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,9 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
323323
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
324324
loaded_model = None
325325
self.params.set('online', {'sniper_mode': False})
326-
# self.tf_in = None
327-
# self.tf_out = None
328-
self.use_torch = None #fix
326+
self.tf_in = None
327+
self.tf_out = None
328+
# self.use_torch = None
329329
else:
330330
try:
331331
from keras.models import load_model
@@ -340,12 +340,12 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
340340
# uses online model -> be careful
341341
model_path = ".".join(path + ["keras"])
342342
loaded_model = model_load(model_path)
343-
self.use_torch = False
343+
# self.use_torch = False
344344
else:
345345
model_path = '.'.join(path + ['pt'])
346346
loaded_model = load_graph(model_path)
347347
loaded_model = torch.load(model_file)
348-
self.use_torch = True
348+
# self.use_torch = True
349349

350350
self.loaded_model = loaded_model
351351

@@ -547,8 +547,8 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
547547
sniper_mode=self.params.get('online', 'sniper_mode'),
548548
use_peak_max=self.params.get('online', 'use_peak_max'),
549549
mean_buff=self.estimates.mean_buff,
550-
# tf_in=self.tf_in, tf_out=self.tf_out,
551-
use_torch=self.use_torch,
550+
tf_in=self.tf_in, tf_out=self.tf_out,
551+
# use_torch=self.use_torch,
552552
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
553553
b0=self.estimates.b0 if self.is1p else None,
554554
corr_img=self.estimates.corr_img if use_corr else None,
@@ -2003,8 +2003,8 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
20032003
patch_size=50, loaded_model=None, test_both=False,
20042004
thresh_CNN_noisy=0.5, use_peak_max=False,
20052005
thresh_std_peak_resid = 1, mean_buff=None,
2006-
# tf_in=None, tf_out=None):
2007-
use_torch=None):
2006+
tf_in=None, tf_out=None):
2007+
# use_torch=None):
20082008
"""
20092009
Extract new candidate components from the residual buffer and test them
20102010
using space correlation or the CNN classifier. The function runs the CNN
@@ -2146,8 +2146,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
21462146
corr_img=None, first_moment=None, second_moment=None,
21472147
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
21482148
max_img=None, downscale_matrix=None, upscale_matrix=None,
2149-
# tf_in=None, tf_out=None):
2150-
torch_in=None, torch_out=None):
2149+
tf_in=None, tf_out=None):
2150+
# torch_in=None, torch_out=None):
21512151
"""
21522152
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
21532153
"""
@@ -2177,8 +2177,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
21772177
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
21782178
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
21792179
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
2180-
# tf_in=tf_in, tf_out=tf_out)
2181-
torch_in=torch_in, torch_out=torch_out)
2180+
tf_in=tf_in, tf_out=tf_out)
2181+
#torch_in=torch_in, torch_out=torch_out)
21822182

21832183
ind_new_all = ijsig_all
21842184

caiman/utils/nn_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32,
555555
Y = np.expand_dims(Y, axis=-1)
556556
run_logdir = get_run_logdir()
557557
os.mkdir(run_logdir)
558-
path_to_model = os.path.join(run_logdir, 'model.h5')
558+
path_to_model = os.path.join(run_logdir, 'model.weights.h5')
559559
chk = ModelCheckpoint(filepath=path_to_model,
560560
verbose=0, save_best_only=True, save_weights_only=True)
561561
es = EarlyStopping(monitor='val_loss', patience=patience,
@@ -566,7 +566,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32,
566566
history_NL = model_NL.fit(Y, Y, epochs=epochs, batch_size=batch_size,
567567
shuffle=True, validation_split=val_split,
568568
callbacks=callbacks)
569-
model_NL.load_weights(os.path.join(run_logdir, 'model.h5'))
569+
model_NL.load_weights(os.path.join(run_logdir, 'model.weights.h5'))
570570
return model_NL, history_NL, path_to_model
571571

572572
def get_MCNN_model(Y, gSig=5, n_channels=8, lr=1e-4, pct=10, r_factor=1.5,

0 commit comments

Comments
 (0)