Skip to content

Commit a7a2f98

Browse files
committed
Add option to use precomputed templates
My experience is that template processing takes a huge fraction of the overall runtime. It's also all on CPU, which is a waste if you're using GPU machines. Note that this uses the features.pkl which is already saved by the existing code. For more stability across numpy versions, etc. we could instead use `npz` files with `numpy.save`/`numpy.load` as `features.pkl` just holds a dictionary of numpy arrays. I wanted to first gauge the reaction to this general idea though. See discussion in #895
1 parent 09ed0c5 commit a7a2f98

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

run_alphafold.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,16 @@ class ModelsToRelax(enum.Enum):
210210
'check if the sequence, database or configuration have '
211211
'changed.',
212212
)
213+
flags.DEFINE_boolean(
214+
'use_precomputed_features',
215+
False,
216+
'Whether to read features that have been written to disk instead '
217+
'of running the template search. The features.pkl file is looked '
218+
'up in the output directory, so it must stay the same between '
219+
'multiple runs that are to reuse the features. WARNING: This will '
220+
'not check if the sequence, database or configuration have '
221+
'changed.',
222+
)
213223
flags.DEFINE_enum_class(
214224
'models_to_relax',
215225
ModelsToRelax.BEST,
@@ -352,6 +362,7 @@ def predict_structure(
352362
benchmark: bool,
353363
random_seed: int,
354364
models_to_relax: ModelsToRelax,
365+
use_precomputed_features: bool,
355366
model_type: str,
356367
):
357368
"""Predicts structure using AlphaFold for the given sequence."""
@@ -364,17 +375,29 @@ def predict_structure(
364375
if not os.path.exists(msa_output_dir):
365376
os.makedirs(msa_output_dir)
366377

367-
# Get features.
368-
t_0 = time.time()
369-
feature_dict = data_pipeline.process(
370-
input_fasta_path=fasta_path, msa_output_dir=msa_output_dir
371-
)
372-
timings['features'] = time.time() - t_0
373-
374-
# Write out features as a pickled dictionary.
375378
features_output_path = os.path.join(output_dir, 'features.pkl')
376-
with open(features_output_path, 'wb') as f:
377-
pickle.dump(feature_dict, f, protocol=4)
379+
if use_precomputed_features and not os.path.exists(features_output_path):
380+
logging.warning(
381+
'use_precomputed_features is set but %s does not exist, running '
382+
'full feature pipeline',
383+
features_output_path,
384+
)
385+
386+
if use_precomputed_features and os.path.exists(features_output_path):
387+
logging.info('Reading features from %s', features_output_path)
388+
with open(features_output_path, 'rb') as f:
389+
feature_dict = pickle.load(f)
390+
else:
391+
# Get features.
392+
t_0 = time.time()
393+
feature_dict = data_pipeline.process(
394+
input_fasta_path=fasta_path,
395+
msa_output_dir=msa_output_dir)
396+
timings['features'] = time.time() - t_0
397+
398+
# Write out features as a pickled dictionary.
399+
with open(features_output_path, 'wb') as f:
400+
pickle.dump(feature_dict, f, protocol=4)
378401

379402
unrelaxed_pdbs = {}
380403
unrelaxed_proteins = {}
@@ -712,6 +735,7 @@ def main(argv):
712735
benchmark=FLAGS.benchmark,
713736
random_seed=random_seed,
714737
models_to_relax=FLAGS.models_to_relax,
738+
use_precomputed_features=FLAGS.use_precomputed_features,
715739
model_type=model_type,
716740
)
717741

0 commit comments

Comments
 (0)