Skip to content

Commit 6a9f2a0

Browse files
authored
Merge pull request #135 from kathyxchen/pytorch-140-update
Minor changes and updating Selene to be compatible with PyTorch 1.4.0
2 parents ec315d9 + 8edb637 commit 6a9f2a0

File tree

10 files changed

+78
-27
lines changed

10 files changed

+78
-27
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ Selene is a Python library and command line interface for training deep neural n
99
We recommend using Selene with Python 3.6 or above.
1010
Package installation should only take a few minutes (less than 10 minutes, typically ~2-3 minutes) with any of these methods (conda, pip, source).
1111

12-
**Install [PyTorch](https://pytorch.org/get-started/locally/).** If you have an NVIDIA GPU, install a version of PyTorch that supports it--Selene will run much faster with a discrete GPU.
12+
**First, install [PyTorch](https://pytorch.org/get-started/locally/).** If you have an NVIDIA GPU, install a version of PyTorch that supports it--Selene will run much faster with a discrete GPU.
13+
The library is currently compatible with PyTorch versions between 0.4.1 and 1.4.0.
14+
We will continue to update Selene to be compatible with the latest version of PyTorch.
1315

1416
### Installing selene with [Anaconda](https://www.anaconda.com/download/) (for Linux):
1517

docs/source/overview/cli.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ evaluate_model: !obj:selene_sdk.EvaluateModel {
172172
features: !obj:selene_sdk.utils.load_features_list {
173173
input_path: /path/to/features_list.txt
174174
},
175+
use_features_ord: !obj:selene_sdk.utils.load_features_list {
176+
input_path: /path/to/features_subset_ordered.txt
177+
},
175178
trained_model_path: /path/to/trained/model.pth.tar,
176179
batch_size: 64,
177180
n_test_samples: 640000,
@@ -190,6 +193,7 @@ evaluate_model: !obj:selene_sdk.EvaluateModel {
190193
- `report_gt_feature_n_positives`: Default is 10. In total, each class/feature must have more than `report_gt_feature_n_positives` positive examples in the test set to be considered in the performance computation. The output file that reports each class's performance will report 'NA' for classes that do not have enough positive samples.
191194
- `use_cuda`: Default is False. Specify whether CUDA-enabled GPUs are available for torch to use.
192195
- `data_parallel`: Default is False. Specify whether multiple GPUs are available for torch to use.
196+
- `use_features_ord`: Default is None. Specify an ordered list of features for which to run the evaluation. The features in this list must be identical to or a subset of `features`, and in the order you want the resulting `test_targets.npz` and `test_predictions.npz` to be saved.
193197

194198
#### Additional notes
195199
Similar to the `train_model` configuration, any arguments that you find in [the documentation](https://selene.flatironinstitute.org/selene.html#evaluatemodel) that are not present in the function-type value's arguments are automatically instantiated and passed in by Selene.

selene_sdk/evaluate_model.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import logging
55
import os
6+
import warnings
67

78
import numpy as np
89
import torch
@@ -57,6 +58,11 @@ class EvaluateModel(object):
5758
data_parallel : bool, optional
5859
Default is `False`. Specify whether multiple GPUs are available
5960
for torch to use during training.
61+
use_features_ord : list(str) or None, optional
62+
Default is None. Specify an ordered list of features for which to
63+
run the evaluation. The features in this list must be identical to or
64+
a subset of `features`, and in the order you want the resulting
65+
`test_targets.npz` and `test_predictions.npz` to be saved.
6066
6167
Attributes
6268
----------
@@ -88,7 +94,8 @@ def __init__(self,
8894
n_test_samples=None,
8995
report_gt_feature_n_positives=10,
9096
use_cuda=False,
91-
data_parallel=False):
97+
data_parallel=False,
98+
use_features_ord=None):
9299
self.criterion = criterion
93100

94101
trained_model = torch.load(
@@ -103,11 +110,26 @@ def __init__(self,
103110

104111
self.sampler = data_sampler
105112

106-
self.features = features
107-
108113
self.output_dir = output_dir
109114
os.makedirs(output_dir, exist_ok=True)
110115

116+
self.features = features
117+
self._use_ixs = list(range(len(features)))
118+
if use_features_ord is not None:
119+
feature_ixs = {f: ix for (ix, f) in enumerate(features)}
120+
self._use_ixs = []
121+
self.features = []
122+
123+
for f in use_features_ord:
124+
if f in feature_ixs:
125+
self._use_ixs.append(feature_ixs[f])
126+
self.features.append(f)
127+
else:
128+
warnings.warn(("Feature {0} in `use_features_ord` "
129+
"does not match any features in the list "
130+
"`features` and will be skipped.").format(f))
131+
self._write_features_ordered_to_file()
132+
111133
initialize_logger(
112134
os.path.join(self.output_dir, "{0}.log".format(
113135
__name__)),
@@ -130,11 +152,30 @@ def __init__(self,
130152

131153
self._test_data, self._all_test_targets = \
132154
self.sampler.get_data_and_targets(self.batch_size, n_test_samples)
155+
# TODO: we should be able to do this on the sampler end instead of
156+
# here. the current workaround is problematic, since
157+
# self._test_data still has the full featureset in it, and we
158+
# select the subset during `evaluate`
159+
self._all_test_targets = self._all_test_targets[:, self._use_ixs]
133160

161+
# reset Genome base ordering when applicable.
134162
if (hasattr(self.sampler, "reference_sequence") and
135-
isinstance(self.sampler.reference_sequence, Genome) and
136-
_is_lua_trained_model(model)):
137-
Genome.update_bases_order(['A', 'G', 'C', 'T'])
163+
isinstance(self.sampler.reference_sequence, Genome)):
164+
if _is_lua_trained_model(model):
165+
Genome.update_bases_order(['A', 'G', 'C', 'T'])
166+
else:
167+
Genome.update_bases_order(['A', 'C', 'G', 'T'])
168+
169+
def _write_features_ordered_to_file(self):
170+
"""
171+
Write the feature ordering specified by `use_features_ord`
172+
after matching it with the `features` list from the class
173+
initialization parameters.
174+
"""
175+
fp = os.path.join(self.output_dir, 'use_features_ord.txt')
176+
with open(fp, 'w+') as file_handle:
177+
for f in self.features:
178+
file_handle.write('{0}\n'.format(f))
138179

139180
def _get_feature_from_index(self, index):
140181
"""
@@ -170,7 +211,7 @@ def evaluate(self):
170211
all_predictions = []
171212
for (inputs, targets) in self._test_data:
172213
inputs = torch.Tensor(inputs)
173-
targets = torch.Tensor(targets)
214+
targets = torch.Tensor(targets[:, self._use_ixs])
174215

175216
if self.use_cuda:
176217
inputs = inputs.cuda()
@@ -182,10 +223,11 @@ def evaluate(self):
182223
predictions = None
183224
if _is_lua_trained_model(self.model):
184225
predictions = self.model.forward(
185-
inputs.transpose(1, 2).unsqueeze_(2))
226+
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
186227
else:
187228
predictions = self.model.forward(
188229
inputs.transpose(1, 2))
230+
predictions = predictions[:, self._use_ixs]
189231
loss = self.criterion(predictions, targets)
190232

191233
all_predictions.append(predictions.data.cpu().numpy())

selene_sdk/predict/_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def predict(model, batch_sequences, use_cuda=False):
9393
inputs = Variable(inputs)
9494

9595
if _is_lua_trained_model(model):
96-
outputs = model.forward(inputs.transpose(1, 2).unsqueeze_(2))
96+
outputs = model.forward(
97+
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
9798
else:
9899
outputs = model.forward(inputs.transpose(1, 2))
99100
return outputs.data.cpu().numpy()

selene_sdk/predict/_variant_effect_prediction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def read_vcf_file(input_path,
123123
if not reference_sequence.coords_in_bounds(chrom, start, end):
124124
na_rows.append(line)
125125
continue
126+
alt = alt.replace('.', ',') # consider '.' a valid delimiter
126127
for a in alt.split(','):
127128
variants.append((chrom, pos, name, ref, a, strand))
128129

selene_sdk/predict/model_predict.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def __init__(self,
176176
if type(self.reference_sequence) == Genome and \
177177
_is_lua_trained_model(model):
178178
Genome.update_bases_order(['A', 'G', 'C', 'T'])
179+
else: # even if not using Genome, I guess we can update?
180+
Genome.update_bases_order(['A', 'C', 'G', 'T'])
179181
self._write_mem_limit = write_mem_limit
180182

181183
def _initialize_reporters(self,
@@ -424,11 +426,11 @@ def get_predictions_for_bed_file(self,
424426
batch_ids.append(label+(contains_unk,))
425427
sequences[ i % self.batch_size, :, :] = encoding
426428
if contains_unk:
427-
warnings.warn("For region {0}, "
428-
"reference sequence contains unknown base(s). "
429-
"--will be marked `True` in the `contains_unk` column "
430-
"of the .tsv or the row_labels .txt file.".format(
431-
label))
429+
warnings.warn(("For region {0}, "
430+
"reference sequence contains unknown "
431+
"base(s). --will be marked `True` in the "
432+
"`contains_unk` column of the .tsv or "
433+
"row_labels .txt file.").format(label))
432434

433435
if (batch_ids and i == 0) or i % self.batch_size != 0:
434436
sequences = sequences[:i % self.batch_size + 1, :, :]

selene_sdk/sequences/genome.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def _not_blacklist_region(chrom, start, end, blacklist_tabix):
1818
"""
1919
Check if the input coordinates are not overlapping with blacklist regions.
20-
20+
2121
Parameters
2222
----------
2323
chrom : str
@@ -29,14 +29,14 @@ def _not_blacklist_region(chrom, start, end, blacklist_tabix):
2929
blacklist_tabix : tabix.open or None, optional
3030
Default is `None`. Tabix file handle if a file of blacklist regions
3131
is available.
32-
32+
3333
Returns
3434
-------
3535
bool
3636
False if the coordinates are overlaping with blacklist regions
3737
(if specified). Otherwise, return True.
38-
39-
38+
39+
4040
"""
4141
if blacklist_tabix is not None:
4242
try:
@@ -203,7 +203,7 @@ class Genome(Sequence):
203203
204204
"""
205205

206-
BASES_ARR = np.array(['A', 'C', 'G', 'T'])
206+
BASES_ARR = ['A', 'C', 'G', 'T']
207207
"""
208208
This is an array with the alphabet (i.e. all possible symbols
209209
that may occur in a sequence). We expect that
@@ -463,7 +463,7 @@ def get_encoding_from_coords_check_unk(self,
463463
strand='+',
464464
pad=False):
465465
"""Gets the one-hot encoding of the genomic sequence at the
466-
queried coordinates and check whether the sequence contains
466+
queried coordinates and check whether the sequence contains
467467
unknown base(s).
468468
469469
Parameters

selene_sdk/sequences/proteome.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ class around the `pyfaidx.Fasta` class.
7171
7272
"""
7373

74-
BASES_ARR = np.array(['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
75-
'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'])
74+
BASES_ARR = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
75+
'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
7676
"""
7777
This is an array with the alphabet (i.e. all possible symbols
7878
that may occur in a sequence). We expect that

selene_sdk/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.4"
1+
__version__ = "0.4.5"

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
cmdclass = {'build_ext': build_ext}
2626

2727
setup(name="selene-sdk",
28-
version="0.4.4",
28+
version="0.4.5",
2929
long_description=long_description,
3030
long_description_content_type='text/markdown',
3131
description=("framework for developing sequence-level "
@@ -62,6 +62,5 @@
6262
"scipy",
6363
"seaborn",
6464
"statsmodels",
65-
"torch>=0.4.1",
66-
"torchvision"
65+
"torch>=0.4.1, <=1.4.0",
6766
])

0 commit comments

Comments
 (0)