From 689500643e667cf077242d5f307f2b2060f9f2f5 Mon Sep 17 00:00:00 2001 From: mmb78 <62362216+mmb78@users.noreply.github.com> Date: Tue, 24 Jun 2025 22:39:35 +0200 Subject: [PATCH] Update colabdesign_utils.py to allow safe parallel execution # MODIFICATION FOR PARALLEL EXECUTION # Minimal changes from v1.5.1 to allow safe execution as a SLURM job array. --- functions/colabdesign_utils.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/functions/colabdesign_utils.py b/functions/colabdesign_utils.py index 4bb1db5..d5c878a 100644 --- a/functions/colabdesign_utils.py +++ b/functions/colabdesign_utils.py @@ -1,6 +1,9 @@ #################################### ############## ColabDesign functions #################################### +# MODIFICATION FOR PARALLEL EXECUTION +# Minimal changes from v1.5.1 to allow safe execution as a SLURM job array. + ### Import dependencies import os, re, shutil, math, pickle import matplotlib.pyplot as plt @@ -18,7 +21,7 @@ from .generic_utils import update_failures # hallucinate a binder -def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residues, length, seed, helicity_value, design_models, advanced_settings, design_paths, failure_csv): +def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residues, length, seed, helicity_value, design_models, advanced_settings, design_paths, failure_csv, failure_csv_lock): model_pdb_path = os.path.join(design_paths["Trajectory"], design_name+".pdb") # clear GPU memory for new trajectory @@ -157,15 +160,15 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu num_models=1, sample_models=advanced_settings["sample_models"], ramp_models=False, save_best=True) else: - update_failures(failure_csv, 'Trajectory_one-hot_pLDDT') + update_failures(failure_csv, 'Trajectory_one-hot_pLDDT', failure_csv_lock) print("One-hot trajectory pLDDT too low to continue: "+str(onehot_plddt)) else: - update_failures(failure_csv, 'Trajectory_softmax_pLDDT') + update_failures(failure_csv, 'Trajectory_softmax_pLDDT', failure_csv_lock) print("Softmax trajectory pLDDT too low to continue: "+str(softmax_plddt)) else: - update_failures(failure_csv, 'Trajectory_logits_pLDDT') + update_failures(failure_csv, 'Trajectory_logits_pLDDT', failure_csv_lock) print("Initial trajectory pLDDT too low to continue: "+str(initial_plddt)) else: @@ -186,14 +189,14 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu #if clash_interface > 25 or ca_clashes > 0: if ca_clashes > 0: af_model.aux["log"]["terminate"] = "Clashing" - update_failures(failure_csv, 'Trajectory_Clashes') + update_failures(failure_csv, 'Trajectory_Clashes', failure_csv_lock) print("Severe clashes detected, skipping analysis and MPNN optimisation") print("") else: # check if low quality prediction if final_plddt < 0.7: af_model.aux["log"]["terminate"] = "LowConfidence" - update_failures(failure_csv, 'Trajectory_final_pLDDT') + update_failures(failure_csv, 'Trajectory_final_pLDDT', failure_csv_lock) print("Trajectory starting confidence low, skipping analysis and MPNN optimisation") print("") else: @@ -204,7 +207,7 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu # if less than 3 contacts then protein is floating above and is not binder if binder_contacts_n < 3: af_model.aux["log"]["terminate"] = "LowConfidence" - update_failures(failure_csv, 'Trajectory_Contacts') + update_failures(failure_csv, 'Trajectory_Contacts', failure_csv_lock) print("Too few contacts at the interface, skipping analysis and MPNN optimisation") print("") else: @@ -235,7 +238,7 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu return af_model # run prediction for binder with masked template target -def predict_binder_complex(prediction_model, binder_sequence, mpnn_design_name, target_pdb, chain, length, trajectory_pdb, prediction_models, advanced_settings, filters, design_paths, failure_csv, seed=None): +def predict_binder_complex(prediction_model, binder_sequence, mpnn_design_name, target_pdb, chain, length, trajectory_pdb, prediction_models, advanced_settings, filters, design_paths, failure_csv, failure_csv_lock, seed=None): prediction_stats = {} # clean sequence @@ -290,7 +293,7 @@ def predict_binder_complex(prediction_model, binder_sequence, mpnn_design_name, # Update the CSV file with the failure counts if filter_failures: - update_failures(failure_csv, filter_failures) + update_failures(failure_csv, filter_failures, failure_csv_lock) # AF2 filters passed, contuing with relaxation for model_num in prediction_models: