Skip to content
2 changes: 1 addition & 1 deletion colabfold/alphafold/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def make_fixed_size_multimer(
feat: Mapping[str, Any],
shape_schema,
num_res,
msa_cluster_size,
num_templates) -> FeatureDict:
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
msa_cluster_size = feat["bert_mask"].shape[0]
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
Expand Down
69 changes: 62 additions & 7 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

from colabfold.inputs import (
get_queries_pairwise, unpack_a3ms,
parse_fasta, get_queries,
parse_fasta, get_queries, msa_to_str
)
from colabfold.run_alphafold import set_model_type

from colabfold.download import default_data_dir, download_alphafold_params

import sys
import logging
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -299,24 +300,43 @@ def main():
headers_list[0].remove(headers_list[0][0])
header_first = headers[0]

queries_temp = []
queries_rest = []
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
max_msa_cluster = 0
else:
max_msa_cluster = None
for jobname, batch in enumerate(output):
query_seqs_unique = []
for x in batch:
if x not in query_seqs_unique:
query_seqs_unique.append(x)
query_seqs_cardinality = [0] * len(query_seqs_unique)
for seq in batch:
seq_idx = query_seqs_unique.index(seq)
query_seqs_cardinality[seq_idx] += 1
use_env = "env" in args.msa_mode or "Environmental" in args.msa_mode
paired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(Path(args.results).joinpath(str(jobname))),
str(Path(args.results).joinpath(str(jobname)+"_paired")),
use_env=use_env,
use_pairwise=True,
use_pairing=True,
host_url=args.host_url,
)

path_o = Path(args.results).joinpath(f"{jobname}_pairwise")
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env")
unpaired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(Path(args.results).joinpath(str(jobname)+"_unpaired")),
use_env=use_env,
use_pairwise=False,
use_pairing=False,
host_url=args.host_url,
)
path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise")
for filenum in path_o.iterdir():
queries_new = []
queries_new = []
if Path(filenum).suffix.lower() == ".a3m":
outdir = path_o.joinpath("tmp")
unpack_a3ms(filenum, outdir)
Expand All @@ -326,14 +346,49 @@ def main():
query_sequence = seqs[0]
a3m_lines = [Path(file).read_text()]
val = int(header[0].split('\t')[1][1:]) - 102
# match paired seq id and unpaired seq id
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
paired_query_a3m_lines = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1]
# a3m_lines = [msa_to_str(
# [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]
# )]
## Another way: do not use msa_to_str and unserialize function rather
## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments..
a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]]
queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines))

### generate features then find max_msa_cluster
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
inputs = ([batch[0], batch[val+1]], [1, 1], [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]])
from colabfold.inputs import generate_msa_size
msa_size = generate_msa_size(inputs, query_seqs_unique, args.use_templates, is_complex, model_type)
# config.model.embeddings_and_evoformer.extra_msa_seqs=2048
# config.model.embeddings_and_evoformer.num_msa=508
# if msa_size < 2048 + 508, pop the sequences and run the model with recompilation
if msa_size < 2556:
queries_rest.append(queries_new.pop())
continue
max_msa_cluster = max(max_msa_cluster, msa_size)

if args.sort_queries_by == "length":
queries_new.sort(key=lambda t: len(''.join(t[1])),reverse=True)
elif args.sort_queries_by == "random":
random.shuffle(queries_new)
queries_temp.append(queries_new)

queries_sel = sum(queries_temp, [])
run_params["max_msa_cluster"] = max_msa_cluster
run_params["interaction_scan"] = args.interaction_scan
run(queries=queries_sel, **run_params)

run(queries=queries_new, **run_params)
if args.pair_mode in ("none", "unpaired", "unpaired_paired"):
if len(queries_rest) > 0:
if args.sort_queries_by == "length":
queries_rest.sort(key=lambda t: len(''.join(t[1])),reverse=True)
elif args.sort_queries_by == "random":
random.shuffle(queries_rest)
run_params["max_msa_cluster"] = None
run(queries=queries_rest, **run_params)

