Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
09f29f1
feat: Add Ice Chunk support for cloud data access
Dakshbir Jul 18, 2025
7cf48cc
resolve conflicts in ocf_data_sampler/load/load_dataset.py and ocf_da…
Dakshbir Jul 22, 2025
58eb2ac
resolve conflicts 2.0 in ocf_data_sampler/load/load_dataset.py and o…
Dakshbir Jul 22, 2025
90bd8aa
feat: Integrate Ice Chunk for optimized satellite loading
Dakshbir Jul 27, 2025
5773715
Final implementation of Sol's architectural feedback
Dakshbir Jul 31, 2025
113b9ce
modified satellite.py
Dakshbir Aug 6, 2025
e6cb9d5
modified satellite.py and model.py
Dakshbir Aug 6, 2025
dd348c7
modified satellite.py again
Dakshbir Aug 9, 2025
e39072b
efactor the test cases 2 and 3 from test_loading into new test_<test-…
Dakshbir Aug 13, 2025
f7fc65b
deleted unnecessary files
Dakshbir Aug 14, 2025
6a6d009
final changes
Dakshbir Aug 20, 2025
2f19b3f
final changes 2.0
Dakshbir Aug 20, 2025
a96b412
Added gist of the gsoc project
Dakshbir Aug 24, 2025
8fac972
final changes 2.0
Dakshbir Aug 28, 2025
fe61946
remove comments
Dakshbir Aug 29, 2025
c521625
removed comments 2.0
Dakshbir Aug 30, 2025
4fb5b69
done with the final changes
Dakshbir Sep 2, 2025
ada0181
deleted all benchmark scripts
Dakshbir Sep 2, 2025
311e54f
last changes
Dakshbir Sep 2, 2025
d883142
updated model.py
Dakshbir Sep 2, 2025
a550ccb
Delete ocf_data_sampler/config/model.py
Dakshbir Sep 2, 2025
e6eabe4
restored model.py
Dakshbir Sep 2, 2025
bea29ff
tried fixing model.py
Dakshbir Sep 2, 2025
7084005
changes done on model.py
Dakshbir Sep 2, 2025
635b476
restored model.py
Dakshbir Sep 2, 2025
a41b7f2
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 2, 2025
4051471
passed all linting checks
Dakshbir Sep 3, 2025
93bf294
just for running the checks
Dakshbir Sep 4, 2025
199d705
Revert "just for running the checks"
Dakshbir Sep 4, 2025
37cfa22
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 15, 2025
fddd143
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 17, 2025
fc081b5
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 19, 2025
8a88ad8
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 22, 2025
f002487
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 24, 2025
2248ec7
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,4 @@ class Configuration(Base):
"""Configuration model for the dataset."""

general: General = General()
input_data: InputData = InputData()
input_data: InputData = InputData()
16 changes: 7 additions & 9 deletions ocf_data_sampler/load/load_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Loads all data sources."""

import xarray as xr

from ocf_data_sampler.config import InputData
from ocf_data_sampler.load import open_gsp, open_nwp, open_sat_data, open_site


def get_dataset_dict(
input_config: InputData,
gsp_ids: list[int] | None = None,
Expand Down Expand Up @@ -46,25 +44,25 @@ def get_dataset_dict(
)

da_nwp = da_nwp.sel(channel=list(nwp_config.channels))

datasets_dict["nwp"][nwp_source] = da_nwp

# Load satellite data if in config
if input_config.satellite:
sat_config = input_config.satellite

da_sat = open_sat_data(sat_config.zarr_path)

da_sat = da_sat.sel(channel=list(sat_config.channels))

# open_sat_data will now internally decide whether to use
# the standard Zarr loader or the Ice Chunk loader.
da_sat = open_sat_data(
zarr_path=sat_config.zarr_path,
channels=list(sat_config.channels),
)
datasets_dict["sat"] = da_sat

# Load site data if in config
if input_config.site:
da_sites = open_site(
generation_file_path=input_config.site.file_path,
metadata_file_path=input_config.site.metadata_file_path,
)

datasets_dict["site"] = da_sites

return datasets_dict
207 changes: 173 additions & 34 deletions ocf_data_sampler/load/satellite.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,198 @@
"""Satellite loader."""
import numpy as np

import logging
import os
import re
from typing import List, Optional

import dask
import icechunk
import xarray as xr
from xarray_tensorstore import open_zarr
from ocf_data_sampler.load.open_tensorstore_zarrs import open_zarrs
from contextlib import contextmanager

from ocf_data_sampler.load.utils import (
check_time_unique_increasing,
get_xr_data_array_from_xr_dataset,
make_spatial_coords_increasing,
)

from .open_tensorstore_zarrs import open_zarrs

logger = logging.getLogger(__name__)

def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
"""Lazily opens the zarr store and validates data types.
# Optimal values from research, now hardcoded as per Sol's feedback.
OPTIMAL_BLOCK_SIZE_MB = 64
OPTIMAL_THREADS = 2

