Skip to content

Commit 9f8c441

Browse files
authored
Add H5 GenomicFeatures support for more flexible target datatypes (#200)
* selene changes for methylation model * remove f1 metric * bugfix and some temporary profiling * adjustments to loss breakdown and logging * adjust return type for get_features_data * bugfix for get_feature_data * remove event object init * initial changes to h5 dataloader * methylation sampler excl * add type specification to unpackbits * change loss * experimenting w loss * add pearsonr * fix metrix nan removal bug * changes to sampler for positives-only hdf5 * loss weighting * adjust methylation perf metric output checking * revert multisampler and add spearmanr to training log * explicit metric fn checking, refine later * minor adjustment to dataloader * non strand specific temp changes * positives only sampler * revamp dataloader, seq length flexibility * attempted a nonstrandspecific utils module, not being used * evaluation classes * revert non strand specific module * trying to figure out unet dataloader changes * trying to figure out unet dataloader changes - add tgt shift * remove print debug statements * addr memory issue in eval * comment * troubleshooting dataloader shift * add strand arg to _retrieve * shift testing * adjust casting of targets in dataloader * minor changes to eval and sampling * remove unused code in nonstrandspecific module * remove indicators from dataloader, clean up methylation performance metrics * remove unused files from version control, adjust multi sampler get dataset batches function * remove commented out code in shift sections * remove ind commented out code in multisampler * add excl chr optional arg * make adjustments to nonstrandspecific, performancemetrics, for what can be generalized from the methylation specific code * remove files that we have merged the functionality into existing classes * clean up indicator code that is no longer used * integrate changes for methylation prediction in training and metrics * incorporate changes from previous PR on config yaml and model file saving * update pytorch version constraint * variable name fix for copying model file / directory * remove train methylation model class * fix bug in random positions sampler with exclude_chrs, overload targets_path in both sampler classes * remove unused method in genomicfeaturesh5 * adjust strand vs feature (target) column ordering assumptions in tabix-indexed BED file * adjust descriptions for target classes * minor adjustments to docstrings * tuple output handling for non strand specific * line breaks for formatting in performance metrics file * adjust CLI config parsing so that a copy of the input config file is made and saved to the output dir * update versioning * add new dependency * adjustment in * update release notes for 0.6.0 * adjustment for sampling at the end of N_samples * adjust expected BED file format when strand is included for seqweaver script * adjust method in seqweaver script to accept lr input and new config loading function
1 parent ad4cd51 commit 9f8c441

20 files changed

+621
-155
lines changed

RELEASE_NOTES.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,26 @@
22

33
This is a document describing new functionality, bug fixes, breaking changes, etc. associated with Selene version releases from v0.5.0 onwards.
44

5+
## Version 0.6.0
6+
- `config_utils.py`: Add additional information saved upon running Selene. Specifically, we now save the version of Selene that the latest run used, make a copy of the input configuration file, and save this along with the model architecture file in the output directory. This adds a new dependency to Selene, the package `ruamel.yaml`
7+
- `H5Dataloader` and `_H5Dataset`: Previously `H5Dataloader` had a number of arguments that were used to then initialize `_H5Dataset` internally. One major change in this version is that we now ask that users initialize `_H5Dataset` explicitly and then pass it to `H5Dataloader` as a class argument. This makes the two classes consistent with the PyTorch specifications for `Dataset` and `DataLoader` classes, enabling them to be compatible with different data parallelization configurations supported by PyTorch and the PyTorch Lightning framework.
8+
- `_H5Dataset` class initialization optional arguments:
9+
- `unpackbits` can now be specified separately for sequences and targets by way of `unpackbits_seq` and `unpackbits_tgt`
10+
- `use_seq_len` enables subsetting to the center `use_seq_len` length of the sequences in the dataset.
11+
- `shift` (particularly paired with `use_seq_len`) allows for retrieving sequences shifted from the center position by `shift` bases. Note currently `shift` only allows shifting in one direction, depending on whether you pass in a positive or negative integer.
12+
- `GenomicFeaturesH5`: This is a new targets class to handle continuous-valued targets, stored in an HDF5 file, that can be retrieved based on genomic coordinate. Once again, genomic regions are stored in a tabix-indexed .bed file, with the main change being that the BED file now specifies for each genomic regions the index of the row in the HDF5 matrix that contains all the target values to predict. If multiple target rows are returned for a query region, the average of those rows is returned.
13+
- `RandomPositionsSampler`:
14+
- `exclude_chrs`: Added a new optional argument which by default excludes all nonstandard chromosomes `exclude_chrs=['_']` by ignoring all chromosomes with an underscore in the name. Pass in a list of chromosomes or substrings to exclude. When loading possible sampling positions, the class now iterates through the `exclude_chrs` list and checks for each substring `s` in list if `s in chrom`, and if so, skips that chromosome entirely.
15+
- Internal function `_retrieve` now takes in an optional argument `strand` (default `None`) to enable explicit retrieval of a sequence at `chrom, position` for a specific side. The default behavior of the `RandomPositionsSampler` class remains the same, with the strand side randomly selected for each genomic position sampled.
16+
- `PerformanceMetrics`:
17+
- Now supports `spearmanr` and `pearsonr` from `scipy.stats`. Room for improvement to generalize this class in the future.
18+
- The `update` function now has an optional argument `scores` which can pass in a subset of the metrics as `list(str)` to compute.
19+
- `TrainModel`:
20+
- `self.step` starts from `self._start_step`, which is non-zero if loaded from a Selene-saved checkpoint
21+
- removed call to `self._test_metrics.visualize` in `evaluate` since the visualize method does not generalize well.
22+
- `NonStrandSpecific`: Can now handle a model outputting two outputs in `forward`, will handle by taking either the mean or max of each of the two individual outputs for their forward and reverse predictions. A custom `NonStrandSpecific` class is recommended for more specific cases.
23+
24+
525
## Version 0.5.3
626
- Adjust dependency requirements
727

scripts/update_seqweaver.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
from selene_sdk.targets.genomic_features import GenomicFeatures
1414
from selene_sdk.samplers.dataloader import H5DataLoader
1515
from selene_sdk.train_model import TrainModel
16-
from selene_sdk.utils.config import load_path
17-
from selene_sdk.utils.config_utils import parse_configs_and_run
16+
from selene_sdk.utils import load_and_parse_configs_and_run
1817

1918
class UpdateSeqweaver():
2019
"""
2120
Stores a dataset specifying sequence regions and features.
2221
Accepts a tabix-indexed `*.bed` file with the following columns,
2322
in order:
24-
[chrom, start, end, feature, strand]
23+
[chrom, start, end, strand, feature]
2524
2625
Parameters
2726
----------
@@ -40,7 +39,15 @@ class UpdateSeqweaver():
4039
should contain the target organism's genome sequence.
4140
4241
"""
43-
def __init__(self, input_path, train_path, validate_path, feature_path, hg_fasta, yaml_path, val_prop=0.1, sequence_len=1000):
42+
def __init__(self,
43+
input_path,
44+
train_path,
45+
validate_path,
46+
feature_path,
47+
hg_fasta,
48+
yaml_path,
49+
val_prop=0.1,
50+
sequence_len=1000):
4451
"""
4552
Constructs a new `UpdateSeqweaver` object.
4653
"""
@@ -142,11 +149,6 @@ def construct_training_data(self):
142149
fh.create_dataset("train_sequences", data=np.array(training_seqs, dtype=np.int64))
143150
fh.create_dataset("train_targets", data=np.array(training_labels, dtype=np.int64))
144151

145-
def _load_yaml(self):
146-
# load yaml configuration
147-
return load_path(self.yaml_path)
148-
149-
def train_model(self):
152+
def train_model(self, lr):
150153
# load config file and train model
151-
yaml_config = self._load_yaml()
152-
parse_configs_and_run(yaml_config)
154+
load_and_parse_configs_and_run(self.yaml_path, lr=lr)

selene-cpu.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ dependencies:
1111
- matplotlib=2.0.2
1212
- numpy=1.21.4
1313
- pandas=0.20.3
14-
- python=3.6.2
14+
- python=3.9
1515
- pyyaml=5.1
16+
- ruamel.yaml=0.18.6
1617
- scikit-learn=0.19.0
1718
- scipy=1.1.0
1819
- seaborn=0.8.1

selene-gpu.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- bioconda
55
- conda-forge
66
dependencies:
7-
- cython=0.29.3
7+
- cython=0.29.24
88
- click==7.1.2
99
- docopt=0.6.2
1010
- h5py=2.9.0
@@ -15,11 +15,12 @@ dependencies:
1515
- statsmodels=0.9.0
1616
- pytabix=0.0.2
1717
- matplotlib=2.2.2
18-
- python=3.6.5
19-
- numpy=1.15.1
18+
- python=3.9
19+
- ruamel.yaml=0.18.6
20+
- numpy=1.21.4
2021
- plotly=2.7.0
2122
- cudatoolkit=10.0.130=0
22-
- pytorch=1.0.1=py3.6_cuda10.0.130_cudnn7.4.2_2
23-
- torchvision=0.2.2=py_3
23+
- pytorch=2.4.1=py3.9_cuda11.8_cudnn9.1.0_0
24+
- torchvision=0.20.0=py39_cu118
2425
- pyfaidx=0.5.5.2
2526
- seaborn=0.8.1

selene_sdk/cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import click
1515

1616
from selene_sdk import __version__
17-
from selene_sdk.utils import load_path, parse_configs_and_run
17+
from selene_sdk.utils import load_and_parse_configs_and_run
1818

1919

2020
@click.command()
@@ -23,8 +23,7 @@
2323
@click.option('--lr', type=float, help='If training, the optimizer learning rate', show_default=True)
2424
def main(path, lr):
2525
"""Build the model and trains it using user-specified input data."""
26-
configs = load_path(path, instantiate=False)
27-
parse_configs_and_run(configs, lr=lr)
26+
load_and_parse_configs_and_run(path, lr=lr)
2827

2928

3029
if __name__ == "__main__":

selene_sdk/evaluate_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,6 @@ def evaluate(self):
253253
average_scores = self._metrics.update(
254254
all_predictions, self._all_test_targets)
255255

256-
self._metrics.visualize(
257-
all_predictions, self._all_test_targets, self.output_dir)
258-
259256
np.savez_compressed(
260257
os.path.join(self.output_dir, "test_predictions.npz"),
261258
data=all_predictions)
@@ -270,4 +267,7 @@ def evaluate(self):
270267
feature_scores_dict = self._metrics.write_feature_scores_to_file(
271268
test_performance)
272269

270+
self._metrics.visualize(
271+
all_predictions, self._all_test_targets, self.output_dir)
272+
273273
return feature_scores_dict

selene_sdk/samplers/dataloader.py

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
which allow parallel sampling for any Sampler using
44
torch DataLoader mechanism.
55
"""
6+
import random
67
import sys
78

89
import h5py
@@ -125,6 +126,27 @@ def worker_init_fn(worker_id):
125126
self.seed = seed
126127

127128

129+
def unpackbits_sequence(sequence, s_len):
130+
sequence = np.unpackbits(sequence.astype(np.uint8), axis=-2)
131+
nulls = np.sum(sequence, axis=-1) == sequence.shape[-1]
132+
sequence = sequence.astype(float)
133+
sequence[nulls, :] = 1.0 / sequence.shape[-1]
134+
if sequence.ndim == 3:
135+
sequence = sequence[:, :s_len, :]
136+
else:
137+
sequence = sequence[:s_len, :]
138+
return sequence
139+
140+
141+
def unpackbits_targets(targets, t_len):
142+
targets = np.unpackbits(targets, axis=-1).astype(float)
143+
if targets.ndim == 2:
144+
targets = targets[:, :t_len]
145+
else:
146+
targets = targets[:self.t_len]
147+
return targets
148+
149+
128150
class _H5Dataset(Dataset):
129151
"""
130152
This class provides a Dataset that directly loads sequences and targets
@@ -160,13 +182,24 @@ class _H5Dataset(Dataset):
160182
def __init__(self,
161183
file_path,
162184
in_memory=False,
163-
unpackbits=False,
185+
unpackbits=False, # implies unpackbits for both
186+
unpackbits_seq=False,
187+
unpackbits_tgt=False,
164188
sequence_key="sequences",
165-
targets_key="targets"):
189+
targets_key="targets",
190+
use_seq_len=None,
191+
shift=None):
166192
super(_H5Dataset, self).__init__()
167193
self.file_path = file_path
168194
self.in_memory = in_memory
195+
169196
self.unpackbits = unpackbits
197+
self.unpackbits_seq = unpackbits_seq
198+
self.unpackbits_tgt = unpackbits_tgt
199+
200+
self.use_seq_len = use_seq_len
201+
self.shift = shift
202+
self._seq_start, self._seq_end = None, None
170203

