From 65c14ddbea9eb5eceb0aa836c96e451dc3a5c947 Mon Sep 17 00:00:00 2001 From: XC Date: Thu, 30 May 2024 09:15:11 -0400 Subject: [PATCH] Add threads option to prediction (replacing multiprocessing.cpu_count()). Set tensorflow version to 2.14.0 for compatibility of using OrderedEnqueuer in training. --- docs/readme/predict.md | 4 ++++ maxatac/analyses/predict.py | 2 +- maxatac/utilities/parser.py | 6 ++++++ setup.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/readme/predict.md b/docs/readme/predict.md index 2572889..47b3d17 100644 --- a/docs/readme/predict.md +++ b/docs/readme/predict.md @@ -85,3 +85,7 @@ The windows to use for prediction. These windows must be 1,024 bp wide and have ### `-skip_call_peaks, --skip_call_peaks` This will skip calling peaks at the end of predictions. + +### `--threads` + +Set number of parallel threads in prediction tasks. If GPUs are used, set this value to be the number of GPUs used for the task. Default: 24. \ No newline at end of file diff --git a/maxatac/analyses/predict.py b/maxatac/analyses/predict.py index f6c9c47..f41b6ac 100644 --- a/maxatac/analyses/predict.py +++ b/maxatac/analyses/predict.py @@ -102,7 +102,7 @@ def run_prediction(args): f"Output filename: {outfile_name_bigwig}" ) - with Pool(int(multiprocessing.cpu_count())) as p: + with Pool(args.threads) as p: forward_strand_predictions = p.starmap(make_stranded_predictions, [(regions_pool, args.signal, diff --git a/maxatac/utilities/parser.py b/maxatac/utilities/parser.py index 1ce7b01..7aaa9de 100644 --- a/maxatac/utilities/parser.py +++ b/maxatac/utilities/parser.py @@ -330,6 +330,12 @@ def get_parser(): help="Skip calling peaks on prediction tracks" ) + predict_parser.add_argument("--threads", + dest="threads", + type=int, + default=24, + help="Number of processes to run prediction in parallel. Default: 24" + ) ############################################# # Train parser ############################################# diff --git a/setup.py b/setup.py index b4b037e..960a83c 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def get_description(): ] ), install_requires=[ - "tensorflow", + "tensorflow==2.14.0", "tensorboard", "biopython", "py2bit",