Skip to content

Commit 6ff4a5e

Browse files
authored
Updating seqweaver (#188)
* init get_new_training_data script and strand spec * refactor main script, fix strand-spec * debugging and testing update_seqweaver * fixed h5 file output * added class modules for training seqweaver * added validation/training strat * debugging main update seqweaver module * strand backward compatibility * further fixes to backward compatibility * val partition fix * indexing fix for backward compatibility * addressing kathy's comments * fixed relative paths in update_seqweaver * handling strand=. as None
1 parent 86f4df3 commit 6ff4a5e

File tree

3 files changed

+244
-9
lines changed

3 files changed

+244
-9
lines changed

models/seqweaver.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Seqweaver architecture (Park & Troyanskaya, 2021).
3+
"""
4+
import torch
5+
import torch.nn as nn
6+
7+
8+
class LambdaBase(nn.Sequential):
9+
def __init__(self, fn, *args):
10+
super(LambdaBase, self).__init__(*args)
11+
self.lambda_func = fn
12+
13+
def forward_prepare(self, input):
14+
output = []
15+
for module in self._modules.values():
16+
output.append(module(input))
17+
return output if output else input
18+
19+
20+
class Lambda(LambdaBase):
21+
def forward(self, input):
22+
return self.lambda_func(self.forward_prepare(input))
23+
24+
25+
class Seqweaver(nn.Module):
26+
27+
def __init__(self, n_classes): # 217 human, 43 mouse
28+
super(Seqweaver, self).__init__()
29+
self.model = nn.Sequential(
30+
nn.Conv2d(4, 160, (1, 8)),
31+
nn.ReLU(),
32+
nn.MaxPool2d((1, 4), (1, 4)),
33+
nn.Dropout(0.1),
34+
nn.Conv2d(160, 320, (1, 8)),
35+
nn.ReLU(),
36+
nn.MaxPool2d((1, 4), (1, 4)),
37+
nn.Dropout(0.1),
38+
nn.Conv2d(320, 480, (1, 8)),
39+
nn.ReLU(),
40+
nn.Dropout(0.3))
41+
self.fc = nn.Sequential(
42+
Lambda(lambda x: torch.reshape(x, (x.size(0), 25440))),
43+
nn.Sequential(
44+
Lambda(lambda x: x.reshape(1, -1)
45+
if 1 == len(x.size()) else x),
46+
nn.Linear(25440, n_classes)
47+
),
48+
nn.ReLU(),
49+
nn.Sequential(
50+
Lambda(lambda x: x.view(1, -1)
51+
if 1 == len(x.size()) else x),
52+
nn.Linear(n_classes, n_classes)
53+
),
54+
nn.Sigmoid(),
55+
)
56+
57+
def forward(self, x):
58+
x = x.unsqueeze(2)
59+
x = self.model(x)
60+
x = self.fc(x)
61+
return x
62+
63+
64+
def criterion():
65+
return nn.BCELoss()
66+
67+
68+
def get_optimizer(lr):
69+
return (torch.optim.SGD,
70+
{"lr": lr, "weight_decay": 1e-6, "momentum": 0.9})

scripts/update_seqweaver.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
This module provides the `UpdateSeqweaver` class, which wraps the master bed file
3+
containing all of the features' binding sites parsed from CLIP-seq.
4+
It supports new dataset construction and training for Seqweaver.
5+
6+
"""
7+
import h5py
8+
import gzip
9+
import numpy as np
10+
import sys
11+
12+
from selene_sdk.sequences.genome import Genome
13+
from selene_sdk.targets.genomic_features import GenomicFeatures
14+
from selene_sdk.samplers.dataloader import H5DataLoader
15+
from selene_sdk.train_model import TrainModel
16+
from selene_sdk.utils.config import load_path
17+
from selene_sdk.utils.config_utils import parse_configs_and_run
18+
19+
class UpdateSeqweaver():
20+
"""
21+
Stores a dataset specifying sequence regions and features.
22+
Accepts a tabix-indexed `*.bed` file with the following columns,
23+
in order:
24+
[chrom, start, end, feature, strand]
25+
26+
Parameters
27+
----------
28+
input_path : str
29+
Path to the tabix-indexed dataset. Note that for the file to
30+
be tabix-indexed, it must have been compressed with `bgzip`.
31+
Thus, `input_path` should be a `*.gz` file with a
32+
corresponding `*.tbi` file in the same directory.
33+
output_path : str
34+
Path to the output constructed-training data file.
35+
feature_path : str
36+
Path to a '\n'-delimited .txt file containing feature names.
37+
hg_fasta : str
38+
Path to an indexed FASTA file -- a `*.fasta` file with
39+
a corresponding `*.fai` file in the same directory. This file
40+
should contain the target organism's genome sequence.
41+
42+
"""
43+
def __init__(self, input_path, train_path, validate_path, feature_path, hg_fasta, yaml_path, val_prop=0.1, sequence_len=1000):
44+
"""
45+
Constructs a new `UpdateSeqweaver` object.
46+
"""
47+
self.input_path = input_path
48+
self.train_path = train_path
49+
self.validate_path = validate_path
50+
self.feature_path = feature_path
51+
self.yaml_path = yaml_path
52+
self.val_prop = val_prop
53+
54+
self.hg_fasta = hg_fasta
55+
56+
self.sequence_len = sequence_len
57+
58+
with open(self.feature_path, 'r') as handle:
59+
self.feature_set = [line.split('\n')[0] for line in handle.readlines()]
60+
61+
def _from_midpoint(self, start, end):
62+
"""
63+
Computes start and end of the sequence about the peak midpoint.
64+
65+
Parameters
66+
----------
67+
start : int
68+
The 0-based first position in the region.
69+
end : int
70+
One past the 0-based last position in the region.
71+
72+
Returns
73+
-------
74+
seq_start : int
75+
Sequence start position about the peak midpoint.
76+
seq_end : int
77+
Sequence end position about the peak midpoint.
78+
"""
79+
region_len = end - start
80+
midpoint = start + region_len // 2
81+
seq_start = midpoint - np.floor(self.sequence_len / 2.)
82+
seq_end = midpoint + np.ceil(self.sequence_len / 2.)
83+
84+
return int(seq_start), int(seq_end)
85+
86+
def construct_training_data(self):
87+
"""
88+
Construct training dataset from bed file and write to output_file.
89+
90+
Parameters
91+
----------
92+
output_path : str
93+
Path to the output file for the constructed training data.
94+
colname_file : str
95+
Path to a .txt file containing newline-delimited feature names.
96+
97+
"""
98+
list_of_regions = []
99+
with gzip.open(self.input_path) as f:
100+
for line in f:
101+
line = [str(data,'utf-8') for data in line.strip().split()]
102+
list_of_regions.append(line)
103+
104+
seqs = Genome(self.hg_fasta, blacklist_regions = 'hg19')
105+
targets = GenomicFeatures(self.input_path,
106+
features = self.feature_set, feature_thresholds = 0.5)
107+
108+
data_seqs = []
109+
data_labels = []
110+
for r in list_of_regions:
111+
chrom, start, end, target, strand = r
112+
start, end = int(start), int(end)
113+
sstart, ssend = self._from_midpoint(start, end)
114+
115+
# 1 x 4 x 1000 bp
116+
# get_encoding_from_coords : Converts sequence to one-hot-encoding for each of the 4 bases
117+
dna_seq, has_unk = seqs.get_encoding_from_coords_check_unk(chrom, sstart, ssend, strand=strand)
118+
if has_unk:
119+
continue
120+
if len(dna_seq) != self.sequence_len:
121+
continue
122+
123+
# 1 x n_features
124+
# get_feature_data: Computes which features overlap with the given region.
125+
labels = targets.get_feature_data(chrom, start, end, strand=strand)
126+
127+
data_seqs.append(dna_seq)
128+
data_labels.append(labels)
129+
130+
# partition some to validation before writing
131+
val_count = int(np.floor(self.val_prop * len(data_seqs)))
132+
validate_seqs = data_seqs[:val_count]
133+
validate_labels = data_labels[:val_count]
134+
training_seqs = data_seqs[val_count:]
135+
training_labels = data_labels[val_count:]
136+
137+
with h5py.File(self.validate_path, "w") as fh:
138+
fh.create_dataset("valid_sequences", data=np.array(validate_seqs, dtype=np.int64))
139+
fh.create_dataset("valid_targets", data=np.array(validate_labels, dtype=np.int64))
140+
141+
with h5py.File(self.train_path, "w") as fh:
142+
fh.create_dataset("train_sequences", data=np.array(training_seqs, dtype=np.int64))
143+
fh.create_dataset("train_targets", data=np.array(training_labels, dtype=np.int64))
144+
145+
def _load_yaml(self):
146+
# load yaml configuration
147+
return load_path(self.yaml_path)
148+
149+
def train_model(self):
150+
# load config file and train model
151+
yaml_config = self._load_yaml()
152+
parse_configs_and_run(yaml_config)

selene_sdk/targets/genomic_features.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _is_positive_row(start, end,
102102

103103

104104
def _get_feature_data(chrom, start, end,
105-
thresholds, feature_index_dict, get_feature_rows):
105+
thresholds, feature_index_dict, get_feature_rows, strand=None):
106106
"""
107107
Generates a target vector for the given query region.
108108
@@ -125,6 +125,9 @@ def _get_feature_data(chrom, start, end,
125125
get_feature_rows : types.FunctionType
126126
A function that takes coordinates and returns rows
127127
(`list(tuple(int, int, str))`).
128+
strand : {'+', '-'}, optional
129+
The strand the sequence is located on. Default is None (no strand provided).
130+
If '+' or '-' is passed in, only retrieve rows with the correct matching strand.
128131
129132
Returns
130133
-------
@@ -133,7 +136,7 @@ def _get_feature_data(chrom, start, end,
133136
`i`th feature is positive, and zero otherwise.
134137
135138
"""
136-
rows = get_feature_rows(chrom, start, end)
139+
rows = get_feature_rows(chrom, start, end, strand=strand)
137140
return _fast_get_feature_data(
138141
start, end, thresholds, feature_index_dict, rows)
139142

@@ -303,7 +306,7 @@ def dfunc(self, *args, **kwargs):
303306
return func(self, *args, **kwargs)
304307
return dfunc
305308

306-
def _query_tabix(self, chrom, start, end):
309+
def _query_tabix(self, chrom, start, end, strand=None):
307310
"""
308311
Queries a tabix-indexed `*.bed` file for features falling into
309312
the specified region.
@@ -317,6 +320,9 @@ def _query_tabix(self, chrom, start, end):
317320
The 0-based start position of the query coordinates.
318321
end : int
319322
One past the last position of the query coordinates.
323+
strand : {'+', '-'}, optional
324+
The strand the sequence is located on. Default is None (no strand provided).
325+
If '+' or '-' is passed in, only retrieve rows with the correct matching strand.
320326
321327
Returns
322328
-------
@@ -329,12 +335,16 @@ def _query_tabix(self, chrom, start, end):
329335
330336
"""
331337
try:
332-
return self.data.query(chrom, start, end)
338+
tabix_query = self.data.query(chrom, start, end)
339+
if strand == '+' or strand == '-':
340+
return [line for line in tabix_query if str(line[4]) == strand] # strand specificity
341+
else: # not strand specific
342+
return tabix_query
333343
except tabix.TabixError:
334344
return None
335345

336346
@init
337-
def is_positive(self, chrom, start, end):
347+
def is_positive(self, chrom, start, end, strand=None):
338348
"""
339349
Determines whether the query the `chrom` queried contains any
340350
genomic features within the :math:`[start, end)` region. If so,
@@ -357,11 +367,11 @@ def is_positive(self, chrom, start, end):
357367
assume the error was the result of no features being present
358368
in the queried region and return `False`.
359369
"""
360-
rows = self._query_tabix(chrom, start, end)
370+
rows = self._query_tabix(chrom, start, end, strand=strand)
361371
return _any_positive_rows(rows, start, end, self.feature_thresholds)
362372

363373
@init
364-
def get_feature_data(self, chrom, start, end):
374+
def get_feature_data(self, chrom, start, end, strand=None):
365375
"""
366376
Computes which features overlap with the given region.
367377
@@ -373,6 +383,9 @@ def get_feature_data(self, chrom, start, end):
373383
The 0-based first position in the region.
374384
end : int
375385
One past the 0-based last position in the region.
386+
strand : {'+', '-'}, optional
387+
The strand the sequence is located on. Default is None (no strand provided).
388+
If '+' or '-' is passed in, only retrieve rows with the correct matching strand.
376389
377390
Returns
378391
-------
@@ -388,7 +401,7 @@ def get_feature_data(self, chrom, start, end):
388401
"""
389402
if self._feature_thresholds_vec is None:
390403
features = np.zeros(self.n_features)
391-
rows = self._query_tabix(chrom, start, end)
404+
rows = self._query_tabix(chrom, start, end, strand=strand) # strand specificity
392405
if not rows:
393406
return features
394407
for r in rows:
@@ -398,4 +411,4 @@ def get_feature_data(self, chrom, start, end):
398411
return features
399412
return _get_feature_data(
400413
chrom, start, end, self._feature_thresholds_vec,
401-
self.feature_index_dict, self._query_tabix)
414+
self.feature_index_dict, self._query_tabix, strand=strand)

0 commit comments

Comments
 (0)