@@ -34,7 +34,8 @@ def unflatten_tensor_state_dict(
3434 '_data': {
3535 'block_size': [1,32],
3636 ...
37- }
37+ },
38+ '_tensor_data_names': ['qdata', 'scale']
3839 }
3940 '0.bias': {
4041 '_type': 'torch.Tensor',
@@ -66,33 +67,53 @@ def unflatten_tensor_state_dict(
6667
6768 tensor_names = json .loads (metadata ["tensor_names" ])
6869 result = {}
69-
70+ leftover_state_dict = tensors_data_dict .copy ()
71+ print (tensors_data_dict .keys ())
7072 for tensor_name in tensor_names :
73+ processed_tensors = []
74+
7175 module_fqn , weight_name = tensor_name .rsplit ("." , 1 )
7276
7377 prefix = f"{ module_fqn } ._{ weight_name } _"
7478 tensor_tensors = {}
79+
7580 for key , value in combined_data .items ():
7681 if key .startswith (prefix ):
7782 # Remove the prefix
7883 tensor_tensors [key [len (prefix ) :]] = value
84+ full_tensor_name_in_state_dict = key
85+ processed_tensors .append (
86+ full_tensor_name_in_state_dict
87+ ) # for tensor subclass
7988
8089 tensor_metadata = json .loads (metadata .get (tensor_name ))
8190 tensor_type = tensor_metadata .get ("_type" )
91+ complete_tensor_data = tensor_metadata .get ("_tensor_data_names" )
8292
8393 if tensor_type in ALLOWED_TENSORS_SUBCLASSES :
84- if not tensor_tensors :
85- # we allow the option of loading in state_dict info for a single tensor
86- # if tensor state dict info is not loaded in yet, we wait for it to be provided
87- # in a future call
94+ # if not all tensor data is present (ie missing qdata) we wait for it
95+ # to be loaded in from a future call
96+ if not len (tensor_tensors ) is len (complete_tensor_data ):
8897 continue
8998 tensor_metadata ["_data" ].update (tensor_tensors )
9099 result [tensor_name ] = object_from_dict (tensor_metadata )
91100 elif tensor_type == torch .Tensor .__name__ :
101+ # we allow the option of loading in state_dict info for a single tensor
102+ # if tensor state dict info is not loaded in yet, we wait for it to be provided
103+ # in a future call
104+ if tensor_name not in tensors_data_dict .keys ():
105+ continue
92106 result [tensor_name ] = tensors_data_dict [tensor_name ]
107+ processed_tensors .append (
108+ tensor_name
109+ ) # add here because key for torch.Tensor has no prefix
93110 else :
94111 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
95- return result
112+
113+ for tensor_name in processed_tensors :
114+ del leftover_state_dict [tensor_name ]
115+
116+ return leftover_state_dict , result
96117
97118
98119def flatten_tensor_state_dict (
@@ -125,7 +146,8 @@ def flatten_tensor_state_dict(
125146 '_data': {
126147 'block_size': [1,32],
127148 ...
128- }
149+ },
150+ '_tensor_data_names': ['qdata', 'scale']
129151 }
130152 '0.bias': {
131153 '_type': 'torch.Tensor',
0 commit comments