From 127af0b9260623cc349a0601b97c8a8098fc2f48 Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Mon, 6 Feb 2023 18:19:06 +0900 Subject: [PATCH 1/9] solve file random ordering after mmseqs server --- colabfold/inputs.py | 12 +++++------- colabfold/run_alphafold.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/colabfold/inputs.py b/colabfold/inputs.py index 9bc43342..68ec8ea9 100644 --- a/colabfold/inputs.py +++ b/colabfold/inputs.py @@ -259,7 +259,7 @@ def get_queries( return queries, is_complex def get_queries_pairwise( - input_path: Union[str, Path], sort_queries_by: str = "length", batch_size: int = 10, + input_path: Union[str, Path], batch_size: int = 10, ) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]: """Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple of job name, sequence and the optional a3m lines""" @@ -272,10 +272,13 @@ def get_queries_pairwise( df = pandas.read_csv(input_path, sep=sep) assert "id" in df.columns and "sequence" in df.columns queries = [] + seq_id_list = [] for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)): if i>0 and i % 10 == 0: queries.append(queries[0].upper()) queries.append(sequence.upper()) + seq_id_list.append(seq_id) + return queries, True, seq_id_list elif input_path.suffix == ".a3m": raise NotImplementedError() elif input_path.suffix in [".fasta", ".faa", ".fa"]: @@ -290,16 +293,11 @@ def get_queries_pairwise( else: # Complex mode queries.append((header, sequence.upper().split(":"), None)) + return queries, True, headers else: raise ValueError(f"Unknown file format {input_path.suffix}") else: raise NotImplementedError() - is_complex = True - if sort_queries_by == "length": - queries.sort(key=lambda t: len(''.join(t[1])),reverse=True) - elif sort_queries_by == "random": - random.shuffle(queries) - return queries, is_complex def pair_sequences( a3m_lines: List[str], query_sequences: List[str], query_cardinality: List[int] diff --git a/colabfold/run_alphafold.py b/colabfold/run_alphafold.py index 4dcfe617..eab6c232 100644 --- a/colabfold/run_alphafold.py +++ b/colabfold/run_alphafold.py @@ -20,6 +20,7 @@ import numpy as np from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from pathlib import Path +import random import logging logger = logging.getLogger(__name__) @@ -867,7 +868,7 @@ def main(): if args.interaction_scan: # protocol from @Dohyun-s batch_size = 10 - queries, is_complex = get_queries_pairwise(args.input, args.sort_queries_by, batch_size) + queries, is_complex, headers = get_queries_pairwise(args.input, batch_size) else: queries, is_complex = get_queries(args.input, args.sort_queries_by) @@ -919,7 +920,9 @@ def main(): # protocol from @Dohyun-s from colabfold.mmseqs.api import run_mmseqs2 output = [queries[i:i + batch_size] for i in range(0, len(queries), batch_size)] - dirnum = 0 + headers_list = [headers[i:i + batch_size] for i in range(0, len(headers), batch_size)] + headers_list[0].remove(headers_list[0][0]) + header_first = headers[0] for jobname, batch in enumerate(output): query_seqs_unique = [] @@ -947,10 +950,15 @@ def main(): (seqs, header) = parse_fasta(Path(file).read_text()) query_sequence = seqs[0] a3m_lines = [Path(file).read_text()] - queries_new.append((outdir.joinpath(file).stem+'_'+str(dirnum), query_sequence, a3m_lines)) + val = int(header[0].split('\t')[1][1:]) - 102 + queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines)) + + if args.sort_queries_by == "length": + queries_new.sort(key=lambda t: len(''.join(t[1])),reverse=True) + elif args.sort_queries_by == "random": + random.shuffle(queries_new) run(queries=queries_new, **run_params) - dirnum += 1 else: run(queries=queries, **run_params) From 7ad01b8960feef0019cfc047630eaffe611c8c2f Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Mon, 20 Feb 2023 16:45:54 +0900 Subject: [PATCH 2/9] fix search.py --- colabfold/mmseqs/search.py | 50 +++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index a333708b..b2a3b46e 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -11,12 +11,55 @@ from argparse import ArgumentParser from pathlib import Path from typing import List, Union -import os +import os, pandas -from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise +from colabfold.inputs import get_queries, msa_to_str, parse_fasta +from typing import Any, Callable, Dict, List, Optional, Tuple, Union logger = logging.getLogger(__name__) +def get_queries_pairwise( + input_path: Union[str, Path], sort_queries_by: str = "length" +) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]: + """Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple + of job name, sequence and the optional a3m lines""" + input_path = Path(input_path) + if not input_path.exists(): + raise OSError(f"{input_path} could not be found") + if input_path.is_file(): + if input_path.suffix == ".csv" or input_path.suffix == ".tsv": + sep = "\t" if input_path.suffix == ".tsv" else "," + df = pandas.read_csv(input_path, sep=sep) + assert "id" in df.columns and "sequence" in df.columns + queries = [ + (str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None) + for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0 + ] + for i in range(len(queries)): + if len(queries[i][1]) == 1: + queries[i] = (queries[i][0], queries[i][1][0], None) + elif input_path.suffix == ".a3m": + raise NotImplementedError() + elif input_path.suffix in [".fasta", ".faa", ".fa"]: + (sequences, headers) = parse_fasta(input_path.read_text()) + queries = [] + for i, (sequence, header) in enumerate(zip(sequences, headers)): + sequence = sequence.upper() + if sequence.count(":") == 0: + # Single sequence + if i==0: + continue + queries.append((headers[0]+'&'+header, [sequences[0],sequence], None)) + else: + # Complex mode + queries.append((header, sequence.upper().split(":"), None)) + else: + raise ValueError(f"Unknown file format {input_path.suffix}") + else: + raise NotImplementedError() + + is_complex = True + return queries, is_complex def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]): params_log = " ".join(str(i) for i in params) @@ -61,8 +104,9 @@ def mmseqs_search_monomer( used_dbs.append(template_db) if use_env: used_dbs.append(metagenomic_db) - for db in used_dbs: + if str(db) == '.': + continue if not dbbase.joinpath(f"{db}.dbtype").is_file(): raise FileNotFoundError(f"Database {db} does not exist") if ( From 548fdcc1c63a6bf76345aee3f9100d4eed11eb26 Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Sun, 26 Feb 2023 22:40:39 +0900 Subject: [PATCH 3/9] implement unpaired alignment && fix colabfold_search error --- colabfold/batch.py | 33 ++++++++++-- colabfold/inputs.py | 6 +-- colabfold/mmseqs/search.py | 108 +++++++++++++++++++++---------------- colabfold/run_alphafold.py | 10 +++- 4 files changed, 104 insertions(+), 53 deletions(-) diff --git a/colabfold/batch.py b/colabfold/batch.py index eeefa51b..c2b47703 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -21,11 +21,12 @@ from colabfold.inputs import ( get_queries_pairwise, unpack_a3ms, - parse_fasta, get_queries, + parse_fasta, get_queries, msa_to_str ) +from colabfold.run_alphafold import set_model_type from colabfold.download import default_data_dir, download_alphafold_params - +import sys import logging logger = logging.getLogger(__name__) @@ -304,17 +305,31 @@ def main(): for x in batch: if x not in query_seqs_unique: query_seqs_unique.append(x) + query_seqs_cardinality = [0] * len(query_seqs_unique) + for seq in batch: + seq_idx = query_seqs_unique.index(seq) + query_seqs_cardinality[seq_idx] += 1 use_env = "env" in args.msa_mode or "Environmental" in args.msa_mode paired_a3m_lines = run_mmseqs2( query_seqs_unique, - str(Path(args.results).joinpath(str(jobname))), + str(Path(args.results).joinpath(str(jobname)+"_paired")), use_env=use_env, use_pairwise=True, use_pairing=True, host_url=args.host_url, ) - path_o = Path(args.results).joinpath(f"{jobname}_pairwise") + if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": + unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env") + unpaired_a3m_lines = run_mmseqs2( + query_seqs_unique, + str(Path(args.results).joinpath(str(jobname)+"_unpaired")), + use_env=use_env, + use_pairwise=False, + use_pairing=False, + host_url=args.host_url, + ) + path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise") for filenum in path_o.iterdir(): queries_new = [] if Path(filenum).suffix.lower() == ".a3m": @@ -326,6 +341,16 @@ def main(): query_sequence = seqs[0] a3m_lines = [Path(file).read_text()] val = int(header[0].split('\t')[1][1:]) - 102 + # match paired seq id and unpaired seq id + if args.pair_mode == "none" or "unpaired" or "unpaired_paired": + tmp = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1] + a3m_lines = [msa_to_str( + [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1] + )] + ## Another way: do not use msa_to_str and unserialize function rather + ## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments.. + ## + # a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]] queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines)) if args.sort_queries_by == "length": diff --git a/colabfold/inputs.py b/colabfold/inputs.py index cf66b89f..f100858a 100644 --- a/colabfold/inputs.py +++ b/colabfold/inputs.py @@ -85,7 +85,6 @@ def pad_input_multimer( pad_len: int, use_templates: bool, ) -> model.features.FeatureDict: - model_config = model_runner.config shape_schema = { "aatype": ["num residues placeholder"], "residue_index": ["num residues placeholder"], @@ -696,7 +695,7 @@ def unserialize_msa( ) prev_query_start += query_len paired_msa = [""] * len(query_seq_len) - unpaired_msa = None + unpaired_msa = [""] * len(query_seq_len) already_in = dict() for i in range(1, len(a3m_lines), 2): header = a3m_lines[i] @@ -734,7 +733,6 @@ def unserialize_msa( paired_msa[j] += ">" + header_no_faster_split[j] + "\n" paired_msa[j] += seqs_line[j] + "\n" else: - unpaired_msa = [""] * len(query_seq_len) for j, seq in enumerate(seqs_line): if has_amino_acid[j]: unpaired_msa[j] += header + "\n" @@ -752,6 +750,8 @@ def unserialize_msa( template_feature = mk_mock_template(query_seq) template_features.append(template_feature) + if unpaired_msa == [""] * len(query_seq_len): + unpaired_msa = None return ( unpaired_msa, paired_msa, diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index b2a3b46e..7be5f730 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -32,12 +32,9 @@ def get_queries_pairwise( df = pandas.read_csv(input_path, sep=sep) assert "id" in df.columns and "sequence" in df.columns queries = [ - (str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None) - for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0 + (str(df["id"][0])+'&'+str(seq_id), sequence.upper(), None) + for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) ] - for i in range(len(queries)): - if len(queries[i][1]) == 1: - queries[i] = (queries[i][0], queries[i][1][0], None) elif input_path.suffix == ".a3m": raise NotImplementedError() elif input_path.suffix in [".fasta", ".faa", ".fa"]: @@ -47,9 +44,7 @@ def get_queries_pairwise( sequence = sequence.upper() if sequence.count(":") == 0: # Single sequence - if i==0: - continue - queries.append((headers[0]+'&'+header, [sequences[0],sequence], None)) + queries.append((header, sequence, None)) else: # Complex mode queries.append((header, sequence.upper().split(":"), None)) @@ -449,9 +444,9 @@ def main(): args = parser.parse_args() if args.interaction_scan: - queries, is_complex = get_queries_pairwise(args.query, None) + queries, is_complex = get_queries_pairwise(args.query) else: - queries, is_complex = get_queries(args.query, None) + queries, is_complex = get_queries(args.query) queries_unique = [] for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries): @@ -481,10 +476,9 @@ def main(): query_seqs_cardinality, ) in enumerate(queries_unique): if job_number==0: - f.write(f">{raw_jobname}_0\n{query_sequences[0]}\n") - f.write(f">{raw_jobname}\n{query_sequences[1]}\n") + f.write(f">{raw_jobname}_0\n{query_sequences}\n") else: - f.write(f">{raw_jobname}\n{query_sequences[1]}\n") + f.write(f">{queries_unique[0][0]+'&'+raw_jobname}\n{query_sequences}\n") else: with query_file.open("w") as f: for job_number, ( @@ -498,18 +492,6 @@ def main(): args.mmseqs, ["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"], ) - with args.base.joinpath("qdb.lookup").open("w") as f: - id = 0 - file_number = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - for seq in query_sequences: - f.write(f"{id}\t{raw_jobname}\t{file_number}\n") - id += 1 - file_number += 1 mmseqs_search_monomer( mmseqs=args.mmseqs, @@ -542,30 +524,66 @@ def main(): interaction_scan=args.interaction_scan, ) + if args.interaction_scan: + if len(queries_unique) > 1: + for i in range(len(queries_unique)-2): + idx = 2 + i*2 + ## delete duplicated query files 2.paired, 4.paired... + os.remove(args.base.joinpath(f"{idx}.paired.a3m")) + for j in range(len(queries_unique)-2): + # replace targets' right file name + id1 = j*2 + 3 + id2 = j + 2 + os.replace(args.base.joinpath(f"{id1}.paired.a3m"), args.base.joinpath(f"{id2}.paired.a3m")) + id = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - unpaired_msa = [] - paired_msa = None - if len(query_seqs_cardinality) > 1: + if not args.interaction_scan: + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + unpaired_msa = [] + paired_msa = None + if len(query_seqs_cardinality) > 1: + paired_msa = [] + else: + for seq in query_sequences: + with args.base.joinpath(f"{id}.a3m").open("r") as f: + unpaired_msa.append(f.read()) + args.base.joinpath(f"{id}.a3m").unlink() + if len(query_seqs_cardinality) > 1: + with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + args.base.joinpath(f"{id}.paired.a3m").unlink() + id += 1 + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality + ) + args.base.joinpath(f"{job_number}.a3m").write_text(msa) + else: + for job_number, _ in enumerate(queries_unique[:-1]): + query_sequences = [queries_unique[0][1], queries_unique[job_number+1][1]] + unpaired_msa = [] paired_msa = [] - for seq in query_sequences: - with args.base.joinpath(f"{id}.a3m").open("r") as f: + with args.base.joinpath(f"0.a3m").open("r") as f: + unpaired_msa.append(f.read()) + with args.base.joinpath(f"{job_number+1}.a3m").open("r") as f: unpaired_msa.append(f.read()) - args.base.joinpath(f"{id}.a3m").unlink() - if len(query_seqs_cardinality) > 1: - with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: - paired_msa.append(f.read()) - args.base.joinpath(f"{id}.paired.a3m").unlink() - id += 1 - msa = msa_to_str( - unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality - ) - args.base.joinpath(f"{job_number}.a3m").write_text(msa) + with args.base.joinpath(f"0.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + with args.base.joinpath(f"{job_number+1}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, [1,1] + ) + args.base.joinpath(f"{job_number}_final.a3m").write_text(msa) + for job_number, _ in enumerate(queries_unique): + args.base.joinpath(f"{job_number}.a3m").unlink() + args.base.joinpath(f"{job_number}.paired.a3m").unlink() + for job_number, _ in enumerate(queries_unique[:-1]): + os.replace(args.base.joinpath(f"{job_number}_final.a3m"), args.base.joinpath(f"{job_number}.a3m")) query_file.unlink() run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")]) run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")]) diff --git a/colabfold/run_alphafold.py b/colabfold/run_alphafold.py index 94ee7248..c3fc8149 100644 --- a/colabfold/run_alphafold.py +++ b/colabfold/run_alphafold.py @@ -259,7 +259,15 @@ def run( (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ = unserialize_msa(a3m_lines, query_sequence) if not use_templates: template_features = template_features_ - + ## Another way passing argument + ## + # (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines + # template_features_ = [] + # from colabfold.inputs import mk_mock_template + # for query_seq in query_seqs_unique: + # template_feature = mk_mock_template(query_seq) + # template_features_.append(template_feature) + # if not use_templates: template_features = template_features_ # save a3m msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) result_dir.joinpath(f"{jobname}.a3m").write_text(msa) From c2102a1309847984e2be1dc0ecfcc2dd4231b056 Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Sat, 4 Mar 2023 11:21:39 +0900 Subject: [PATCH 4/9] pad the msa input & fix jax compilation --- colabfold/alphafold/msa.py | 2 +- colabfold/batch.py | 30 +++++++++++++++++++++++------- colabfold/inputs.py | 2 ++ colabfold/predict.py | 7 ++++++- colabfold/run_alphafold.py | 22 ++++++++++++---------- 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/colabfold/alphafold/msa.py b/colabfold/alphafold/msa.py index 65bd1838..9d03a0d3 100644 --- a/colabfold/alphafold/msa.py +++ b/colabfold/alphafold/msa.py @@ -47,11 +47,11 @@ def make_fixed_size_multimer( feat: Mapping[str, Any], shape_schema, num_res, + msa_cluster_size, num_templates) -> FeatureDict: NUM_RES = "num residues placeholder" NUM_MSA_SEQ = "msa placeholder" NUM_TEMPLATES = "num templates placeholder" - msa_cluster_size = feat["bert_mask"].shape[0] pad_size_map = { NUM_RES: num_res, NUM_MSA_SEQ: msa_cluster_size, diff --git a/colabfold/batch.py b/colabfold/batch.py index c2b47703..081db9b3 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -318,7 +318,7 @@ def main(): use_pairing=True, host_url=args.host_url, ) - + max_msa_cluster = None if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env") unpaired_a3m_lines = run_mmseqs2( @@ -329,6 +329,7 @@ def main(): use_pairing=False, host_url=args.host_url, ) + max_msa_cluster = 0 path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise") for filenum in path_o.iterdir(): queries_new = [] @@ -342,21 +343,36 @@ def main(): a3m_lines = [Path(file).read_text()] val = int(header[0].split('\t')[1][1:]) - 102 # match paired seq id and unpaired seq id - if args.pair_mode == "none" or "unpaired" or "unpaired_paired": + if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": tmp = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1] - a3m_lines = [msa_to_str( - [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1] - )] + # a3m_lines = [msa_to_str( + # [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1] + # )] ## Another way: do not use msa_to_str and unserialize function rather ## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments.. - ## - # a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]] + a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]] queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines)) + ### generate features then find max_msa_cluster + if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": + template_features_ = [] + from colabfold.inputs import mk_mock_template + for query_seq in query_seqs_unique: + template_feature = mk_mock_template(query_seq) + template_features_.append(template_feature) + if not args.use_templates: template_features = template_features_ + + from colabfold.inputs import generate_input_feature + (feature_dict, domain_names) \ + = generate_input_feature([batch[0], batch[val+1]], [1, 1], [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], + template_features, is_complex, model_type) + max_msa_cluster = max(max_msa_cluster, feature_dict["bert_mask"].shape[0]) + if args.sort_queries_by == "length": queries_new.sort(key=lambda t: len(''.join(t[1])),reverse=True) elif args.sort_queries_by == "random": random.shuffle(queries_new) + run_params["max_msa_cluster"] = max_msa_cluster run(queries=queries_new, **run_params) diff --git a/colabfold/inputs.py b/colabfold/inputs.py index f100858a..24cf944a 100644 --- a/colabfold/inputs.py +++ b/colabfold/inputs.py @@ -83,6 +83,7 @@ def pad_input_multimer( model_runner: model.RunModel, model_name: str, pad_len: int, + msa_cluster_size: Optional[int], use_templates: bool, ) -> model.features.FeatureDict: shape_schema = { @@ -122,6 +123,7 @@ def pad_input_multimer( input_features, shape_schema, num_res=pad_len, + msa_cluster_size=msa_cluster_size, num_templates=4, ) return input_fix diff --git a/colabfold/predict.py b/colabfold/predict.py index 2cf88196..45ba9d06 100644 --- a/colabfold/predict.py +++ b/colabfold/predict.py @@ -39,6 +39,7 @@ def predict_structure( save_single_representations: bool = False, save_pair_representations: bool = False, save_recycles: bool = False, + max_msa_cluster: Optional[int] = None, ): """Predicts structure using AlphaFold for the given sequence.""" @@ -69,8 +70,12 @@ def predict_structure( input_features = feature_dict input_features["asym_id"] = input_features["asym_id"] - input_features["asym_id"][...,0] # TODO + if max_msa_cluster == None: + msa_cluster_size = input_features["bert_mask"].shape[0] + else: + msa_cluster_size = max_msa_cluster + input_features = pad_input_multimer(input_features, model_runner, model_name, pad_len, msa_cluster_size, use_templates) if seq_len < pad_len: - input_features = pad_input_multimer(input_features, model_runner, model_name, pad_len, use_templates) logger.info(f"Padding length to {pad_len}") else: if model_num == 0: diff --git a/colabfold/run_alphafold.py b/colabfold/run_alphafold.py index c3fc8149..d81a96bf 100644 --- a/colabfold/run_alphafold.py +++ b/colabfold/run_alphafold.py @@ -133,6 +133,7 @@ def run( use_fuse = kwargs.pop("use_fuse", True) use_bfloat16 = kwargs.pop("use_bfloat16", True) max_msa = kwargs.pop("max_msa",None) + max_msa_cluster = kwargs.pop("max_msa_cluster", None) if max_msa is not None: max_seq, max_extra_seq = [int(x) for x in max_msa.split(":")] @@ -256,18 +257,18 @@ def run( = get_msa_and_templates(jobname, query_sequence, result_dir, msa_mode, use_templates, custom_template_path, pair_mode, host_url) if a3m_lines is not None: - (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ - = unserialize_msa(a3m_lines, query_sequence) - if not use_templates: template_features = template_features_ + # (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ + # = unserialize_msa(a3m_lines, query_sequence) + # if not use_templates: template_features = template_features_ ## Another way passing argument ## - # (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines - # template_features_ = [] - # from colabfold.inputs import mk_mock_template - # for query_seq in query_seqs_unique: - # template_feature = mk_mock_template(query_seq) - # template_features_.append(template_feature) - # if not use_templates: template_features = template_features_ + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines + template_features_ = [] + from colabfold.inputs import mk_mock_template + for query_seq in query_seqs_unique: + template_feature = mk_mock_template(query_seq) + template_features_.append(template_feature) + if not use_templates: template_features = template_features_ # save a3m msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) result_dir.joinpath(f"{jobname}.a3m").write_text(msa) @@ -367,6 +368,7 @@ def run( save_single_representations=save_single_representations, save_pair_representations=save_pair_representations, save_recycles=save_recycles, + max_msa_cluster=max_msa_cluster, ) result_files = results["result_files"] ranks.append(results["rank"]) From b2e92d526e03d365e82ddefeb2e3c405deca119a Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Wed, 8 Mar 2023 22:53:51 +0900 Subject: [PATCH 5/9] Update max_msa_cluster in for inner loop --- colabfold/batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colabfold/batch.py b/colabfold/batch.py index 081db9b3..9db80c98 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -329,11 +329,12 @@ def main(): use_pairing=False, host_url=args.host_url, ) - max_msa_cluster = 0 path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise") for filenum in path_o.iterdir(): queries_new = [] if Path(filenum).suffix.lower() == ".a3m": + if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": + max_msa_cluster = 0 outdir = path_o.joinpath("tmp") unpack_a3ms(filenum, outdir) for i, file in enumerate(sorted(outdir.iterdir())): From 569e3367abd9778dbdd72b22f064cc1c858139c4 Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Sat, 11 Mar 2023 17:41:00 +0900 Subject: [PATCH 6/9] Fix wrong statement about db --- colabfold/mmseqs/search.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index 7be5f730..1d87dca2 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -100,8 +100,6 @@ def mmseqs_search_monomer( if use_env: used_dbs.append(metagenomic_db) for db in used_dbs: - if str(db) == '.': - continue if not dbbase.joinpath(f"{db}.dbtype").is_file(): raise FileNotFoundError(f"Database {db} does not exist") if ( From 9d9eaaa14e76f9547e39bcfe0803c31014a78d20 Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Sun, 19 Mar 2023 23:18:21 +0900 Subject: [PATCH 7/9] update config to pass unpadded_seq_len & extra_msa to alphafold --- colabfold/run_alphafold.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/colabfold/run_alphafold.py b/colabfold/run_alphafold.py index d81a96bf..82b61c6c 100644 --- a/colabfold/run_alphafold.py +++ b/colabfold/run_alphafold.py @@ -347,6 +347,13 @@ def run( ) first_job = False + if "multimer" in model_suffix: + for idx in range(num_models): + if max_msa_cluster == None: + model_runner_and_params[idx][1].config.model.embeddings_and_evoformer.num_unpadded_seqs = 0 + else: + model_runner_and_params[idx][1].config.model.embeddings_and_evoformer.num_unpadded_seqs = int(len(feature_dict["msa"])) + model_runner_and_params[idx][1].config.model.embeddings_and_evoformer.extra_msa_seqs = int(len(feature_dict["msa"])) - 508 results = predict_structure( prefix=jobname, result_dir=result_dir, From 273098ac8348c7614733f9cf395097d5368915ba Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Tue, 21 Mar 2023 22:41:21 +0900 Subject: [PATCH 8/9] Revert "Fix wrong statement about db" This reverts commit 569e3367abd9778dbdd72b22f064cc1c858139c4. --- colabfold/mmseqs/search.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index 1d87dca2..7be5f730 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -100,6 +100,8 @@ def mmseqs_search_monomer( if use_env: used_dbs.append(metagenomic_db) for db in used_dbs: + if str(db) == '.': + continue if not dbbase.joinpath(f"{db}.dbtype").is_file(): raise FileNotFoundError(f"Database {db} does not exist") if ( From 93941fea8b7c101140c628b1fd8a4241f41a0418 Mon Sep 17 00:00:00 2001 From: dohyun-s Date: Tue, 21 Mar 2023 23:29:20 +0900 Subject: [PATCH 9/9] Change run condition in interaction_scan --- colabfold/batch.py | 59 +++++++++++++++++++++++--------------- colabfold/inputs.py | 14 +++++++++ colabfold/mmseqs/search.py | 2 -- colabfold/run_alphafold.py | 27 +++++++++-------- 4 files changed, 63 insertions(+), 39 deletions(-) diff --git a/colabfold/batch.py b/colabfold/batch.py index 9db80c98..f29be882 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -300,6 +300,12 @@ def main(): headers_list[0].remove(headers_list[0][0]) header_first = headers[0] + queries_temp = [] + queries_rest = [] + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + max_msa_cluster = 0 + else: + max_msa_cluster = None for jobname, batch in enumerate(output): query_seqs_unique = [] for x in batch: @@ -318,8 +324,7 @@ def main(): use_pairing=True, host_url=args.host_url, ) - max_msa_cluster = None - if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env") unpaired_a3m_lines = run_mmseqs2( query_seqs_unique, @@ -331,10 +336,8 @@ def main(): ) path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise") for filenum in path_o.iterdir(): - queries_new = [] + queries_new = [] if Path(filenum).suffix.lower() == ".a3m": - if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": - max_msa_cluster = 0 outdir = path_o.joinpath("tmp") unpack_a3ms(filenum, outdir) for i, file in enumerate(sorted(outdir.iterdir())): @@ -344,38 +347,48 @@ def main(): a3m_lines = [Path(file).read_text()] val = int(header[0].split('\t')[1][1:]) - 102 # match paired seq id and unpaired seq id - if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": - tmp = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1] + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + paired_query_a3m_lines = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1] # a3m_lines = [msa_to_str( # [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1] # )] ## Another way: do not use msa_to_str and unserialize function rather ## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments.. - a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]] + a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]] queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines)) ### generate features then find max_msa_cluster - if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired": - template_features_ = [] - from colabfold.inputs import mk_mock_template - for query_seq in query_seqs_unique: - template_feature = mk_mock_template(query_seq) - template_features_.append(template_feature) - if not args.use_templates: template_features = template_features_ - - from colabfold.inputs import generate_input_feature - (feature_dict, domain_names) \ - = generate_input_feature([batch[0], batch[val+1]], [1, 1], [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], - template_features, is_complex, model_type) - max_msa_cluster = max(max_msa_cluster, feature_dict["bert_mask"].shape[0]) + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + inputs = ([batch[0], batch[val+1]], [1, 1], [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]]) + from colabfold.inputs import generate_msa_size + msa_size = generate_msa_size(inputs, query_seqs_unique, args.use_templates, is_complex, model_type) + # config.model.embeddings_and_evoformer.extra_msa_seqs=2048 + # config.model.embeddings_and_evoformer.num_msa=508 + # if msa_size < 2048 + 508, pop the sequences and run the model with recompilation + if msa_size < 2556: + queries_rest.append(queries_new.pop()) + continue + max_msa_cluster = max(max_msa_cluster, msa_size) if args.sort_queries_by == "length": queries_new.sort(key=lambda t: len(''.join(t[1])),reverse=True) elif args.sort_queries_by == "random": random.shuffle(queries_new) - run_params["max_msa_cluster"] = max_msa_cluster + queries_temp.append(queries_new) + + queries_sel = sum(queries_temp, []) + run_params["max_msa_cluster"] = max_msa_cluster + run_params["interaction_scan"] = args.interaction_scan + run(queries=queries_sel, **run_params) - run(queries=queries_new, **run_params) + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + if len(queries_rest) > 0: + if args.sort_queries_by == "length": + queries_rest.sort(key=lambda t: len(''.join(t[1])),reverse=True) + elif args.sort_queries_by == "random": + random.shuffle(queries_rest) + run_params["max_msa_cluster"] = None + run(queries=queries_rest, **run_params) else: run(queries=queries, **run_params) diff --git a/colabfold/inputs.py b/colabfold/inputs.py index 24cf944a..2358c96d 100644 --- a/colabfold/inputs.py +++ b/colabfold/inputs.py @@ -655,6 +655,20 @@ def generate_input_feature( } return (input_feature, domain_names) +def generate_msa_size(inputs, query_seqs_unique, use_templates, is_complex, model_type): + template_features_ = [] + from colabfold.inputs import mk_mock_template + from colabfold.inputs import generate_input_feature + for query_seq in query_seqs_unique: + template_feature = mk_mock_template(query_seq) + template_features_.append(template_feature) + if not use_templates: template_features = template_features_ + else: raise NotImplementedError + + (feature_dict, _) \ + = generate_input_feature(*inputs, template_features, is_complex, model_type) + return feature_dict["bert_mask"].shape[0] + def unserialize_msa( a3m_lines: List[str], query_sequence: Union[List[str], str] ) -> Tuple[ diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index 7be5f730..1d87dca2 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -100,8 +100,6 @@ def mmseqs_search_monomer( if use_env: used_dbs.append(metagenomic_db) for db in used_dbs: - if str(db) == '.': - continue if not dbbase.joinpath(f"{db}.dbtype").is_file(): raise FileNotFoundError(f"Database {db} does not exist") if ( diff --git a/colabfold/run_alphafold.py b/colabfold/run_alphafold.py index 82b61c6c..4ec404a7 100644 --- a/colabfold/run_alphafold.py +++ b/colabfold/run_alphafold.py @@ -134,6 +134,7 @@ def run( use_bfloat16 = kwargs.pop("use_bfloat16", True) max_msa = kwargs.pop("max_msa",None) max_msa_cluster = kwargs.pop("max_msa_cluster", None) + interaction_scan = kwargs.pop("interaction_scan", True) if max_msa is not None: max_seq, max_extra_seq = [int(x) for x in max_msa.split(":")] @@ -262,13 +263,18 @@ def run( # if not use_templates: template_features = template_features_ ## Another way passing argument ## - (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines - template_features_ = [] - from colabfold.inputs import mk_mock_template - for query_seq in query_seqs_unique: - template_feature = mk_mock_template(query_seq) - template_features_.append(template_feature) - if not use_templates: template_features = template_features_ + if interaction_scan == True and pair_mode in ("none", "unpaired", "unpaired_paired"): + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines + template_features_ = [] + from colabfold.inputs import mk_mock_template + for query_seq in query_seqs_unique: + template_feature = mk_mock_template(query_seq) + template_features_.append(template_feature) + if not use_templates: template_features = template_features_ + else: + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ + = unserialize_msa(a3m_lines, query_sequence) + if not use_templates: template_features = template_features_ # save a3m msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) result_dir.joinpath(f"{jobname}.a3m").write_text(msa) @@ -347,13 +353,6 @@ def run( ) first_job = False - if "multimer" in model_suffix: - for idx in range(num_models): - if max_msa_cluster == None: - model_runner_and_params[idx][1].config.model.embeddings_and_evoformer.num_unpadded_seqs = 0 - else: - model_runner_and_params[idx][1].config.model.embeddings_and_evoformer.num_unpadded_seqs = int(len(feature_dict["msa"])) - model_runner_and_params[idx][1].config.model.embeddings_and_evoformer.extra_msa_seqs = int(len(feature_dict["msa"])) - 508 results = predict_structure( prefix=jobname, result_dir=result_dir,