Graph Neural Network for GlueX forward calo data - Tutorial#

# %pip -q install uproot awkward torch-geometric scikit-learn safetensors tqdm > /dev/null
%pip install -q uproot awkward scikit-learn safetensors tqdm
%pip install -q torch-geometric
%pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.10.0+cu128.html
import torch
import torch_scatter
import torch_geometric

print(torch.__version__)
print(torch_scatter.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA:", torch.version.cuda)
!uname -a
!nvidia-smi

1. Data#

Download unformatted particle-gun ROOT (for GNN)#

import os
from urllib.parse import urlparse
import urllib.request
from tqdm import tqdm

class DownloadProgressBar(tqdm):
    """Custom TQDM progress bar for urllib downloads."""
    def update_to(self, blocks=1, block_size=1, total_size=None):
        """
        Update the progress bar.

        Args:
            blocks (int): Number of blocks transferred so far.
            block_size (int): Size of each block (in bytes).
            total_size (int, optional): Total size of the file (in bytes).
        """
        if total_size is not None:
            self.total = total_size
        self.update(blocks * block_size - self.n)


def download(url, target_dir):
    """
    Download a file from a URL into the target directory with progress display.

    Args:
        url (str): Direct URL to the file.
        target_dir (str): Directory to save the file.

    Returns:
        str: Path to the downloaded (or existing) file.
    """
    # Ensure the target directory exists
    os.makedirs(target_dir, exist_ok=True) #do nothing if target_dir exists

    # Infer the filename from the URL
    filename = os.path.basename(urlparse(url).path) # parse the url (scheme, netloc, path) and take the name of the file
    local_path = os.path.join(target_dir, filename)

    # If file already exists, skip download
    if os.path.exists(local_path):
        print(f"\n✅ File already exists: {local_path}\n")
        return local_path

    # Download with progress bar
    print(f"\n⬇️  Downloading {filename} from {url}\n")

    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t: #miniters controls how frequently tqdm updates the bar
        urllib.request.urlretrieve(url, filename=local_path, reporthook=t.update_to)

    print(f"\n✅ Download complete: {local_path}\n")
    return local_path

data_dir = "data"
unformatted_particle_data_url = "https://huggingface.co/datasets/AI4EIC/DNP2025-tutorial/resolve/main/unformatted_dataset/ParticleGunDataSet_800k.root"
data_dir = "data"
gun_root = download(unformatted_particle_data_url, data_dir)

Read ROOT arrays#

import uproot, awkward as ak, numpy as np

file = uproot.open(gun_root)
file.keys()
tree = uproot.open(gun_root)["FCALShowers"]

tree.keys()

Explanation of these features can be found here: data preparation

branches = ["rows","cols","energies","showerE","thrownEnergy","numBlocks","isNearBorder","isSplitOff","isPhotonShower"]
arr = tree.arrays(branches, library="ak")

rows_ak, cols_ak, en_ak = arr["rows"], arr["cols"], arr["energies"]
showerE = ak.to_numpy(arr["showerE"]).astype(np.float32)
thrownE = ak.to_numpy(arr["thrownEnergy"]).astype(np.float32)
numBlocks = ak.to_numpy(arr["numBlocks"]).astype(np.int32)
isNearBorder = ak.to_numpy(arr["isNearBorder"]).astype(bool)
isSplit = ak.to_numpy(arr["isSplitOff"]).astype(bool)
isPhoton = ak.to_numpy(arr["isPhotonShower"]).astype(bool)

# Labels consistent with GNN: 1 photon, 0 splitOff
labels = np.where(isPhoton, 1, np.where(isSplit, 0, -1)).astype(np.int64)

# Keep only clean ones
keep = labels >= 0
keep_idx = np.where(keep)[0]

print("Total:", len(labels), "Kept:", len(keep_idx))
print("Photons:", np.sum(labels[keep_idx]==1), "SplitOffs:", np.sum(labels[keep_idx]==0))

Train/Val/Test Split#

from sklearn.model_selection import train_test_split

# ------------------------------------------------------------
# Keep only a fraction of the full dataset
# ------------------------------------------------------------
fraction = 0.10   # try 10% first
rng_seed = 42

indices = keep_idx
y = labels[indices]

small_idx, _ = train_test_split(
    indices,
    train_size=fraction,
    stratify=y,
    random_state=rng_seed
)

print("Original size:", len(indices))
print("Reduced size :", len(small_idx))
print("Photons      :", np.sum(labels[small_idx] == 1))
print("SplitOffs    :", np.sum(labels[small_idx] == 0))
train_idx, temp_idx = train_test_split(small_idx, test_size=0.3, stratify=labels[small_idx], random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=labels[temp_idx], random_state=42)

# class counts (for weighting)
n_ph = np.sum(labels[train_idx] == 1)
n_so = np.sum(labels[train_idx] == 0)
counts = np.array([n_so, n_ph], dtype=np.float32)   # [splitOff, photon]
counts

2. Graph Construction#

Graph Construction (all hits)#

Data Engineering

Typical pattern in a shower:

  • central blocks -> large energy

  • nearby blocks -> moderate

  • peripheral blocks -> very small

Consequently, if you feed raw energies directly into a neural network:

  • large values fominate gradients

  • tiny values become almost invisible

  • training could become unstable

We do a node-level energy transform

\[ z(E) = \frac{log(1+E/e_{0})}{log(1+E_{max}/e_{0})} \in [0,1] \]

We are applying a non-linear compression of the energy scale

  • large energies -> compressed

  • small energies -> expanded

This could make the shape of the shower easier to learn. In calorimeter ML this is very common.

import matplotlib.pyplot as plt

# Flatten all crystal energies across all showers
all_E = ak.to_numpy(ak.flatten(en_ak))

# Plot histogram
plt.figure(figsize=(6,4))
plt.hist(all_E, bins=200)
plt.xlabel("Crystal energy (GeV)")
plt.ylabel("Counts")
plt.title("Distribution of individual FCAL crystal energies")
plt.yscale("log")   # useful because distribution spans many orders
plt.show()
np.quantile(all_E, [0.9, 0.95, 0.99, 0.999])
array([0.53997727, 1.16369786, 2.37762243, 3.12767875])
import torch
from torch_geometric.data import Data

FCAL_GEOMETRY = (59, 59)
ENERGY_SCALE  = 0.05 #
CLIP_MAX      = 2.0 #

#############################################################################################################
# nodes = calorimeter blocks belonging to the shower
# edges = adjacency between neighboring blocks
# node features = per-block quantities like position and energy
# global graph features = shower-level quantities like total shower energy, number of blocks, and border flag
#############################################################################################################
def logE_norm(E, e0=ENERGY_SCALE, emax=CLIP_MAX, eps=1e-9):
    E = np.clip(E.astype(np.float32, copy=False), 0.0, emax)
    Z = np.log1p(E / np.float32(e0))
    Z = Z / np.log1p(np.float32(emax) / np.float32(e0))
    return np.clip(Z, 0.0, 1.0).astype(np.float32)

def build_edges_8n(rows, cols):
    # creates a dictionary mapping grid position → node index.
    # Later we want to ask: if crystal (r,c) has a neighbour at (r+1,c), which node is that?
    idx = {(int(r), int(c)): i for i, (r, c) in enumerate(zip(rows, cols))}

    neigh = [(-1,-1), (-1,0), (-1,1),
             ( 0,-1),         ( 0,1),
             ( 1,-1), ( 1,0), ( 1,1)]
    edges = []
    for i, (r, c) in enumerate(zip(rows, cols)):
        r = int(r); c = int(c)
        for dr, dc in neigh:
            j = idx.get((r+dr, c+dc), None)
            if j is not None:
                edges.append((i, j))
    if len(edges) == 0:
        return torch.empty((2, 0), dtype=torch.long)               # returns 2 x E, where E are edges (E=0 if no edges)
    return torch.tensor(edges, dtype=torch.long).t().contiguous()  # returns 2 x E
    # .contiguous forces PyTorch to copy the data into a new block of memory arranged in the correct order after transpose (not just view)

#############################################################################################################
# the output of shower_to_graph is a torch_geometric.data.Data object representing one shower
#############################################################################################################
def shower_to_graph(rows, cols, energies, showerE, numBlocks, isNearBorder, eps=1e-9):
    rows = np.asarray(rows, dtype=np.int64)
    cols = np.asarray(cols, dtype=np.int64)
    Eraw = np.asarray(energies, dtype=np.float32)
    if len(Eraw) == 0:
        return None

    # Energy transform
    E = logE_norm(Eraw)

    Esum = float(Eraw.sum()) + eps  # use raw sum for fractions

    # Cluster "center of mass"
    # --- In FCAL, rows and cols are indices of position on the calorimeter plane. We can compute energy-weighted indices:
    row_c = float((rows * Eraw).sum() / Esum)
    col_c = float((cols * Eraw).sum() / Esum)

    # Relative coordinates: cluster shape
    dx = (rows - row_c).astype(np.float32)
    dy = (cols - col_c).astype(np.float32)
    rr = np.sqrt(dx*dx + dy*dy).astype(np.float32)

    # Normalize detector coordinates
    # rows 0->58
    # cols 0->58
    row_n = (rows.astype(np.float32) / (FCAL_GEOMETRY[0]-1)) * 2 - 1  #[0,1]->[-1,1]
    col_n = (cols.astype(np.float32) / (FCAL_GEOMETRY[1]-1)) * 2 - 1  #[0,1]->[-1,1]

    Efrac = (Eraw / Esum).astype(np.float32)
    logE  = np.log(Eraw + eps).astype(np.float32)

    # Node features
    # [row_n, col_n, E_logscaled, logEraw, Efrac, dx, dy, r] --- they all are (N,)
    x = np.stack([row_n, col_n, E, logE, Efrac, dx, dy, rr], axis=1).astype(np.float32) # --- returns (N, 8) where N: #crystals, 8: #features
    x = torch.tensor(x, dtype=torch.float32)

    edge_index = build_edges_8n(rows, cols)

    data = Data(x=x, edge_index=edge_index) # defines data.x and data.edge_index (graph connectivity)
    data.g = torch.tensor([float(showerE), float(numBlocks), float(isNearBorder)], dtype=torch.float32) #data.g is a custom attribute that stores graph-level (global) features for the shower
    # N.b.: building the input tensor that the neural network expects, False → 0.0, True  → 1.0
    return data
# flatten energies
all_E = ak.to_numpy(ak.flatten(en_ak))

# apply your transform
Z = logE_norm(all_E)

plt.figure(figsize=(6,5))
plt.scatter(all_E, Z, s=1, alpha=0.1)
plt.xlabel("Raw crystal energy (GeV)")
plt.ylabel("Transformed energy z(E)")
plt.title("Energy transform used for GNN input")
plt.xlim(0,5)
plt.ylim(0,1.05)
plt.show()

Visualize Graphs#

def plot_shower_nodes(rows, cols, energies=None, title="Shower nodes"):
    rows = np.asarray(rows)
    cols = np.asarray(cols)

    plt.figure(figsize=(6, 6))
    if energies is None:
        plt.scatter(cols, rows, s=80)
    else:
        sc = plt.scatter(cols, rows, c=energies, s=120)
        plt.colorbar(sc, label="Crystal energy")

    plt.gca().invert_yaxis()   # optional, often nicer for detector grids
    plt.xlabel("col")
    plt.ylabel("row")
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.axis("equal")
    plt.show()
i = keep_idx[7879]
plot_shower_nodes(rows_ak[i], cols_ak[i], en_ak[i], title="FCAL shower nodes")
print(labels[i])

Dataset + DataLoader#

from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader

class FCALGraphDataset(Dataset):
    def __init__(self, idx_list):
        self.idx = np.asarray(idx_list, dtype=np.int64)

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, i):

      j = int(self.idx[i])
      rows = ak.to_numpy(rows_ak[j])
      cols = ak.to_numpy(cols_ak[j])
      ens  = ak.to_numpy(en_ak[j])

      data = shower_to_graph(rows, cols, ens,
                            showerE=showerE[j],
                            numBlocks=numBlocks[j],
                            isNearBorder=isNearBorder[j])

      y = labels[j]  # 0/1
      data.y = torch.tensor([y], dtype=torch.long)

      # IMPORTANT: make graph-level features 2D so they batch correctly later
      data.g = torch.tensor([[float(showerE[j]),
                              float(numBlocks[j]),
                              float(isNearBorder[j])]], dtype=torch.float32)

      # Optional "metadata" (also 2D is fine; 1D works, but keep consistent)
      data.showerE = torch.tensor([float(showerE[j])], dtype=torch.float32)
      data.thrownE = torch.tensor([float(thrownE[j])], dtype=torch.float32)

      return data

