Skip to content

Commit d94872f

Browse files
committed
fix(general): Typo annotations and ruff
For what ever reason those checks were missed. THey are now partial fixed. Check #25 for final solution
1 parent 76ba41f commit d94872f

File tree

5 files changed

+60
-58
lines changed

5 files changed

+60
-58
lines changed

tests/unit/test_builder.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
11
from __future__ import annotations
22

33
import ast
4-
import sys
54
from pathlib import Path
5+
from tomllib import TOMLDecodeError
66
from unittest.mock import mock_open, patch
77

88
import pytest
99

10-
try:
11-
from tomllib import TOMLDecodeError
12-
except ImportError:
13-
try:
14-
from tomli import TOMLDecodeError # type: ignore # noqa: PGH003
15-
except ImportError:
16-
sys.exit("Error: This program requires either tomllib or tomli but neither is available")
17-
1810
from tipi.abstractions import Permanence, PipelineProcess
1911
from tipi.core.builder import PipelineBuilder, get_objects_for_pipeline
2012
from tipi.errors import (

tipi/core/processes.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# type: ignore
2-
# ruff: noqa
3-
41
# Random selected images for final result
52
# Training: [tensor([28]), tensor([46]), tensor([60]), tensor([63]), tensor([90])]
63
# Validation: [tensor([10]), tensor([18]), tensor([35]), tensor([57]), tensor([79]]]
@@ -14,14 +11,20 @@
1411
# mask against GT if available, Difference of Mask with GT if available.
1512

1613
from collections import OrderedDict
17-
from typing import Any
14+
from typing import Any, Protocol
1815

1916
import torch
20-
2117
import wandb
18+
2219
from tipi.abstractions import PipelineProcess
2320

2421

22+
class ProgressTaskCallable(Protocol):
23+
"""Protocol for functions decorated with progress_task."""
24+
25+
def __call__(self, total: int) -> None: ...
26+
27+
2528
class ResultProcess(PipelineProcess):
2629
"""
2730
ResultProcess is a class that handles the logging of images and their corresponding masks during
@@ -76,26 +79,27 @@ def __init__(self, controller: Any, force: bool, selected_images: dict[str, list
7679
self.valset = self.datasets.segnet_dataset_val
7780
self.testset = self.datasets.segnet_dataset_test
7881

79-
def execute(self):
82+
def execute(self) -> None:
8083
train_image_log = self._get_log_train_images()
8184
val_image_log = self._get_log_val_images()
8285
test_image_log = self._get_log_test_images()
8386

84-
train_image_log(len(self.train_images_indices))
87+
if self.train_images_indices:
88+
train_image_log(len(self.train_images_indices))
8589
if self.datasets.val_available() and self.val_images_indices:
8690
val_image_log(len(self.val_images_indices))
8791
if self.datasets.test_available() and self.test_images_indices:
8892
test_image_log(len(self.test_images_indices))
8993

90-
def _get_log_train_images(self) -> callable:
94+
def _get_log_train_images(self) -> ProgressTaskCallable:
9195
@self.progress_manager.progress_task("result", visible=False)
92-
def _inner_log_image(total, task_id, progress):
96+
def _inner_log_image(total: int, task_id: int, progress: Any) -> None:
9397
image_stack = []
9498
mask_stack = []
9599
pred_mask_stack = []
96100
for idx in range(total):
97101
progress.advance(task_id)
98-
selected_images = self.train_images_indices[idx]
102+
selected_images = self.train_images_indices[idx] # type: ignore[index]
99103
image, mask = self.trainset[selected_images]
100104
image_stack.append(image)
101105
mask_stack.append(mask)
@@ -106,17 +110,17 @@ def _inner_log_image(total, task_id, progress):
106110
pred_masks = torch.cat(pred_mask_stack, 1)
107111
self._log_image(images, masks, pred_masks, "train")
108112

109-
return _inner_log_image
113+
return _inner_log_image # type: ignore[no-any-return]
110114

111-
def _get_log_val_images(self) -> callable:
115+
def _get_log_val_images(self) -> ProgressTaskCallable:
112116
@self.progress_manager.progress_task("result", visible=False)
113-
def _inner_log_image(total, task_id, progress):
117+
def _inner_log_image(total: int, task_id: int, progress: Any) -> None:
114118
image_stack = []
115119
mask_stack = []
116120
pred_mask_stack = []
117121
for idx in range(total):
118122
progress.advance(task_id)
119-
selected_images = self.val_images_indices[idx]
123+
selected_images = self.val_images_indices[idx] # type: ignore[index]
120124
image, mask = self.valset[selected_images]
121125
image_stack.append(image)
122126
mask_stack.append(mask)
@@ -127,17 +131,17 @@ def _inner_log_image(total, task_id, progress):
127131
pred_masks = torch.cat(pred_mask_stack, 1)
128132
self._log_image(images, masks, pred_masks, "val")
129133

130-
return _inner_log_image
134+
return _inner_log_image # type: ignore[no-any-return]
131135

132-
def _get_log_test_images(self) -> callable:
136+
def _get_log_test_images(self) -> ProgressTaskCallable:
133137
@self.progress_manager.progress_task("result", visible=False)
134-
def _inner_log_image(total, task_id, progress):
138+
def _inner_log_image(total: int, task_id: int, progress: Any) -> None:
135139
image_stack = []
136140
mask_stack = []
137141
pred_mask_stack = []
138142
for idx in range(total):
139143
progress.advance(task_id)
140-
selected_images = self.test_images_indices[idx]
144+
selected_images = self.test_images_indices[idx] # type: ignore[index]
141145
image, mask = self.testset[selected_images]
142146
image_stack.append(image)
143147
mask_stack.append(mask)
@@ -148,26 +152,26 @@ def _inner_log_image(total, task_id, progress):
148152
pred_masks = torch.cat(pred_mask_stack, 1)
149153
self._log_image(images, masks, pred_masks, "test")
150154

151-
return _inner_log_image
155+
return _inner_log_image # type: ignore[no-any-return]
152156

153-
def _inference_model(self, image):
157+
def _inference_model(self, image: Any) -> Any:
154158
self.model.eval()
155159
with torch.no_grad():
156160
return self.model(image.unsqueeze(0).to(self.device))
157161

158-
def _get_pred_mask(self, pred):
162+
def _get_pred_mask(self, pred: Any) -> Any:
159163
if isinstance(pred, OrderedDict):
160164
pred = pred["out"]
161165
return pred.argmax(dim=1).squeeze(0).cpu()
162166

163-
def _get_mask_difference(self, mask, pred_mask):
167+
def _get_mask_difference(self, mask: Any, pred_mask: Any) -> Any:
164168
mask[mask == 255] = 0
165169
mask_difference = mask - pred_mask
166170
if mask_difference.min() < 0:
167171
mask_difference = mask_difference + mask_difference.min().abs()
168172
return mask_difference.to(torch.uint8)
169173

170-
def _log_image(self, image, mask, pred_mask, dataset):
174+
def _log_image(self, image: Any, mask: Any, pred_mask: Any, dataset: str) -> None:
171175
class_labels = self._swap_labels(self.datasets.data_container.classes)
172176
just_image = wandb.Image(image, caption=f"{dataset} images")
173177
image_with_mask = wandb.Image(
@@ -186,8 +190,8 @@ def _log_image(self, image, mask, pred_mask, dataset):
186190
wandb.log({f"{dataset}_images_with_mask": image_with_mask})
187191
wandb.log({f"{dataset}_mask_difference": mask_difference})
188192

189-
def _swap_labels(self, labels):
193+
def _swap_labels(self, labels: dict[Any, Any]) -> dict[Any, Any]:
190194
return {v: k for k, v in labels.items()}
191195

192-
def skip(self):
196+
def skip(self) -> bool:
193197
return False

tipi/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, controller: Any, force: bool = False, **kwargs: Any) -> None:
9797
def execute(self) -> None:
9898
"""Execute the wrapped function."""
9999
# Set pipeline context so helpers work
100-
from tipi.helpers import ( # type: ignore[attr-defined]
100+
from tipi.helpers import (
101101
clear_pipeline_context,
102102
set_pipeline_context,
103103
)

tipi/errors.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,40 +47,47 @@ def __init__(self, code: str, message: str):
4747
@dataclass
4848
class BuilderError(RuntimeError):
4949
error_value: Any
50+
error_code: ErrorCode | None = None
5051

51-
def __post_init__(self, error_code: ErrorCode) -> None: # type: ignore # noqa: PGH003
52+
def __post_init__(self) -> None:
53+
pass
54+
55+
def _set_error_code(self, error_code: ErrorCode) -> None:
56+
"""Set the error code (called by subclasses)."""
5257
self.error_code = error_code
5358

5459
def __str__(self) -> str:
55-
return f"[{self.error_code.code}]: {self.error_code.message}: {self.error_value}"
60+
if self.error_code:
61+
return f"[{self.error_code.code}]: {self.error_code.message}: {self.error_value}"
62+
return f"BuilderError: {self.error_value}"
5663

5764

5865
class ConfigNotFoundError(BuilderError):
5966
"""Raised when the builder configuration file does not exists"""
6067

6168
def __post_init__(self) -> None:
62-
super().__post_init__(ErrorCode.CONFIG_MISSING)
69+
self._set_error_code(ErrorCode.CONFIG_MISSING)
6370

6471

6572
class ConfigPermissionError(BuilderError):
6673
"""Raised when the builder configuration file does not exists"""
6774

6875
def __post_init__(self) -> None:
69-
super().__post_init__(ErrorCode.CONFIG_PERMISSION)
76+
self._set_error_code(ErrorCode.CONFIG_PERMISSION)
7077

7178

7279
class ConfigInvalidTomlError(BuilderError):
7380
"""Raised when the configuration file is not valid toml"""
7481

7582
def __post_init__(self) -> None:
76-
super().__post_init__(ErrorCode.CONFIG_INVALID)
83+
self._set_error_code(ErrorCode.CONFIG_INVALID)
7784

7885

7986
class ConfigSectionError(BuilderError):
8087
"""Raised for config section missing"""
8188

8289
def __post_init__(self) -> None:
83-
super().__post_init__(ErrorCode.CONFIG_SECTION)
90+
self._set_error_code(ErrorCode.CONFIG_SECTION)
8491

8592

8693
class InvalidConfigError(Exception):
@@ -95,21 +102,21 @@ class RegistryError(BuilderError):
95102
"""Raised for class registration issues"""
96103

97104
def __post_init__(self) -> None:
98-
super().__post_init__(ErrorCode.REGISTRY_INVALID)
105+
self._set_error_code(ErrorCode.REGISTRY_INVALID)
99106

100107

101108
class RegistryParamError(BuilderError):
102109
"""Raised for class instatioation with wrong params"""
103110

104111
def __post_init__(self) -> None:
105-
super().__post_init__(ErrorCode.REGISTRY_PARAM)
112+
self._set_error_code(ErrorCode.REGISTRY_PARAM)
106113

107114

108115
class InstTypeError(BuilderError):
109116
"""Raised when type in config not set"""
110117

111118
def __post_init__(self) -> None:
112-
super().__post_init__(ErrorCode.INST_TYPE)
119+
self._set_error_code(ErrorCode.INST_TYPE)
113120

114121

115122
## Execution

tipi/helpers.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# type: ignore
2-
# ruff: noqa
31
"""Helper utilities for smooth transition from scripts to pipeline.
42
53
This module provides simple, script-friendly functions that work standalone
@@ -23,13 +21,14 @@
2321

2422
from __future__ import annotations
2523

26-
from typing import Any, Iterable, Optional
24+
from collections.abc import Iterable
25+
from typing import Any
2726

2827
import torch
2928
from rich.progress import track
3029

3130
# Global context that pipeline can set
32-
_pipeline_context: Optional[dict[str, Any]] = None
31+
_pipeline_context: dict[str, Any] | None = None
3332

3433

3534
def set_pipeline_context(context: dict[str, Any]) -> None:
@@ -51,7 +50,7 @@ def clear_pipeline_context() -> None:
5150
def progress_bar(
5251
iterable: Iterable,
5352
desc: str = "Processing",
54-
total: Optional[int] = None,
53+
total: int | None = None,
5554
) -> Iterable:
5655
"""Progress bar that auto-integrates with pipeline or uses rich.track.
5756
@@ -92,12 +91,12 @@ class Logger:
9291
- Pipeline: Uses pipeline's WandBManager if available
9392
"""
9493

95-
def __init__(self):
94+
def __init__(self) -> None:
9695
self._wandb_initialized = False
97-
self._project = None
98-
self._entity = None
96+
self._project: str | None = None
97+
self._entity: str | None = None
9998

100-
def init(self, project: str, entity: Optional[str] = None, name: Optional[str] = None, **kwargs) -> None:
99+
def init(self, project: str, entity: str | None = None, name: str | None = None, **kwargs: Any) -> None:
101100
"""Initialize logger (manual WandB setup or use pipeline's).
102101
103102
Args:
@@ -125,7 +124,7 @@ def init(self, project: str, entity: Optional[str] = None, name: Optional[str] =
125124
except ImportError:
126125
print("WandB not available, logging to console only")
127126

128-
def log(self, metrics: dict[str, Any], step: Optional[int] = None) -> None:
127+
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
129128
"""Log metrics to WandB or console.
130129
131130
Args:
@@ -174,7 +173,7 @@ def get_device(self) -> torch.device:
174173
if _pipeline_context:
175174
device_perm = _pipeline_context.get("device")
176175
if device_perm:
177-
return device_perm.device
176+
return torch.device(device_perm.device)
178177

179178
# Standalone: Select best device
180179
if torch.cuda.is_available():
@@ -189,9 +188,9 @@ def get_device(self) -> torch.device:
189188

190189

191190
__all__ = [
192-
"progress_bar",
193-
"logger",
191+
"clear_pipeline_context",
194192
"device_manager",
193+
"logger",
194+
"progress_bar",
195195
"set_pipeline_context",
196-
"clear_pipeline_context",
197196
]

0 commit comments

Comments
 (0)