Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
09f29f1
feat: Add Ice Chunk support for cloud data access
Dakshbir Jul 18, 2025
7cf48cc
resolve conflicts in ocf_data_sampler/load/load_dataset.py and ocf_da…
Dakshbir Jul 22, 2025
58eb2ac
resolve conflicts 2.0 in ocf_data_sampler/load/load_dataset.py and o…
Dakshbir Jul 22, 2025
90bd8aa
feat: Integrate Ice Chunk for optimized satellite loading
Dakshbir Jul 27, 2025
5773715
Final implementation of Sol's architectural feedback
Dakshbir Jul 31, 2025
113b9ce
modified satellite.py
Dakshbir Aug 6, 2025
e6cb9d5
modified satellite.py and model.py
Dakshbir Aug 6, 2025
dd348c7
modified satellite.py again
Dakshbir Aug 9, 2025
e39072b
efactor the test cases 2 and 3 from test_loading into new test_<test-…
Dakshbir Aug 13, 2025
f7fc65b
deleted unnecessary files
Dakshbir Aug 14, 2025
6a6d009
final changes
Dakshbir Aug 20, 2025
2f19b3f
final changes 2.0
Dakshbir Aug 20, 2025
a96b412
Added gist of the gsoc project
Dakshbir Aug 24, 2025
8fac972
final changes 2.0
Dakshbir Aug 28, 2025
fe61946
remove comments
Dakshbir Aug 29, 2025
c521625
removed comments 2.0
Dakshbir Aug 30, 2025
4fb5b69
done with the final changes
Dakshbir Sep 2, 2025
ada0181
deleted all benchmark scripts
Dakshbir Sep 2, 2025
311e54f
last changes
Dakshbir Sep 2, 2025
d883142
updated model.py
Dakshbir Sep 2, 2025
a550ccb
Delete ocf_data_sampler/config/model.py
Dakshbir Sep 2, 2025
e6eabe4
restored model.py
Dakshbir Sep 2, 2025
bea29ff
tried fixing model.py
Dakshbir Sep 2, 2025
7084005
changes done on model.py
Dakshbir Sep 2, 2025
635b476
restored model.py
Dakshbir Sep 2, 2025
a41b7f2
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 2, 2025
4051471
passed all linting checks
Dakshbir Sep 3, 2025
93bf294
just for running the checks
Dakshbir Sep 4, 2025
199d705
Revert "just for running the checks"
Dakshbir Sep 4, 2025
37cfa22
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 15, 2025
fddd143
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 17, 2025
fc081b5
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 19, 2025
8a88ad8
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 22, 2025
f002487
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 24, 2025
2248ec7
Merge branch 'main' into feat/ice-chunk-support
devsjc Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,79 @@
"""Configuration model for the dataset.


Absolute or relative zarr filepath(s).
Prefix with a protocol like s3:// to read from alternative filesystems.
"""


from collections.abc import Iterator
from typing import Literal


from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
from typing_extensions import override



NWP_PROVIDERS = [
"ukv",
"ecmwf",
"mo_global",
"gfs",
"icon_eu",
"cloudcasting",
]




class Base(BaseModel):
"""Pydantic Base model where no extras can be added."""


class Config:
"""Config class."""


extra = "forbid" # forbid use of extra kwargs




class General(Base):
"""General pydantic model."""


name: str = Field("example", description="The name of this configuration file")
description: str = Field(
"example configuration",
description="Description of this configuration file",
)




class TimeWindowMixin(Base):
"""Mixin class, to add interval start, end and resolution minutes."""


time_resolution_minutes: int = Field(
...,
gt=0,
description="The temporal resolution of the data in minutes",
)


interval_start_minutes: int = Field(
...,
description="Data interval starts at `t0 + interval_start_minutes`",
)


interval_end_minutes: int = Field(
...,
description="Data interval ends at `t0 + interval_end_minutes`",
)