BATCH_SIZE = 512
train_loader = DataLoader(FCALGraphDataset(train_idx), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(FCALGraphDataset(val_idx), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(FCALGraphDataset(test_idx), batch_size=BATCH_SIZE, shuffle=False)

it = iter(train_loader)
batch1 = next(it)
print(batch1)
batch2 = next(it)
print(batch2)

Notes:

  • ptr is a pointer array automatically created by the PyTorch Geometric DataLoader. It tells you where each graph starts and ends in the concatenated node list.

  • Formally, ptr.shape = batch_size + 1

batch2.ptr

Convert a PyG graph to NetworkX and plot

from torch_geometric.utils import to_networkx
import networkx as nx

# take a single shower
i = keep_idx[7879]

rows = np.asarray(rows_ak[i])
cols = np.asarray(cols_ak[i])
ens  = np.asarray(en_ak[i])

data = shower_to_graph(rows, cols, ens,
                       showerE=showerE[i],
                       numBlocks=numBlocks[i],
                       isNearBorder=isNearBorder[i])

# convert to networkx
G = to_networkx(data, to_undirected=True)

Use detector coordinates for layout

pos = {i: (cols[i], rows[i]) for i in range(len(rows))}
plt.figure(figsize=(6,6))

nx.draw(
    G,
    pos=pos,
    node_size=220,
    node_color=ens,
    cmap="plasma",
    with_labels=True
)

plt.gca().invert_yaxis()
plt.title("FCAL Shower Graph")
plt.show()

3. GNN Model#

To-do list

  • Explain GNN — link to slides?

  • Explain — Explain SAGEConv and link to other types of graph neural networks

GNN Model (graph classifier)#

################################################################################
#                            MODEL DESCRIPTION                                 #
################################################################################
#	1.	crystals start with 8 handcrafted node features                          #
#	2.	GraphSAGE updates each crystal using neighboring crystals                #
#	3.	after 3 layers, each crystal embedding reflects local shower structure   #
#	4.	node embeddings are pooled into one vector representing the whole shower #
#	5.	that vector is combined with global shower metadata                      #
#	6.	final MLP predicts whether the shower is photon or split-off             #
################################################################################

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool

# (crystal/node features) x_i = [row_n, col_n, E, logE, Efrac, dx, dy, rr]
# (shower/graph features) [showerE, numBlocks, isNearBorder]

class SmallGNN(nn.Module):
    def __init__(self, node_in=8, g_in=3, hidden=64, n_classes=2, dropout=0.1):

        # node_in : number of input features per node (here 8 crystal features)
        # g_in    : number of graph-level features (here 3 shower-level quantities in data.g)
        # hidden  : size of the node embedding produced by the GNN layers
        # n_classes : number of output classes (photon vs split-off)
        # dropout : dropout probability used in the classification head

        super().__init__()
        # calls the constructor of the parent class (nn.Module)
        # this registers the model with PyTorch so parameters, layers,
        # and gradients are properly tracked during training

        # GraphSAGE layers (message passing between neighboring crystals)
        # N = total number of nodes in the batch

        self.conv1 = SAGEConv(node_in, 64)  # (N, node_in=8)  -> (N, 64)
        # each node updates its features by aggregating information from its neighbors

        self.conv2 = SAGEConv(64, 64)       # (N, 64) -> (N, 64)
        # deeper layer: node embeddings now encode information from a larger neighborhood

        """
        ADD AN ADDITIONAL LAYER HERE                        #:::::::::: (see Exercise 1) ::::::::::#
        """

        self.conv3 = SAGEConv(64, hidden)   # (N, 64) -> (N, hidden)
        # final node embeddings used for graph-level pooling

        self.head = nn.Sequential(
            nn.Linear(2*hidden + g_in, 128),                #:::::::::: (see Exercise 2) ::::::::::#
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_classes)
        )

    def forward(self, data):

        # data.batch has shape (Ntot,): one entry per node.
        # data.batch[i] tells which graph/shower node i belongs to.
        # Thus:
        #   x  : (Ntot, 8)
        #   ei : (2, Etot)
        #   b  : (Ntot,)
        x, ei, b = data.x, data.edge_index, data.batch

        #############################################################################
        # Node embedding stage:
        # start from crystal-level node features (8 per crystal)
        # each SAGEConv layer updates a crystal representation by combining
        # its own features with features aggregated from neighboring crystals
        # (defined by the graph edges)
        #############################################################################

        x = F.relu(self.conv1(x, ei))   # (Ntot, node_in=8)  -> (Ntot, 64)
        x = F.relu(self.conv2(x, ei))   # (Ntot, 64)         -> (Ntot, 64)
        """
        ADD AN ADDITIONAL LAYER HERE                        #:::::::::: (see Exercise 1)  ::::::::::#
        """
        x = F.relu(self.conv3(x, ei))   # (Ntot, 64)         -> (Ntot, hidden)

        #############################################################################
        # Pool all node embeddings into one-graph-level embedding for the whole shower
        #############################################################################

        # B: number of graphs/showers in the batch
        pooled = torch.cat([global_mean_pool(x, b), global_max_pool(x, b)], dim=1)
        # (Ntot x hidden) -> (B x 2*hidden)

        out = self.head(torch.cat([pooled, data.g], dim=1))   #:::::::::: (see Exercise 2) ::::::::::#
        # (B x 2*hidden) + (B x 3) -> (B x (2*hidden + 3)) -> head -> (B x 2)
        return out

Evaluate + Train#

Returns y_true, y_prob where y_prob=P(photon) or P(splitoff) depending on the convention.

We consider here

  • label 1 = photon

  • label 0 = splitoff

  • y_prob = P(photon)

import os, math
from tqdm.auto import tqdm, trange
from sklearn.metrics import roc_auc_score
from contextlib import nullcontext
from safetensors.torch import save_model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_AMP = (DEVICE == "cuda")

################################################################################
@torch.inference_mode() # disables gradient tracking and autograd
def evaluate_gnn(model, loader, desc="Eval", returnShowerE=False, returnThrownE=False):

    # switches the network to evaluation mode and turns off training behaviors
    model.eval().to(DEVICE)

    # true labels, predicted photon probs, reco shower energy, true thrown energy
    ys, ps, shE, thE = [], [], [], []
    correct = total = 0
    bar = tqdm(loader, desc=desc, leave=False)
    for batch in bar:
        batch = batch.to(DEVICE)
        logits = model(batch)                             # [B,2] row: batch of showers, col: classes
        prob_photon = torch.softmax(logits, dim=1)[:, 1]  # P(photon) sliced as 1D tensor by [:,1]
        pred = logits.argmax(dim=1)                       # returns index of largest value: 0 split-off, 1 photon

        yb = batch.y.view(-1)                             # take true class label, flattened as 1D vector
        ys.append(yb.cpu().numpy())
        ps.append(prob_photon.cpu().numpy())

        shE.append(batch.showerE.view(-1).cpu().numpy())
        thE.append(batch.thrownE.view(-1).cpu().numpy())

        correct += (pred == yb).sum().item()
        total += yb.numel() # numel() returns the number of elements in a tensor
        bar.set_postfix(acc=f"{(correct/max(1,total)):.3f}") # show running accuracy in the progress bar

    y_true = np.concatenate(ys) if ys else np.array([])      # stitching into one long array
    y_prob = np.concatenate(ps) if ps else np.array([])
    showerE_out = np.concatenate(shE) if shE else np.array([])
    thrownE_out = np.concatenate(thE) if thE else np.array([])

    acc = correct / max(1, total)
    auc = roc_auc_score(y_true, y_prob) if y_true.size and np.unique(y_true).size > 1 else float("nan")

    if returnShowerE:
        return acc, auc, y_true, y_prob, showerE_out
    if returnThrownE:
        return acc, auc, y_true, y_prob, thrownE_out
    return acc, auc, y_true, y_prob

################################################################################
def train_gnn(model, opt, train_loader, val_loader, save_path="./models", counts=counts, epochs=20):

    #---------------------------------------------------------------------------
    # Parameters
    # ----------
    # model: your SmallGNN model
    # opt: optimizer (e.g. Adam)
    # train_loader: training DataLoader - Yields mini-batches of training graphs.
    # val_loader: validation DataLoader - Yields mini-batches of validation graphs.
    # save_path: where the trained model will be saved
    # counts: array-like, shape (n_classes,) # of training examples in each class.
    #        counts[0] = number of split-off showers
    #        counts[1] = number of photon showers
    # epochs: number of training passes over the dataset

    # Returns
    # -------
    # model : nn.Module
    #    The model restored to the best validation AUC seen during training.
    #---------------------------------------------------------------------------

    os.makedirs(save_path, exist_ok=True)

    # -------------------------------------------------------------------------
    # Build class weights for the loss function.
    #
    # counts.sum() = total number of samples across classes.
    # counts + 1e-6 avoids division by zero if a class were empty.
    #
    # This produces one weight per class: w_c = total_count / count_c
    #
    # Rarer classes get a larger weight. This helps when the dataset is imbalanced.
    # -------------------------------------------------------------------------
    w = torch.tensor(
        (counts.sum() / (counts + 1e-6)),
        dtype=torch.float32,
        device=DEVICE,
    )

    # Weighted cross-entropy loss.
    # The weight used for a sample depends on its true class.
    # For binary classification with labels 0/1:
    #   weight[0] is used for class 0 samples
    #   weight[1] is used for class 1 samples
    crit = nn.CrossEntropyLoss(weight=w)

    # -------------------------------------------------------------------------
    # Automatic Mixed Precision (AMP) support.
    #
    # USE_AMP is True only when training on CUDA in this notebook.
    #
    # GradScaler is used only with float16 mixed precision on GPU.
    # It scales the loss before backpropagation so tiny gradients do not
    # underflow to zero in float16. If AMP is not used, scaler is simply None.
    # -------------------------------------------------------------------------
    scaler = torch.amp.GradScaler('cuda') if USE_AMP else None

    # Move the model parameters to the selected device (GPU or CPU).
    model.to(DEVICE)

    # -------------------------------------------------------------------------
    # Early-stopping bookkeeping.
    #
    # best_auc   : best validation AUC seen so far
    # best_state : frozen copy of the model parameters at best_auc
    # patience   : stop after this many consecutive non-improving epochs
    # no_improv  : number of consecutive epochs without improvement
    # -------------------------------------------------------------------------
    best_auc, best_state, patience, no_improv = -1.0, None, 5, 0

    # Main epoch loop.
    # trange(...) is just tqdm(range(...)) with a progress bar.
    for epoch in trange(1, epochs+1, desc="Training"):

        # Put the model in training mode.
        # This matters for layers like Dropout and BatchNorm.
        model.train()

        # Running statistics for the current epoch.
        # running : sum of per-sample losses over the epoch
        # seen    : number of graphs processed so far in this epoch
        # correct : number of correctly classified graphs so far
        running = 0.0
        seen = 0
        correct = 0

        # Wrap the train_loader in a tqdm progress bar.
        # bar now iterates over the same batches as train_loader with a live progress bar and dynamic text updates.
        bar = tqdm(
            train_loader,
            desc=f"Epoch {epoch}/{epochs} (train)",
            leave=False,
        )

        bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} (train)", leave=False) ## does now bar control train_loader, i.e., one can loop batches in bar?

        # ---------------------------------------------------------------------
        # Mini-batch training loop
        # ---------------------------------------------------------------------
        for batch in bar:

            # Move the entire batched graph object to the selected device.
            # This moves data.x, data.edge_index, data.y, data.g, etc.
            batch = batch.to(DEVICE)

            # batch.y stores one label per graph.
            # view(-1) reshapes it into a flat 1D tensor of shape (B,) where B is the number of graphs in this batch.
            yb = batch.y.view(-1)

            # Reset gradients from the previous optimization step.
            # set_to_none=True is a slightly more memory-efficient / faster way than filling old gradients with zeros.
            opt.zero_grad(set_to_none=True)


            # -----------------------------------------------------------------
            # Build the precision context.
            #
            # If AMP is enabled: use autocast so many ops run in float16 on CUDA
            # If AMP is disabled: use nullcontext(), i.e. do nothing special
            # -----------------------------------------------------------------
            ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16) if USE_AMP else nullcontext()

            # "with ctx:" enters that context manager
            # If ctx is autocast(...): operations inside the block run in mixed precision where safe
            # If ctx is nullcontext(): the block runs normally

            # The point of using "with" is that only this block is affected,
            # and once we exit the block, precision behavior returns to normal.
            with ctx:
                # Forward pass: model outputs logits of shape (B, n_classes).
                logits = model(batch)

                # Compute weighted cross-entropy loss against the true labels.
                loss = crit(logits, yb)

            # -----------------------------------------------------------------
            # Backpropagation + optimizer step
            # -----------------------------------------------------------------
            if USE_AMP:

                # Scale the loss to protect small gradients in float16.
                scaler.scale(loss).backward()

                # Perform the optimizer step using the scaled/unscaled gradients.
                scaler.step(opt)

                # Update the internal scaling factor for the next iteration.
                # The scaler dynamically adapts:
                #   - if gradients are safe, it may increase the scale
                #   - if overflow is detected, it may reduce the scale
                scaler.update()
            else:
                # Standard float32 training path.
                loss.backward()
                opt.step() ## why no updated as before?

            # -----------------------------------------------------------------
            # Accumulate training metrics for this epoch
            # -----------------------------------------------------------------

            # loss.item() is typically the mean loss over the current batch.
            #
            # To compute the true epoch-average loss, we multiply by batch size
            # to convert the batch mean loss into the batch sum loss, then later
            # divide by the total number of samples seen.
            running += float(loss.item()) * yb.size(0)

            # Number of graphs processed so far in this epoch.
            seen += yb.size(0)

            # logits.argmax(1):
            #   argmax over class dimension (dim=1)
            #   gives the predicted class index for each graph in the batch
            #
            # If logits has shape (B, 2), then logits.argmax(1) has shape (B,)
            # and contains 0 or 1 for each graph.
            correct += (logits.argmax(1) == yb).sum().item()

            # Update the live progress-bar text with running averages.
            bar.set_postfix(avg_loss=f"{running/max(1,seen):.4f}", acc=f"{(correct/max(1,seen)):.3f}")

        # Average training loss over the full epoch.
        train_loss = running / max(1, seen)

        # Evaluate on the validation set.
        # evaluate_gnn returns: val_acc, val_auc, y_true, y_prob
        val_acc, val_auc, _, _ = evaluate_gnn(model, val_loader, desc=f"Epoch {epoch}/{epochs} (val)")

        tqdm.write(f"Epoch {epoch:02d} | train_loss {train_loss:.4f} | val_acc {val_acc:.4f} | val_auc {val_auc:.4f}")

        # If AUC is NaN, replace it by 0.0 for comparison purposes.
        score = 0.0 if math.isnan(val_auc) else val_auc

        # ---------------------------------------------------------------------
        # Check whether this epoch improved the validation score
        # ---------------------------------------------------------------------
        if score > best_auc:
            # Improvement found: update best score and reset bad-epoch counter.
            best_auc = score
            no_improv = 0

            # Save a frozen snapshot of the model weights.
            #
            # model.state_dict().items() gives (name, tensor) pairs.
            #
            # For each tensor v:
            #   detach() : remove gradient history / autograd tracking
            #   cpu()    : move to CPU memory so it is safely stored there
            #   clone()  : make an independent copy of the tensor data
            #
            best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()} ## when do detach? ## why clone? ## I do not understand this
            # if you do best_state = model.state_dict(), the best_state would be the same tensors used by the model during training.
            # A deep copy creates a new object with new memory. The copied tensory point to a different memory block.
            # PyTorch tracks operations in a computation graph so gradients can be computed. .detach() creates a tensor that: shares the same values but has no gradient history
            # .detach().cpu().clone() together allows a frozen snapshot of model weights independent of training, stored safely in CPU memory

            tqdm.write(f"✓ New best AUC {best_auc:.4f}")
            save_model(model, os.path.join(save_path, "FCAL_GNN_Classifier.safetensors"))

        else:
            # No improvement this epoch.
            no_improv += 1

            # If we have gone too many epochs without improvement, stop early.
            if no_improv >= patience:
                tqdm.write("Early stopping.")
                break

    # After training ends, restore the best weights we saw during validation.
    if best_state is not None:
        model.load_state_dict(best_state)

    return model

