Skip to content

Commit 7207b29

Browse files
committed
added unit tests for model
1 parent dac7135 commit 7207b29

File tree

3 files changed

+85
-4
lines changed

3 files changed

+85
-4
lines changed

src/utils/logger.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
class PipelineLogger(logging.Logger):
66
def __init__(self, name, log_file=None, level=logging.INFO):
7-
self.logger = logging.getLogger(name)
8-
self.logger.setLevel(level)
7+
super().__init__(name, level)
98

109
# Configure handlers based on parameters
1110
self._setup_handlers(log_file, level)
@@ -26,6 +25,6 @@ def _setup_handlers(self, log_file, level):
2625
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
2726
)
2827
f_handler.setFormatter(f_format)
29-
self.logger.addHandler(f_handler)
28+
self.addHandler(f_handler)
3029

31-
self.logger.addHandler(c_handler)
30+
self.addHandler(c_handler)

tests/test_data_loader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def main():
7777
print(f"Test failed: {e}", file=sys.stderr)
7878
sys.exit(1)
7979

80+
try:
81+
test_preprocessor()
82+
except Exception as e:
83+
print(f"Test failed: {e}", file=sys.stderr)
84+
sys.exit(1)
85+
8086
print("All DataLoader & Preprocessor tests passed.")
8187
sys.exit(0)
8288

tests/test_model.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
5+
6+
import pandas as pd
7+
from src.data.loader import DataLoader
8+
from src.data.preprocessor import DataPreprocessor
9+
from src.model.model import FraudDetectionModel
10+
import numpy as np
11+
12+
13+
def detect_date_columns(df):
14+
date_columns = []
15+
for col in df.columns:
16+
try:
17+
temp_series = pd.to_datetime(df[col], errors="coerce")
18+
if temp_series.notna().any():
19+
date_columns.append(col)
20+
except Exception:
21+
pass
22+
return date_columns
23+
24+
25+
def test_model() -> float:
26+
27+
d = DataLoader("data/processed")
28+
29+
d.load_data("train.csv", "test.csv")
30+
31+
train, val, test = d.train_valid_split(0.2)
32+
33+
date_columns = detect_date_columns(train.copy())
34+
35+
categorical = [
36+
col
37+
for col in train.select_dtypes(include=["object"]).columns
38+
if col not in date_columns
39+
]
40+
numerical = train.select_dtypes(include=["float64", "int64"]).columns
41+
42+
p = DataPreprocessor(categorical, numerical, "is_fraud", date_columns)
43+
44+
p = p.fit(train)
45+
X, y = p.transform(train)
46+
y = np.array(y).astype(np.float64)
47+
48+
fd = FraudDetectionModel(
49+
X.shape[1],
50+
)
51+
52+
fd.train(X, y)
53+
ans = fd.evaluate(X, y)
54+
55+
return ans["accuracy"]
56+
57+
58+
def main():
59+
accuracy = -1
60+
try:
61+
accuracy = test_model()
62+
except Exception as e:
63+
print(f"Test failed: {e}", file=sys.stderr)
64+
sys.exit(1)
65+
66+
# Regression test
67+
if accuracy < 0.8:
68+
print("Model accuracy too low:", accuracy, file=sys.stderr)
69+
sys.exit(1)
70+
71+
print("All model tests passed.")
72+
sys.exit(0)
73+
74+
75+
if __name__ == "__main__":
76+
main()

0 commit comments

Comments
 (0)