Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 51 additions & 59 deletions garf/bin/garf_compare_image_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,106 +14,98 @@
@click.argument("image1_mhd")
@click.argument("image2_mhd")
@click.option(
"--events",
"-e",
default=float(1),
"--scaling",
"-s",
default=1.0,
help="Scale the image2 by this value before comparing",
)
@click.option("--islice", "-s", default=int(64), help="Image slice for the profile")
@click.option(
"--islice", "-i", default=None, help="Image slice for the profile (middle if None)"
)
@click.option("--wslice", "-w", default=int(3), help="Slice width (to smooth)")
def garf_compare_image_profile(image1_mhd, image2_mhd, islice, events, wslice):
@click.option("--output", "-o", default="output.pdf", help="output")
@click.option(
"-rmf",
default=False,
is_flag=True,
help="Remove first slice of ref data (hit slice)",
)
def garf_compare_image_profile(
image1_mhd, image2_mhd, islice, scaling, wslice, rmf, output
):
# Load image
img_ref = sitk.ReadImage(image1_mhd)
img = sitk.ReadImage(image2_mhd)
events = float(events)
islice = int(islice)
scaling = float(scaling)
wslice = int(wslice)

# Scale data to the ref nb of particles
img = img * events
# slice
if islice is None:
islice = int(img.GetSize()[0] / 2)
else:
islice = int(islice)

# Get the pixels values as np array
data_ref = sitk.GetArrayFromImage(img_ref).astype(float)
data = sitk.GetArrayFromImage(img).astype(float)

# Scale data to the ref nb of particles
data = data * scaling

print(f"Reference image shape : {data_ref.shape}")
print(f"Test image shape : {data.shape}")

# Sometimes not same nb of slices -> crop the data_ref
if len(data_ref) > len(data):
data_ref = data_ref[0 : len(data), :, :]

# Remove first slice ?
if rmf:
data_ref = data_ref[1:, :, :]

# Criterion1: global counts in every windows
s_ref = np.sum(data_ref, axis=(1, 2))
s = np.sum(data, axis=(1, 2))
ratio = (s - s_ref) / s_ref * 100.0

print("Ref: Singles/Scatter/Peak1/Peak2: {}".format(s_ref))
print("Img: WARNING/Scatter/Peak1/Peak2: {}".format(s))
print(
"% diff : WARNING/Scatter/Peak1/Peak2: {}".format((s - s_ref) / s_ref * 100.0)
)
# global counts
print(f"Global counts, reference : {s_ref}")
print(f"Global counts, test image: {s}")
print(f"Global counts, % diff : {ratio} %")

# Profiles
# data image: !eee!Z,Y,X
p_ref = np.mean(data_ref[:, islice - wslice : islice + wslice - 1, :], axis=1)
p = np.mean(data[:, islice - wslice : islice + wslice - 1, :], axis=1)
x = np.arange(0, 128, 1)
x = np.arange(0, data.shape[1], 1)

# max
vmax_ref = np.max(p_ref[1:, :])
vmax = np.max(p[1:, :])
print(f"Max value in ref image : {vmax_ref}")
print(f"Max value in test image : {vmax}")

# nb of energy windows
nb_ene = len(data)
print("Nb of energy windows: ", nb_ene)
win = [f"win {i}" for i in np.arange(nb_ene)]

if nb_ene == 3: # Tc99m
win = ["WARNING", "Scatter", "Peak 140"]

if nb_ene == 6: # In111
win = ["WARNING", "Scatter1", "Peak171", "Scatter2", "Scatter3", "Peak245"]

if nb_ene == 7: # Lu177
win = [
"WARNING",
"Scatter1",
"Peak113",
"Scatter2",
"Scatter3",
"Peak208",
"Scatter4",
]

if nb_ene == 8:
win = [
"WARNING",
"Scatter1",
"Peak364",
"Scatter2",
"Scatter3",
"Scatter4",
"Peak637",
"Peak722",
]

fig, ax = plt.subplots(ncols=nb_ene - 1, nrows=1, figsize=(35, 5))

i = 1
vmax = np.max(p_ref[1:, :])
vmax = np.max(p[1:, :])
print("Max value in ref image for the scale : {}".format(vmax))

# figure
fig, ax = plt.subplots(ncols=nb_ene, nrows=1, figsize=(35, 5))
fs = 12

plt.rc("font", size=fs)
while i < nb_ene:
a = ax[i - 1]

for i in range(nb_ene):
a = ax[i]
a.plot(x, p_ref[i], "g", label="Analog", alpha=0.5, linewidth=2.0)
a.plot(x, p[i], "k--", label="ARF", alpha=0.9, linewidth=1.0)
a.set_title(win[i], fontsize=fs + 5)
a.legend(loc="best")
# a.labelsize = 40
a.tick_params(labelsize=fs)
# a.set_ylim([0, vmax])
i += 1