Run Training#

EPOCHS = 1 # for lack of time
LR = 3e-4

model = SmallGNN().to(torch.float32) # forces all model parameters and buffers to use the float32 datatype.

# Adaptive Moment Estimation (Adam) + weight decay regularization
opt = torch.optim.AdamW(model.parameters(), lr=LR)

model = train_gnn(model, opt, train_loader, val_loader, save_path="./models", epochs=EPOCHS)

4. Performance#

Evaluate Performance#

from sklearn.metrics import confusion_matrix
import seaborn as sns, matplotlib.pyplot as plt

test_acc, test_auc, y_true, y_prob = evaluate_gnn(model, test_loader, desc="Test")

cm = confusion_matrix(y_true, (y_prob >= 0.5).astype(int), labels=[0,1])
print(f"Test | acc {test_acc:.4f} | auc {test_auc:.4f}")

sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["SplitOff","Photon"], yticklabels=["SplitOff","Photon"])
plt.show()
############################################################
# Reference values from https://arxiv.org/pdf/2002.09530
# 85% signal with 60% background rejection
############################################################

# --- Photon efficiency vs split-off rejection utilities (GNN convention) ---

import numpy as np
from sklearn.metrics import roc_curve, confusion_matrix

# GNN convention
# y_true: 0 = SplitOff (background), 1 = Photon (signal)
# y_prob: P(Photon)

