Source code for autods_pet.ops.masks

"""Mask operations for multilabel segmentations."""

from __future__ import annotations

import SimpleITK as sitk


[docs] def label_mask(seg: sitk.Image, label: int) -> sitk.Image: """Extract a binary mask for a single label from a multilabel segmentation. Parameters ---------- seg : sitk.Image Multilabel segmentation image (integer-valued). label : int The label value to extract. Returns ------- sitk.Image Binary ``uint8`` mask where the selected label is 1 and all else is 0. """ return sitk.Cast(sitk.Equal(seg, int(label)), sitk.sitkUInt8)
[docs] def label_union(seg: sitk.Image, labels: list[int]) -> sitk.Image: """Binary mask that is the union of several labels. Parameters ---------- seg : sitk.Image Multilabel segmentation image (integer-valued). labels : list[int] Label values to include. An empty list returns an all-zero mask. Returns ------- sitk.Image Binary ``uint8`` mask where any of the selected labels is 1. """ if not labels: empty = sitk.Image(seg.GetSize(), sitk.sitkUInt8) empty.CopyInformation(seg) return empty mask = None for lab in labels: m = label_mask(seg, lab) mask = m if mask is None else sitk.Or(mask, m) return sitk.Cast(mask, sitk.sitkUInt8)
[docs] def keep_largest_component(mask: sitk.Image) -> sitk.Image: """Keep only the largest connected component of a binary mask. Parameters ---------- mask : sitk.Image Binary mask (non-zero values are foreground). Returns ------- sitk.Image Binary ``uint8`` mask containing only the largest connected component. If the mask is empty, returns the input unchanged. """ mask = sitk.Cast(mask != 0, sitk.sitkUInt8) cc = sitk.ConnectedComponent(mask) stats = sitk.LabelShapeStatisticsImageFilter() stats.Execute(cc) if stats.GetNumberOfLabels() == 0: return mask best = max(stats.GetLabels(), key=lambda L: stats.GetNumberOfPixels(L)) return sitk.Cast(sitk.Equal(cc, best), sitk.sitkUInt8)
[docs] def fill_holes( mask: sitk.Image, max_hole_volume_mm3: float | None = None, ) -> sitk.Image: """Fill holes in a binary mask. Parameters ---------- mask : sitk.Image Binary mask. max_hole_volume_mm3 : float or None If *None*, fills all holes (original behavior). If set, only fills holes whose volume is below this threshold (mm³). Large holes (e.g. portal vein in liver) are preserved. Returns ------- sitk.Image Binary ``uint8`` mask with holes filled. """ mask = sitk.Cast(mask != 0, sitk.sitkUInt8) if max_hole_volume_mm3 is None: filled = sitk.BinaryFillhole(mask, fullyConnected=True) return sitk.Cast(filled, sitk.sitkUInt8) # Identify holes: invert mask, then label connected components. # The background is the component touching the image border; everything # else is a hole inside the mask. inverted = sitk.Cast(mask == 0, sitk.sitkUInt8) cc = sitk.ConnectedComponent(inverted, True) stats = sitk.LabelShapeStatisticsImageFilter() stats.Execute(cc) if stats.GetNumberOfLabels() == 0: return mask # Find the background label: the component whose bounding box touches # the image border. If none touches the border, there is no exterior # (mask doesn't reach the image edge) - all components are holes. cc_arr = sitk.GetArrayFromImage(cc) border_labels: set[int] = set() for slc in ( cc_arr[0, :, :], cc_arr[-1, :, :], cc_arr[:, 0, :], cc_arr[:, -1, :], cc_arr[:, :, 0], cc_arr[:, :, -1], ): border_labels.update(slc.ravel()) border_labels.discard(0) # 0 is the mask itself, not a hole spacing = mask.GetSpacing() voxel_vol = spacing[0] * spacing[1] * spacing[2] # Build a mask of small holes to fill fill_mask = sitk.Image(mask.GetSize(), sitk.sitkUInt8) fill_mask.CopyInformation(mask) for label in stats.GetLabels(): if label in border_labels: continue hole_vol = stats.GetNumberOfPixels(label) * voxel_vol if hole_vol < max_hole_volume_mm3: fill_mask = sitk.Or( fill_mask, sitk.Cast(sitk.Equal(cc, label), sitk.sitkUInt8) ) result = sitk.Or(mask, fill_mask) return sitk.Cast(result, sitk.sitkUInt8)
[docs] def subtract_mask(base: sitk.Image, subtract: sitk.Image) -> sitk.Image: """Set difference of two binary masks (*base* minus *subtract*). Parameters ---------- base : sitk.Image Binary mask to subtract from. subtract : sitk.Image Binary mask whose foreground voxels are removed from *base*. Returns ------- sitk.Image Binary ``uint8`` mask containing voxels in *base* that are not in *subtract*. """ return sitk.Cast( sitk.And( sitk.Cast(base != 0, sitk.sitkUInt8), sitk.Not(sitk.Cast(subtract != 0, sitk.sitkUInt8)), ), sitk.sitkUInt8, )