@model_validator(mode="after")
def validate_intervals(self) -> "TimeWindowMixin":
"""Validator for time interval fields."""
Expand Down Expand Up @@ -104,7 +120,6 @@ def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
raise ValueError("Dropout timedeltas must be negative")
return v


@field_validator("dropout_fraction")
def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
"""Validate 'dropout_frac'."""
Expand All @@ -128,12 +143,10 @@ def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]
if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
raise ValueError("Sum of all floats in the list must be 1.0")


else:
raise TypeError("Must be either a float or a list of floats")
return dropout_frac


@model_validator(mode="after")
def dropout_instructions_consistent(self) -> "DropoutMixin":
"""Validator for dropout instructions."""
Expand All @@ -145,33 +158,39 @@ def dropout_instructions_consistent(self) -> "DropoutMixin":
raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
return self


class SpatialWindowMixin(Base):
"""Mixin class, to add path and image size."""


image_size_pixels_height: int = Field(
...,
ge=0,
description="The number of pixels of the height of the region of interest",
)


image_size_pixels_width: int = Field(
...,
ge=0,
description="The number of pixels of the width of the region of interest",
)




class NormalisationValues(Base):
"""Normalisation mean and standard deviation."""
mean: float = Field(..., description="Mean value for normalization")
std: float = Field(..., gt=0, description="Standard deviation (must be positive)")




class NormalisationConstantsMixin(Base):
"""Normalisation constants for multiple channels."""
normalisation_constants: dict[str, NormalisationValues]


@property
def channel_means(self) -> dict[str, float]:
"""Return the channel means."""
Expand All @@ -181,6 +200,8 @@ def channel_means(self) -> dict[str, float]:
}




@property
def channel_stds(self) -> dict[str, float]:
"""Return the channel standard deviations."""
Expand All @@ -190,8 +211,10 @@ def channel_stds(self) -> dict[str, float]:
}




class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
"""Satellite configuration model."""
"""Satellite configuration model with Ice Chunk support."""

