diff --git a/Python/tigre/algorithms/__init__.py b/Python/tigre/algorithms/__init__.py index c3f1e926..557ab968 100644 --- a/Python/tigre/algorithms/__init__.py +++ b/Python/tigre/algorithms/__init__.py @@ -7,6 +7,7 @@ from .art_family_algorithms import ossart from .art_family_algorithms import sart_tv from .art_family_algorithms import ossart_tv +from .art_family_algorithms import fast_os_sart from .ista_algorithms import fista from .ista_algorithms import ista from .iterative_recon_alg import iterativereconalg @@ -38,6 +39,7 @@ "ossart", "sart_tv", "ossart_tv", + "fast_os_sart", "iterativereconalg", "FDK", "asd_pocs", diff --git a/Python/tigre/algorithms/art_family_algorithms.py b/Python/tigre/algorithms/art_family_algorithms.py index 18f0fa33..ce6f6465 100644 --- a/Python/tigre/algorithms/art_family_algorithms.py +++ b/Python/tigre/algorithms/art_family_algorithms.py @@ -1,5 +1,5 @@ import copy - +import numpy as np from tigre.algorithms.iterative_recon_alg import IterativeReconAlg from tigre.algorithms.iterative_recon_alg import decorator from tigre.utilities.im_3d_denoise import im3ddenoise @@ -150,3 +150,42 @@ def run_main_iter(self): self.error_measurement(res_prev, i) ossart_tv = decorator(OSSART_TV, name="ossart_tv") + +class Fast_OS_SART(IterativeReconAlg): + __doc__ = ( + "Fast_OS_SART solves Cone Beam CT image reconstruction using Nesterov accelerated\n" + "Oriented Subsets Simultaneous Algebraic Reconstruction Technique algorithm\n" + "Fast_OS_SART(PROJ,GEO,ALPHA,NITER,BLOCKSIZE=20) solves the reconstruction problem\n" + "using the projection data PROJ taken over ALPHA angles, corresponding\n" + "to the geometry described in GEO, using NITER iterations.\n" + ) + IterativeReconAlg.__doc__ + + def __init__(self, proj, geo, angles, niter, **kwargs): + self.blocksize = 20 if 'blocksize' not in kwargs else kwargs["blocksize"] + IterativeReconAlg.__init__(self, proj, geo, angles, niter, **kwargs) + self.__t__ = 1.0 + + def run_main_iter(self): + Quameasopts = self.Quameasopts + t = self.__t__ + y_rec = copy.deepcopy(self.res) + + for i in range(self.niter): + res_prev = copy.deepcopy(self.res) if Quameasopts is not None else None + if self.verbose: + self._estimate_time_until_completion(i) + + x_rec_old = copy.deepcopy(self.res) + + self.res = copy.deepcopy(y_rec) + getattr(self, self.dataminimizing)() + + t_old = t + t = (1.0 + np.sqrt(1.0 + 4.0 * t ** 2)) / 2.0 + y_rec = self.res + (t_old - 1.0) / t * (self.res - x_rec_old) + y_rec = np.float32(y_rec) + + if Quameasopts is not None: + self.error_measurement(res_prev, i) + +fast_os_sart = decorator(Fast_OS_SART, name="fast_os_sart") diff --git a/generate_benchmarks.py b/generate_benchmarks.py new file mode 100644 index 00000000..2c7a8a2c --- /dev/null +++ b/generate_benchmarks.py @@ -0,0 +1,70 @@ +import sys +import time +import numpy as np +import matplotlib.pyplot as plt + +# Attempt to import TIGRE +try: + import tigre + import tigre.algorithms as algs + from tigre.utilities.sample_loader import load_head_phantom + from tigre.utilities.Measure_Quality import Measure_Quality +except ImportError: + print("ERROR: TIGRE is not properly installed or compiled.") + print("Please run this script in an environment with TIGRE's C++/CUDA backend compiled.") + sys.exit(1) + +def run_benchmarks(): + print("--- Setting up TIGRE Geometry & Phantom ---") + # 1. Setup geometry and phantom + geo = tigre.geometry_default(high_resolution=False) + geo.nVoxel = np.array([64, 64, 64]) # Use small voxel size for fast benchmarking + + # Generate angles + angles = np.linspace(0, 2 * np.pi, 100) + + # Load phantom + head = load_head_phantom(geo.nVoxel) + + # Generate projection data + print("Generating forward projections...") + proj = tigre.Ax(head, geo, angles) + + niter = 30 + blocksize = 20 + + print("\n--- Benchmark 1: Convergence Speed & Time (OS_SART vs Fast_OS_SART) ---") + + # Standard OS_SART + print("Running standard OS_SART...") + start_time = time.time() + res_os_sart, err_os_sart = algs.ossart(proj, geo, angles, niter=niter, blocksize=blocksize, computel2=True) + time_os_sart = time.time() - start_time + + # Fast OS_SART + print("Running Fast_OS_SART...") + start_time = time.time() + res_fast, err_fast = algs.fast_os_sart(proj, geo, angles, niter=niter, blocksize=blocksize, computel2=True) + time_fast = time.time() - start_time + + print(f"OS_SART Total Time: {time_os_sart:.2f}s ({time_os_sart/niter:.3f}s per iteration)") + print(f"Fast_OS_SART Total Time: {time_fast:.2f}s ({time_fast/niter:.3f}s per iteration)") + print(f"Final L2 Error -> OS_SART: {err_os_sart[0][-1]:.4f} | Fast_OS_SART: {err_fast[0][-1]:.4f}") + + # Plot convergence + plt.figure(figsize=(8, 5)) + plt.plot(err_os_sart[0], label="OS_SART", linewidth=2) + plt.plot(err_fast[0], label="Fast_OS_SART", linewidth=2) + plt.title("Convergence Speed: OS_SART vs Fast_OS_SART") + plt.xlabel("Iteration") + plt.ylabel("L2 Error") + plt.legend() + plt.grid(True) + plt.savefig("convergence_benchmark.png") + print("Saved convergence plot to convergence_benchmark.png") + + print("\n--- Benchmarks Complete! ---") + print("You can copy these results into your GitHub PR.") + +if __name__ == '__main__': + run_benchmarks()