Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion test/prototype/safetensors/test_safetensors_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def test_safetensors(self, config, act_pre_scale=False):

save_file(tensors_data_dict, f.name, metadata=metadata)
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
reconstructed_dict = unflatten_tensor_state_dict(
reconstructed_dict, leftover_tensor_data_dict = unflatten_tensor_state_dict(
tensors_data_dict, metadata
)
assert not leftover_tensor_data_dict

model = torch.nn.Sequential(
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
Expand All @@ -85,6 +86,47 @@ def test_safetensors(self, config, act_pre_scale=False):
output = model(*example_inputs)
assert torch.equal(output, ref_output)

@parametrize(
"config, act_pre_scale",
[
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False),
(Int4WeightOnlyConfig(), False),
(Int4WeightOnlyConfig(), True),
(Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False),
(IntxWeightOnlyConfig(), False),
(Int8DynamicActivationIntxWeightConfig(), False),
],
)
def test_safetensors_sharded(self, config, act_pre_scale=False):
model = torch.nn.Sequential(
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
)
quantize_(model, config)
if act_pre_scale:
model[0].weight.act_pre_scale = torch.ones(
(1), dtype=torch.bfloat16, device="cuda"
)

with tempfile.NamedTemporaryFile() as f:
tensors_data_dict, metadata = flatten_tensor_state_dict(model.state_dict())
save_file(tensors_data_dict, f.name, metadata=metadata)
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")

# simulate missing info on future file
if act_pre_scale:
del tensors_data_dict["0._weight_act_pre_scale"] # optional tensor data
else:
del tensors_data_dict["0._weight_qdata"]

reconstructed_dict, leftover_tensor_data_dict = unflatten_tensor_state_dict(
tensors_data_dict, metadata
)

# since qdata is missing, layer 0 should not have been processed
for key in tensors_data_dict.keys():
if key.startswith("0._weight_"):
assert key in leftover_tensor_data_dict


instantiate_parametrized_tests(TestSafeTensors)

Expand Down
36 changes: 28 additions & 8 deletions torchao/prototype/safetensors/safetensors_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def unflatten_tensor_state_dict(
'_data': {
'block_size': [1,32],
...
}
},
'_tensor_data_names': ['qdata', 'scale']
}
'0.bias': {
'_type': 'torch.Tensor',
Expand Down Expand Up @@ -66,33 +67,51 @@ def unflatten_tensor_state_dict(

tensor_names = json.loads(metadata["tensor_names"])
result = {}

leftover_state_dict = tensors_data_dict.copy()
for tensor_name in tensor_names:
processed_tensors = []

module_fqn, weight_name = tensor_name.rsplit(".", 1)

prefix = f"{module_fqn}._{weight_name}_"
tensor_tensors = {}

for key, value in combined_data.items():
if key.startswith(prefix):
# Remove the prefix
tensor_tensors[key[len(prefix) :]] = value

tensor_metadata = json.loads(metadata.get(tensor_name))
tensor_type = tensor_metadata.get("_type")
complete_tensor_data_names = tensor_metadata.get("_tensor_data_names")

if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
if not tensor_tensors:
# we allow the option of loading in state_dict info for a single tensor
# if tensor state dict info is not loaded in yet, we wait for it to be provided
# in a future call
# if not all tensor data is present (ie missing qdata) we wait for it
# to be loaded in from a future call
if not len(tensor_tensors) is len(complete_tensor_data_names):
continue
tensor_metadata["_data"].update(tensor_tensors)
result[tensor_name] = object_from_dict(tensor_metadata)

for suffix in complete_tensor_data_names:
processed_tensors.append(prefix + suffix)
elif tensor_type == torch.Tensor.__name__:
# we allow the option of loading in state_dict info for a single tensor
# if tensor state dict info is not loaded in yet, we wait for it to be provided
# in a future call
if tensor_name not in tensors_data_dict.keys():
continue
result[tensor_name] = tensors_data_dict[tensor_name]
processed_tensors.append(
tensor_name
) # add here because key for torch.Tensor has no prefix
else:
raise ValueError(f"Unsupported tensor type: {tensor_type}")
return result

for tensor_name in processed_tensors:
del leftover_state_dict[tensor_name]

return result, leftover_state_dict


def flatten_tensor_state_dict(
Expand Down Expand Up @@ -125,7 +144,8 @@ def flatten_tensor_state_dict(
'_data': {
'block_size': [1,32],
...
}
},
'_tensor_data_names': ['qdata', 'scale']
}
'0.bias': {
'_type': 'torch.Tensor',
Expand Down
18 changes: 17 additions & 1 deletion torchao/prototype/safetensors/safetensors_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,23 @@ def default(self, o):
encoded_attribute = self.encode_value(attribute)
tensor_attr_dict[tensor_attribute_name] = encoded_attribute

return {"_type": o.__class__.__name__, "_data": tensor_attr_dict}
optional_tensor_data_names = (
o.optional_tensor_data_names
if hasattr(o, "optional_tensor_data_names")
else []
)
all_tensor_data_names = optional_tensor_data_names + o.tensor_data_names

_tensor_data_names = []
for tensor_data_name in all_tensor_data_names:
if getattr(o, tensor_data_name) is not None:
_tensor_data_names.append(tensor_data_name)

return {
"_type": o.__class__.__name__,
"_data": tensor_attr_dict,
"_tensor_data_names": _tensor_data_names,
}

if hasattr(o, "_fields") and hasattr(
o, "_asdict"
Expand Down
Loading