171204
self._initialized = False
172205
self._sequence_key = sequence_key
@@ -178,15 +211,22 @@ def init(func):
178211
def dfunc(self, *args, **kwargs):
179212
if not self._initialized:
180213
self.db = h5py.File(self.file_path, 'r')
214+
181215
if self.unpackbits:
182216
self.s_len = self.db['{0}_length'.format(self._sequence_key)][()]
183217
self.t_len = self.db['{0}_length'.format(self._targets_key)][()]
218+
elif self.unpackbits_seq:
219+
self.s_len = self.db['{0}_length'.format(self._sequence_key)][()]
220+
elif self.unpackbits_tgt:
221+
self.t_len = self.db['{0}_length'.format(self._targets_key)][()]
222+
184223
if self.in_memory:
185224
self.sequences = np.asarray(self.db[self._sequence_key])
186225
self.targets = np.asarray(self.db[self._targets_key])
187226
else:
188227
self.sequences = self.db[self._sequence_key]
189228
self.targets = self.db[self._targets_key]
229+
190230
self._initialized = True
191231
return func(self, *args, **kwargs)
192232
return dfunc
@@ -195,25 +235,33 @@ def dfunc(self, *args, **kwargs):
195235
def __getitem__(self, index):
196236
if isinstance(index, int):
197237
index = index % self.sequences.shape[0]
198-
sequence = self.sequences[index, :, :]
199-
targets = self.targets[index, :]
238+
sequence = self.sequences[index]
239+
targets = self.targets[index]
240+
200241
if self.unpackbits:
201-
sequence = np.unpackbits(sequence, axis=-2)
202-
nulls = np.sum(sequence, axis=-1) == sequence.shape[-1]
203-
sequence = sequence.astype(float)
204-
sequence[nulls, :] = 1.0 / sequence.shape[-1]
205-
targets = np.unpackbits(
206-
targets, axis=-1).astype(float)
207-
if sequence.ndim == 3:
208-
sequence = sequence[:, :self.s_len, :]
209-
else:
210-
sequence = sequence[:self.s_len, :]
211-
if targets.ndim == 2:
212-
targets = targets[:, :self.t_len]
213-
else:
214-
targets = targets[:self.t_len]
215-
return (torch.from_numpy(sequence.astype(np.float32)),
216-
torch.from_numpy(targets.astype(np.float32)))
242+
sequence = unpackbits_sequence(sequence, self.s_len)
243+
targets = unpackbits_targets(targets, self.t_len)
244+
elif self.unpackbits_seq:
245+
sequence = unpackbits_sequence(sequence, self.s_len)
246+
elif self.unpackbits_tgt:
247+
targets = unpackbits_targets(targets, self.t_len)
248+
249+
if self._seq_start is None:
250+
self._seq_start = 0
251+
self._seq_end = len(sequence)
252+
253+
if self.use_seq_len is not None:
254+
mid = len(sequence) // 2
255+
self._seq_start = int(mid - np.ceil(self.use_seq_len / 2))
256+
self._seq_end = mid + self.use_seq_len // 2
257+
if self.shift is not None:
258+
self._seq_start += self.shift
259+
self._seq_end += self.shift
260+
sequence = sequence[self._seq_start:self._seq_end]
261+
262+
s = sequence.astype(np.float32)
263+
return (torch.from_numpy(s), torch.from_numpy(targets))
264+
217265

