Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions functions/colabdesign_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
####################################
############## ColabDesign functions
####################################
# MODIFICATION FOR PARALLEL EXECUTION
# Minimal changes from v1.5.1 to allow safe execution as a SLURM job array.

### Import dependencies
import os, re, shutil, math, pickle
import matplotlib.pyplot as plt
Expand All @@ -18,7 +21,7 @@
from .generic_utils import update_failures

# hallucinate a binder
def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residues, length, seed, helicity_value, design_models, advanced_settings, design_paths, failure_csv):
def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residues, length, seed, helicity_value, design_models, advanced_settings, design_paths, failure_csv, failure_csv_lock):
model_pdb_path = os.path.join(design_paths["Trajectory"], design_name+".pdb")

# clear GPU memory for new trajectory
Expand Down Expand Up @@ -157,15 +160,15 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu
num_models=1, sample_models=advanced_settings["sample_models"], ramp_models=False, save_best=True)

else:
update_failures(failure_csv, 'Trajectory_one-hot_pLDDT')
update_failures(failure_csv, 'Trajectory_one-hot_pLDDT', failure_csv_lock)
print("One-hot trajectory pLDDT too low to continue: "+str(onehot_plddt))

else:
update_failures(failure_csv, 'Trajectory_softmax_pLDDT')
update_failures(failure_csv, 'Trajectory_softmax_pLDDT', failure_csv_lock)
print("Softmax trajectory pLDDT too low to continue: "+str(softmax_plddt))

else:
update_failures(failure_csv, 'Trajectory_logits_pLDDT')
update_failures(failure_csv, 'Trajectory_logits_pLDDT', failure_csv_lock)
print("Initial trajectory pLDDT too low to continue: "+str(initial_plddt))

else:
Expand All @@ -186,14 +189,14 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu
#if clash_interface > 25 or ca_clashes > 0:
if ca_clashes > 0:
af_model.aux["log"]["terminate"] = "Clashing"
update_failures(failure_csv, 'Trajectory_Clashes')
update_failures(failure_csv, 'Trajectory_Clashes', failure_csv_lock)
print("Severe clashes detected, skipping analysis and MPNN optimisation")
print("")
else:
# check if low quality prediction
if final_plddt < 0.7:
af_model.aux["log"]["terminate"] = "LowConfidence"
update_failures(failure_csv, 'Trajectory_final_pLDDT')
update_failures(failure_csv, 'Trajectory_final_pLDDT', failure_csv_lock)
print("Trajectory starting confidence low, skipping analysis and MPNN optimisation")
print("")
else:
Expand All @@ -204,7 +207,7 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu
# if less than 3 contacts then protein is floating above and is not binder
if binder_contacts_n < 3:
af_model.aux["log"]["terminate"] = "LowConfidence"
update_failures(failure_csv, 'Trajectory_Contacts')
update_failures(failure_csv, 'Trajectory_Contacts', failure_csv_lock)
print("Too few contacts at the interface, skipping analysis and MPNN optimisation")
print("")
else:
Expand Down Expand Up @@ -235,7 +238,7 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu
return af_model

# run prediction for binder with masked template target
def predict_binder_complex(prediction_model, binder_sequence, mpnn_design_name, target_pdb, chain, length, trajectory_pdb, prediction_models, advanced_settings, filters, design_paths, failure_csv, seed=None):
def predict_binder_complex(prediction_model, binder_sequence, mpnn_design_name, target_pdb, chain, length, trajectory_pdb, prediction_models, advanced_settings, filters, design_paths, failure_csv, failure_csv_lock, seed=None):
prediction_stats = {}

# clean sequence
Expand Down Expand Up @@ -290,7 +293,7 @@ def predict_binder_complex(prediction_model, binder_sequence, mpnn_design_name,

# Update the CSV file with the failure counts
if filter_failures:
update_failures(failure_csv, filter_failures)
update_failures(failure_csv, filter_failures, failure_csv_lock)

# AF2 filters passed, contuing with relaxation
for model_num in prediction_models:
Expand Down