Source code for autods_pet.ops.dicom_seg

"""Read DICOM SEG segmentation objects as SimpleITK images.

This module lazily imports ``highdicom`` and ``pydicom`` so that the
optional dependency is only required when a ``.dcm`` mask file is
actually encountered.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

log = logging.getLogger(__name__)

# DICOM SEG SOP Class UID (Segmentation Storage).
_SEG_SOP_CLASS_UID = "1.2.840.10008.5.1.4.1.1.66.4"

_INSTALL_HINT = (
    "DICOM SEG support requires the 'highdicom' package. "
    "Install it with:  pip install autods-pet[dicom-seg]"
)


def _import_highdicom() -> Any:
    """Lazy-import highdicom, raising a helpful error if missing."""
    try:
        import highdicom  # noqa: F811
    except ModuleNotFoundError:
        raise ModuleNotFoundError(_INSTALL_HINT) from None
    return highdicom


[docs] def is_dicom_seg(path: Path) -> bool: """Check whether *path* is a DICOM SEG file (header-only read). Parameters ---------- path : Path Path to a ``.dcm`` file. Returns ------- bool ``True`` if the file's SOPClassUID matches Segmentation Storage. """ import pydicom try: ds = pydicom.dcmread(str(path), stop_before_pixels=True, force=True) return getattr(ds, "SOPClassUID", None) == _SEG_SOP_CLASS_UID except Exception: log.debug("Cannot read %s as DICOM SEG", path, exc_info=True) return False
[docs] def read_referenced_series_uids(path: Path) -> list[str]: """Read the ``ReferencedSeriesSequence`` SeriesInstanceUIDs from a DICOM SEG. Header-only read; no ``highdicom`` dependency. Used by mask discovery to match a SEG file against the patient's PET series. Parameters ---------- path : Path Path to a DICOM SEG (``.dcm``) file. Returns ------- list[str] ``SeriesInstanceUID`` values referenced by the SEG. Empty list if the sequence is missing or the file cannot be parsed. """ import pydicom try: ds = pydicom.dcmread(str(path), stop_before_pixels=True, force=True) except Exception: log.debug("Cannot read DICOM headers from %s", path, exc_info=True) return [] seq = getattr(ds, "ReferencedSeriesSequence", None) if seq is None: return [] uids: list[str] = [] for item in seq: uid = getattr(item, "SeriesInstanceUID", None) if uid is not None: uids.append(str(uid)) return uids
[docs] def list_segments(path: Path) -> list[dict[str, Any]]: """List all segments in a DICOM SEG file. Parameters ---------- path : Path Path to a DICOM SEG (``.dcm``) file. Returns ------- list[dict[str, Any]] Each dict has keys ``number``, ``label``, and ``description``. """ hd = _import_highdicom() seg = hd.seg.segread(str(path)) result: list[dict[str, Any]] = [] for seg_num in seg.segment_numbers: desc = seg.get_segment_description(seg_num) result.append( { "number": int(seg_num), "label": str(desc.segment_label), "description": str(getattr(desc, "SegmentDescription", "")), } ) return result
def _select_segment_number( seg: Any, segment_label: str | None, ) -> int: """Resolve the segment number to extract. Parameters ---------- seg : highdicom.seg.Segmentation Loaded DICOM SEG object. segment_label : str or None User-requested label (case-insensitive match). Returns ------- int The matching segment number. Raises ------ ValueError If the label is not found or is ambiguous. """ seg_nums = list(seg.segment_numbers) labels_map: dict[str, int] = {} for num in seg_nums: desc = seg.get_segment_description(num) labels_map[str(desc.segment_label)] = int(num) if segment_label is not None: # Case-insensitive lookup. lower_map = {k.lower(): (k, v) for k, v in labels_map.items()} key = segment_label.lower() if key in lower_map: return lower_map[key][1] available = ", ".join(f"'{lab}' (#{num})" for lab, num in labels_map.items()) raise ValueError( f"Segment label '{segment_label}' not found in DICOM SEG. " f"Available segments: {available}" ) # No label specified. if len(seg_nums) == 1: return int(seg_nums[0]) available = ", ".join(f"'{lab}' (#{num})" for lab, num in labels_map.items()) raise ValueError( f"DICOM SEG contains {len(seg_nums)} segments: {available}. " "Specify 'segment_label' in the config to select one." )
[docs] def read_dicom_seg( path: Path, segment_label: str | None = None, ) -> Any: """Read a DICOM SEG file and return the selected segment as a SimpleITK image. Parameters ---------- path : Path Path to a DICOM SEG (``.dcm``) file. segment_label : str or None Label of the segment to extract. Required when the file contains more than one segment. Matched case-insensitively against each segment's ``SegmentLabel`` attribute. Returns ------- SimpleITK.Image Binary ``uint8`` mask in LPS orientation (native DICOM / SimpleITK coordinate system). Raises ------ ModuleNotFoundError If ``highdicom`` is not installed. ValueError If the file is not a valid DICOM SEG, the requested label is not found, or a multi-segment file is loaded without specifying a label. """ import SimpleITK as sitk hd = _import_highdicom() seg = hd.seg.segread(str(path)) seg_num = _select_segment_number(seg, segment_label) # Handle FRACTIONAL segmentation type. is_fractional = seg.segmentation_type == hd.seg.SegmentationTypeValues.FRACTIONAL if is_fractional: log.warning( "DICOM SEG '%s' uses FRACTIONAL segmentation type; " "thresholding at 0.5 to produce a binary mask.", path.name, ) volume = seg.get_volume( segment_numbers=[seg_num], combine_segments=True, rescale_fractional=is_fractional, ) arr = volume.array # get_volume returns shape (slices, rows, cols) or with extra dim. if arr.ndim > 3: arr = arr.squeeze() if arr.ndim != 3: raise ValueError( f"Unexpected array shape {arr.shape} from DICOM SEG '{path.name}'." ) # Binarize: combine_segments=True may assign segment numbers as values; # fractional types need thresholding. import numpy as np if is_fractional: arr = (arr >= 0.5).astype(np.uint8) else: arr = (arr > 0).astype(np.uint8) # Build SimpleITK image with spatial metadata from the Volume's affine. # # highdicom Volume axes are (slices, rows, cols). # SimpleITK (via GetImageFromArray) maps numpy axes as: # numpy axis 0 → k (z), axis 1 → j (y), axis 2 → i (x) # So SimpleITK axes (i, j, k) correspond to (cols, rows, slices). # # The affine maps (slice, row, col) indices → (x, y, z) patient coords. # We reorder columns to (col, row, slice) for SimpleITK's (i, j, k). affine = np.asarray(volume.affine, dtype=np.float64) origin = affine[:3, 3] # Reorder spatial columns: (slice=0, row=1, col=2) → (col, row, slice) spacing_vecs = affine[:3, [2, 1, 0]] spacing = np.linalg.norm(spacing_vecs, axis=0) if np.any(spacing < 1e-10): raise ValueError( f"Degenerate affine in DICOM SEG '{path.name}': " f"zero-length spacing vector (spacing={spacing.tolist()})." ) direction = spacing_vecs / spacing img = sitk.GetImageFromArray(arr) img.SetOrigin(origin.tolist()) img.SetSpacing(spacing.tolist()) img.SetDirection(direction.T.flatten().tolist()) img = sitk.Cast(img, sitk.sitkUInt8) desc = seg.get_segment_description(seg_num) log.info( "Loaded DICOM SEG segment '%s' (#%d) from %s [%s].", desc.segment_label, seg_num, path.name, "FRACTIONAL→binary" if is_fractional else "BINARY", ) return img