
"""
Minimal benchmark: elastix.exe without TBB vs elastix.exe with TBB.

Compares the performance of using internally using ITK's ThreadPool versus TBB (Threading Building Blocks).

Author: Niels Dekker, LKEB, Leiden University Medical Center, 2026
Inspired by benchmark-itk.zip by Nicolas Chiaruttini, 22 march 2026
from https://discourse.itk.org/t/8x-slower-registration-with-itk-elastix-python-api-vs-elastix-cli-minimal-reproducible-example/7736

Usage
-----
    TBB-bench.py
        --elastix-without-TBB /file-path/to/elastix-without-TBB.exe
        --elastix-with-TBB /file-path/to/elastix-with-TBB.exe
        --fixed /file-path/to/blobs.tif
        --moving /file-path/to/blobs-rot15deg.tif
        --param-file  /file-path/to/params_bspline.txt
        --out /path/to/output/directory
        --min-threads 1
        --max-threads 42
"""

import argparse
import os
import subprocess
import sys
import time


def run_cli(elastix_exe, fixed, moving, param_file, out, n_threads):

    out += "/out_"
    if n_threads == "":
        out += "default_n_threads"
    else:
        out += "n_threads=" + n_threads

    os.makedirs(out, exist_ok=True)
    cmd = [
        elastix_exe,
        "-f",       fixed,
        "-m",       moving,
        "-p",       param_file,
        "-out",     out,
        "-threads", n_threads,
    ]
    t0 = time.perf_counter()
    result = subprocess.run(cmd, capture_output=True, text=True)
    elapsed = time.perf_counter() - t0
    if result.returncode != 0:
        print("  [CLI] STDERR (last 600 chars):", result.stderr[-600:], file=sys.stderr)
        raise RuntimeError(f"elastix CLI failed (rc={result.returncode})")
    return elapsed


def run_multiple_times(elastix_exe, fixed, moving, param_file, out, n_threads):
    elapsed = run_cli(elastix_exe, fixed, moving, param_file, out, n_threads)
    for i in range(2):
        elapsed = min(elapsed, run_cli(elastix_exe, fixed, moving, param_file, out, n_threads))
    return elapsed


def main():
    parser = argparse.ArgumentParser(
        description="Benchmark elastix CLI vs itk-elastix Python API"
    )
    parser.add_argument("--elastix-without-TBB", required=True,
                        help="Path to elastix executable that was built using ITK's ThreadPool (without TBB)")
    parser.add_argument("--elastix-with-TBB", required=True,
                        help="Path to elastix executable that was built with TBB")
    parser.add_argument("--fixed",  required=True, help="Fixed image (TIFF/MHD/...)")
    parser.add_argument("--moving", required=True, help="Moving image")
    parser.add_argument("--param-file",  required=True, help="Parameter text file")
    parser.add_argument("--out",  required=True, help="Output directory. (Must already exist.)")
    parser.add_argument("--min-threads", type=int, default=1, help="Minimum number of threads.")
    parser.add_argument("--max-threads", type=int, required=True, help="Maximum number of threads.")

    args = parser.parse_args()
    print('TBB-bench arguments:')
    print('\n'.join(f'  {k}={v}' for k, v in vars(args).items()))
    print()

    t = run_multiple_times(args.elastix_without_TBB, args.fixed, args.moving, args.param_file, args.out + "/without-TBB", "")
    disabled_times = [t]
    print("Disable TBB: ", t, " seconds")
    t = run_multiple_times(args.elastix_with_TBB, args.fixed, args.moving, args.param_file, args.out + "/with-TBB", "")
    enabled_times = [t]
    print("Enable TBB: ", t, " seconds")

    print()
    for n_threads in range(args.min_threads, args.max_threads + 1):
        print(f"n_threads = {n_threads}")
        t = run_multiple_times(args.elastix_without_TBB, args.fixed, args.moving, args.param_file, args.out + "/without-TBB/", str(n_threads))
        print("  Disable TBB: ", t, " seconds")
        disabled_times.append(t)
        t = run_multiple_times(args.elastix_with_TBB, args.fixed, args.moving, args.param_file, args.out + "/with-TBB/", str(n_threads))
        print("  Enable TBB: ", t, " seconds")
        enabled_times.append(t)

    print("n_threads duration_without_TBB duration_with_TBB")
    print(0, disabled_times[0], enabled_times[0])

    for i in range(args.min_threads, args.max_threads + 1):
        print(i, disabled_times[i + 1 - args.min_threads], enabled_times[i + 1 - args.min_threads])


if __name__ == "__main__":
    main()