Module qute.metrics
Custom metrics.
Classes
class CombinedInvMeanAbsoluteErrorBinaryDiceMetric (alpha: float = 0.5, max_mae_value: float = 1.0, regression_channel: int = 0, classification_channel: int = 1, foreground_class: int = 1, with_batch_dim: bool = True, dist_sync_on_step=False)
-
Combined Inverse Exponential Mean Absolute Error and Dice Metric to handle the output of qute.transforms.objects.WatershedAndLabelTransform(). The input prediction and ground truth are expected to have one regression and one binary classification channel (e.g., inverse distance transform and seed points). Supported dimensionality * [B, C, H, W] with
with_bach_dim
= True or [C, H, W] withwith_bach_dim
= False for 2D * [B, C, D, H, W] withwith_bach_dim
= True or [C, D, H, W] withwith_bach_dim
= False for 3DThe Inverse Exponential Mean Absolute Error is computed as: ie_mae = torch.exp(-self.beta * mae(output, target))
The Dice Metric is the one implemented in
monai.metrics.DiceMetric
:Constructor.
num_classes: int = 2 Number of classes for the Dice Metric calculation.
alpha: float Fraction of the MeanAbsoluteError() to be combined with the corresponding (1 - alpha) fraction of the DiceMetric.
max_mae_value: float Maximum possible value for normalizing the MAE. This should be chosen based on the range of the regression target.
regression_channel: int = 0 Regression channel (e.g., inverse distance transform), on which to apply the Mean Absolute Error metric.
classification_channel: int = 1 Classification channel (e.g., watershed seeds), on which to apply the Dice metric.
foreground_class: int = 1 Class corresponding to the foreground in the classification (usually, background is 0 and foreground is 1).
with_batch_dim: bool (Optional, default is True) Whether the input tensor has a batch dimension or not. This is to distinguish between the 2D case (B, C, H, W) and the 3D case (C, D, H, W). All other supported cases are clear.
dist_sync_on_step: bool Whether the synchronization of metric states across all processes (nodes/GPUs) should occur after each training step (if True) or at the end of the epoch (if False). It can be left on False in most cases.
Expand source code
class CombinedInvMeanAbsoluteErrorBinaryDiceMetric(torchmetrics.Metric, ABC): """ Combined Inverse Exponential Mean Absolute Error and Dice Metric to handle the output of qute.transforms.objects.WatershedAndLabelTransform(). The input prediction and ground truth are expected to have one regression and one binary classification channel (e.g., inverse distance transform and seed points). Supported dimensionality * [B, C, H, W] with `with_bach_dim` = True or [C, H, W] with `with_bach_dim` = False for 2D * [B, C, D, H, W] with `with_bach_dim` = True or [C, D, H, W] with `with_bach_dim` = False for 3D The Inverse Exponential Mean Absolute Error is computed as: ie_mae = torch.exp(-self.beta * mae(output, target)) The Dice Metric is the one implemented in `monai.metrics.DiceMetric`: """ def __init__( self, alpha: float = 0.5, max_mae_value: float = 1.0, regression_channel: int = 0, classification_channel: int = 1, foreground_class: int = 1, with_batch_dim: bool = True, dist_sync_on_step=False, ): """Constructor. num_classes: int = 2 Number of classes for the Dice Metric calculation. alpha: float Fraction of the MeanAbsoluteError() to be combined with the corresponding (1 - alpha) fraction of the DiceMetric. max_mae_value: float Maximum possible value for normalizing the MAE. This should be chosen based on the range of the regression target. regression_channel: int = 0 Regression channel (e.g., inverse distance transform), on which to apply the Mean Absolute Error metric. classification_channel: int = 1 Classification channel (e.g., watershed seeds), on which to apply the Dice metric. foreground_class: int = 1 Class corresponding to the foreground in the classification (usually, background is 0 and foreground is 1). with_batch_dim: bool (Optional, default is True) Whether the input tensor has a batch dimension or not. This is to distinguish between the 2D case (B, C, H, W) and the 3D case (C, D, H, W). All other supported cases are clear. dist_sync_on_step: bool Whether the synchronization of metric states across all processes (nodes/GPUs) should occur after each training step (if True) or at the end of the epoch (if False). It can be left on False in most cases. """ super().__init__(dist_sync_on_step=dist_sync_on_step) if alpha < 0.0 or alpha > 1.0: raise ValueError("alpha must be between 0.0 and 1.0") if max_mae_value <= 0.0: raise ValueError("max_mae_value must a positive number (larger than zero).") self.alpha = alpha self.max_mae_value = max_mae_value self.mae_metric = MeanAbsoluteError() self.dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False ) self.regression_channel = regression_channel self.classification_channel = classification_channel self.foreground_class = foreground_class self.with_batch_dim = with_batch_dim self.add_state("total_metric", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("num_updates", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, output, target): """Update the state of the metric with new predictions and targets.""" if len(output.shape) not in [3, 4, 5]: raise ValueError("Unsupported geometry.") # Do we have a 2D or 3D tensor (excluding batch and channel dimensions)? effective_dims = get_tensor_num_spatial_dims(output, self.with_batch_dim) if effective_dims not in [2, 3]: raise ValueError("Unsupported geometry.") # For simplicity, let's make sure the input tensors have consistent dimensions if effective_dims == 2: if self.with_batch_dim: if len(output.shape) == 4: # [B, C, W, H] -> [B, C, D, W, H] output = output.unsqueeze(2) target = target.unsqueeze(2) else: raise ValueError("Unsupported geometry.") else: if len(output.shape) == 3: # [C, W, H] -> [B, C, D, W, H] output = output.unsqueeze(1).unsqueeze(0) target = target.unsqueeze(1).unsqueeze(0) else: raise ValueError("Unsupported geometry.") elif effective_dims == 3: if self.with_batch_dim: if len(output.shape) == 5: # Already [B, C, D, W, H] pass else: raise ValueError("Unsupported geometry.") else: if len(output.shape) == 4: # [C, D, W, H] -> [B, C, D, W, H] output = output.unsqueeze(0) target = target.unsqueeze(0) else: # Already [B, C, D, W, H] pass else: raise ValueError("Unsupported geometry.") # Calculate the MAE metric mae_value = self.mae_metric( output[:, self.regression_channel, ...].unsqueeze(1), target[:, self.regression_channel, ...].unsqueeze(1), ) # Normalize and invert MAE inv_norm_mae = 1 - (mae_value / self.max_mae_value) # Calculate the DICE metric (ignore the background) dice_metric = self.dice_metric( self._as_discrete(output[:, self.classification_channel, ...].unsqueeze(1)), self._as_discrete(target[:, self.classification_channel, ...].unsqueeze(1)), ) dice_metric = dice_metric[:, self.foreground_class].mean() # Combine them linearly num_updates = 1 combined_metric = self.alpha * inv_norm_mae + (1 - self.alpha) * dice_metric # Accumulate the metric self.total_metric += combined_metric self.num_updates += num_updates # Return the combined metric return combined_metric def forward(self, output, target): """Update the state of the metric with new predictions and targets.""" # Update the metrics self.update(output, target) # Return the computed value directly return self.compute() def compute(self): """Compute the final metric based on the state.""" if self.num_updates == 0: return torch.tensor(0.0) return self.total_metric / self.num_updates def aggregate(self): """Aggregate the metrics.""" return self.compute() @staticmethod def _as_discrete(logits): """Convert logits to classes and then convert to one-hot format.""" if logits.dim() != 5: raise ValueError("Unsupported geometry.") # Apply sigmoid to convert logits to probabilities probabilities = torch.sigmoid(logits) # Apply a threshold to convert probabilities to binary class indices threshold = 0.5 class_indices = (probabilities > threshold).long() # Apply one-hot encoding for binary classification one_hot = torch.nn.functional.one_hot(class_indices, num_classes=2) # Reshape the one-hot tensor to bring the channel dimension in the right position one_hot = one_hot.permute(0, 5, 2, 3, 4, 1).squeeze(-1).float() return one_hot
Ancestors
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Methods
def aggregate(self)
-
Aggregate the metrics.
def compute(self)
-
Compute the final metric based on the state.
def forward(self, output, target) ‑> Callable[..., Any]
-
Update the state of the metric with new predictions and targets.
def update(self, output, target)
-
Update the state of the metric with new predictions and targets.