diff --git a/colabfold/alphafold/msa.py b/colabfold/alphafold/msa.py index 65bd1838..9d03a0d3 100644 --- a/colabfold/alphafold/msa.py +++ b/colabfold/alphafold/msa.py @@ -47,11 +47,11 @@ def make_fixed_size_multimer( feat: Mapping[str, Any], shape_schema, num_res, + msa_cluster_size, num_templates) -> FeatureDict: NUM_RES = "num residues placeholder" NUM_MSA_SEQ = "msa placeholder" NUM_TEMPLATES = "num templates placeholder" - msa_cluster_size = feat["bert_mask"].shape[0] pad_size_map = { NUM_RES: num_res, NUM_MSA_SEQ: msa_cluster_size, diff --git a/colabfold/batch.py b/colabfold/batch.py index eeefa51b..f29be882 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -21,11 +21,12 @@ from colabfold.inputs import ( get_queries_pairwise, unpack_a3ms, - parse_fasta, get_queries, + parse_fasta, get_queries, msa_to_str ) +from colabfold.run_alphafold import set_model_type from colabfold.download import default_data_dir, download_alphafold_params - +import sys import logging logger = logging.getLogger(__name__) @@ -299,24 +300,43 @@ def main(): headers_list[0].remove(headers_list[0][0]) header_first = headers[0] + queries_temp = [] + queries_rest = [] + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + max_msa_cluster = 0 + else: + max_msa_cluster = None for jobname, batch in enumerate(output): query_seqs_unique = [] for x in batch: if x not in query_seqs_unique: query_seqs_unique.append(x) + query_seqs_cardinality = [0] * len(query_seqs_unique) + for seq in batch: + seq_idx = query_seqs_unique.index(seq) + query_seqs_cardinality[seq_idx] += 1 use_env = "env" in args.msa_mode or "Environmental" in args.msa_mode paired_a3m_lines = run_mmseqs2( query_seqs_unique, - str(Path(args.results).joinpath(str(jobname))), + str(Path(args.results).joinpath(str(jobname)+"_paired")), use_env=use_env, use_pairwise=True, use_pairing=True, host_url=args.host_url, ) - - path_o = Path(args.results).joinpath(f"{jobname}_pairwise") + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env") + unpaired_a3m_lines = run_mmseqs2( + query_seqs_unique, + str(Path(args.results).joinpath(str(jobname)+"_unpaired")), + use_env=use_env, + use_pairwise=False, + use_pairing=False, + host_url=args.host_url, + ) + path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise") for filenum in path_o.iterdir(): - queries_new = [] + queries_new = [] if Path(filenum).suffix.lower() == ".a3m": outdir = path_o.joinpath("tmp") unpack_a3ms(filenum, outdir) @@ -326,14 +346,49 @@ def main(): query_sequence = seqs[0] a3m_lines = [Path(file).read_text()] val = int(header[0].split('\t')[1][1:]) - 102 + # match paired seq id and unpaired seq id + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + paired_query_a3m_lines = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1] + # a3m_lines = [msa_to_str( + # [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1] + # )] + ## Another way: do not use msa_to_str and unserialize function rather + ## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments.. + a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]] queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines)) + ### generate features then find max_msa_cluster + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + inputs = ([batch[0], batch[val+1]], [1, 1], [unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [paired_query_a3m_lines, paired_a3m_lines[val+1]]) + from colabfold.inputs import generate_msa_size + msa_size = generate_msa_size(inputs, query_seqs_unique, args.use_templates, is_complex, model_type) + # config.model.embeddings_and_evoformer.extra_msa_seqs=2048 + # config.model.embeddings_and_evoformer.num_msa=508 + # if msa_size < 2048 + 508, pop the sequences and run the model with recompilation + if msa_size < 2556: + queries_rest.append(queries_new.pop()) + continue + max_msa_cluster = max(max_msa_cluster, msa_size) + if args.sort_queries_by == "length": queries_new.sort(key=lambda t: len(''.join(t[1])),reverse=True) elif args.sort_queries_by == "random": random.shuffle(queries_new) + queries_temp.append(queries_new) + + queries_sel = sum(queries_temp, []) + run_params["max_msa_cluster"] = max_msa_cluster + run_params["interaction_scan"] = args.interaction_scan + run(queries=queries_sel, **run_params) - run(queries=queries_new, **run_params) + if args.pair_mode in ("none", "unpaired", "unpaired_paired"): + if len(queries_rest) > 0: + if args.sort_queries_by == "length": + queries_rest.sort(key=lambda t: len(''.join(t[1])),reverse=True) + elif args.sort_queries_by == "random": + random.shuffle(queries_rest) + run_params["max_msa_cluster"] = None + run(queries=queries_rest, **run_params) else: run(queries=queries, **run_params) diff --git a/colabfold/inputs.py b/colabfold/inputs.py index cf66b89f..2358c96d 100644 --- a/colabfold/inputs.py +++ b/colabfold/inputs.py @@ -83,9 +83,9 @@ def pad_input_multimer( model_runner: model.RunModel, model_name: str, pad_len: int, + msa_cluster_size: Optional[int], use_templates: bool, ) -> model.features.FeatureDict: - model_config = model_runner.config shape_schema = { "aatype": ["num residues placeholder"], "residue_index": ["num residues placeholder"], @@ -123,6 +123,7 @@ def pad_input_multimer( input_features, shape_schema, num_res=pad_len, + msa_cluster_size=msa_cluster_size, num_templates=4, ) return input_fix @@ -654,6 +655,20 @@ def generate_input_feature( } return (input_feature, domain_names) +def generate_msa_size(inputs, query_seqs_unique, use_templates, is_complex, model_type): + template_features_ = [] + from colabfold.inputs import mk_mock_template + from colabfold.inputs import generate_input_feature + for query_seq in query_seqs_unique: + template_feature = mk_mock_template(query_seq) + template_features_.append(template_feature) + if not use_templates: template_features = template_features_ + else: raise NotImplementedError + + (feature_dict, _) \ + = generate_input_feature(*inputs, template_features, is_complex, model_type) + return feature_dict["bert_mask"].shape[0] + def unserialize_msa( a3m_lines: List[str], query_sequence: Union[List[str], str] ) -> Tuple[ @@ -696,7 +711,7 @@ def unserialize_msa( ) prev_query_start += query_len paired_msa = [""] * len(query_seq_len) - unpaired_msa = None + unpaired_msa = [""] * len(query_seq_len) already_in = dict() for i in range(1, len(a3m_lines), 2): header = a3m_lines[i] @@ -734,7 +749,6 @@ def unserialize_msa( paired_msa[j] += ">" + header_no_faster_split[j] + "\n" paired_msa[j] += seqs_line[j] + "\n" else: - unpaired_msa = [""] * len(query_seq_len) for j, seq in enumerate(seqs_line): if has_amino_acid[j]: unpaired_msa[j] += header + "\n" @@ -752,6 +766,8 @@ def unserialize_msa( template_feature = mk_mock_template(query_seq) template_features.append(template_feature) + if unpaired_msa == [""] * len(query_seq_len): + unpaired_msa = None return ( unpaired_msa, paired_msa, diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index a333708b..1d87dca2 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -11,12 +11,50 @@ from argparse import ArgumentParser from pathlib import Path from typing import List, Union -import os +import os, pandas -from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise +from colabfold.inputs import get_queries, msa_to_str, parse_fasta +from typing import Any, Callable, Dict, List, Optional, Tuple, Union logger = logging.getLogger(__name__) +def get_queries_pairwise( + input_path: Union[str, Path], sort_queries_by: str = "length" +) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]: + """Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple + of job name, sequence and the optional a3m lines""" + input_path = Path(input_path) + if not input_path.exists(): + raise OSError(f"{input_path} could not be found") + if input_path.is_file(): + if input_path.suffix == ".csv" or input_path.suffix == ".tsv": + sep = "\t" if input_path.suffix == ".tsv" else "," + df = pandas.read_csv(input_path, sep=sep) + assert "id" in df.columns and "sequence" in df.columns + queries = [ + (str(df["id"][0])+'&'+str(seq_id), sequence.upper(), None) + for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) + ] + elif input_path.suffix == ".a3m": + raise NotImplementedError() + elif input_path.suffix in [".fasta", ".faa", ".fa"]: + (sequences, headers) = parse_fasta(input_path.read_text()) + queries = [] + for i, (sequence, header) in enumerate(zip(sequences, headers)): + sequence = sequence.upper() + if sequence.count(":") == 0: + # Single sequence + queries.append((header, sequence, None)) + else: + # Complex mode + queries.append((header, sequence.upper().split(":"), None)) + else: + raise ValueError(f"Unknown file format {input_path.suffix}") + else: + raise NotImplementedError() + + is_complex = True + return queries, is_complex def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]): params_log = " ".join(str(i) for i in params) @@ -61,7 +99,6 @@ def mmseqs_search_monomer( used_dbs.append(template_db) if use_env: used_dbs.append(metagenomic_db) - for db in used_dbs: if not dbbase.joinpath(f"{db}.dbtype").is_file(): raise FileNotFoundError(f"Database {db} does not exist") @@ -405,9 +442,9 @@ def main(): args = parser.parse_args() if args.interaction_scan: - queries, is_complex = get_queries_pairwise(args.query, None) + queries, is_complex = get_queries_pairwise(args.query) else: - queries, is_complex = get_queries(args.query, None) + queries, is_complex = get_queries(args.query) queries_unique = [] for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries): @@ -437,10 +474,9 @@ def main(): query_seqs_cardinality, ) in enumerate(queries_unique): if job_number==0: - f.write(f">{raw_jobname}_0\n{query_sequences[0]}\n") - f.write(f">{raw_jobname}\n{query_sequences[1]}\n") + f.write(f">{raw_jobname}_0\n{query_sequences}\n") else: - f.write(f">{raw_jobname}\n{query_sequences[1]}\n") + f.write(f">{queries_unique[0][0]+'&'+raw_jobname}\n{query_sequences}\n") else: with query_file.open("w") as f: for job_number, ( @@ -454,18 +490,6 @@ def main(): args.mmseqs, ["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"], ) - with args.base.joinpath("qdb.lookup").open("w") as f: - id = 0 - file_number = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - for seq in query_sequences: - f.write(f"{id}\t{raw_jobname}\t{file_number}\n") - id += 1 - file_number += 1 mmseqs_search_monomer( mmseqs=args.mmseqs, @@ -498,30 +522,66 @@ def main(): interaction_scan=args.interaction_scan, ) + if args.interaction_scan: + if len(queries_unique) > 1: + for i in range(len(queries_unique)-2): + idx = 2 + i*2 + ## delete duplicated query files 2.paired, 4.paired... + os.remove(args.base.joinpath(f"{idx}.paired.a3m")) + for j in range(len(queries_unique)-2): + # replace targets' right file name + id1 = j*2 + 3 + id2 = j + 2 + os.replace(args.base.joinpath(f"{id1}.paired.a3m"), args.base.joinpath(f"{id2}.paired.a3m")) + id = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - unpaired_msa = [] - paired_msa = None - if len(query_seqs_cardinality) > 1: + if not args.interaction_scan: + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + unpaired_msa = [] + paired_msa = None + if len(query_seqs_cardinality) > 1: + paired_msa = [] + else: + for seq in query_sequences: + with args.base.joinpath(f"{id}.a3m").open("r") as f: + unpaired_msa.append(f.read()) + args.base.joinpath(f"{id}.a3m").unlink() + if len(query_seqs_cardinality) > 1: + with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + args.base.joinpath(f"{id}.paired.a3m").unlink() + id += 1 + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality + ) + args.base.joinpath(f"{job_number}.a3m").write_text(msa) + else: + for job_number, _ in enumerate(queries_unique[:-1]): + query_sequences = [queries_unique[0][1], queries_unique[job_number+1][1]] + unpaired_msa = [] paired_msa = [] - for seq in query_sequences: - with args.base.joinpath(f"{id}.a3m").open("r") as f: + with args.base.joinpath(f"0.a3m").open("r") as f: + unpaired_msa.append(f.read()) + with args.base.joinpath(f"{job_number+1}.a3m").open("r") as f: unpaired_msa.append(f.read()) - args.base.joinpath(f"{id}.a3m").unlink() - if len(query_seqs_cardinality) > 1: - with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: - paired_msa.append(f.read()) - args.base.joinpath(f"{id}.paired.a3m").unlink() - id += 1 - msa = msa_to_str( - unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality - ) - args.base.joinpath(f"{job_number}.a3m").write_text(msa) + with args.base.joinpath(f"0.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + with args.base.joinpath(f"{job_number+1}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, [1,1] + ) + args.base.joinpath(f"{job_number}_final.a3m").write_text(msa) + for job_number, _ in enumerate(queries_unique): + args.base.joinpath(f"{job_number}.a3m").unlink() + args.base.joinpath(f"{job_number}.paired.a3m").unlink() + for job_number, _ in enumerate(queries_unique[:-1]): + os.replace(args.base.joinpath(f"{job_number}_final.a3m"), args.base.joinpath(f"{job_number}.a3m")) query_file.unlink() run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")]) run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")]) diff --git a/colabfold/predict.py b/colabfold/predict.py index 2cf88196..45ba9d06 100644 --- a/colabfold/predict.py +++ b/colabfold/predict.py @@ -39,6 +39,7 @@ def predict_structure( save_single_representations: bool = False, save_pair_representations: bool = False, save_recycles: bool = False, + max_msa_cluster: Optional[int] = None, ): """Predicts structure using AlphaFold for the given sequence.""" @@ -69,8 +70,12 @@ def predict_structure( input_features = feature_dict input_features["asym_id"] = input_features["asym_id"] - input_features["asym_id"][...,0] # TODO + if max_msa_cluster == None: + msa_cluster_size = input_features["bert_mask"].shape[0] + else: + msa_cluster_size = max_msa_cluster + input_features = pad_input_multimer(input_features, model_runner, model_name, pad_len, msa_cluster_size, use_templates) if seq_len < pad_len: - input_features = pad_input_multimer(input_features, model_runner, model_name, pad_len, use_templates) logger.info(f"Padding length to {pad_len}") else: if model_num == 0: diff --git a/colabfold/run_alphafold.py b/colabfold/run_alphafold.py index 94ee7248..4ec404a7 100644 --- a/colabfold/run_alphafold.py +++ b/colabfold/run_alphafold.py @@ -133,6 +133,8 @@ def run( use_fuse = kwargs.pop("use_fuse", True) use_bfloat16 = kwargs.pop("use_bfloat16", True) max_msa = kwargs.pop("max_msa",None) + max_msa_cluster = kwargs.pop("max_msa_cluster", None) + interaction_scan = kwargs.pop("interaction_scan", True) if max_msa is not None: max_seq, max_extra_seq = [int(x) for x in max_msa.split(":")] @@ -256,10 +258,23 @@ def run( = get_msa_and_templates(jobname, query_sequence, result_dir, msa_mode, use_templates, custom_template_path, pair_mode, host_url) if a3m_lines is not None: - (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ - = unserialize_msa(a3m_lines, query_sequence) - if not use_templates: template_features = template_features_ - + # (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ + # = unserialize_msa(a3m_lines, query_sequence) + # if not use_templates: template_features = template_features_ + ## Another way passing argument + ## + if interaction_scan == True and pair_mode in ("none", "unpaired", "unpaired_paired"): + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines + template_features_ = [] + from colabfold.inputs import mk_mock_template + for query_seq in query_seqs_unique: + template_feature = mk_mock_template(query_seq) + template_features_.append(template_feature) + if not use_templates: template_features = template_features_ + else: + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \ + = unserialize_msa(a3m_lines, query_sequence) + if not use_templates: template_features = template_features_ # save a3m msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) result_dir.joinpath(f"{jobname}.a3m").write_text(msa) @@ -359,6 +374,7 @@ def run( save_single_representations=save_single_representations, save_pair_representations=save_pair_representations, save_recycles=save_recycles, + max_msa_cluster=max_msa_cluster, ) result_files = results["result_files"] ranks.append(results["rank"])