218266
@init
219267
def __len__(self):
@@ -288,20 +336,38 @@ class H5DataLoader(DataLoader):
288336
289337
"""
290338
def __init__(self,
291-
filepath,
292-
in_memory=False,
339+
dataset,
293340
num_workers=1,
294341
use_subset=None,
295342
batch_size=1,
296-
shuffle=True,
297-
unpackbits=False,
298-
sequence_key="sequences",
299-
targets_key="targets"):
343+
seed=436,
344+
sampler=None,
345+
batch_sampler=None,
346+
shuffle=True):
347+
g = torch.Generator()
348+
g.manual_seed(seed)
349+
350+
def worker_init_fn(worker_id):
351+
worker_seed = torch.initial_seed() % 2**32
352+
print("Worker seed", worker_seed)
353+
np.random.seed(worker_seed)
354+
random.seed(worker_seed)
355+
torch.manual_seed(worker_seed)
356+
300357
args = {
301358
"batch_size": batch_size,
302-
"num_workers": 0 if in_memory else num_workers,
303-
"pin_memory": True
359+
"pin_memory": True,
360+
"worker_init_fn": worker_init_fn,
361+
"sampler": sampler,
362+
"batch_sampler": batch_sampler,
363+
"generator": g,
304364
}
365+
366+
if hasattr(dataset, 'in_memory'):
367+
args['num_workers'] = 0 if dataset.in_memory else num_workers
368+
else:
369+
args['num_workers'] = num_workers
370+
305371
if use_subset is not None:
306372
from torch.utils.data.sampler import SubsetRandomSampler
307373
if isinstance(use_subset, int):
@@ -311,10 +377,6 @@ def __init__(self,
311377
args["sampler"] = SubsetRandomSampler(use_subset)
312378
else:
313379
args["shuffle"] = shuffle
314-
super(H5DataLoader, self).__init__(
315-
_H5Dataset(filepath,
316-
in_memory=in_memory,
317-
unpackbits=unpackbits,
318-
sequence_key=sequence_key,
319-
targets_key=targets_key),
320-
**args)
380+
381+
super(H5DataLoader, self).__init__(dataset, **args)
382+

selene_sdk/samplers/file_samplers/mat_file_sampler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,12 @@ def get_data_and_targets(self, batch_size, n_samples=None):
240240
sequences_and_targets = []
241241
targets_mat = []
242242

243-
count = 0
244-
while count < n_samples:
245-
sample_size = min(n_samples - count, batch_size)
246-
seqs, tgts = self.sample(batch_size=sample_size)
247-
sequences_and_targets.append((seqs, tgts))
248-
targets_mat.append(tgts)
249-
count += sample_size
243+
for ix in range(0, n_samples, batch_size):
244+
s = ix
245+
e = min(ix+batch_size, n_samples)
246+
seqs, tgts = self.sample(batch_size=batch_size)
247+
sequences_and_targets.append((seqs[:e-s], tgts[:e-s]))
248+
targets_mat.append(tgts[:e-s])
250249

251250
# TODO: should not assume targets are always integers
252251
targets_mat = np.vstack(targets_mat).astype(float)

selene_sdk/samplers/multi_sampler.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def sample(self, batch_size=1, mode=None):
221221
except StopIteration:
222222
#If DataLoader iterator reaches its length, reinitialize
223223
self._iterators[mode] = iter(self._dataloaders[mode])
224+
224225
data, targets = next(self._iterators[mode])
225226
return data.numpy(), targets.numpy()
226227

@@ -260,16 +261,11 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None):
260261
self._set_batch_size(batch_size, mode=mode)
261262
data_and_targets = []
262263
targets_mat = []
263-
count = batch_size
264-
while count < n_samples:
265-
data, tgts = self.sample(batch_size=batch_size, mode=mode)
266-
data_and_targets.append((data, tgts))
267-
targets_mat.append(tgts)
268-
count += batch_size
269-
remainder = batch_size - (count - n_samples)
270-
data, tgts = self.sample(batch_size=remainder)
271-
data_and_targets.append((data, tgts))
272-
targets_mat.append(tgts)
264+
for s in range(0, n_samples, batch_size):
265+
e = min(n_samples, s+batch_size)
266+
data, targets = self.sample(batch_size=batch_size, mode=mode)
267+
data_and_targets.append((data[:e-s], targets[:e-s]))
268+
targets_mat.append(targets[:e-s])
273269
targets_mat = np.vstack(targets_mat)
274270
return data_and_targets, targets_mat
275271

0 commit comments

Comments
 (0)