Skip to content

Commit ae774e7

Browse files
committed
changes
1 parent 6c78c4d commit ae774e7

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ def test_safetensors(self, config, act_pre_scale=False):
7474

7575
save_file(tensors_data_dict, f.name, metadata=metadata)
7676
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
77-
reconstructed_dict = unflatten_tensor_state_dict(
77+
leftover_tensor_data_dict, reconstructed_dict = unflatten_tensor_state_dict(
7878
tensors_data_dict, metadata
7979
)
80+
assert not leftover_tensor_data_dict
8081

8182
model = torch.nn.Sequential(
8283
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def unflatten_tensor_state_dict(
3434
'_data': {
3535
'block_size': [1,32],
3636
...
37-
}
37+
},
38+
'_tensor_data_names': ['qdata', 'scale']
3839
}
3940
'0.bias': {
4041
'_type': 'torch.Tensor',
@@ -66,33 +67,53 @@ def unflatten_tensor_state_dict(
6667

6768
tensor_names = json.loads(metadata["tensor_names"])
6869
result = {}
69-
70+
leftover_state_dict = tensors_data_dict.copy()
71+
print(tensors_data_dict.keys())
7072
for tensor_name in tensor_names:
73+
processed_tensors = []
74+
7175
module_fqn, weight_name = tensor_name.rsplit(".", 1)
7276

7377
prefix = f"{module_fqn}._{weight_name}_"
7478
tensor_tensors = {}
79+
7580
for key, value in combined_data.items():
7681
if key.startswith(prefix):
7782
# Remove the prefix
7883
tensor_tensors[key[len(prefix) :]] = value
84+
full_tensor_name_in_state_dict = key
85+
processed_tensors.append(
86+
full_tensor_name_in_state_dict
87+
) # for tensor subclass
7988

8089
tensor_metadata = json.loads(metadata.get(tensor_name))
8190
tensor_type = tensor_metadata.get("_type")
91+
complete_tensor_data = tensor_metadata.get("_tensor_data_names")
8292

8393
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
84-
if not tensor_tensors:
85-
# we allow the option of loading in state_dict info for a single tensor
86-
# if tensor state dict info is not loaded in yet, we wait for it to be provided
87-
# in a future call
94+
# if not all tensor data is present (ie missing qdata) we wait for it
95+
# to be loaded in from a future call
96+
if not len(tensor_tensors) is len(complete_tensor_data):
8897
continue
8998
tensor_metadata["_data"].update(tensor_tensors)
9099
result[tensor_name] = object_from_dict(tensor_metadata)
91100
elif tensor_type == torch.Tensor.__name__:
101+
# we allow the option of loading in state_dict info for a single tensor
102+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
103+
# in a future call
104+
if tensor_name not in tensors_data_dict.keys():
105+
continue
92106
result[tensor_name] = tensors_data_dict[tensor_name]
107+
processed_tensors.append(
108+
tensor_name
109+
) # add here because key for torch.Tensor has no prefix
93110
else:
94111
raise ValueError(f"Unsupported tensor type: {tensor_type}")
95-
return result
112+
113+
for tensor_name in processed_tensors:
114+
del leftover_state_dict[tensor_name]
115+
116+
return leftover_state_dict, result
96117

97118

98119
def flatten_tensor_state_dict(
@@ -125,7 +146,8 @@ def flatten_tensor_state_dict(
125146
'_data': {
126147
'block_size': [1,32],
127148
...
128-
}
149+
},
150+
'_tensor_data_names': ['qdata', 'scale']
129151
}
130152
'0.bias': {
131153
'_type': 'torch.Tensor',

torchao/prototype/safetensors/safetensors_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,23 @@ def default(self, o):
6060
encoded_attribute = self.encode_value(attribute)
6161
tensor_attr_dict[tensor_attribute_name] = encoded_attribute
6262

63-
return {"_type": o.__class__.__name__, "_data": tensor_attr_dict}
63+
optional_tensor_data = (
64+
o.optional_tensor_data_names
65+
if hasattr(o, "optional_tensor_data_names")
66+
else []
67+
)
68+
all_tensor_data = optional_tensor_data + o.tensor_data_names
69+
70+
_tensor_data_names = []
71+
for tensor_data_name in all_tensor_data:
72+
if getattr(o, tensor_data_name) is not None:
73+
_tensor_data_names.append(tensor_data_name)
74+
75+
return {
76+
"_type": o.__class__.__name__,
77+
"_data": tensor_attr_dict,
78+
"_tensor_data_names": _tensor_data_names,
79+
}
6480

6581
if hasattr(o, "_fields") and hasattr(
6682
o, "_asdict"

0 commit comments

Comments
 (0)