Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@ class ModelsToRelax(enum.Enum):
'check if the sequence, database or configuration have '
'changed.',
)
flags.DEFINE_boolean(
'use_precomputed_features',
False,
'Whether to read features that have been written to disk instead '
'of running the template search. The features.pkl file is looked '
'up in the output directory, so it must stay the same between '
'multiple runs that are to reuse the features. WARNING: This will '
'not check if the sequence, database or configuration have '
'changed.',
)
flags.DEFINE_enum_class(
'models_to_relax',
ModelsToRelax.BEST,
Expand Down Expand Up @@ -352,6 +362,7 @@ def predict_structure(
benchmark: bool,
random_seed: int,
models_to_relax: ModelsToRelax,
use_precomputed_features: bool,
model_type: str,
):
"""Predicts structure using AlphaFold for the given sequence."""
Expand All @@ -364,17 +375,29 @@ def predict_structure(
if not os.path.exists(msa_output_dir):
os.makedirs(msa_output_dir)

# Get features.
t_0 = time.time()
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path, msa_output_dir=msa_output_dir
)
timings['features'] = time.time() - t_0

# Write out features as a pickled dictionary.
features_output_path = os.path.join(output_dir, 'features.pkl')
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)
if use_precomputed_features and not os.path.exists(features_output_path):
logging.warning(
'use_precomputed_features is set but %s does not exist, running '
'full feature pipeline',
features_output_path,
)

if use_precomputed_features and os.path.exists(features_output_path):
logging.info('Reading features from %s', features_output_path)
with open(features_output_path, 'rb') as f:
feature_dict = pickle.load(f)
else:
# Get features.
t_0 = time.time()
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir)
timings['features'] = time.time() - t_0

# Write out features as a pickled dictionary.
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

unrelaxed_pdbs = {}
unrelaxed_proteins = {}
Expand Down Expand Up @@ -712,6 +735,7 @@ def main(argv):
benchmark=FLAGS.benchmark,
random_seed=random_seed,
models_to_relax=FLAGS.models_to_relax,
use_precomputed_features=FLAGS.use_precomputed_features,
model_type=model_type,
)

Expand Down