def eff_rej_at_threshold(y_true, y_prob, thr):

    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob)

    # classify photon if probability >= threshold
    y_pred = (y_prob >= thr).astype(int)

    Nsig = np.sum(y_true == 1)
    Nbkg = np.sum(y_true == 0)

    eps_sig = np.sum((y_true == 1) & (y_pred == 1)) / max(1, Nsig)
    rej_bkg = np.sum((y_true == 0) & (y_pred == 0)) / max(1, Nbkg)

    cm = confusion_matrix(y_true, y_pred, labels=[0,1])

    return eps_sig, rej_bkg, cm


def eff_rej_curve(y_true, y_prob):

    # positive class = photon
    fpr, tpr, thrs = roc_curve(y_true, y_prob, pos_label=1)

    eps_sig = tpr          # photon efficiency
    rej_bkg = 1 - fpr      # split-off rejection

    return eps_sig, rej_bkg, thrs


# --- Compute performance curve ---
eps_sig, rej_bkg, thrs = eff_rej_curve(y_true, y_prob)


# --- Working point: photon efficiency ≈ 0.85 ---
target_eps = 0.85

i = int(np.argmin(np.abs(eps_sig - target_eps)))
thr_star = thrs[i]

eps_star, rej_star, cm_star = eff_rej_at_threshold(y_true, y_prob, thr_star)

print(f"Working point (ε_sig≈{target_eps:.2f})")
print(f"threshold = {thr_star:.4f}")
print(f"Photon efficiency = {eps_star:.4f}")
print(f"Split-off rejection = {rej_star:.4f}")

print("\nConfusion matrix")
print("(rows=true, cols=pred)")
print("0=SplitOff, 1=Photon")

print(cm_star)


# --- Performance at default threshold 0.5 ---
eps_05, rej_05, cm_05 = eff_rej_at_threshold(y_true, y_prob, 0.5)

print(f"\nAt threshold 0.5")
print(f"Photon efficiency = {eps_05:.4f}")
print(f"Split-off rejection = {rej_05:.4f}")

Figure of Merit Definitions#

When optimizing a binary classifier in the presence of signal and background, the optimal threshold is often chosen by maximizing a Figure of Merit (FoM) that balances signal efficiency and background contamination.

In this tutorial:

  • Signal → photon showers

  • Background → split-off showers


Signal Significance#

A commonly used metric is the signal significance, defined as

\[ \mathrm{FoM}_{\text{sig}} = \frac{S}{\sqrt{S + B}} \]

where

  • \(S = N_S \, \varepsilon_S\) is the number of selected signal events

  • \(B = N_B \, \varepsilon_B\) is the number of selected background events

  • \(\varepsilon_S\) is the signal efficiency (fraction of photons retained)

  • \(\varepsilon_B\) is the background efficiency (fraction of split-offs misidentified as photons)

This FoM estimates the expected statistical significance of a signal observation in the presence of background.


Purity#

Purity measures how clean the selected sample is, i.e. the fraction of signal among all selected events:

\[ \text{Purity} = \frac{S}{S + B} \]

Substituting efficiencies and total counts:

\[ \text{Purity} = \frac{N_S \varepsilon_S}{N_S \varepsilon_S + N_B \varepsilon_B} \]

If the dataset signal-to-background ratio is known

\[ \frac{N_S}{N_B} \]

this becomes

\[ \text{Purity} = \frac{\varepsilon_S}{\varepsilon_S + \varepsilon_B / (N_S/N_B)} \]

Interpretation#

  • Signal significance is useful when the goal is to maximize discovery potential (typical in searches for new physics).

  • Purity is useful when the goal is to obtain a clean signal sample, for example when building templates or performing precision measurements.

FoMs

import numpy as np
from sklearn.metrics import roc_curve

def find_optimal_threshold(y_true, y_prob, S_over_B=1.0):
    """
    Optimize threshold using signal significance FoM.

    Convention used in this GNN tutorial:
        y_true : 0 = SplitOff (background), 1 = Photon (signal)
        y_prob : P(photon)

    S_over_B : expected signal-to-background ratio in dataset
    """

    # ROC with photon as positive class
    fpr, tpr, thrs = roc_curve(y_true, y_prob, pos_label=1)

    eps_sig = tpr          # photon efficiency
    eps_bkg = fpr          # background efficiency (split-offs misidentified as photons)

    # Signal significance
    fom = eps_sig / np.sqrt(eps_sig + eps_bkg / S_over_B + 1e-12)

    # Purity
    purity = eps_sig / (eps_sig + eps_bkg / S_over_B + 1e-12)

    i_opt = np.argmax(fom)
    thr_opt = thrs[i_opt]

    return {
        "threshold": float(thr_opt),
        "eps_sig": float(eps_sig[i_opt]),
        "rej_bkg": float(1 - eps_bkg[i_opt]),
        "SoverSqrtSB": float(fom[i_opt]),
        "purity": float(purity[i_opt]),
        "curve": {
            "thresholds": thrs,
            "eps_sig": eps_sig,
            "rej_bkg": 1 - eps_bkg,
            "FoM": fom,
            "purity": purity
        },
    }


# --- Run optimization ---
opt = find_optimal_threshold(y_true, y_prob, S_over_B=1.0)

print("Optimal point Summary:")
print(f"  threshold = {opt['threshold']:.3f}")
print(f"  Photon efficiency = {opt['eps_sig']:.3f}")
print(f"  Split-off rejection = {opt['rej_bkg']:.3f}")
print(f"  FoM (S/sqrt(S+B)) = {opt['SoverSqrtSB']:.3f}")
print(f"  Purity = {opt['purity']:.3f}")

Plot ROC + optimal point

import matplotlib.pyplot as plt

plt.plot(opt["curve"]["eps_sig"], opt["curve"]["rej_bkg"], label="ROC")
plt.scatter(opt["eps_sig"], opt["rej_bkg"], color="r", label="Optimal point")

plt.xlabel("Photon Efficiency (ε_sig)")
plt.ylabel("Split-off Rejection (R_bkg)")
plt.title("Photon vs Split-off Classification Performance")
plt.legend()
plt.grid(True)

plt.show()

Plot FoM and purity vs threshold

plt.figure()

plt.plot(opt["curve"]["thresholds"], opt["curve"]["FoM"],
         label=r"$S/\sqrt{S+B}$")

plt.plot(opt["curve"]["thresholds"], opt["curve"]["purity"],
         label="Purity")

plt.axvline(opt["threshold"],
            color="r",
            ls="--",
            label=f"Opt thr={opt['threshold']:.3f}")

plt.xlabel("Threshold on P(photon)")
plt.ylabel("Metric value")

plt.legend()
plt.grid(True)

plt.show()

Energy-dependent Performance

EBins = [0.100, 0.500, 1.0, 2.0, 4.0]


test_acc, test_auc, y_true, y_prob, showerE = evaluate_gnn(
    model,
    test_loader,
    desc="Test",
    returnShowerE=True
)
from sklearn.metrics import roc_auc_score, accuracy_score

def metrics_vs_energy(y_true, y_prob, showerE, EBins):
    """
    Compute accuracy and AUC vs shower energy bins.
    """
    EBins = np.asarray(EBins)
    labels = [f"[{EBins[i]:.1f}, {EBins[i+1]:.1f})" for i in range(len(EBins)-1)]
    results = []

    for i in range(len(EBins)-1):
        mask = (showerE >= EBins[i]) & (showerE < EBins[i+1])
        if not np.any(mask):
            results.append((labels[i], np.nan, np.nan, np.sum(mask)))
            continue
        yt = y_true[mask]
        yp = y_prob[mask]
        acc = accuracy_score(yt, yp >= 0.5)
        auc = roc_auc_score(yt, yp) if np.unique(yt).size > 1 else np.nan
        results.append((labels[i], acc, auc, np.sum(mask)))

    return results
bin_results = metrics_vs_energy(y_true, y_prob, showerE, EBins)

