diff --git a/fastMONAI/__init__.py b/fastMONAI/__init__.py index dd9b22c..7225152 100644 --- a/fastMONAI/__init__.py +++ b/fastMONAI/__init__.py @@ -1 +1 @@ -__version__ = "0.5.1" +__version__ = "0.5.2" diff --git a/fastMONAI/_modidx.py b/fastMONAI/_modidx.py index d59ef24..9562d1a 100644 --- a/fastMONAI/_modidx.py +++ b/fastMONAI/_modidx.py @@ -99,6 +99,12 @@ 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.CustomDictTransform.encodes': ( 'vision_augment.html#customdicttransform.encodes', 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.NormalizeIntensity': ( 'vision_augment.html#normalizeintensity', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.NormalizeIntensity.__init__': ( 'vision_augment.html#normalizeintensity.__init__', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.NormalizeIntensity.encodes': ( 'vision_augment.html#normalizeintensity.encodes', + 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.OneOf': ( 'vision_augment.html#oneof', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.OneOf.__init__': ( 'vision_augment.html#oneof.__init__', @@ -163,6 +169,12 @@ 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.RandomSpike.encodes': ( 'vision_augment.html#randomspike.encodes', 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RescaleIntensity': ( 'vision_augment.html#rescaleintensity', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RescaleIntensity.__init__': ( 'vision_augment.html#rescaleintensity.__init__', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RescaleIntensity.encodes': ( 'vision_augment.html#rescaleintensity.encodes', + 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.ZNormalization': ( 'vision_augment.html#znormalization', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.ZNormalization.__init__': ( 'vision_augment.html#znormalization.__init__', diff --git a/fastMONAI/vision_augmentation.py b/fastMONAI/vision_augmentation.py index 16f1ff6..445c82e 100644 --- a/fastMONAI/vision_augmentation.py +++ b/fastMONAI/vision_augmentation.py @@ -1,14 +1,16 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_vision_augment.ipynb. # %% auto 0 -__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'BraTSMaskConverter', 'BinaryConverter', - 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField', 'RandomBlur', 'RandomGamma', - 'RandomMotion', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf'] +__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity', + 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField', + 'RandomBlur', 'RandomGamma', 'RandomMotion', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', + 'OneOf'] # %% ../nbs/03_vision_augment.ipynb 2 from fastai.data.all import * from .vision_core import * import torchio as tio +from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity # %% ../nbs/03_vision_augment.ipynb 5 class CustomDictTransform(ItemTransform): @@ -84,9 +86,29 @@ def __init__(self, masking_method=None, channel_wise=True): self.channel_wise = channel_wise def encodes(self, o: MedImage): - if self.channel_wise: - o = torch.stack([self.z_normalization(c[None])[0] for c in o]) - else: o = self.z_normalization(o) + try: + if self.channel_wise: + o = torch.stack([self.z_normalization(c[None])[0] for c in o]) + else: + o = self.z_normalization(o) + except RuntimeError as e: + if "Standard deviation is 0" in str(e): + # Calculate mean for debugging information + mean = float(o.mean()) + + error_msg = ( + f"Standard deviation is 0 for image (mean={mean:.3f}).\n" + f"This indicates uniform pixel values.\n\n" + f"Possible causes:\n" + f"• Corrupted or blank image\n" + f"• Oversaturated regions\n" + f"• Background-only regions\n" + f"• All-zero mask being processed as image\n\n" + f"Suggested solutions:\n" + f"• Check image quality and acquisition\n" + f"• Verify image vs mask data loading" + ) + raise RuntimeError(error_msg) from e return MedImage.create(o) @@ -94,6 +116,68 @@ def encodes(self, o: MedMask): return o # %% ../nbs/03_vision_augment.ipynb 10 +class RescaleIntensity(DisplayedTransform): + """Apply TorchIO RescaleIntensity for robust intensity scaling. + + Args: + out_min_max (tuple[float, float]): Output intensity range (min, max) + in_min_max (tuple[float, float]): Input intensity range (min, max) + + Example for CT images: + # Normalize CT from air (-1000 HU) to bone (1000 HU) into range (-1, 1) + transform = RescaleIntensity(out_min_max=(-1, 1), in_min_max=(-1000, 1000)) + """ + + order = 0 + + def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]): + self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max) + + def encodes(self, o: MedImage): + return MedImage.create(self.rescale(o)) + + def encodes(self, o: MedMask): + return o + +# %% ../nbs/03_vision_augment.ipynb 11 +class NormalizeIntensity(DisplayedTransform): + """Apply MONAI NormalizeIntensity. + + Args: + nonzero (bool): Only normalize non-zero values (default: True) + channel_wise (bool): Apply normalization per channel (default: True) + subtrahend (float, optional): Value to subtract + divisor (float, optional): Value to divide by + """ + + order = 0 + + def __init__(self, nonzero: bool = True, channel_wise: bool = True, + subtrahend: float = None, divisor: float = None): + self.nonzero = nonzero + self.channel_wise = channel_wise + self.subtrahend = subtrahend + self.divisor = divisor + + self.transform = MonaiNormalizeIntensity( + nonzero=nonzero, + channel_wise=False, # Always 'False', we handle channel-wise manually + subtrahend=subtrahend, + divisor=divisor + ) + + def encodes(self, o: MedImage): + if self.channel_wise: + result = torch.stack([self.transform(c[None])[0] for c in o]) + else: + result = torch.Tensor(self.transform(o)) + + return MedImage.create(result) + + def encodes(self, o: MedMask): + return o + +# %% ../nbs/03_vision_augment.ipynb 12 class BraTSMaskConverter(DisplayedTransform): '''Convert BraTS masks.''' @@ -105,7 +189,7 @@ def encodes(self, o:(MedMask)): o = torch.where(o==4, 3., o) return MedMask.create(o) -# %% ../nbs/03_vision_augment.ipynb 11 +# %% ../nbs/03_vision_augment.ipynb 13 class BinaryConverter(DisplayedTransform): '''Convert to binary mask.''' @@ -118,7 +202,7 @@ def encodes(self, o: MedMask): o = torch.where(o>0, 1., 0) return MedMask.create(o) -# %% ../nbs/03_vision_augment.ipynb 12 +# %% ../nbs/03_vision_augment.ipynb 14 class RandomGhosting(DisplayedTransform): """Apply TorchIO `RandomGhosting`.""" @@ -133,7 +217,7 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 13 +# %% ../nbs/03_vision_augment.ipynb 15 class RandomSpike(DisplayedTransform): '''Apply TorchIO `RandomSpike`.''' @@ -148,7 +232,7 @@ def encodes(self, o:MedImage): def encodes(self, o:MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 14 +# %% ../nbs/03_vision_augment.ipynb 16 class RandomNoise(DisplayedTransform): '''Apply TorchIO `RandomNoise`.''' @@ -163,7 +247,7 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 15 +# %% ../nbs/03_vision_augment.ipynb 17 class RandomBiasField(DisplayedTransform): '''Apply TorchIO `RandomBiasField`.''' @@ -178,7 +262,7 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 16 +# %% ../nbs/03_vision_augment.ipynb 18 class RandomBlur(DisplayedTransform): '''Apply TorchIO `RandomBiasField`.''' @@ -193,7 +277,7 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 17 +# %% ../nbs/03_vision_augment.ipynb 19 class RandomGamma(DisplayedTransform): '''Apply TorchIO `RandomGamma`.''' @@ -209,7 +293,7 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 18 +# %% ../nbs/03_vision_augment.ipynb 20 class RandomMotion(DisplayedTransform): """Apply TorchIO `RandomMotion`.""" @@ -237,7 +321,7 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o -# %% ../nbs/03_vision_augment.ipynb 20 +# %% ../nbs/03_vision_augment.ipynb 22 class RandomElasticDeformation(CustomDictTransform): """Apply TorchIO `RandomElasticDeformation`.""" @@ -250,7 +334,7 @@ def __init__(self, num_control_points=7, max_displacement=7.5, image_interpolation=image_interpolation, p=p)) -# %% ../nbs/03_vision_augment.ipynb 21 +# %% ../nbs/03_vision_augment.ipynb 23 class RandomAffine(CustomDictTransform): """Apply TorchIO `RandomAffine`.""" @@ -266,14 +350,14 @@ def __init__(self, scales=0, degrees=10, translation=0, isotropic=False, default_pad_value=default_pad_value, p=p)) -# %% ../nbs/03_vision_augment.ipynb 22 +# %% ../nbs/03_vision_augment.ipynb 24 class RandomFlip(CustomDictTransform): """Apply TorchIO `RandomFlip`.""" def __init__(self, axes='LR', p=0.5): super().__init__(tio.RandomFlip(axes=axes, flip_probability=p)) -# %% ../nbs/03_vision_augment.ipynb 23 +# %% ../nbs/03_vision_augment.ipynb 25 class OneOf(CustomDictTransform): """Apply only one of the given transforms using TorchIO `OneOf`.""" diff --git a/nbs/03_vision_augment.ipynb b/nbs/03_vision_augment.ipynb index f34ec56..5587ca6 100644 --- a/nbs/03_vision_augment.ipynb +++ b/nbs/03_vision_augment.ipynb @@ -28,7 +28,8 @@ "#| export\n", "from fastai.data.all import *\n", "from fastMONAI.vision_core import *\n", - "import torchio as tio" + "import torchio as tio\n", + "from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity" ] }, { @@ -153,9 +154,29 @@ " self.channel_wise = channel_wise\n", "\n", " def encodes(self, o: MedImage):\n", - " if self.channel_wise:\n", - " o = torch.stack([self.z_normalization(c[None])[0] for c in o])\n", - " else: o = self.z_normalization(o) \n", + " try:\n", + " if self.channel_wise:\n", + " o = torch.stack([self.z_normalization(c[None])[0] for c in o])\n", + " else: \n", + " o = self.z_normalization(o)\n", + " except RuntimeError as e:\n", + " if \"Standard deviation is 0\" in str(e):\n", + " # Calculate mean for debugging information\n", + " mean = float(o.mean())\n", + " \n", + " error_msg = (\n", + " f\"Standard deviation is 0 for image (mean={mean:.3f}).\\n\"\n", + " f\"This indicates uniform pixel values.\\n\\n\"\n", + " f\"Possible causes:\\n\"\n", + " f\"• Corrupted or blank image\\n\"\n", + " f\"• Oversaturated regions\\n\" \n", + " f\"• Background-only regions\\n\"\n", + " f\"• All-zero mask being processed as image\\n\\n\"\n", + " f\"Suggested solutions:\\n\"\n", + " f\"• Check image quality and acquisition\\n\"\n", + " f\"• Verify image vs mask data loading\"\n", + " )\n", + " raise RuntimeError(error_msg) from e\n", "\n", " return MedImage.create(o)\n", "\n", @@ -163,6 +184,82 @@ " return o" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# | export\n", + "class RescaleIntensity(DisplayedTransform):\n", + " \"\"\"Apply TorchIO RescaleIntensity for robust intensity scaling.\n", + " \n", + " Args:\n", + " out_min_max (tuple[float, float]): Output intensity range (min, max)\n", + " in_min_max (tuple[float, float]): Input intensity range (min, max) \n", + " \n", + " Example for CT images:\n", + " # Normalize CT from air (-1000 HU) to bone (1000 HU) into range (-1, 1)\n", + " transform = RescaleIntensity(out_min_max=(-1, 1), in_min_max=(-1000, 1000))\n", + " \"\"\"\n", + " \n", + " order = 0\n", + " \n", + " def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]):\n", + " self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max)\n", + " \n", + " def encodes(self, o: MedImage):\n", + " return MedImage.create(self.rescale(o))\n", + " \n", + " def encodes(self, o: MedMask):\n", + " return o" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# | export\n", + "class NormalizeIntensity(DisplayedTransform):\n", + " \"\"\"Apply MONAI NormalizeIntensity.\n", + " \n", + " Args:\n", + " nonzero (bool): Only normalize non-zero values (default: True)\n", + " channel_wise (bool): Apply normalization per channel (default: True)\n", + " subtrahend (float, optional): Value to subtract \n", + " divisor (float, optional): Value to divide by\n", + " \"\"\"\n", + " \n", + " order = 0\n", + " \n", + " def __init__(self, nonzero: bool = True, channel_wise: bool = True, \n", + " subtrahend: float = None, divisor: float = None):\n", + " self.nonzero = nonzero\n", + " self.channel_wise = channel_wise\n", + " self.subtrahend = subtrahend\n", + " self.divisor = divisor\n", + " \n", + " self.transform = MonaiNormalizeIntensity(\n", + " nonzero=nonzero,\n", + " channel_wise=False, # Always 'False', we handle channel-wise manually\n", + " subtrahend=subtrahend,\n", + " divisor=divisor\n", + " )\n", + " \n", + " def encodes(self, o: MedImage):\n", + " if self.channel_wise:\n", + " result = torch.stack([self.transform(c[None])[0] for c in o])\n", + " else:\n", + " result = torch.Tensor(self.transform(o))\n", + " \n", + " return MedImage.create(result)\n", + " \n", + " def encodes(self, o: MedMask):\n", + " return o" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/10e_tutorial_inference.ipynb b/nbs/10e_tutorial_inference.ipynb index 781bb06..1b815de 100644 --- a/nbs/10e_tutorial_inference.ipynb +++ b/nbs/10e_tutorial_inference.ipynb @@ -96,43 +96,7 @@ "execution_count": null, "id": "f2dd02b4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading artifacts from run: 1566b936ef474ad495149496560365c2\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f8bb7ca59cae4c67bcd9926aa6b47b96", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading artifacts: 0%| | 0/1 [00:00=1.2.0 scikit-image==0.25.2 imagedata==3.8.4 mlflow==3.3.1 huggingface-hub gdown gradio opencv-python plum-dispatch