else:
run(queries=queries, **run_params)
Expand Down
22 changes: 19 additions & 3 deletions colabfold/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def pad_input_multimer(
model_runner: model.RunModel,
model_name: str,
pad_len: int,
msa_cluster_size: Optional[int],
use_templates: bool,
) -> model.features.FeatureDict:
model_config = model_runner.config
shape_schema = {
"aatype": ["num residues placeholder"],
"residue_index": ["num residues placeholder"],
Expand Down Expand Up @@ -123,6 +123,7 @@ def pad_input_multimer(
input_features,
shape_schema,
num_res=pad_len,
msa_cluster_size=msa_cluster_size,
num_templates=4,
)
return input_fix
Expand Down Expand Up @@ -654,6 +655,20 @@ def generate_input_feature(
}
return (input_feature, domain_names)

def generate_msa_size(inputs, query_seqs_unique, use_templates, is_complex, model_type):
template_features_ = []
from colabfold.inputs import mk_mock_template
from colabfold.inputs import generate_input_feature
for query_seq in query_seqs_unique:
template_feature = mk_mock_template(query_seq)
template_features_.append(template_feature)
if not use_templates: template_features = template_features_
else: raise NotImplementedError

(feature_dict, _) \
= generate_input_feature(*inputs, template_features, is_complex, model_type)
return feature_dict["bert_mask"].shape[0]

def unserialize_msa(
a3m_lines: List[str], query_sequence: Union[List[str], str]
) -> Tuple[
Expand Down Expand Up @@ -696,7 +711,7 @@ def unserialize_msa(
)
prev_query_start += query_len
paired_msa = [""] * len(query_seq_len)
unpaired_msa = None
unpaired_msa = [""] * len(query_seq_len)
already_in = dict()
for i in range(1, len(a3m_lines), 2):
header = a3m_lines[i]
Expand Down Expand Up @@ -734,7 +749,6 @@ def unserialize_msa(
paired_msa[j] += ">" + header_no_faster_split[j] + "\n"
paired_msa[j] += seqs_line[j] + "\n"
else:
unpaired_msa = [""] * len(query_seq_len)
for j, seq in enumerate(seqs_line):
if has_amino_acid[j]:
unpaired_msa[j] += header + "\n"
Expand All @@ -752,6 +766,8 @@ def unserialize_msa(
template_feature = mk_mock_template(query_seq)
template_features.append(template_feature)

if unpaired_msa == [""] * len(query_seq_len):
unpaired_msa = None
return (
unpaired_msa,
paired_msa,
Expand Down
140 changes: 100 additions & 40 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,50 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Union
import os
import os, pandas

from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise
from colabfold.inputs import get_queries, msa_to_str, parse_fasta
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

def get_queries_pairwise(
input_path: Union[str, Path], sort_queries_by: str = "length"
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
"""Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple
of job name, sequence and the optional a3m lines"""
input_path = Path(input_path)
if not input_path.exists():
raise OSError(f"{input_path} could not be found")
if input_path.is_file():
if input_path.suffix == ".csv" or input_path.suffix == ".tsv":
sep = "\t" if input_path.suffix == ".tsv" else ","
df = pandas.read_csv(input_path, sep=sep)
assert "id" in df.columns and "sequence" in df.columns
queries = [
(str(df["id"][0])+'&'+str(seq_id), sequence.upper(), None)
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False))
]
elif input_path.suffix == ".a3m":
raise NotImplementedError()
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
(sequences, headers) = parse_fasta(input_path.read_text())
queries = []
for i, (sequence, header) in enumerate(zip(sequences, headers)):
sequence = sequence.upper()
if sequence.count(":") == 0:
# Single sequence
queries.append((header, sequence, None))
else:
# Complex mode
queries.append((header, sequence.upper().split(":"), None))
else:
raise ValueError(f"Unknown file format {input_path.suffix}")
else:
raise NotImplementedError()

is_complex = True
return queries, is_complex

def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]):
params_log = " ".join(str(i) for i in params)
Expand Down Expand Up @@ -61,7 +99,6 @@ def mmseqs_search_monomer(
used_dbs.append(template_db)
if use_env:
used_dbs.append(metagenomic_db)

for db in used_dbs:
if not dbbase.joinpath(f"{db}.dbtype").is_file():
raise FileNotFoundError(f"Database {db} does not exist")
Expand Down Expand Up @@ -405,9 +442,9 @@ def main():
args = parser.parse_args()

if args.interaction_scan:
queries, is_complex = get_queries_pairwise(args.query, None)
queries, is_complex = get_queries_pairwise(args.query)
else:
queries, is_complex = get_queries(args.query, None)
queries, is_complex = get_queries(args.query)

queries_unique = []
for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries):
Expand Down Expand Up @@ -437,10 +474,9 @@ def main():
query_seqs_cardinality,
) in enumerate(queries_unique):
if job_number==0:
f.write(f">{raw_jobname}_0\n{query_sequences[0]}\n")
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
f.write(f">{raw_jobname}_0\n{query_sequences}\n")
else:
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
f.write(f">{queries_unique[0][0]+'&'+raw_jobname}\n{query_sequences}\n")
else:
with query_file.open("w") as f:
for job_number, (
Expand All @@ -454,18 +490,6 @@ def main():
args.mmseqs,
["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"],
)
with args.base.joinpath("qdb.lookup").open("w") as f:
id = 0
file_number = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
for seq in query_sequences:
f.write(f"{id}\t{raw_jobname}\t{file_number}\n")
id += 1
file_number += 1

mmseqs_search_monomer(
mmseqs=args.mmseqs,
Expand Down Expand Up @@ -498,30 +522,66 @@ def main():
interaction_scan=args.interaction_scan,
)

if args.interaction_scan:
if len(queries_unique) > 1:
for i in range(len(queries_unique)-2):
idx = 2 + i*2
## delete duplicated query files 2.paired, 4.paired...
os.remove(args.base.joinpath(f"{idx}.paired.a3m"))
for j in range(len(queries_unique)-2):
# replace targets' right file name
id1 = j*2 + 3
id2 = j + 2
os.replace(args.base.joinpath(f"{id1}.paired.a3m"), args.base.joinpath(f"{id2}.paired.a3m"))

id = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
if not args.interaction_scan:
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
paired_msa = []
else:
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
else:
for job_number, _ in enumerate(queries_unique[:-1]):
query_sequences = [queries_unique[0][1], queries_unique[job_number+1][1]]
unpaired_msa = []
paired_msa = []
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
with args.base.joinpath(f"0.a3m").open("r") as f:
unpaired_msa.append(f.read())
with args.base.joinpath(f"{job_number+1}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)

with args.base.joinpath(f"0.paired.a3m").open("r") as f:
paired_msa.append(f.read())
with args.base.joinpath(f"{job_number+1}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, [1,1]
)
args.base.joinpath(f"{job_number}_final.a3m").write_text(msa)
for job_number, _ in enumerate(queries_unique):
args.base.joinpath(f"{job_number}.a3m").unlink()
args.base.joinpath(f"{job_number}.paired.a3m").unlink()
for job_number, _ in enumerate(queries_unique[:-1]):
os.replace(args.base.joinpath(f"{job_number}_final.a3m"), args.base.joinpath(f"{job_number}.a3m"))
query_file.unlink()
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])
Expand Down
Loading