zarr_path: str | tuple[str] | list[str] = Field(
...,
Expand Down Expand Up @@ -220,39 +243,33 @@ def check_all_channel_have_normalisation_constants(self) -> "Satellite":
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
"""NWP configuration model."""


zarr_path: str | tuple[str] | list[str] = Field(
...,
description="Absolute or relative zarr filepath(s). Prefix with a protocol like s3:// "
"to read from alternative filesystems.",
)


channels: list[str] = Field(
...,
description="the channels used in the nwp data",
)


provider: str = Field(..., description="The provider of the NWP data")


accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")


max_staleness_minutes: int | None = Field(
None,
description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
" used to construct an example. If set to None, then the max staleness is set according to"
" the maximum forecast horizon of the NWP and the requested forecast length.",
)
public: bool = Field(False, description="Whether the NWP data is public or private")

@model_validator(mode="after")
def validate_accum_channels_subset(self) -> "NWP":
"""Validate accum_channels is subset of channels."""
invalid_channels = set(self.accum_channels) - set(self.channels)
if invalid_channels:
raise ValueError(
f"NWP provider '{self.provider}': all values in 'accum_channels' should "
f"be present in 'channels'. Extra values found: {invalid_channels}",
)
return self

@field_validator("provider")
def validate_provider(cls, v: str) -> str:
Expand All @@ -262,20 +279,24 @@ def validate_provider(cls, v: str) -> str:
return v




@model_validator(mode="after")
def check_all_channel_have_normalisation_constants(self) -> "NWP":
"""Check that all the channels have normalisation constants."""
normalisation_channels = set(self.normalisation_constants.keys())
non_accum_channels = [c for c in self.channels if c not in self.accum_channels]
accum_channel_names = [f"diff_{c}" for c in self.accum_channels]


missing_norm_values = set(non_accum_channels) - set(normalisation_channels)
if len(missing_norm_values)>0:
raise ValueError(
"Normalsation constants must be provided for all channels. Missing values for "
f"channels: {missing_norm_values}",
)


missing_norm_values = set(accum_channel_names) - set(normalisation_channels)
if len(missing_norm_values)>0:
raise ValueError(
Expand All @@ -287,56 +308,64 @@ def check_all_channel_have_normalisation_constants(self) -> "NWP":
return self




class MultiNWP(RootModel):
"""Configuration for multiple NWPs."""


root: dict[str, NWP]


@override
def __getattr__(self, item: str) -> NWP:
return self.root[item]


@override
def __getitem__(self, item: str) -> NWP:
return self.root[item]


@override
def __len__(self) -> int:
return len(self.root)


@override
def __iter__(self) -> Iterator:
return iter(self.root)


def keys(self) -> Iterator[str]:
"""Returns dictionary-like keys."""
return self.root.keys()


def items(self) -> Iterator[tuple[str, NWP]]:
"""Returns dictionary-like items."""
return self.root.items()




class GSP(TimeWindowMixin, DropoutMixin):
"""GSP configuration model."""


zarr_path: str = Field(
...,
description="Absolute or relative zarr filepath. Prefix with a protocol like s3:// "
"to read from alternative filesystems.",
)

boundaries_version: Literal["20220314", "20250109"] = Field(
"20220314",
description="Version of the GSP boundaries to use. Options are '20220314' or '20250109'.",
)

public: bool = Field(False, description="Whether the NWP data is public or private")


class Site(TimeWindowMixin, DropoutMixin):
"""Site configuration model."""


file_path: str = Field(
...,
description="The NetCDF files holding the power timeseries.",
Expand All @@ -346,26 +375,35 @@ class Site(TimeWindowMixin, DropoutMixin):
description="The CSV files describing power system",
)


# TODO validate the netcdf for sites
# TODO validate the csv for metadata




class SolarPosition(TimeWindowMixin):
"""Solar position configuration model."""




class InputData(Base):
"""Input data model."""


satellite: Satellite | None = None
nwp: MultiNWP | None = None
gsp: GSP | None = None
site: Site | None = None
solar_position: SolarPosition | None = None




class Configuration(Base):
"""Configuration model for the dataset."""


general: General = General()
input_data: InputData = InputData()
19 changes: 11 additions & 8 deletions ocf_data_sampler/load/load_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Loads all data sources."""

import logging
import xarray as xr

from ocf_data_sampler.config import InputData
from ocf_data_sampler.load import open_gsp, open_nwp, open_sat_data, open_site

logger = logging.getLogger(__name__)

def get_dataset_dict(
input_config: InputData,
Expand All @@ -25,7 +27,7 @@ def get_dataset_dict(
zarr_path=input_config.gsp.zarr_path,
boundaries_version=input_config.gsp.boundaries_version,
public=input_config.gsp.public,
)
).compute()

if gsp_ids is None:
# Remove national (gsp_id=0)
Expand All @@ -46,25 +48,26 @@ def get_dataset_dict(
)

da_nwp = da_nwp.sel(channel=list(nwp_config.channels))

datasets_dict["nwp"][nwp_source] = da_nwp

# Load satellite data if in config
if input_config.satellite:
sat_config = input_config.satellite

da_sat = open_sat_data(sat_config.zarr_path)

da_sat = da_sat.sel(channel=list(sat_config.channels))

logger.info(f"Loading satellite data from: {sat_config.zarr_path}")
# open_sat_data will now internally decide whether to use
# the standard Zarr loader or the Ice Chunk loader.
da_sat = open_sat_data(
zarr_path=sat_config.zarr_path,
channels=list(sat_config.channels),
)
datasets_dict["sat"] = da_sat

# Load site data if in config
if input_config.site:
da_sites = open_site(
generation_file_path=input_config.site.file_path,
metadata_file_path=input_config.site.metadata_file_path,
)

datasets_dict["site"] = da_sites

return datasets_dict
Loading