diff --git a/garf/bin/garf_compare_image_profile.py b/garf/bin/garf_compare_image_profile.py index 17accb7..51bc8cd 100755 --- a/garf/bin/garf_compare_image_profile.py +++ b/garf/bin/garf_compare_image_profile.py @@ -14,106 +14,98 @@ @click.argument("image1_mhd") @click.argument("image2_mhd") @click.option( - "--events", - "-e", - default=float(1), + "--scaling", + "-s", + default=1.0, help="Scale the image2 by this value before comparing", ) -@click.option("--islice", "-s", default=int(64), help="Image slice for the profile") +@click.option( + "--islice", "-i", default=None, help="Image slice for the profile (middle if None)" +) @click.option("--wslice", "-w", default=int(3), help="Slice width (to smooth)") -def garf_compare_image_profile(image1_mhd, image2_mhd, islice, events, wslice): +@click.option("--output", "-o", default="output.pdf", help="output") +@click.option( + "-rmf", + default=False, + is_flag=True, + help="Remove first slice of ref data (hit slice)", +) +def garf_compare_image_profile( + image1_mhd, image2_mhd, islice, scaling, wslice, rmf, output +): # Load image img_ref = sitk.ReadImage(image1_mhd) img = sitk.ReadImage(image2_mhd) - events = float(events) - islice = int(islice) + scaling = float(scaling) wslice = int(wslice) - # Scale data to the ref nb of particles - img = img * events + # slice + if islice is None: + islice = int(img.GetSize()[0] / 2) + else: + islice = int(islice) # Get the pixels values as np array data_ref = sitk.GetArrayFromImage(img_ref).astype(float) data = sitk.GetArrayFromImage(img).astype(float) + # Scale data to the ref nb of particles + data = data * scaling + + print(f"Reference image shape : {data_ref.shape}") + print(f"Test image shape : {data.shape}") + # Sometimes not same nb of slices -> crop the data_ref if len(data_ref) > len(data): data_ref = data_ref[0 : len(data), :, :] + # Remove first slice ? + if rmf: + data_ref = data_ref[1:, :, :] + # Criterion1: global counts in every windows s_ref = np.sum(data_ref, axis=(1, 2)) s = np.sum(data, axis=(1, 2)) + ratio = (s - s_ref) / s_ref * 100.0 - print("Ref: Singles/Scatter/Peak1/Peak2: {}".format(s_ref)) - print("Img: WARNING/Scatter/Peak1/Peak2: {}".format(s)) - print( - "% diff : WARNING/Scatter/Peak1/Peak2: {}".format((s - s_ref) / s_ref * 100.0) - ) + # global counts + print(f"Global counts, reference : {s_ref}") + print(f"Global counts, test image: {s}") + print(f"Global counts, % diff : {ratio} %") # Profiles - # data image: !eee!Z,Y,X p_ref = np.mean(data_ref[:, islice - wslice : islice + wslice - 1, :], axis=1) p = np.mean(data[:, islice - wslice : islice + wslice - 1, :], axis=1) - x = np.arange(0, 128, 1) + x = np.arange(0, data.shape[1], 1) + + # max + vmax_ref = np.max(p_ref[1:, :]) + vmax = np.max(p[1:, :]) + print(f"Max value in ref image : {vmax_ref}") + print(f"Max value in test image : {vmax}") + # nb of energy windows nb_ene = len(data) print("Nb of energy windows: ", nb_ene) + win = [f"win {i}" for i in np.arange(nb_ene)] - if nb_ene == 3: # Tc99m - win = ["WARNING", "Scatter", "Peak 140"] - - if nb_ene == 6: # In111 - win = ["WARNING", "Scatter1", "Peak171", "Scatter2", "Scatter3", "Peak245"] - - if nb_ene == 7: # Lu177 - win = [ - "WARNING", - "Scatter1", - "Peak113", - "Scatter2", - "Scatter3", - "Peak208", - "Scatter4", - ] - - if nb_ene == 8: - win = [ - "WARNING", - "Scatter1", - "Peak364", - "Scatter2", - "Scatter3", - "Scatter4", - "Peak637", - "Peak722", - ] - - fig, ax = plt.subplots(ncols=nb_ene - 1, nrows=1, figsize=(35, 5)) - - i = 1 - vmax = np.max(p_ref[1:, :]) - vmax = np.max(p[1:, :]) - print("Max value in ref image for the scale : {}".format(vmax)) - + # figure + fig, ax = plt.subplots(ncols=nb_ene, nrows=1, figsize=(35, 5)) fs = 12 - plt.rc("font", size=fs) - while i < nb_ene: - a = ax[i - 1] - + for i in range(nb_ene): + a = ax[i] a.plot(x, p_ref[i], "g", label="Analog", alpha=0.5, linewidth=2.0) a.plot(x, p[i], "k--", label="ARF", alpha=0.9, linewidth=1.0) a.set_title(win[i], fontsize=fs + 5) a.legend(loc="best") - # a.labelsize = 40 a.tick_params(labelsize=fs) - # a.set_ylim([0, vmax]) i += 1 plt.suptitle("Compare " + image1_mhd + " vs " + image2_mhd + " w=" + str(wslice)) plt.tight_layout() plt.subplots_adjust(top=0.85) - plt.savefig("output.pdf") + plt.savefig(output) plt.show() diff --git a/garf/bin/garf_nn_info.py b/garf/bin/garf_nn_info.py index 18256fb..47eef12 100755 --- a/garf/bin/garf_nn_info.py +++ b/garf/bin/garf_nn_info.py @@ -23,7 +23,9 @@ def garf_nn_info(filename_pth): loss_values = p["loss_values"] x = np.arange(0, len(loss_values)) - plt.plot(x, loss_values) + # plt.plot(x, loss_values) + # remove first and last + plt.plot(x[1:-1], loss_values[1:-1]) plt.tight_layout() plt.show() diff --git a/garf/bin/garf_plot_training_dataset.py b/garf/bin/garf_plot_training_dataset.py index 0f0914d..ea6c109 100755 --- a/garf/bin/garf_plot_training_dataset.py +++ b/garf/bin/garf_plot_training_dataset.py @@ -1,13 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import sys import garf import matplotlib.pyplot as plt -from matplotlib import cm -import numpy as np -import uproot -import ntpath import click # ----------------------------------------------------------------------------- diff --git a/garf/bin/garf_plot_training_dataset2.py b/garf/bin/garf_plot_training_dataset2.py new file mode 100755 index 0000000..c33454c --- /dev/null +++ b/garf/bin/garf_plot_training_dataset2.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import garf +import matplotlib.pyplot as plt +import numpy as np +import click + +# ----------------------------------------------------------------------------- +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) + + +@click.command(context_settings=CONTEXT_SETTINGS) +@click.argument("data_file") +@click.option( + "--sample-size", default=5000, help="Number of non-detected (w=0) events to plot." +) +def garf_plot_training_dataset_2d(data_file, sample_size): + """ + \b + Display 2D scatter plots of the training dataset to show relationships + between angles, energy, and the final energy window. + + : dataset in root format + """ + print(f"Loading data from '{data_file}'") + data, theta, phi, E, w = garf.load_training_dataset(data_file) + print("Data loaded. Preparing plots...") + + # Find the unique window IDs and assign colors + window_ids = np.unique(w) + colors = plt.cm.viridis(np.linspace(0, 1, len(window_ids))) + + # Create a 1x3 figure for the plots + fig, ax = plt.subplots(1, 3, figsize=(18, 5.5)) + + # --- Sub-sample the dominant class for clarity --- + # Separate data by window ID + data_by_window = {win_id: data[w == win_id] for win_id in window_ids} + + # Sub-sample the window=0 class if it's too large + if 0 in data_by_window and len(data_by_window[0]) > sample_size: + print( + f"Sub-sampling non-detected (window=0) class from {len(data_by_window[0])} to {sample_size} points for clarity." + ) + indices = np.random.choice( + data_by_window[0].shape[0], sample_size, replace=False + ) + data_by_window[0] = data_by_window[0][indices] + + # --- Create the plots --- + plot_titles = ["Theta vs. Phi", "Energy vs. Theta", "Energy vs. Phi"] + plot_vars = [(1, 0), (0, 2), (1, 2)] # (x_col_idx, y_col_idx) from `data` array + + for i, p_ax in enumerate(ax): + for win_id, color in zip(window_ids, colors): + if win_id not in data_by_window: + continue + + subset = data_by_window[win_id] + x_data = subset[:, plot_vars[i][0]] + y_data = subset[:, plot_vars[i][1]] + + # For Energy plots, convert to keV + if plot_vars[i][1] == 2: # Energy is the Y-axis + y_data = y_data * 1000 + + p_ax.scatter( + x_data, + y_data, + color=color, + label=f"Window {int(win_id)}", + alpha=0.5, + s=5, + ) + + p_ax.set_title(plot_titles[i]) + p_ax.set_xlabel(f"{['Theta', 'Phi', 'Energy (keV)'][plot_vars[i][0]]}") + p_ax.set_ylabel(f"{['Theta', 'Phi', 'Energy (keV)'][plot_vars[i][1]]}") + p_ax.legend() + p_ax.grid(True, linestyle="--", alpha=0.6) + + plt.tight_layout() + plt.show() + + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + garf_plot_training_dataset_2d() diff --git a/garf/bin/garf_train.py b/garf/bin/garf_train.py index 86648c8..7360e54 100755 --- a/garf/bin/garf_train.py +++ b/garf/bin/garf_train.py @@ -16,7 +16,13 @@ @click.argument("data") @click.argument("output") @click.option("--progress-bar/--no-progress-bar", default=True) -def garf_train(param, data, output, progress_bar): +@click.option( + "--rr", default=None, help="RR value (overwrite the one in the param file)" +) +@click.option( + "--epoch", default=None, help="Nb of epoch (overwrite the one in the param file)" +) +def garf_train(param, data, output, rr, epoch, progress_bar): """ \b Train a ARF-nn (neural network) from a training dataset. @@ -34,6 +40,14 @@ def garf_train(param, data, output, progress_bar): params = json.loads(param_file) params["progress_bar"] = progress_bar + # RR ? + if rr is not None: + params["RR"] = float(rr) + + # epoch ? + if epoch is not None: + params["epoch_max"] = int(epoch) + # Print info print("Training dataset", data_filename) garf.print_training_dataset_info(data, params["RR"]) diff --git a/garf/bin/garf_train_xgboost.py b/garf/bin/garf_train_xgboost.py new file mode 100755 index 0000000..82650c0 --- /dev/null +++ b/garf/bin/garf_train_xgboost.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import click +import json +import joblib +import xgboost + +# Assuming 'garf' is a module in your project that contains 'load_training_dataset' +import garf + +# ----------------------------------------------------------------------------- +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) + + +@click.command(context_settings=CONTEXT_SETTINGS) +@click.argument("json_params_file") +@click.argument("data_file") +@click.argument("output_model_file") +def garf_train_xgboost(json_params_file, data_file, output_model_file): + """ + \b + Train an XGBoost model for ARF based on a training dataset. + : Training parameters (H, L, rr, etc) in json format + : Dataset in root format + : Filename for the output trained model (e.g., 'model.joblib') + """ + + # --- 1. Load Data and Parameters --- + print(f"Loading data from '{data_file}'") + data, theta, phi, E, w = garf.load_training_dataset(data_file) + x_train = np.column_stack((theta, phi, E)) + y_train = w.astype(int) # Labels must be integers + # counts how many different unique windows? + n_ene_win = np.max(y_train) + 1 + print(f"Number of ENE win: {n_ene_win}") + + print(f"Loading parameters from '{json_params_file}'") + with open(json_params_file) as f: + params = json.load(f) + rr_factor = 500 # params["RR"] FIXME + print(f"Russian Roulette factor found: {rr_factor}") + + # --- 2. Prepare Data and Metadata --- + print("Preparing data and calculating weights...") + + # Normalize input features (good practice, though XGBoost is less sensitive) + x_mean = np.mean(x_train, 0) + x_std = np.std(x_train, 0) + x_train_normalized = (x_train - x_mean) / x_std + + # Create model_data dictionary to save with the model + model_data = { + "x_mean": x_mean, + "x_std": x_std, + "rr": rr_factor, + "N": len(x_train), + "n_ene_win": n_ene_win, + "model_type": "xgboost", # Add model type for later use + } + + # Calculate sample weights using the robust hybrid method + # Note: XGBoost uses 'sample_weight' (one weight per training sample) + class_counts = np.bincount(y_train) + class_weights = 1.0 / (class_counts + 1e-9) + class_weights = np.ones_like(class_counts) + if rr_factor > 1 and len(class_weights) > 0: + class_weights[0] *= rr_factor + # class_weights = class_weights / np.mean(class_weights) + + # Create the final sample_weight array + sample_weights = class_weights[y_train] + print(f"Calculated class weights: {class_weights}") + + # --- 3. Define and Train XGBoost Model --- + print("Training XGBoost model...") + # These are starter parameters; you can tune them for better performance + model = xgboost.XGBClassifier( + objective="multi:softprob", # Output probabilities for each class + n_estimators=200, # Number of boosting rounds (trees) was 200 + max_depth=8, # Maximum depth of a tree + learning_rate=0.01, # Step size shrinkage + # use_label_encoder=False, + eval_metric="mlogloss", + n_jobs=-1, # Use all available CPU cores + ) + + # Train the model + model.fit(x_train_normalized, y_train, sample_weight=sample_weights) + print("Training complete.") + + # --- 4. Save the Model and Metadata --- + # We save both the trained model and the metadata needed for inference + output_data = {"model": model, "model_data": model_data} + + print(f"Saving trained model and metadata to '{output_model_file}'") + joblib.dump(output_data, output_model_file) + print("Done.") + + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + garf_train_xgboost() diff --git a/garf/garf_detector.py b/garf/garf_detector.py index 006c35d..0818da2 100644 --- a/garf/garf_detector.py +++ b/garf/garf_detector.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- -import numpy as np -import torch from torch import Tensor import itk -from .garf_model import Net_v1 +from .garf_model import * from .helpers import get_gpu_device @@ -40,12 +38,31 @@ def load_nn(filename, verbose=True, gpu_mode="auto"): "Best epoch = {}".format(nn["optim"]["data"][best_epoch_eval]["epoch"]) ) + # which net type ? + model_type = "Net_v1" + if "model_type" in model_data: + model_type = model_data["model_type"] + + # which angle parametrisation ? + # old one = acos acos + # new one = acos atan2 + angle_param = "acos" + if "angle_param" not in model_data: + model_data["angle_param"] = angle_param + # prepare the model state = nn["optim"]["model_state"][best_epoch_eval] H = model_data["H"] n_ene_win = model_data["n_ene_win"] L = model_data["L"] - model = Net_v1(H, L, n_ene_win) + + if model_type == "Net_v1": + model = Net_v1(H, L, n_ene_win) + if model_type == "ResNet_v2": + model = ResNet_v2(H, L, n_ene_win) + if model_type == "MultiTask_v3": + model = MultiTask_v3(H, L, n_ene_win) + model.load_state_dict(state) return nn, model @@ -62,6 +79,7 @@ def __init__(self): self.initial_plane_rotation = None self.radius = None self.hit_slice_flag = False + self.plane_axis = None # computed self.plane_rotations = None @@ -139,6 +157,10 @@ def initialize(self, gantry_rotations): ) exit(-1) + # plane axis + if self.plane_axis is None: + self.plane_axis = [0, 1, 2] + # planes self.initialize_detector_plane_rotations(gantry_rotations) @@ -250,7 +272,7 @@ def arf_plane_init_numpy(self, rotation, nb): def project_to_planes_torch(self, batch, projected_points): for i, detector_plane in enumerate(self.detector_planes): - projected_batch = detector_plane.plane_intersection(batch) + projected_batch = detector_plane.plane_intersection_torch(batch) self.build_image_from_projected_points_torch(projected_batch, i) def project_to_planes_numpy(self, batch, i, planes, projected_points, data_img): @@ -425,6 +447,7 @@ def image_from_coordinates_add_torch_hit_slice(self, img, vu): class GarfDetectorPlane: def __init__(self, garf_detector, center, rotation): self.garf_detector = garf_detector + self.plane_axis = self.garf_detector.plane_axis self.M = self.rotation_to_tensor(rotation) self.Mt = self.M.t() self.center = center @@ -444,7 +467,7 @@ def rotation_to_tensor(self, m): t = t.to(self.garf_detector.current_gpu_device) return t - def plane_intersection(self, batch): + def plane_intersection_torch(self, batch): # See arf_plane_intersection # get energy, position and direction @@ -486,7 +509,6 @@ def plane_intersection(self, batch): # two first coord pos_xy_rot = pos_xyz_rot[:, 0:2] - dir_xy_rot = dir_xyz_rot[:, 0:2] s = self.garf_detector.image_size_mm indexes_to_keep = torch.where( @@ -495,11 +517,17 @@ def plane_intersection(self, batch): )[0] # convert direction into theta/phi - # theta is acos(dy) - # phi is acos(dx) - nb = len(dir_xy_rot) - theta = torch.rad2deg(torch.arccos(dir_xy_rot[:, 1])).reshape((nb, 1)) - phi = torch.rad2deg(torch.arccos(dir_xy_rot[:, 0])).reshape((nb, 1)) + nb = len(dir_xyz_rot) + d_x_plane = dir_xyz_rot[:, self.plane_axis[0]] + d_y_plane = dir_xyz_rot[:, self.plane_axis[1]] + d_z_plane = dir_xyz_rot[:, self.plane_axis[2]] + theta = torch.rad2deg(torch.arccos(d_z_plane)).reshape((nb, 1)) + phi = torch.rad2deg(torch.arctan2(d_y_plane, d_x_plane)).reshape((nb, 1)) + + # FIXME previous angle parametrisation ? + # theta = torch.rad2deg(torch.arccos(dir_xy_rot[:, 1])).reshape((nb, 1)) + # phi = torch.rad2deg(torch.arccos(dir_xy_rot[:, 0])).reshape((nb, 1)) + angles = torch.concat((theta, phi), dim=1) batch = torch.concat( @@ -514,9 +542,6 @@ def plane_intersection(self, batch): return batch -# noinspection PyUnreachableCode - - def normalize_logproba(x): """ Convert un-normalized log probabilities to normalized ones (0-100%) @@ -581,7 +606,7 @@ def nn_predict_numpy(model, model_data, x): # apply input model normalisation x = (x - x_mean) / x_std - # gpu ? (usually not) + # gpu? if "current_gpu_device" not in model_data: current_gpu_mode, current_gpu_device = get_gpu_device(gpu_mode="auto") model_data["current_gpu_device"] = current_gpu_device @@ -595,9 +620,15 @@ def nn_predict_numpy(model, model_data, x): # predict values vy_pred = model(vx) + # Apply correction to the logits BEFORE softmax + # Adding log(rr) to a logit is equivalent to multiplying its + # probability by rr after exponentiation. + if rr > 1: + vy_pred[:, 0] += np.log(rr) + # convert to numpy and normalize probabilities y_pred = normalize_logproba(vy_pred.data) - y_pred = normalize_proba_with_russian_roulette(y_pred, 0, rr) + # y_pred = normalize_proba_with_russian_roulette(y_pred, 0, rr) y_pred = y_pred.cpu().numpy() y_pred = y_pred.astype(np.float64) @@ -605,6 +636,74 @@ def nn_predict_numpy(model, model_data, x): return y_pred +def nn_predict_numpy_multitask(model, model_data, x): + """ + Apply the Multi-Task NN to predict y from x. + This version correctly handles the two-headed output. + """ + # Apply input normalization (same as before) + x_mean = model_data["x_mean"] + x_std = model_data["x_std"] + x = (x - x_mean) / x_std + + # Set device and model (same as before) + device = model_data.get("current_gpu_device", torch.device("cpu")) + model.to(device) + model.eval() # Set model to evaluation mode + + # Torch encapsulation + x = x.astype("float32") + vx = torch.from_numpy(x).to(device) + + # --- New Multi-Task Prediction Logic --- + + with torch.no_grad(): # Disable gradient calculation for inference + # 1. Get the two outputs from the model + acceptance_logit, energy_logits = model(vx) + + # 2. Calculate probability of detection using the sigmoid function + # This is P(detection) + p_acceptance = torch.sigmoid(acceptance_logit) + + # 3. Calculate conditional probabilities for each energy window using softmax + # This is P(window k | detected) + p_energy_windows = torch.softmax(energy_logits, dim=1) + + # 4. Combine the probabilities to get the final prediction + p_acceptance = p_acceptance.cpu().numpy() + p_energy_windows = p_energy_windows.cpu().numpy() + + # The number of final windows is 1 (for non-detected) + number of energy heads + n_samples = x.shape[0] + n_total_windows = p_energy_windows.shape[1] + 1 + y_pred = np.zeros((n_samples, n_total_windows), dtype=np.float64) + + # Probability of non-detection (window 0) is 1 - P(detection) + y_pred[:, 0] = 1.0 - p_acceptance.flatten() + + # Probability of a given detected window k is P(detection) * P(window k | detected) + for k in range(p_energy_windows.shape[1]): + y_pred[:, k + 1] = p_acceptance.flatten() * p_energy_windows[:, k] + + return y_pred + + +def xgb_predict_numpy(model, model_data, x): + """ + Apply a trained XGBoost model to predict y from x. + """ + # Apply the same normalization used during training + x_mean = model_data["x_mean"] + x_std = model_data["x_std"] + x_normalized = (x - x_mean) / x_std + + # Predict probabilities directly; no RR correction needed + # as the model was trained with weighted loss. + y_pred = model.predict_proba(x_normalized) + + return y_pred.astype(np.float64) + + def compute_angle_offset_torch(angles, length): """ compute the x,y offset according to the angle @@ -614,6 +713,8 @@ def compute_angle_offset_torch(angles, length): cos_theta = torch.cos(angles_rad[:, 0]) cos_phi = torch.cos(angles_rad[:, 1]) + print("FIXME acos !! ") + # see in Gate_NN_ARF_Actor, line "phi = acos(dir.x())/degree;" tx = length * cos_phi # see in Gate_NN_ARF_Actor, line "theta = acos(dir.y())/degree;" @@ -623,23 +724,6 @@ def compute_angle_offset_torch(angles, length): return t -def compute_angle_offset_numpy(angles, length): - """ - compute the x,y offset according to the angle - """ - angles_rad = np.deg2rad(angles) - cos_theta = np.cos(angles_rad[:, 0]) - cos_phi = np.cos(angles_rad[:, 1]) - - # see in Gate_NN_ARF_Actor, line "phi = acos(dir.x())/degree;" - tx = length * cos_phi - # see in Gate_NN_ARF_Actor, line "theta = acos(dir.y())/degree;" - ty = length * cos_theta - t = np.column_stack((tx, ty)) - - return t - - def normalize_proba_with_russian_roulette(w_pred, channel, rr): """ Consider rr times the values for the energy windows channel @@ -711,7 +795,7 @@ def image_from_coordinates_add_numpy(img, u, v, w_pred, hit_slice=False): img[i, uv16Bins[chx > tiny, 0], uv16Bins[chx > tiny, 1]] += chx[chx > tiny] -def arf_plane_intersection(batch, plane, image_plane_size_mm): +def arf_plane_intersection(batch, plane, image_plane_size_mm, plane_axis): """ Project the x points (Ekine X Y Z dX dY dZ) on the image plane defined by plane_U, plane_V, plane_center, plane_normal @@ -792,14 +876,18 @@ def arf_plane_intersection(batch, plane, image_plane_size_mm): # two first coord of dir dx = dir_xyz_rot[:, 0] dy = dir_xyz_rot[:, 1] + dz = dir_xyz_rot[:, 2] # FIXME -> clip arcos -1;1 ? # convert direction into theta/phi - # theta is acos(dy) - # phi is acos(dx) - theta = np.degrees(np.arccos(dy)).reshape((nb, 1)) - phi = np.degrees(np.arccos(dx)).reshape((nb, 1)) + dirs = np.stack((dx, dy, dz), axis=-1) + d_x_plane = dirs[:, plane_axis[0]] + d_y_plane = dirs[:, plane_axis[1]] + d_z_plane = dirs[:, plane_axis[2]] + theta = np.degrees(np.arccos(d_z_plane)).reshape((nb, 1)) + phi = np.degrees(np.arctan2(d_y_plane, d_x_plane)).reshape((nb, 1)) + y = np.concatenate((y, theta), axis=1) y = np.concatenate((y, phi), axis=1) @@ -809,112 +897,69 @@ def arf_plane_intersection(batch, plane, image_plane_size_mm): return batch -def arf_from_points_to_image_counts_OLD( - projected_batch, # 5D: 2 plane coordinates, 2 angles, 1 energy - model, # ARF neural network model - model_data, # associated model data - distance_to_crystal, # from detection plane to crystal center - image_plane_size_mm, # image plane in mm - image_plane_size_pixel, # image plane in pixel - image_plane_spacing, -): # image plane spacing - """ - Input : position, direction on the detector plane, energy - Compute - - garf.nn_predict - - garf.compute_angle_offset - - garf.remove_out_of_image_boundaries2 - - Used in 1) GarfDetector class and 2) gate ARFActor - - """ - - # get the two angles and the energy - ax = projected_batch[:, 2:5] - - # predict weights - w_pred = nn_predict_numpy(model, model_data, ax) - - # Get the two first columns = points coordinates - cx = projected_batch[:, 0:2] - - # Get the two next columns = angles - angles = projected_batch[:, 2:4] - - # Take angle into account: consider position at collimator + half crystal - t = compute_angle_offset_numpy(angles, distance_to_crystal) - cx = cx + t - - # convert coord to pixel - coord = ( - cx + image_plane_size_mm / 2 - image_plane_spacing / 2 - ) / image_plane_spacing - coord = np.around(coord).astype(int) - - # why vu and not uv ? - v = coord[:, 0] - u = coord[:, 1] - - # remove points outside the image - u, v, w_pred = remove_out_of_image_boundaries_numpy( - u, v, w_pred, image_plane_size_pixel - ) - - return u, v, w_pred - - def arf_from_points_to_image_counts( - projected_batch, # 5D: 2 plane coordinates, 2 angles, 1 energy, 1 weight - model, # ARF neural network model - model_data, # associated model data - distance_to_crystal, # from detection plane to crystal center - image_plane_size_mm, # image plane in mm - image_plane_size_pixel, # image plane in pixel + projected_batch, + model, + model_data, + distance_to_crystal, + image_plane_size_mm, + image_plane_size_pixel, image_plane_spacing, -): # image plane spacing +): """ - Input : position, direction on the detector plane, energy - Compute - - garf.nn_predict - - garf.compute_angle_offset - - garf.remove_out_of_image_boundaries2 - - Used in 1) GarfDetector class and 2) gate ARFActor - + Input: position, direction on the detector plane, energy. + This version is model-agnostic (PyTorch or XGBoost). """ - # get the two angles and the energy - ax = projected_batch[:, 2:5] + # Get directions and energy for model input + dirs = projected_batch[:, 2:5] + energy = projected_batch[:, 5:6] - # predict weights - w_pred = nn_predict_numpy(model, model_data, ax) + # Calculate angles from directions + if model_data["angle_param"] == "atan2": + theta = np.degrees(np.arccos(np.clip(dirs[:, 2], -1, 1))) + phi = np.degrees(np.arctan2(dirs[:, 1], dirs[:, 0])) + else: + theta = np.degrees(np.arccos(dirs[:, 1])) + phi = np.degrees(np.arccos(dirs[:, 0])) + + # Assemble input (theta, phi, E) + ax = np.column_stack((theta, phi, energy)) + + # Conditionally call the correct prediction function + model_type = model_data.get("model_type", "Net_v1") + if model_type == "xgboost": + w_pred = xgb_predict_numpy(model, model_data, ax) + elif model_type == "MultiTask_v3": + w_pred = nn_predict_numpy_multitask(model, model_data, ax) + else: + w_pred = nn_predict_numpy(model, model_data, ax) - # particle weight ? - if projected_batch.shape[1] == 6: - weights = projected_batch[:, 5] + # Apply particle weights from the simulation if they exist + if projected_batch.shape[1] == 7: + weights = projected_batch[:, 6] w_pred = w_pred * weights[:, np.newaxis] - # Get the two first columns = points coordinates + # Get initial position (px, py) cx = projected_batch[:, 0:2] - # Get the two next columns = angles - angles = projected_batch[:, 2:4] + # Correctly project the trajectory over distance_to_crystal + dir_z = dirs[:, 2] + mask = np.abs(dir_z) > 1e-9 + offset = np.zeros_like(cx) + offset[mask, 0] = distance_to_crystal * (dirs[mask, 0] / dir_z[mask]) + offset[mask, 1] = distance_to_crystal * (dirs[mask, 1] / dir_z[mask]) + final_pos = cx + offset - # Take angle into account: consider position at collimator + half crystal - t = compute_angle_offset_numpy(angles, distance_to_crystal) - cx = cx + t - - # convert coord to pixel + # Convert mm coordinates to pixel indices coord = ( - cx + image_plane_size_mm / 2 - image_plane_spacing / 2 + final_pos + image_plane_size_mm / 2 - image_plane_spacing / 2 ) / image_plane_spacing coord = np.around(coord).astype(int) - # why vu and not uv ? + # Separate into u, v coordinates and remove out-of-bounds points v = coord[:, 0] u = coord[:, 1] - - # remove points outside the image u, v, w_pred = remove_out_of_image_boundaries_numpy( u, v, w_pred, image_plane_size_pixel ) diff --git a/garf/garf_model.py b/garf/garf_model.py index 73900ce..ea9c95b 100644 --- a/garf/garf_model.py +++ b/garf/garf_model.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn +import numpy as np +import torch.nn.functional as F class Net_v1(nn.Module): @@ -38,3 +40,144 @@ def forward(self, X): X = torch.clamp(X, min=0) # relu X = self.fc3(X) # output layer return X + + +class ResidualBlock(nn.Module): + """A simple residual block with two linear layers.""" + + def __init__(self, H): + super().__init__() + self.linear1 = nn.Linear(H, H) + self.linear2 = nn.Linear(H, H) + + def forward(self, x): + # Calculate the residual + residual = self.linear1(x) + residual = torch.clamp(residual, min=0) # ReLU + residual = self.linear2(residual) + + # Add the input to the residual (the "skip connection") + # and apply the final activation for this block + out = x + residual + out = torch.clamp(out, min=0) # ReLU + return out + + +class ResNet_v2(nn.Module): + """A ResNet-style architecture for the ARF problem.""" + + def __init__(self, H, L, n_ene_win): + super().__init__() + # Initial layer to project input from 3 dimensions to H dimensions + self.fc1 = nn.Linear(3, H) + + # A series of residual blocks + self.residual_layers = nn.ModuleList([ResidualBlock(H) for _ in range(L)]) + + # Final output layer + self.output_layer = nn.Linear(H, n_ene_win) + + def forward(self, x): + # Pass through the input layer and apply first activation + x = self.fc1(x) + x = torch.clamp(x, min=0) # ReLU + + # Pass through all the residual blocks + for layer in self.residual_layers: + x = layer(x) + + # Final prediction + x = self.output_layer(x) + return x + + +class MultiTask_v3(nn.Module): + """ + A multi-task ResNet architecture for the ARF problem. + It has two output heads: one for acceptance and one for energy window classification. + """ + + def __init__(self, H, L, n_energy_windows): + super().__init__() + # Shared backbone + self.fc1 = nn.Linear(3, H) + self.residual_layers = nn.ModuleList([ResidualBlock(H) for _ in range(L)]) + + # Head 1: Predicts detection vs. non-detection (1 output logit) + self.acceptance_head = nn.Sequential( + nn.Linear(H, H // 2), nn.ReLU(), nn.Linear(H // 2, 1) + ) + + # Head 2: Predicts which detected window (n_energy_windows-1 outputs) + # We subtract 1 because we don't need to predict the "non-detected" class here. + self.energy_head = nn.Sequential( + nn.Linear(H, H // 2), nn.ReLU(), nn.Linear(H // 2, n_energy_windows - 1) + ) + + def forward(self, x): + # Pass through the shared backbone + x = self.fc1(x) + x = torch.clamp(x, min=0) # ReLU + for layer in self.residual_layers: + x = layer(x) + + # Get predictions from each head + acceptance_logit = self.acceptance_head(x) + energy_logits = self.energy_head(x) + + return acceptance_logit, energy_logits + + +class MultiTask_v3_loss: + + def __init__(self, y_train, rr_factor, current_gpu_device): + print("init") + self.current_gpu_device = current_gpu_device + + # Get class weights for the acceptance loss (same hybrid method as before) + class_counts = np.bincount(y_train) + # Create binary weights: weight for non-detected (0) vs. detected (1) + weight_0 = 1.0 / (class_counts[0] + 1e-9) + weight_1 = 1.0 / (np.sum(class_counts[1:]) + 1e-9) + self.acceptance_weights = torch.tensor( + [weight_0, weight_1], dtype=torch.float + ).to(current_gpu_device) + if rr_factor > 1: + self.acceptance_weights[0] *= rr_factor + self.acceptance_weights /= torch.mean(self.acceptance_weights) + print(f"Acceptance loss weights: {self.acceptance_weights.cpu().numpy()}") + + def loss(self, Y_out, Y_true): + # Forward pass - model now returns two outputs + acceptance_logit, energy_logits = Y_out + + # --- Custom Loss Calculation --- + + # 1. Acceptance Loss (Binary) + # Create binary target: 0 if non-detected, 1 if detected + Y_binary_true = (Y_true > 0).float().view(-1, 1) + # Use BCEWithLogitsLoss which is numerically stable and takes class weights + # We need to compute pos_weight for the binary case + pos_weight = torch.tensor( + [self.acceptance_weights[1] / self.acceptance_weights[0]], + device=self.current_gpu_device, + ) + acceptance_loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + loss1 = acceptance_loss_fn(acceptance_logit, Y_binary_true) + + # 2. Energy Loss (Categorical, only for detected events) + detected_mask = Y_true > 0 + if torch.sum(detected_mask) > 0: + # We subtract 1 from the labels because energy_head has (n-1) outputs + # e.g., window 1 -> class 0, window 2 -> class 1 + Y_energy_true = Y_true[detected_mask] - 1 + energy_logits_detected = energy_logits[detected_mask] + loss2 = F.cross_entropy(energy_logits_detected, Y_energy_true) + else: + # If no detected events in this batch, loss is zero + loss2 = 0.0 + + # 3. Total Loss + loss = loss1 + loss2 # You can weight these, e.g., loss1 + 0.5 * loss2 + + return loss diff --git a/garf/garf_train.py b/garf/garf_train.py index c1bde60..59a9458 100644 --- a/garf/garf_train.py +++ b/garf/garf_train.py @@ -6,7 +6,7 @@ from torch import Tensor import copy from tqdm import tqdm -from .garf_model import Net_v1 +from .garf_model import * from .helpers import get_gpu_device @@ -49,6 +49,10 @@ def nn_prepare_data(x_train, y_train, params): model_data["x_std"] = x_std model_data["N"] = N + # this flag indicate that we use the new version of angle parametrisation + # (acos + atan2 instead of acos + acos) + model_data["angle_param"] = "atan2" + # copy param except comments for i in params: if not i[0] == "#": @@ -70,13 +74,13 @@ def nn_get_optimiser(model_data, model): # decreasing learning_rate scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, "min", verbose=False, patience=50 + optimizer, "min", patience=50 ) return optimizer, scheduler -def train_nn(x_train, y_train, params): +def train_nn_OLD_CORRECT(x_train, y_train, params): """ Train the ARF neural network. @@ -259,3 +263,314 @@ def train_nn(x_train, y_train, params): model_data["best_loss"] = best_loss return nn + + +def train_nn(x_train, y_train, params): + """ + Train the ARF neural network. + + x_train -- x samples (3 dimensions: theta, phi, E) + y_train -- output probabilities vector for N energy windows + params -- dictionary of parameters and options + + params contains: + - n_ene_win + - batch_size + - batch_per_epoch + - epoch_store_every + - H + - L + - epoch_max + - early_stopping + - gpu_mode : auto cpu gpu + """ + + # Initialization + x_train, y_train, model_data, N = nn_prepare_data(x_train, y_train, params) + + # One-hot encoding + print("One-hot encoding") + y_vals, y_train = np.unique(y_train, return_inverse=True) + n_ene_win = len(y_vals) + print("Number of energy windows:", n_ene_win) + model_data["n_ene_win"] = n_ene_win + + # Device type + current_gpu_mode, current_gpu_device = get_gpu_device(params["gpu_mode"]) + model_data["current_gpu_mode"] = current_gpu_mode + print(f"Device GPU type is {current_gpu_mode}") + + # Batch parameters + batch_size = model_data["batch_size"] + epoch_store_every = model_data["epoch_store_every"] + + # DataLoader + print("Data loader batch_size", batch_size) + train_data2 = np.column_stack((x_train, y_train)) + if current_gpu_mode == "mps": + print("With device mps (gpu), convert data to float32", train_data2.dtype) + train_data2 = train_data2.astype(np.float32) + + train_loader2 = DataLoader( + train_data2, + batch_size=batch_size, + num_workers=8, + pin_memory=True, + # shuffle=True, # if false ~20% faster, seems identical + shuffle=False, # if false ~20% faster, seems identical + drop_last=True, + ) + + # Create the main NN + H = model_data["H"] + L = model_data["L"] + rr_factor = params["RR"] + + model_type = "Net_v1" + loss_function = F.cross_entropy + if "model_type" in model_data: + model_type = model_data["model_type"] + if model_type == "Net_v1": + model = Net_v1(H, L, n_ene_win) + if model_type == "ResNet_v2": + model = ResNet_v2(H, L, n_ene_win) + if model_type == "MultiTask_v3": + model = MultiTask_v3(H, L, n_ene_win) + l = MultiTask_v3_loss(y_train, rr_factor, current_gpu_device) + loss_function = l.loss + + # Create the optimizer + # optimizer, scheduler = nn_get_optimiser(model_data, model) + learning_rate = model_data["learning_rate"] + # optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) + optimizer = torch.optim.Adam( + model.parameters(), + lr=learning_rate, + weight_decay=1e-4, + ) + # decreasing learning_rate + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=3) + + # Main loop initialization + epoch_max = model_data["epoch_max"] + early_stopping = model_data["early_stopping"] + best_loss = np.Inf + best_epoch = 0 + best_train_loss = np.Inf + loss_values = np.zeros(epoch_max + 1) + + # Print parameters + print_nn_params(model_data) + + # create main structures + nn = dict() + nn["model_data"] = model_data + nn["optim"] = dict() + nn["optim"]["model_state"] = [] + nn["optim"]["data"] = [] + previous_best = 9999 + best_epoch_index = 0 + + # set the model to the device (cpu or gpu = cuda or mps) + model.to(current_gpu_device) + + # Main loop + print("\nStart learning ...") + pbar = tqdm(total=epoch_max + 1, disable=not params["progress_bar"]) + epoch = 0 + stop = False + while (not stop) and (epoch < epoch_max): + # Train pass + model.train() + train_loss = 0.0 + n_samples_processed = 0 + + # Loop on the data batch (batch_per_epoch times) + for batch_idx, data in enumerate(train_loader2): + x = data[:, 0:3] + y = data[:, 3] + X = Tensor(x.to(model.fc1.weight.dtype)).to(current_gpu_device) + Y = Tensor(y).to(current_gpu_device).long() + + # Forward pass + Y_out = model(X) + + # Compute expected loss + # combines log_softmax and nll_loss in a single function + loss = loss_function(Y_out, Y) + + # Backward pass + loss.backward() + + # Parameter update (gradient descent) + optimizer.step() + optimizer.zero_grad() + batch_size = X.shape[0] # important with variable batch sizes + train_loss += loss.data.item() * batch_size + n_samples_processed += batch_size + + # end for loop train_loader + + # end of train + train_loss /= n_samples_processed + if train_loss < best_train_loss * (1 - 1e-4): + best_train_loss = train_loss + mean_loss = train_loss + + loss_values[epoch] = mean_loss + if mean_loss < best_loss * (1 - 1e-4): + best_loss = mean_loss + best_epoch = epoch + elif epoch - best_epoch > early_stopping: + tqdm.write( + "{} epochs without improvement, early stop.".format(early_stopping) + ) + stop = True + + # scheduler for learning rate + scheduler.step(mean_loss) + + # FIXME WRONG + # Check if need to print and store this epoch + if best_train_loss < previous_best: + tqdm.write("Epoch {} loss is {:.5f}".format(epoch, best_loss)) + previous_best = best_train_loss + + if ( + (epoch != 0 and epoch % epoch_store_every == 0) + or stop + or epoch >= epoch_max - 1 + ): + optim_data = dict() + print("Store weights", epoch) + optim_data["epoch"] = epoch + optim_data["train_loss"] = train_loss + state = model.state_dict() + nn["optim"]["model_state"].append(state) + nn["optim"]["data"].append(optim_data) + best_epoch_index = epoch + + # update progress bar + pbar.update(1) + epoch = epoch + 1 + + # end for loop + print("Training done. Best = {:.5f} at epoch {:.0f}".format(best_loss, best_epoch)) + + # prepare data to be saved + model_data["loss_values"] = loss_values + model_data["final_epoch"] = epoch + model_data["best_epoch"] = best_epoch + model_data["best_epoch_index"] = best_epoch_index + model_data["best_loss"] = best_loss + + return nn + + +def train_nn_TEST(x_train, y_train, params): + """ + An optimized function to train the ARF neural network. + """ + # 1. PREPARE DATA AND MODEL PARAMETERS + # =================================================================== + print("Preparing data and model parameters...") + x_train, y_train, model_data, N = nn_prepare_data(x_train, y_train, params) + y_vals, y_train = np.unique(y_train, return_inverse=True) + model_data["n_ene_win"] = len(y_vals) + current_gpu_mode, current_gpu_device = get_gpu_device(params["gpu_mode"]) + model_data["current_gpu_mode"] = current_gpu_mode + print_nn_params(model_data) + + # 2. CREATE EFFICIENT DATALOADER + # =================================================================== + # This is the most efficient method for small, in-memory datasets. + print("Creating PyTorch TensorDataset and DataLoader...") + + # Convert NumPy arrays to PyTorch Tensors ONCE. + x_tensor = torch.from_numpy(x_train.astype(np.float32)) + y_tensor = torch.from_numpy(y_train.astype(np.int64)) + + # Create a TensorDataset + train_dataset = torch.utils.data.TensorDataset(x_tensor, y_tensor) + + # Use DataLoader with num_workers=0 (no multiprocessing overhead) + # and pin_memory=True (for faster CPU to CUDA transfer). + train_loader = DataLoader( + dataset=train_dataset, + batch_size=model_data["batch_size"], + num_workers=0, + pin_memory=True, + shuffle=True, # Shuffle is good practice for training + drop_last=True, + ) + + # 3. SETUP MODEL, OPTIMIZER, and LOSS + # =================================================================== + print("Setting up model, optimizer, and loss function...") + model = Net_v1(model_data["H"], model_data["L"], model_data["n_ene_win"]) + model.to(current_gpu_device) + + optimizer = torch.optim.Adam(model.parameters(), lr=model_data["learning_rate"]) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5) + + # (Optional, but recommended) - Use the weighted loss for better results + # class_counts = np.bincount(y_train) + # class_weights = 1.0 / (class_counts + 1e-9) + # class_weights[0] *= model_data["RR"] + # class_weights = class_weights / np.mean(class_weights) + # class_weights = torch.tensor(class_weights, dtype=torch.float).to(current_gpu_device) + + # 4. MAIN TRAINING LOOP + # =================================================================== + print("\nStarting optimized training...") + nn = {"model_data": model_data, "optim": {"model_state": [], "data": []}} + best_loss = np.Inf + best_epoch = -1 + epochs_no_improve = 0 + pbar = tqdm(range(model_data["epoch_max"]), disable=not params["progress_bar"]) + for epoch in pbar: + model.train() + total_loss = 0 + + # An epoch is now one full pass over the entire dataset + for x_batch, y_batch in train_loader: + # Move data to GPU + X = x_batch.to(current_gpu_device) + Y = y_batch.to(current_gpu_device) + + # Standard forward/backward pass + optimizer.zero_grad() + Y_out = model(X) + + # Use unweighted loss as requested (or uncomment weighted version) + loss = F.cross_entropy(Y_out, Y) + # loss = F.cross_entropy(Y_out, Y, weight=class_weights) + + loss.backward() + optimizer.step() + total_loss += loss.item() + + # After each epoch, calculate average loss + avg_loss = total_loss / len(train_loader) + scheduler.step(avg_loss) + pbar.set_description(f"Epoch {epoch + 1}, Loss: {avg_loss:.5f}") + + # Check for improvement (for early stopping) + if avg_loss < best_loss: + best_loss = avg_loss + best_epoch = epoch + epochs_no_improve = 0 + # Store the best model state efficiently + nn["optim"]["model_state"] = [model.state_dict()] + nn["optim"]["data"] = [{"epoch": epoch, "train_loss": avg_loss}] + nn["model_data"]["best_epoch"] = epoch + nn["model_data"]["best_loss"] = best_loss + else: + epochs_no_improve += 1 + + if epochs_no_improve >= model_data["early_stopping"]: + print(f"\nEarly stopping triggered after {epoch + 1} epochs.") + break + + print(f"\nTraining done. Best loss = {best_loss:.5f} at epoch {best_epoch + 1}") + return nn diff --git a/garf/helpers.py b/garf/helpers.py index 8ef75fa..6b45252 100644 --- a/garf/helpers.py +++ b/garf/helpers.py @@ -3,6 +3,8 @@ import os import uproot import torch +import SimpleITK as sitk +import matplotlib.pyplot as plt def load_training_dataset(filename): @@ -43,7 +45,7 @@ def load_training_dataset(filename): return data, theta, phi, E, w -def print_training_dataset_info(data, rr=40): +def print_training_dataset_info(data, rr): """ Print training dataset information """ @@ -225,3 +227,61 @@ def get_gpu_device(gpu_mode): return get_gpu_device("cpu") return current_gpu_mode, current_gpu_device + + +def plot_spect_projection( + image1_mhd, image2_mhd, scaling=1, islice=None, wslice=1, win_labels=None +): + # Load image + img_ref = sitk.ReadImage(image1_mhd) + img = sitk.ReadImage(image2_mhd) + scaling = float(scaling) + wslice = int(wslice) + + # slice + if islice is None: + islice = int(img.GetSize()[0] / 2) + else: + islice = int(islice) + + # Get the pixels values as np array + data_ref = sitk.GetArrayFromImage(img_ref).astype(float) + data = sitk.GetArrayFromImage(img).astype(float) + + # Scale data to the ref nb of particles + data = data * scaling + + # Profiles + p_ref = np.mean(data_ref[:, islice - wslice : islice + wslice - 1, :], axis=1) + p = np.mean(data[:, islice - wslice : islice + wslice - 1, :], axis=1) + x = np.arange(0, data.shape[1], 1) + + # nb of energy windows + nb_ene = len(data) + if win_labels is None: + win_labels = [f"win {i}" for i in np.arange(nb_ene)] + + # Criterion1: global counts in every windows + s_ref = np.sum(data_ref, axis=(1, 2)) + s = np.sum(data, axis=(1, 2)) + ratio = (s - s_ref) / s_ref * 100.0 + + # figure + fig, ax = plt.subplots(ncols=nb_ene, nrows=1, figsize=(35, 5)) + fs = 12 + plt.rc("font", size=fs) + for i in range(nb_ene): + a = ax[i] + a.plot(x, p_ref[i], "g", label="Analog", alpha=0.5, linewidth=2.0) + a.plot(x, p[i], "k--", label="ARF", alpha=0.9, linewidth=1.0) + a.set_title(win_labels[i], fontsize=fs + 5) + a.legend(loc="best") + a.tick_params(labelsize=fs) + i += 1 + + plt.suptitle( + f"Slice ={islice}, w={wslice} {image1_mhd} {image2_mhd} global diff = {ratio}" + ) + plt.tight_layout() + plt.subplots_adjust(top=0.85) + return plt diff --git a/pyproject.toml b/pyproject.toml index 955e883..c06d393 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ [project.scripts] garf_compare_image_profile = "garf.bin.garf_compare_image_profile:garf_compare_image_profile" garf_plot_training_dataset = "garf.bin.garf_plot_training_dataset:garf_plot_training_dataset" +garf_plot_training_dataset2 = "garf.bin.garf_plot_training_dataset:garf_plot_training_dataset2" garf_plot_test_dataset = "garf.bin.garf_plot_test_dataset:garf_plot_training_dataset" garf_train = "garf.bin.garf_train:garf_train" garf_nn_info = "garf.bin.garf_nn_info:garf_nn_info"