Skip to content

Commit 6004120

Browse files
committed
Correct the generation of cfg.loo
1 parent 93df2d0 commit 6004120

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

cellbox/cellbox/dataset.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ def factory(cfg):
2424
cfg.expr_out = tf.compat.v1.placeholder(tf.float32, [None, cfg.n_x], name='expr_out')
2525
cfg.pert = pd.read_csv(os.path.join(cfg.root_dir, cfg.pert_file), header=None, dtype=np.float32)
2626
cfg.expr = pd.read_csv(os.path.join(cfg.root_dir, cfg.expr_file), header=None, dtype=np.float32)
27-
cfg.loo = np.vstack(np.where(cfg.pert!=0)).T + 1
27+
group_df = pd.DataFrame(np.where(cfg.pert != 0), index=['row_id', 'pert_idx']).T.groupby('row_id')
28+
max_combo_degree = group_df.pert_idx.count().max()
29+
cfg.loo = pd.DataFrame(group_df.pert_idx.apply(
30+
lambda x: pad_and_realign(x, max_combo_degree, cfg.n_activity_nodes - 1)
31+
).tolist())
2832

2933
# add noise
3034
if cfg.add_noise_level > 0:
@@ -68,6 +72,12 @@ def factory(cfg):
6872
return cfg
6973

7074

75+
def pad_and_realign(x, length, idx_shift=0):
76+
x -= idx_shift
77+
padded = np.pad(x, (0, length - len(x)), 'constant')
78+
return padded
79+
80+
7181
def get_tensors(cfg):
7282
# prepare training placeholders
7383
cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda')

0 commit comments

Comments
 (0)