Skip to content

Commit 201dde7

Browse files
committed
update datamodule and tutorial
1 parent 9fa19b1 commit 201dde7

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

claymodel/datamodule.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__( # noqa: PLR0913
244244
self.batch_size = batch_size
245245
self.num_workers = num_workers
246246
self.prefetch_factor = prefetch_factor
247-
self.split_ratio = 0.8
247+
self.split_ratio = 0.7
248248

249249
def setup(self, stage: Literal["fit", "predict"] | None = None) -> None:
250250
# Get list of GeoTIFF filepaths from s3 bucket or data/ folder
@@ -258,13 +258,24 @@ def setup(self, stage: Literal["fit", "predict"] | None = None) -> None:
258258
print(f"Total number of chips: {len(chips_path)}")
259259

260260
if stage == "fit":
261-
trn_paths, val_paths = train_test_split(
262-
chips_path,
263-
test_size=(1 - self.split_ratio),
264-
stratify=chips_platform,
265-
shuffle=True,
266-
)
267-
261+
# Check how many unique platforms we have
262+
unique_platforms = set(chips_platform)
263+
264+
if len(unique_platforms) > 1:
265+
# Use stratification if multiple platforms exist
266+
trn_paths, val_paths = train_test_split(
267+
chips_path,
268+
test_size=(1 - self.split_ratio),
269+
stratify=chips_platform,
270+
shuffle=True,
271+
)
272+
else:
273+
# Disable stratification if only one platform (e.g. NAIP)
274+
trn_paths, val_paths = train_test_split(
275+
chips_path,
276+
test_size=(1 - self.split_ratio),
277+
shuffle=True,
278+
)
268279
self.trn_ds = EODataset(
269280
chips_path=trn_paths,
270281
size=self.size,

docs/tutorials/wall-to-wall.ipynb

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 45,
5+
"execution_count": 1,
66
"id": "b958c15a",
77
"metadata": {},
88
"outputs": [],
@@ -13,10 +13,19 @@
1313
},
1414
{
1515
"cell_type": "code",
16-
"execution_count": 46,
16+
"execution_count": 3,
1717
"id": "648742f7",
1818
"metadata": {},
19-
"outputs": [],
19+
"outputs": [
20+
{
21+
"name": "stderr",
22+
"output_type": "stream",
23+
"text": [
24+
"/opt/anaconda3/envs/claymodel/lib/python3.11/site-packages/pyproj/network.py:59: UserWarning: pyproj unable to set PROJ database path.\n",
25+
" _set_context_ca_bundle_path(ca_bundle_path)\n"
26+
]
27+
}
28+
],
2029
"source": [
2130
"\n",
2231
"#import os\n",
@@ -30,10 +39,19 @@
3039
},
3140
{
3241
"cell_type": "code",
33-
"execution_count": 47,
42+
"execution_count": 4,
3443
"id": "de0b279c",
3544
"metadata": {},
36-
"outputs": [],
45+
"outputs": [
46+
{
47+
"name": "stderr",
48+
"output_type": "stream",
49+
"text": [
50+
"/opt/anaconda3/envs/claymodel/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
51+
" from .autonotebook import tqdm as notebook_tqdm\n"
52+
]
53+
}
54+
],
3755
"source": [
3856
"import math\n",
3957
"import geopandas as gpd\n",

0 commit comments

Comments
 (0)