Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Python/tigre/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .art_family_algorithms import ossart
from .art_family_algorithms import sart_tv
from .art_family_algorithms import ossart_tv
from .art_family_algorithms import fast_os_sart
from .ista_algorithms import fista
from .ista_algorithms import ista
from .iterative_recon_alg import iterativereconalg
Expand Down Expand Up @@ -38,6 +39,7 @@
"ossart",
"sart_tv",
"ossart_tv",
"fast_os_sart",
"iterativereconalg",
"FDK",
"asd_pocs",
Expand Down
41 changes: 40 additions & 1 deletion Python/tigre/algorithms/art_family_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy

import numpy as np
from tigre.algorithms.iterative_recon_alg import IterativeReconAlg
from tigre.algorithms.iterative_recon_alg import decorator
from tigre.utilities.im_3d_denoise import im3ddenoise
Expand Down Expand Up @@ -150,3 +150,42 @@ def run_main_iter(self):
self.error_measurement(res_prev, i)

ossart_tv = decorator(OSSART_TV, name="ossart_tv")

class Fast_OS_SART(IterativeReconAlg):
__doc__ = (
"Fast_OS_SART solves Cone Beam CT image reconstruction using Nesterov accelerated\n"
"Oriented Subsets Simultaneous Algebraic Reconstruction Technique algorithm\n"
"Fast_OS_SART(PROJ,GEO,ALPHA,NITER,BLOCKSIZE=20) solves the reconstruction problem\n"
"using the projection data PROJ taken over ALPHA angles, corresponding\n"
"to the geometry described in GEO, using NITER iterations.\n"
) + IterativeReconAlg.__doc__

def __init__(self, proj, geo, angles, niter, **kwargs):
self.blocksize = 20 if 'blocksize' not in kwargs else kwargs["blocksize"]
IterativeReconAlg.__init__(self, proj, geo, angles, niter, **kwargs)
self.__t__ = 1.0

def run_main_iter(self):
Quameasopts = self.Quameasopts
t = self.__t__
y_rec = copy.deepcopy(self.res)

for i in range(self.niter):
res_prev = copy.deepcopy(self.res) if Quameasopts is not None else None
if self.verbose:
self._estimate_time_until_completion(i)

x_rec_old = copy.deepcopy(self.res)

self.res = copy.deepcopy(y_rec)
getattr(self, self.dataminimizing)()

t_old = t
t = (1.0 + np.sqrt(1.0 + 4.0 * t ** 2)) / 2.0
y_rec = self.res + (t_old - 1.0) / t * (self.res - x_rec_old)
y_rec = np.float32(y_rec)

if Quameasopts is not None:
self.error_measurement(res_prev, i)

fast_os_sart = decorator(Fast_OS_SART, name="fast_os_sart")
70 changes: 70 additions & 0 deletions generate_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import sys
import time
import numpy as np
import matplotlib.pyplot as plt

# Attempt to import TIGRE
try:
import tigre
import tigre.algorithms as algs
from tigre.utilities.sample_loader import load_head_phantom
from tigre.utilities.Measure_Quality import Measure_Quality
except ImportError:
print("ERROR: TIGRE is not properly installed or compiled.")
print("Please run this script in an environment with TIGRE's C++/CUDA backend compiled.")
sys.exit(1)

def run_benchmarks():
print("--- Setting up TIGRE Geometry & Phantom ---")
# 1. Setup geometry and phantom
geo = tigre.geometry_default(high_resolution=False)
geo.nVoxel = np.array([64, 64, 64]) # Use small voxel size for fast benchmarking

# Generate angles
angles = np.linspace(0, 2 * np.pi, 100)

# Load phantom
head = load_head_phantom(geo.nVoxel)

# Generate projection data
print("Generating forward projections...")
proj = tigre.Ax(head, geo, angles)

niter = 30
blocksize = 20

print("\n--- Benchmark 1: Convergence Speed & Time (OS_SART vs Fast_OS_SART) ---")

# Standard OS_SART
print("Running standard OS_SART...")
start_time = time.time()
res_os_sart, err_os_sart = algs.ossart(proj, geo, angles, niter=niter, blocksize=blocksize, computel2=True)
time_os_sart = time.time() - start_time

# Fast OS_SART
print("Running Fast_OS_SART...")
start_time = time.time()
res_fast, err_fast = algs.fast_os_sart(proj, geo, angles, niter=niter, blocksize=blocksize, computel2=True)
time_fast = time.time() - start_time

print(f"OS_SART Total Time: {time_os_sart:.2f}s ({time_os_sart/niter:.3f}s per iteration)")
print(f"Fast_OS_SART Total Time: {time_fast:.2f}s ({time_fast/niter:.3f}s per iteration)")
print(f"Final L2 Error -> OS_SART: {err_os_sart[0][-1]:.4f} | Fast_OS_SART: {err_fast[0][-1]:.4f}")

# Plot convergence
plt.figure(figsize=(8, 5))
plt.plot(err_os_sart[0], label="OS_SART", linewidth=2)
plt.plot(err_fast[0], label="Fast_OS_SART", linewidth=2)
plt.title("Convergence Speed: OS_SART vs Fast_OS_SART")
plt.xlabel("Iteration")
plt.ylabel("L2 Error")
plt.legend()
plt.grid(True)
plt.savefig("convergence_benchmark.png")
print("Saved convergence plot to convergence_benchmark.png")

print("\n--- Benchmarks Complete! ---")
print("You can copy these results into your GitHub PR.")

if __name__ == '__main__':
run_benchmarks()
Loading