Args:
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
it must start with 'gs://'
"""
# Open the data
def open_sat_data(zarr_path: str | list[str], channels: list[str] | None = None) -> xr.DataArray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After [1], this should return to just the one input parameter zarr_path

"""Lazily opens the zarr store and validates data types."""

if isinstance(zarr_path, list | tuple):
ds = open_zarrs(zarr_path, concat_dim="time")
else:
ds = open_zarr(zarr_path)
# Parse path components using Sol's regex approach
path_info = _parse_zarr_path(zarr_path)

# Sol's requested match/case pattern for path routing
match path_info:
# Updated case to handle local icechunk paths correctly
case {"protocol": protocol, "bucket": bucket, "prefix": prefix, "sha1": sha1} if prefix.endswith(".icechunk"):
# Single case for both local and cloud Ice Chunk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Single case for both local and cloud Ice Chunk

ds = _open_sat_data_icechunk(protocol, bucket, prefix, sha1)

case {"protocol": _, "bucket": _, "prefix": _, "sha1": None}:
# this doesn't work for blosc2
# use ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto") in the case of blosc2
ds = open_zarr(zarr_path)

case _:
# Raise error on unhandled path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Raise error on unhandled path

raise ValueError(f"Unhandled path format: {zarr_path}")

check_time_unique_increasing(ds.time)

ds = ds.rename(
{
"variable": "channel",
"time": "time_utc",
},
)

check_time_unique_increasing(ds.time_utc)

# Select channels if provided (before renaming variables)
if channels:
ds = ds.sel(variable=channels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can go after [1]


ds = ds.rename({"variable": "channel", "time": "time_utc"})
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")

data_array = get_xr_data_array_from_xr_dataset(ds)

# Validate data types directly loading function
if not np.issubdtype(data_array.dtype, np.number):
# Validate data types directly in loading function
if not data_array.dtype.kind in 'bifc': # boolean, int, float, complex
raise TypeError(f"Satellite data should be numeric, not {data_array.dtype}")


# Updated coordinate validation - more flexible for datetime64 subtypes
coord_dtypes = {
"time_utc": np.datetime64,
"channel": np.str_,
"x_geostationary": np.floating,
"y_geostationary": np.floating,
"time_utc": "M", # datetime64 (any precision)
"channel": "U", # Unicode string
"x_geostationary": "f", # floating
"y_geostationary": "f", # floating
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"y_geostationary": "f", # floating
"y_geostationary": "f",

}

for coord, expected_kind in coord_dtypes.items():
actual_kind = data_array.coords[coord].dtype.kind
if actual_kind != expected_kind:
# Special handling for datetime64 - accept any datetime64 precision
if expected_kind == "M" and actual_kind == "M":
continue # Both are datetime64, just different precisions
raise TypeError(f"Coordinate {coord} should be {expected_kind}, not {actual_kind}")

return data_array

@contextmanager
def _setup_optimal_environment():
"""Apply optimization settings for cloud data streaming with context management."""
# Store original values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Store original values

original_values = {}
env_vars = {
"GCSFS_CACHE_TIMEOUT": "3600",
"GCSFS_BLOCK_SIZE": str(OPTIMAL_BLOCK_SIZE_MB * 1024 * 1024),
"GCSFS_DEFAULT_CACHE_TYPE": "readahead",
"GOOGLE_CLOUD_DISABLE_GRPC": "true"
}

# Store original environment values
for key in env_vars:
original_values[key] = os.environ.get(key)
os.environ[key] = env_vars[key]

# Store original dask config (THIS MUST BE DECLARED HERE)
original_dask_config = dict(dask.config.config)

dask.config.set({
"distributed.worker.memory.target": 0.7,
"array.chunk-size": "512MB",
"distributed.comm.compression": None,
"distributed.worker.threads": OPTIMAL_THREADS,
})

try:
yield
finally:
# Restore original environment values
for key, original_value in original_values.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value

# Restore original dask config
try:
dask.config.reset()
except AttributeError:
# Use the correctly named variable
dask.config.set(original_dask_config)

for coord, expected_dtype in coord_dtypes.items():
if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype):
dtype = data_array.coords[coord].dtype
raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")
def _parse_zarr_path(path: str) -> dict:
"""Parse a path into its components, supporting both local and cloud paths."""

return data_array
# Sol's recommended regex pattern - handles optional protocol and wildcards
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Sol's recommended regex pattern - handles optional protocol and wildcards

pattern = r"^(?:(?P<protocol>[\w]{2,6}):\/\/)?(?P<bucket>\/?[\w-]+)\/(?P<prefix>[\w*.\/-]+?)(?:@(?P<sha1>[\w]+))?$"
match = re.match(pattern, path)
if not match:
raise ValueError(f"Invalid path format: {path}")

components = match.groupdict()

# Validation checks moved from match block
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Validation checks moved from match block

if components["sha1"] is not None and not components["prefix"].endswith(".icechunk"):
raise ValueError("Commit syntax (@commit) not supported for non-icechunk stores")

if components["protocol"] == "gs" and components["prefix"] is not None and "*" in components["prefix"]:
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")

return components

def _open_sat_data_icechunk(
protocol: str | None, bucket: str, prefix: str, sha1: str | None
) -> xr.Dataset:
"""Open satellite data from an Ice Chunk repository with optimized settings."""

# Get storage according to protocol
if protocol is None:
logger.info(f"Opening local Ice Chunk repository: {prefix}")
storage = icechunk.local_filesystem_storage(prefix)
elif protocol == "gs":
logger.info(f"Opening Ice Chunk repository: {protocol}://{bucket}/{prefix}")
with _setup_optimal_environment(): # Use context manager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with _setup_optimal_environment(): # Use context manager
with _setup_optimal_environment():

# Ensure proper trailing slash
if not prefix.endswith('/'):
prefix = prefix + '/'

logger.info(f"Accessing Ice Chunk repository: {protocol}://{bucket}/{prefix}")
storage = icechunk.gcs_storage(bucket=bucket, prefix=prefix, from_env=True)
else:
raise ValueError(f"Unsupported protocol: {protocol}")

# Get repo from storage (single try/catch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Get repo from storage (single try/catch)

try:
repo = icechunk.Repository.open(storage)
except Exception as e:
logger.error(f"Failed to open Ice Chunk repository at {protocol or 'local'}://{bucket or ''}/{prefix}")
raise e

# CORRECT - uses proper Ice Chunk API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# CORRECT - uses proper Ice Chunk API

try:
if sha1:
session = repo.readonly_session(snapshot_id=sha1)
else:
session = repo.readonly_session("main")
except Exception as e:
target = sha1 or "main"
raise ValueError(f"Failed to open session for '{target}': {e}") from e

# Open the dataset from the Ice Chunk session store
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Open the dataset from the Ice Chunk session store

ds = xr.open_zarr(session.store, consolidated=True, chunks="auto")

# Convert Ice Chunk format to standard format
if len(ds.data_vars) > 1:
data_arrays = [ds[var] for var in sorted(ds.data_vars)]
combined_da = xr.concat(data_arrays, dim="variable")
combined_da = combined_da.assign_coords(variable=sorted(ds.data_vars))
ds = xr.Dataset({"data": combined_da})

return ds
Loading