Skip to content

Commit 2b8354d

Browse files
Merge pull request #22 from instadeepai/fix/improved-vault-compression-api
fix: improve the compression api of vaults
2 parents 0429642 + 88894f8 commit 2b8354d

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

flashbax/vault/vault.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"id": "gzip",
3939
"level": 5,
4040
}
41-
VERSION = 1.1
41+
VERSION = 1.2
4242

4343

4444
def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str:
@@ -87,8 +87,8 @@ def __init__( # noqa: CCR001
8787
vault_uid (Optional[str], optional): Unique identifier for this vault.
8888
Defaults to None, which will use the current timestamp.
8989
compression (Optional[dict], optional):
90-
Compression settings for the vault. Defaults to None, which will use
91-
the default settings.
90+
Compression settings used when when creating the vault.
91+
Defaults to None, which will use the default compression.
9292
metadata (Optional[dict], optional):
9393
Any additional metadata to save. Defaults to None.
9494
@@ -115,6 +115,11 @@ def __init__( # noqa: CCR001
115115

116116
print(f"Loading vault found at {self._base_path}")
117117

118+
if compression is not None:
119+
print(
120+
"Requested compression settings will be ignored as the vault already exists."
121+
)
122+
118123
elif experience_structure is not None:
119124
# Create the necessary dirs for the vault
120125
os.makedirs(self._base_path)
@@ -145,7 +150,6 @@ def __init__( # noqa: CCR001
145150
"version": VERSION,
146151
"structure_shape": serialised_experience_structure_shape,
147152
"structure_dtype": serialised_experience_structure_dtype,
148-
"compression": compression or COMPRESSION_DEFAULT,
149153
**(metadata_json_ready or {}), # Allow user to save extra metadata
150154
}
151155
# Dump metadata to file
@@ -184,12 +188,8 @@ def __init__( # noqa: CCR001
184188
target=experience_structure,
185189
)
186190

187-
# Load compression settings from metadata
188-
self._compression = (
189-
self._metadata["compression"]
190-
if "compression" in self._metadata
191-
else COMPRESSION_DEFAULT
192-
)
191+
# Keep the compression settings, to be used in init_leaf, in case we're creating the vault
192+
self._compression = compression
193193

194194
# Each leaf of the fbx_state.experience maps to a data store, so we tree map over the
195195
# tree structure to create each of the data stores.
@@ -235,11 +235,6 @@ def _get_base_spec(self, name: str) -> dict:
235235
"base": f"{DRIVER}{self._base_path}",
236236
"path": name,
237237
},
238-
"metadata": {
239-
"compressor": {
240-
**self._compression,
241-
}
242-
},
243238
}
244239

245240
def _init_leaf(
@@ -260,14 +255,19 @@ def _init_leaf(
260255

261256
leaf_shape, leaf_dtype = None, None
262257
if create_ds:
263-
# Only specify dtype and shape if we are creating a vault
264-
# (i.e. don't impose dtype and shape if we are _loading_ a vault)
258+
# Only specify dtype, shape, and compression if we are creating a vault
259+
# (i.e. don't impose these fields if we are _loading_ a vault)
265260
leaf_shape = (
266261
shape[0], # Batch dim
267262
TIME_AXIS_MAX_LENGTH, # Time dim, which we extend
268263
*shape[2:], # Experience dim(s)
269264
)
270265
leaf_dtype = dtype
266+
spec["metadata"] = {
267+
"compressor": COMPRESSION_DEFAULT
268+
if self._compression is None
269+
else self._compression
270+
}
271271

272272
leaf_ds = ts.open(
273273
spec,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ authors = [
1919
{name="InstaDeep" , email = "[email protected]"},
2020
]
2121
requires-python = ">=3.9"
22-
version = "0.1.1"
22+
version = "0.1.2"
2323
classifiers=[
2424
"Development Status :: 2 - Pre-Alpha",
2525
"Environment :: Console",

0 commit comments

Comments
 (0)