plt.suptitle("Compare " + image1_mhd + " vs " + image2_mhd + " w=" + str(wslice))
plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.savefig("output.pdf")
plt.savefig(output)
plt.show()


Expand Down
4 changes: 3 additions & 1 deletion garf/bin/garf_nn_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def garf_nn_info(filename_pth):

loss_values = p["loss_values"]
x = np.arange(0, len(loss_values))
plt.plot(x, loss_values)
# plt.plot(x, loss_values)
# remove first and last
plt.plot(x[1:-1], loss_values[1:-1])
plt.tight_layout()
plt.show()

Expand Down
5 changes: 0 additions & 5 deletions garf/bin/garf_plot_training_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import garf
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import uproot
import ntpath
import click

# -----------------------------------------------------------------------------
Expand Down
89 changes: 89 additions & 0 deletions garf/bin/garf_plot_training_dataset2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import garf
import matplotlib.pyplot as plt
import numpy as np
import click

# -----------------------------------------------------------------------------
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument("data_file")
@click.option(
"--sample-size", default=5000, help="Number of non-detected (w=0) events to plot."
)
def garf_plot_training_dataset_2d(data_file, sample_size):
"""
\b
Display 2D scatter plots of the training dataset to show relationships
between angles, energy, and the final energy window.

<DATA_FILE> : dataset in root format
"""
print(f"Loading data from '{data_file}'")
data, theta, phi, E, w = garf.load_training_dataset(data_file)
print("Data loaded. Preparing plots...")

# Find the unique window IDs and assign colors
window_ids = np.unique(w)
colors = plt.cm.viridis(np.linspace(0, 1, len(window_ids)))

# Create a 1x3 figure for the plots
fig, ax = plt.subplots(1, 3, figsize=(18, 5.5))

# --- Sub-sample the dominant class for clarity ---
# Separate data by window ID
data_by_window = {win_id: data[w == win_id] for win_id in window_ids}

# Sub-sample the window=0 class if it's too large
if 0 in data_by_window and len(data_by_window[0]) > sample_size:
print(
f"Sub-sampling non-detected (window=0) class from {len(data_by_window[0])} to {sample_size} points for clarity."
)
indices = np.random.choice(
data_by_window[0].shape[0], sample_size, replace=False
)
data_by_window[0] = data_by_window[0][indices]

# --- Create the plots ---
plot_titles = ["Theta vs. Phi", "Energy vs. Theta", "Energy vs. Phi"]
plot_vars = [(1, 0), (0, 2), (1, 2)] # (x_col_idx, y_col_idx) from `data` array

for i, p_ax in enumerate(ax):
for win_id, color in zip(window_ids, colors):
if win_id not in data_by_window:
continue

subset = data_by_window[win_id]
x_data = subset[:, plot_vars[i][0]]
y_data = subset[:, plot_vars[i][1]]

# For Energy plots, convert to keV
if plot_vars[i][1] == 2: # Energy is the Y-axis
y_data = y_data * 1000

p_ax.scatter(
x_data,
y_data,
color=color,
label=f"Window {int(win_id)}",
alpha=0.5,
s=5,
)

p_ax.set_title(plot_titles[i])
p_ax.set_xlabel(f"{['Theta', 'Phi', 'Energy (keV)'][plot_vars[i][0]]}")
p_ax.set_ylabel(f"{['Theta', 'Phi', 'Energy (keV)'][plot_vars[i][1]]}")
p_ax.legend()
p_ax.grid(True, linestyle="--", alpha=0.6)

plt.tight_layout()
plt.show()


# -----------------------------------------------------------------------------
if __name__ == "__main__":
garf_plot_training_dataset_2d()
16 changes: 15 additions & 1 deletion garf/bin/garf_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
@click.argument("data")
@click.argument("output")
@click.option("--progress-bar/--no-progress-bar", default=True)
def garf_train(param, data, output, progress_bar):
@click.option(
"--rr", default=None, help="RR value (overwrite the one in the param file)"
)
@click.option(
"--epoch", default=None, help="Nb of epoch (overwrite the one in the param file)"
)
def garf_train(param, data, output, rr, epoch, progress_bar):
"""
\b
Train a ARF-nn (neural network) from a training dataset.
Expand All @@ -34,6 +40,14 @@ def garf_train(param, data, output, progress_bar):
params = json.loads(param_file)
params["progress_bar"] = progress_bar

# RR ?
if rr is not None:
params["RR"] = float(rr)

# epoch ?
if epoch is not None:
params["epoch_max"] = int(epoch)

# Print info
print("Training dataset", data_filename)
garf.print_training_dataset_info(data, params["RR"])
Expand Down
Loading