print(f"{'E_bin':<15}{'Acc':>10}{'AUC':>10}{'Nevents':>10}")
for label, acc_b, auc_b, n in bin_results:
    print(f"{label:<15}{acc_b:10.3f}{auc_b:10.3f}{n:10d}")
import numpy as np
import matplotlib.pyplot as plt

centers = 0.5 * (np.arange(1, len(EBins)))
print (centers)
width = 0.15  # bar width

accs = [r[1] for r in bin_results]
aucs = [r[2] for r in bin_results]

plt.figure(figsize=(8,5))

# Horizontal offsets for each metric
plt.bar(centers - 1.5*width, accs, width, label="Accuracy")
plt.bar(centers - 0.5*width, aucs, width, label="AUC")

# Styling
plt.xlabel("Shower Energy [GeV]")
plt.ylabel("Metric Value")
plt.title("Classifier Performance vs. FCAL Shower Energy")
plt.xticks(centers - 0.15, [f"[{EBins[i]:.1f},{EBins[i+1]:.1f}) [GeV]" for i in range(len(EBins)-1)])
plt.ylim(0, 1.05)
plt.legend(frameon=False)
plt.grid(axis="y", linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()

Exercise 1:

Can you stack an additional graph convolution layer in our model?


Exercise 2:

Do you think the GNN can classify photons vs split-offs using only the crystal pattern, without global shower quantities? Try it out.


5. Physics Analysis#

Inference on Physics Events#

unformatted_omega_data_url = "https://huggingface.co/datasets/AI4EIC/DNP2025-tutorial/resolve/main/unformatted_dataset/OmegaExclusive_100k.root"
omega_dataset_path = download(unformatted_omega_data_url, data_dir)
# --- ω physics-events analysis with your trained SmallGNN on UNFORMATTED ROOT ---

import uproot, awkward as ak, numpy as np
import torch
from torch_geometric.data import Data
from itertools import combinations # to generate shower pairs
from tqdm import tqdm
import matplotlib.pyplot as plt

# ======================
# Assumptions from above
# ======================
# - model : trained SmallGNN
# - DEVICE : "cuda"/"cpu"
# - shower_to_graph(), build_edges_8n(), logE_norm() already defined exactly as in your training code
# - omega_dataset_path points to OmegaExclusive_100k.root

PI0_MASS = 0.1349768  # GeV
FCAL_NROW, FCAL_NCOL = 59, 59

# -------------------------
# Helpers: TLorentzVector IO
# -------------------------

import numpy as np
import awkward as ak

def _to_numpy_scalar_or_array(x):
    """
    Convert awkward/np/python scalar-or-array to numpy array or float.
    """
    # If it's an awkward array/record field, convert to numpy
    try:
        return ak.to_numpy(x)
    except Exception:
        return np.asarray(x)

# ---------------------------------------------------------------------------
# Convert a ROOT TLorentzVector (read by uproot/awkward) into a NumPy 4-vector
# Output convention: [px, py, pz, E]
# ---------------------------------------------------------------------------
def _tlv_to_p4(rec):
    """
    rec:
        One object that should represent a ROOT TLorentzVector as exposed
        by uproot/awkward.

    Returns:
        NumPy array with last dimension = 4, ordered as [px, py, pz, E].

        Possible shapes:
          - (4,)          for one single 4-vector
          - (..., 4)      for an array of 4-vectors
    """

    # Ask the object "rec" whether it has a .fields attribute.
    # If yes, get the list of fields. If not, use [].
    #
    # For a TLorentzVector read from ROOT, we expect top-level fields like:
    #   fE : energy
    #   fP : 3-momentum object
    fields = set(getattr(rec, "fields", []))

    if "fE" not in fields or "fP" not in fields:
        raise RuntimeError(f"Expected TLorentzVector with fields fE and fP. Got fields: {sorted(list(fields))}")

    E = _to_numpy_scalar_or_array(rec["fE"])
    P = rec["fP"]

    # -----------------------------------------------------------
    # Case A: fP is a nested record, like a TVector3
    # -----------------------------------------------------------

    # Check whether P itself has fields.
    # If so, those fields may be something like:
    #   fX, fY, fZ
    # or:
    #   X, Y, Z
    # or:
    #   x, y, z

    p_fields = set(getattr(P, "fields", []))

    # ROOT TVector3 commonly uses fX,fY,fZ (or sometimes X,Y,Z)
    if {"fX", "fY", "fZ"}.issubset(p_fields):
        px = _to_numpy_scalar_or_array(P["fX"])
        py = _to_numpy_scalar_or_array(P["fY"])
        pz = _to_numpy_scalar_or_array(P["fZ"])
    elif {"X", "Y", "Z"}.issubset(p_fields):
        px = _to_numpy_scalar_or_array(P["X"])
        py = _to_numpy_scalar_or_array(P["Y"])
        pz = _to_numpy_scalar_or_array(P["Z"])
    elif {"x", "y", "z"}.issubset(p_fields):
        px = _to_numpy_scalar_or_array(P["x"])
        py = _to_numpy_scalar_or_array(P["y"])
        pz = _to_numpy_scalar_or_array(P["z"])

    # -----------------------------------------------------------
    # Case B: fP is not a record, but just a numeric array
    # -----------------------------------------------------------

    else:
        # Case B: fP is a numeric array of shape (...,3)
        P_np = _to_numpy_scalar_or_array(P)
        P_np = np.asarray(P_np)

        # If shape is exactly (3,), we interpret it as one single vector:
        #   [px, py, pz]
        if P_np.shape == (3,):
            px, py, pz = P_np[0], P_np[1], P_np[2]

        # If the last axis has length 3, interpret it as an array of vectors.
        # Example:
        #   shape (N, 3)
        # Then moveaxis(..., -1, 0) swaps last axis with 0, giving (3,N)
        elif P_np.ndim >= 1 and P_np.shape[-1] == 3:
            px, py, pz = np.moveaxis(P_np, -1, 0)

        # Otherwise, we do not know how to interpret fP.
        else:
            raise RuntimeError(
                "Could not decode TLorentzVector.fP as a TVector3 record or (...,3) array.\n"
                f"Top fields: {sorted(list(fields))}\n"
                f"fP fields: {sorted(list(p_fields)) if p_fields else 'None'}\n"
                f"fP numpy shape: {getattr(P_np, 'shape', None)}"
            )

    # -----------------------------------------------------------
    # Build the final four-vector [px, py, pz, E]
    # -----------------------------------------------------------

    # Convert everything explicitly to NumPy arrays.
    # This makes broadcasting / stacking behave consistently. Stack to p4 = [px,py,pz,E]
    px = np.asarray(px)
    py = np.asarray(py)
    pz = np.asarray(pz)
    E  = np.asarray(E)


    # Stack the four components along the last axis.
    #
    # Example 1: px,py,pz,E are scalars -> output shape (4,)
    #
    # Example 2: px,py,pz,E are arrays of shape (N,) -> output shape (N,4)
    #
    # axis=-1 means the last axis is the component axis:
    #   [..., 0] = px, [..., 1] = py, [..., 2] = pz, [..., 3] = E

    # Ensure same shape via numpy broadcasting
    p4 = np.stack([px, py, pz, E], axis=-1).astype(np.float32, copy=False)
    return p4


def fourvec_mass(p4):

    p4 = np.asarray(p4)

    if p4.shape[-1] != 4:
        raise ValueError("Input must have last dimension = 4 (px,py,pz,E)")

    px, py, pz, E = np.moveaxis(p4, -1, 0)
    m2 = E**2 - (px**2 + py**2 + pz**2)
    m2 = np.maximum(m2, 0.0)
    return np.sqrt(m2)

def combine_p4(p4a, p4b):
    return p4a + p4b

# --------------------------------------------------------------------------
# Infer whether a shower is close to the FCAL detector border (isNearBorder)
# --------------------------------------------------------------------------
def infer_isNearBorder_from_hits(rows, cols, energies, margin=2, eps=1e-9):
    """
    Decide whether a shower is near the detector edge, using the
    energy-weighted centroid of its hits.

    Parameters
    ----------
    rows, cols : array-like
        Grid coordinates of the hit crystals belonging to one shower.
    energies : array-like
        Energy deposited in each hit crystal.
    margin : int or float
        Number of cells from the detector edge considered "near border".
        Example: margin=2 means anything within ~2 cells of an edge is flagged.
    eps : float
        Small number added to avoid division by zero.

    Returns
    -------
    bool
        True  -> shower centroid is near a detector edge
        False -> shower centroid is safely inside the detector
    """

    # If the shower has no hits, we cannot define a meaningful centroid.
    # Conservatively label it as near the border / problematic.
    if len(energies) == 0:
        return True

    # Convert inputs to NumPy arrays with consistent floating-point type.
    Eraw = np.asarray(energies, dtype=np.float32)
    rows = np.asarray(rows, dtype=np.float32)
    cols = np.asarray(cols, dtype=np.float32)

    # Energy-weighted centroid ("center of mass") of the shower in row/col space.
    Esum = float(Eraw.sum()) + eps
    r_c = float((rows * Eraw).sum() / Esum)
    c_c = float((cols * Eraw).sum() / Esum)

    # Check whether the centroid is within 'margin' cells of any FCAL edge.
    near = (
        (r_c < margin) or (r_c > (FCAL_NROW - 1 - margin)) or
        (c_c < margin) or (c_c > (FCAL_NCOL - 1 - margin))
    )
    return bool(near)

# ------------------------------------------------------------
# Build one shower graph for the omega analysis, using the same
# graph representation as the one used during training.
# ------------------------------------------------------------
def omega_shower_to_graph(rows, cols, energies, showerE_scalar, numBlocks_scalar):
    """
    Convert one neutral shower from the omega ROOT sample into a PyG graph compatible with the trained GNN.

    Parameters
    -------

      x: [row_n, col_n, E_logscaled, logEraw, Efrac, dx, dy, r]  (8) --- array-like
      g: [showerE, numBlocks, isNearBorder]                      (3) --- scalar

    Returns
    -------
    data : torch_geometric.data.Data or None
        Graph object for this shower, or None if the shower is empty.
    """

    rows = np.asarray(rows, dtype=np.int64)
    cols = np.asarray(cols, dtype=np.int64)
    Eraw = np.asarray(energies, dtype=np.float32)

    # If the shower has no hits at all, we cannot build a graph. Return None so the caller can handle this case.
    if Eraw.size == 0:
        return None

    # Same border proxy for g
    near = infer_isNearBorder_from_hits(rows, cols, Eraw)

    # Same graph-construction function
    data = shower_to_graph(
        rows, cols, Eraw,
        showerE=float(showerE_scalar),
        numBlocks=int(numBlocks_scalar),
        isNearBorder=near
    )

    # Defensive check (no graph, returns None)
    if data is None:
        return None

    # IMPORTANT: for your model, data.g must be 2D: (1,3)
    data.g = torch.tensor([[float(showerE_scalar), float(numBlocks_scalar), float(near)]], dtype=torch.float32)

    return data

# -------------------------
# GNN inference: one shower (this can be parallelized)
# -------------------------
@torch.inference_mode() # Disable gradient tracking (faster + less memory)
def gnn_photon_prob_for_shower(model, data: Data, device: str):
    """
    Returns P(photon) for a single shower graph.
    """
    # ------------------------------------------------------------
    # If graph construction failed (e.g. empty shower),
    # return probability 0 so the analysis can continue safely.
    # ------------------------------------------------------------
    if data is None:
        return 0.0

    # ------------------------------------------------------------
    # Move the entire graph object to the chosen device.
    #   data.x           node features
    #   data.edge_index  graph connectivity
    #   data.g           graph-level features
    # ------------------------------------------------------------
    data = data.to(device)

    # ------------------------------------------------------------
    # PyTorch Geometric expects a "batch vector" telling which
    # node belongs to which graph when multiple graphs are processed together.
    #
    # Example batch vector:
    #   [0,0,0,1,1,1]
    #
    #   nodes 0–2 belong to graph 0
    #   nodes 3–5 belong to graph 1
    #
    # In this function we only process ONE graph, so every node
    # belongs to graph 0. If a batch vector is missing, we create it.
    # ------------------------------------------------------------
    if not hasattr(data, "batch") or data.batch is None:

        # Create a tensor of zeros with length the number of nodes in the graph.
        # Each node is assigned to graph index 0.
        data.batch = torch.zeros(data.x.size(0), dtype=torch.long, device=device)

    # ------------------------------------------------------------
    # Run the GNN forward pass. The model outputs logits for each graph in the batch.
    #
    # Shape: (B, 2) = (number of graphs in the batch, number of classes)
    #
    # class 0 → split-off (background); class 1 → photon (signal)
    # ------------------------------------------------------------
    logits = model(data)  # (1,2) or (B,2)

    # ------------------------------------------------------------
    # Convert logits to probabilities using softmax.
    #
    # softmax(logits, dim=1) normalizes scores across the class dimension so they sum to 1.
    #
    # Since we only have one graph in this function: B = 1  → logits shape = (1,2)
    #
    # Therefore:[0,1] selects graph index 0 and class index 1 (photon)
    #
    # .item() extracts the scalar value from the tensor.
    # ------------------------------------------------------------

    prob_photon = torch.softmax(logits, dim=1)[0, 1].item()

    return float(prob_photon)

# ------------------------------------------------------------
# Event-level ω reconstruction
#
# For one event:
#   1) score each neutral shower with the trained GNN
#   2) build all possible shower pairs as π0 -> γγ candidates
#   3) combine each π0 with π+ and π− to form an ω candidate
#   4) compare three selection strategies:
#        - bench : use all shower pairs
#        - rect  : keep only pairs near the nominal π0 mass
#        - gnn   : keep only pairs where both showers look photon-like
#
# Returns:
#   results : dict of lists of tuples (m_pi0, m_omega, weight)
#   probs   : list of per-shower photon probabilities from the GNN
# ------------------------------------------------------------
def analyze_event_omega_gnn(ev, model, device, thr_photon=0.2, rect_mass_window=0.04):
    """
    Modes:
      - gnn   : keep showers with P(photon) >= thr_photon
      - rect  : select pairs with |m(pi0)-m(pi0)_PDG| < rect_mass_window
      - bench : all pairs

    Returns
    -------
    results : dict
        Keys: "gnn", "rect", "bench"
        Each entry is a list of tuples:
            (m_pi0, m_omega, weight)

    probs : list
        Per-shower photon probabilities from the GNN
    """

    # Initialize containers for the three reconstruction strategies.
    # Each list will collect tuples (m_pi0, m_omega, weight).
    results = {"gnn": [], "rect": [], "bench": []}

    # ------------------------------------------------------------
    # Charged pion four-vectors
    #
    # NOTE:
    # Here the code is using thrown π+ and π−, not reconstructed ones.
    # We pretend to use a kinematic fit that provides 4-momenta close to thrown values.
    # That means the ω reconstruction is "hybrid":
    #   - neutral part from reconstructed FCAL showers
    #   - charged part from truth-level pions
    #
    # If you want a fully reconstructed analysis, you would use:
    # ev["reconPiPlus"], ev["reconPiMinus"] instead.
    # ------------------------------------------------------------
    p4_pi_plus  = _tlv_to_p4(ev["thrownPiPlus"])
    p4_pi_minus = _tlv_to_p4(ev["thrownPiMinus"])

    # ------------------------------------------------------------
    # Neutral shower information for this event
    #
    # shower_p4       : reconstructed 4-vectors of neutral FCAL showers
    # shower_hits     : per-shower hit lists (rows, cols, energies)
    # showerE_vec     : one shower-energy value per shower
    # numBlocks_vec   : one number-of-blocks value per shower
    # ------------------------------------------------------------
    shower_p4 = ev["shower_p4"]          # shape: (nShowers,4)
    shower_hits = ev["shower_hits"]      # list of tuples: (rows, cols, energies)
    showerE_vec = ev["showerE"]          # shape: (nShowers,)
    numBlocks_vec = ev["numBlocks"]      # shape: (nShowers,)

    nShowers = int(ev["nShowers"])
    if nShowers < 2:
        return results


    # ------------------------------------------------------------
    # Compute one GNN photon score per shower
    #
    # Each shower is converted into a graph, then passed through the GNN to obtain P(photon).
    # ------------------------------------------------------------
    probs = []

    for s in range(nShowers):
        # Unpack the hit-level information of shower s
        rows_s, cols_s, en_s = shower_hits[s]

        # Build one graph for this shower, matching the training setup
        data = omega_shower_to_graph(rows_s, cols_s, en_s, showerE_vec[s], numBlocks_vec[s])

        # Run GNN inference on this single shower graph
        p = gnn_photon_prob_for_shower(model, data, device)

        # Store the photon probability
        probs.append(p)

    # Keep the indices of showers whose photon probability exceeds threshold
    good = [i for i, p in enumerate(probs) if p >= thr_photon]


    # ------------------------------------------------------------
    # Build all unique shower pairs
    #
    # These are all possible π0 -> γγ candidates.
    # If there are nShowers showers, combinations(range(nShowers), 2)
    # gives all pairs (i, j) with i < j.
    # ------------------------------------------------------------
    pairs = list(combinations(range(nShowers), 2))

    # Defensive check
    if not pairs:
        return results

    # ============================================================
    # 1) Benchmark mode: use ALL shower pairs
    # ============================================================
    for (i, j) in pairs:
        p4_pi0   = combine_p4(shower_p4[i], shower_p4[j])
        m_pi0    = float(fourvec_mass(p4_pi0))
        p4_omega = combine_p4(p4_pi0, combine_p4(p4_pi_plus, p4_pi_minus))
        m_omega  = float(fourvec_mass(p4_omega))
        results["bench"].append((m_pi0, m_omega, 1.0)) # show the raw combinatorial bkgd

    # ============================================================
    # 2) Rectangular π0 mass cut
    #
    # Keep only shower pairs whose reconstructed π0 mass is within
    # a fixed window around the nominal π0 mass.
    # ============================================================
    rect_pairs = []
    for (i, j) in pairs:
        m_pi0 = float(fourvec_mass(combine_p4(shower_p4[i], shower_p4[j])))
        if abs(m_pi0 - PI0_MASS) < rect_mass_window:
            rect_pairs.append((i, j))

    if rect_pairs:
        Nrect = len(rect_pairs)
        for (i, j) in rect_pairs:
            p4_pi0   = combine_p4(shower_p4[i], shower_p4[j])
            m_pi0    = float(fourvec_mass(p4_pi0))
            p4_omega = combine_p4(p4_pi0, combine_p4(p4_pi_plus, p4_pi_minus))
            m_omega  = float(fourvec_mass(p4_omega))
            results["rect"].append((m_pi0, m_omega, 1.0 / Nrect))

    # ============================================================
    # 3) GNN selection
    #
    # Keep only pairs where BOTH showers pass the photon threshold.
    # ============================================================
    gnn_pairs = [p for p in pairs if (p[0] in good and p[1] in good)]
    if gnn_pairs:
        N = len(gnn_pairs)
        for (i, j) in gnn_pairs:
            p4_pi0   = combine_p4(shower_p4[i], shower_p4[j])
            m_pi0    = float(fourvec_mass(p4_pi0))
            p4_omega = combine_p4(p4_pi0, combine_p4(p4_pi_plus, p4_pi_minus))
            m_omega  = float(fourvec_mass(p4_omega))
            results["gnn"].append((m_pi0, m_omega, 1.0 / N))

    return results, probs

Load \(\omega\) tree and Run the analysis

# -------------------------
# Load ω tree + run analysis
# -------------------------

# Open the ROOT file using uproot and access the tree named "OmegaSampleTree".
# uproot.open(...) returns a file-like object; indexing with tree name retrieves TTree structure.
omega_tree = uproot.open(omega_dataset_path)["OmegaSampleTree"]


# ------------------------------------------------------------------------------
# Select branches we want to read from tree. uproot will only load these variables.
# NOTE:
# TLorentzVector branches are read as awkward "record" objects rather than simple arrays.
# ------------------------------------------------------------------------------
branches = [
    "thrownPiPlus", "thrownPiMinus",     # true generated π⁺ and π⁻ 4-vectors
    "reconPiPlus", "reconPiMinus",       # reconstructed π⁺ and π⁻ 4-vectors
    "nShowers",                          # number of neutral showers in FCAL
    "rows", "cols", "energies",          # hit-level information (flattened arrays)
    "startIndexPerShower", "nHitsPerShower",  # indices describing shower segmentation
    "showerE", "numBlocks",                   # shower-level features
    "FCALNeutralShowerPx", "FCALNeutralShowerPy",
    "FCALNeutralShowerPz", "FCALNeutralShowerE",  # reconstructed shower 4-vectors
]

# ------------------------------------------------------------
# Read the selected branches into an awkward array
# ------------------------------------------------------------
arr = omega_tree.arrays(branches, library="ak")

Nevt = len(arr["nShowers"])
print("✅ Loaded ω events:", Nevt)

# ------------------------------------------------------------
# Move the trained model to the chosen device and switch it to evaluation mode.
# .eval() disables training-specific behaviors such as:
#   - dropout randomness
#   - batch normalization statistics updates
# ------------------------------------------------------------
model = model.to(DEVICE).eval()

# Choose a *reasonable* default threshold for event-level selection.
# IMPORTANT: For event-level ω, thr=0.5 may be far too strict; start with 0.1–0.2 and scan later.

# ------------------------------------------------------------
# Selection parameters used later during event reconstruction
# ------------------------------------------------------------
# Photon classification threshold from the GNN.
# If P(photon) >= THR_PHOTON the shower is considered photon-like.
# NOTE: 0.5 is often too strict for event reconstruction, because losing photons destroys the π⁰ reconstruction.
THR_PHOTON   = 0.15

# Mass window used for the rectangular π⁰ selection
PI0_MASS_CUT = 0.04
#------------------------------------------------------------

# Each method stores tuples like: (m_pi0, m_omega, weight)
all_results = {"bench": [], "rect": [], "gnn": []}

# Store some photon scores for diagnostic plots later.
score_samples = []

# ------------------------------------------------------------
# Main event loop
# ------------------------------------------------------------
frac_omega = 0.10
Nfrac_omega = int(frac_omega * Nevt)

print(f"...using {frac_omega:.2f} fraction of ω-events ({Nfrac_omega:d} of {Nevt:d})")
for ievt in tqdm(range(Nfrac_omega), desc="Running ω-event inference (unformatted ROOT)"):

    nShow = int(arr["nShowers"][ievt])
    if nShow < 2:
        continue

    # ------------------------------------------------------------
    # Extract hit-level arrays for the whole event.
    #
    # These arrays contain hits from all showers concatenated
    # together in one flattened array.
    # ------------------------------------------------------------
    rows_evt = ak.to_numpy(arr["rows"][ievt]).astype(np.int64, copy=False)
    cols_evt = ak.to_numpy(arr["cols"][ievt]).astype(np.int64, copy=False)
    en_evt   = ak.to_numpy(arr["energies"][ievt]).astype(np.float32, copy=False)

    # ------------------------------------------------------------
    # startIndexPerShower and nHitsPerShower allow us to slice
    # the flattened hit arrays back into individual showers.
    # ------------------------------------------------------------
    start = ak.to_numpy(arr["startIndexPerShower"][ievt]).astype(np.int64, copy=False)
    nhits = ak.to_numpy(arr["nHitsPerShower"][ievt]).astype(np.int64, copy=False)

    # ------------------------------------------------------------
    # Shower-level features (one value per shower)
    # ------------------------------------------------------------
    showerE_vec   = ak.to_numpy(arr["showerE"][ievt]).astype(np.float32, copy=False)
    numBlocks_vec = ak.to_numpy(arr["numBlocks"][ievt]).astype(np.int32, copy=False)

    # ------------------------------------------------------------
    # Extract reconstructed shower 4-vectors directly from branches
    #
    # These arrays have length = nShowers
    # ------------------------------------------------------------
    px = ak.to_numpy(arr["FCALNeutralShowerPx"][ievt]).astype(np.float32, copy=False)
    py = ak.to_numpy(arr["FCALNeutralShowerPy"][ievt]).astype(np.float32, copy=False)
    pz = ak.to_numpy(arr["FCALNeutralShowerPz"][ievt]).astype(np.float32, copy=False)
    E  = ak.to_numpy(arr["FCALNeutralShowerE"][ievt]).astype(np.float32, copy=False)
    shower_p4 = np.stack([px, py, pz, E], axis=1)  # (nShowers,4)

    # ------------------------------------------------------------
    # Reconstruct per-shower hit lists
    #
    # Each shower corresponds to a slice of the flattened hit arrays.
    # ------------------------------------------------------------
    shower_hits = []
    for s in range(nShow):
        a = int(start[s])
        b = a + int(nhits[s])
        shower_hits.append((rows_evt[a:b], cols_evt[a:b], en_evt[a:b]))

    # ------------------------------------------------------------
    # Build an event dictionary to pass to the reconstruction code
    # ------------------------------------------------------------
    ev = dict(
        thrownPiPlus=arr["thrownPiPlus"][ievt],
        thrownPiMinus=arr["thrownPiMinus"][ievt],
        reconPiPlus=arr["reconPiPlus"][ievt],
        reconPiMinus=arr["reconPiMinus"][ievt],
        nShowers=nShow,
        shower_p4=shower_p4,
        shower_hits=shower_hits,
        showerE=showerE_vec,
        numBlocks=numBlocks_vec,
    )

    # ------------------------------------------------------------
    # Run the ω reconstruction algorithm
    #
    # This will:
    #   1) classify showers with the GNN
    #   2) build π⁰ candidates
    #   3) reconstruct ω → π⁰ π⁺ π⁻
    # ------------------------------------------------------------
    out, probs = analyze_event_omega_gnn(ev, model, DEVICE, thr_photon=THR_PHOTON, rect_mass_window=PI0_MASS_CUT)

    # ------------------------------------------------------------
    # Store reconstruction results
    # ------------------------------------------------------------
    for k in all_results:
        all_results[k].extend(out[k])

    # ------------------------------------------------------------
    # Save a subset of photon scores for diagnostic plots
    # ------------------------------------------------------------
    if len(score_samples) < 5000: # only keep a limited sample for a quick diagnostic histogram
        score_samples.extend(probs)

#-------------------------------------------------------------------------------

# -------------------------
# Print raw selection stats
# -------------------------
# Remember (m_pi0, m_omega, weight); so x[2] are weights
def _sumw(lst):
    return float(np.sum([x[2] for x in lst])) if len(lst) else 0.0

print("\nRaw entry counts:")
for k in ["bench", "rect", "gnn"]:
    print(f"  {k:5s}: N={len(all_results[k]):8d}  sum(weights)={_sumw(all_results[k]):.2f}")

# -------------------------
# Plot mass spectra overlays
# -------------------------
def plot_mass_spectrum(all_results, key="pi0", nbins=200, mass_range=(0.05, 0.30)):
    plt.figure(figsize=(7,5))
    for label in ["bench", "rect", "gnn"]:
        masses  = [m[0] if key=="pi0" else m[1] for m in all_results[label]]
        weights = [m[2] for m in all_results[label]]
        plt.hist(masses, bins=nbins, range=mass_range, histtype="step",
                 linewidth=1.8, label=label, weights=weights)
    plt.xlabel(f"{key} mass [GeV]")
    plt.ylabel("Weighted counts")
    plt.title(f"{key} invariant mass")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

plot_mass_spectrum(all_results, key="pi0",   nbins=200, mass_range=(0.05, 0.30))
plot_mass_spectrum(all_results, key="omega", nbins=200, mass_range=(0.60, 1.00))

# -------------------------
# Diagnostic: score spectrum
# -------------------------
score_samples = np.asarray(score_samples, dtype=np.float32)
if score_samples.size:
    print("\nGNN photon-score statistics (sampled showers):")
    print("  mean:", float(score_samples.mean()))
    print("  std :", float(score_samples.std()))
    print("  min :", float(score_samples.min()))
    print("  max :", float(score_samples.max()))
    plt.figure(figsize=(7,4))
    plt.hist(score_samples, bins=60, range=(0,1), histtype="step")
    plt.xlabel("GNN photon probability")
    plt.ylabel("Counts")
    plt.title("Distribution of GNN photon scores (ω showers, sampled)")
    plt.grid(True, alpha=0.3)
    plt.show()

#-------------------------------------------------------------------------------

Note: The GNN-based photon identification achieves a performance comparable to (or slightly better than) the traditional π⁰ mass window selection. Importantly, the classifier relies solely on calorimeter shower topology and does not use the π⁰ invariant mass as an input. When applied to ω → π⁺π⁻π⁰ events, the GNN selection reproduces the expected π⁰ and ω mass peaks while significantly reducing combinatorial background relative to the inclusive benchmark

# --- Improved ω mass fits with proper weighted uncertainties + better background ---

%pip -q install scipy > /dev/null

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.special import wofz
from scipy.stats import chi2

# ------------------------
# Signal: Voigt (BW ⊗ Gauss)
# ------------------------
def voigt_pdf(x, A, m0, gamma, sigma):
    """
    A = overall amplitude (not exactly yield unless integrated)
    m0 = peak position
    gamma = Lorentzian FWHM-ish parameter used as gamma in Voigt definition (GeV)
    sigma = Gaussian sigma (GeV)
    """
    z = ((x - m0) + 1j * gamma / 2.0) / (sigma * np.sqrt(2.0))
    return A * np.real(wofz(z)) / (sigma * np.sqrt(2.0 * np.pi))

# ------------------------
# Background options
# ------------------------
def bkg_poly2(x, c0, c1, c2):
    return c0 + c1*x + c2*(x**2)

def bkg_exp(x, c0, c1, c2):
    # positive-ish exponential + offset: c0 + c1 * exp(c2 x)
    return c0 + c1 * np.exp(c2 * x)

# Combined models
def model_voigt_poly2(x, A, m0, gamma, sigma, c0, c1, c2):
    return voigt_pdf(x, A, m0, gamma, sigma) + bkg_poly2(x, c0, c1, c2)

def model_voigt_exp(x, A, m0, gamma, sigma, c0, c1, c2):
    return voigt_pdf(x, A, m0, gamma, sigma) + bkg_exp(x, c0, c1, c2)

# ------------------------
# Histogram w/ proper weighted uncertainties
# ------------------------
def hist_w_sumw2(masses, weights, nbins, fit_range):
    masses = np.asarray(masses, dtype=np.float64)
    weights = np.ones_like(masses) if weights is None else np.asarray(weights, dtype=np.float64)

    hist, bins = np.histogram(masses, bins=nbins, range=fit_range, weights=weights)
    sumw2, _   = np.histogram(masses, bins=nbins, range=fit_range, weights=weights**2)
    centers = 0.5*(bins[1:] + bins[:-1])

    # sigma for weighted hist
    sigma = np.sqrt(np.maximum(sumw2, 1e-12))
    return hist, sigma, centers, bins

# ------------------------
# Fit driver
# ------------------------
def fit_omega_spectrum(masses, weights=None, label="sel",
                       nbins=120, fit_range=(0.72, 0.84),
                       bkg="poly2", min_rel_err=1e-6):
    """
    bkg: "poly2" or "exp"
    fit_range: recommend (0.72, 0.84) first; widen later if needed
    """

    y, yerr, x, bins = hist_w_sumw2(masses, weights, nbins=nbins, fit_range=fit_range)

    # Drop bins with ~0 uncertainty (can happen if no entries)
    mask = yerr > (min_rel_err * np.max(yerr))
    xfit, yfit, sfit = x[mask], y[mask], yerr[mask]

    if xfit.size < 12:
        raise RuntimeError(f"[{label}] Too few populated bins to fit. Try fewer bins or wider range.")

    # Initial guesses (robust-ish)
    A0 = float(np.max(yfit))
    m00 = 0.782
    gamma0 = 0.0085   # ~ ω natural width ~8.5 MeV
    sigma0 = 0.012    # detector smearing guess (12 MeV)

    # background init
    c00 = float(np.median(yfit))
    c10 = 0.0
    c20 = 0.0

    if bkg == "poly2":
        f = model_voigt_poly2
        p0 = [A0, m00, gamma0, sigma0, c00, c10, c20]
        bounds = (
            [0.0,   0.76,  0.001, 0.002, -np.inf, -np.inf, -np.inf],
            [np.inf,0.80,  0.050, 0.050,  np.inf,  np.inf,  np.inf],
        )
    elif bkg == "exp":
        f = model_voigt_exp
        # exp params: c0 + c1 exp(c2 x)
        p0 = [A0, m00, gamma0, sigma0, c00, max(1.0, 0.1*c00), -5.0]
        bounds = (
            [0.0,   0.76,  0.001, 0.002, -np.inf, -np.inf, -50.0],
            [np.inf,0.80,  0.050, 0.050,  np.inf,  np.inf,  50.0],
        )
    else:
        raise ValueError("bkg must be 'poly2' or 'exp'")

    popt, pcov = curve_fit(
        f, xfit, yfit, p0=p0, bounds=bounds,
        sigma=sfit, absolute_sigma=True, maxfev=40000
    )
    perr = np.sqrt(np.diag(pcov))

    # Goodness-of-fit (using fit bins only)
    ymodel = f(xfit, *popt)
    chi2_val = np.sum(((yfit - ymodel)/sfit)**2)
    ndf = xfit.size - len(popt)
    pval = chi2.sf(chi2_val, max(ndf, 1))

    # Yield = integral of signal component over fit range
    xfine = np.linspace(fit_range[0], fit_range[1], 4000)
    A, m0, gamma, sigma_g = popt[:4]
    sig = voigt_pdf(xfine, A, m0, gamma, sigma_g)
    yield_sig = np.trapezoid(sig, xfine)

    return {
        "label": label,
        "bkg": bkg,
        "fit_range": fit_range,
        "nbins": nbins,
        "popt": popt,
        "perr": perr,
        "pcov": pcov,
        "chi2": float(chi2_val),
        "ndf": int(ndf),
        "pval": float(pval),
        "yield": float(yield_sig),
        "hist": y,
        "err": yerr,
        "centers": x,
        "bins": bins,
        "model_fn": f,
    }

def plot_fit(result, alpha_data=0.25):
    x = result["centers"]
    y = result["hist"]
    e = result["err"]
    f = result["model_fn"]
    popt = result["popt"]
    bw = result["bins"][1] - result["bins"][0]

    # show data with uncertainties (weighted)
    plt.bar(x, y, width=bw, alpha=alpha_data, label=f"{result['label']} data")
    plt.errorbar(x, y, yerr=e, fmt="none", capsize=0, alpha=0.6)

    # smooth fit curve
    xx = np.linspace(result["fit_range"][0], result["fit_range"][1], 2000)
    plt.plot(xx, f(xx, *popt), lw=2, label=f"{result['label']} fit ({result['bkg']})")

# ------------------------
# Run on your selections: all_results must exist with keys "gnn","rect","bench"
# ------------------------
range_fit = (0.72, 0.84)   # <-- start tight; if stable, expand later
nbins = 120
bkg_model = "poly2"        # try "exp" if poly2 still struggles

m_gnn  = [t[1] for t in all_results["gnn"]]
w_gnn  = [t[2] for t in all_results["gnn"]]
m_rect = [t[1] for t in all_results["rect"]]
w_rect = [t[2] for t in all_results["rect"]]

fit_gnn  = fit_omega_spectrum(m_gnn,  w_gnn,  label="GNN",        nbins=nbins, fit_range=range_fit, bkg=bkg_model)
fit_rect = fit_omega_spectrum(m_rect, w_rect, label="Rectangular",nbins=nbins, fit_range=range_fit, bkg=bkg_model)

plt.figure(figsize=(9,5))
plot_fit(fit_rect)
plot_fit(fit_gnn)
plt.xlabel(r"$M_{\pi^+\pi^-\pi^0}$ [GeV]")
plt.ylabel("Weighted counts")
plt.title(r"$\omega(782)$ fits with weighted uncertainties (Voigt + improved background)")
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

def print_summary(res):
    A,m0,gamma,sigma = res["popt"][:4]
    dA,dm0,dg,ds = res["perr"][:4]
    print(f"\n[{res['label']}]  model=Voigt+{res['bkg']}  range={res['fit_range']}  nbins={res['nbins']}")
    print(f"  m0        = {m0:.6f} ± {dm0:.6f} GeV")
    print(f"  gamma     = {gamma*1000:.2f} ± {dg*1000:.2f} MeV")
    print(f"  sigma     = {sigma*1000:.2f} ± {ds*1000:.2f} MeV")
    print(f"  Yield     = {res['yield']:.1f} (a.u., integral of signal in fit window)")
    print(f"  chi2/ndf  = {res['chi2']:.1f} / {res['ndf']} = {res['chi2']/max(1,res['ndf']):.3f}")
    print(f"  p-value   = {res['pval']:.3g}")

print_summary(fit_rect)
print_summary(fit_gnn)

Meaning of the Voigt Fit Parameters#

The ω mass peak is modeled using a Voigt function, which is the convolution of a Breit–Wigner distribution (intrinsic particle width) and a Gaussian distribution (detector resolution).

  • Γ (Gamma)Intrinsic decay width of the particle.
    This is a physical property of the resonance related to its lifetime through
    (\Gamma = \hbar / \tau).
    For the ω meson, the natural width is about 8.5 MeV.

  • σ (Sigma)Detector resolution.
    This represents the Gaussian smearing introduced by the detector and reconstruction, such as calorimeter energy resolution, tracking resolution, and π⁰ reconstruction uncertainties.

Interpretation of p-values in χ² goodness-of-fit tests#

If the model is correct and the uncertainties are correct, what is the probability of observing a χ² at least this large just from statistical fluctuations?

\[ p = p(\chi^2 \leq \chi^{2}_{obs}| \text{model true}) \]

p-value range

Interpretation

p < 0.001

Model very unlikely to describe the data (fit clearly poor)

0.001 – 0.01

Suspicious: model may be inadequate or uncertainties underestimated

0.01 – 0.05

Borderline fit

0.05 – 0.95

Acceptable / good fit

p > 0.95

Residuals smaller than expected — uncertainties may be overestimated