11import os
22from pathlib import Path
33from typing import Any , Dict , List , Optional , Tuple , Union
4- import transformers
54
65import torch
6+ import transformers
77from torch import nn
88
99from xturing .engines .causal import CausalEngine , CausalLoraEngine
1010from xturing .engines .llama_utils import LlamaConfig , LlamaForCausalLM , LlamaTokenizer
1111from xturing .engines .lora_engine import prepare_model_for_int8_training
12- from xturing .engines .quant_utils import make_quant , autotune_warmup
12+ from xturing .engines .quant_utils import autotune_warmup , make_quant
1313from xturing .utils .hub import ModelHub
1414
15+
1516class LLamaEngine (CausalEngine ):
1617 config_name : str = "llama_engine"
1718
@@ -102,24 +103,28 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
102103 target_modules = ["q_proj" , "v_proj" ],
103104 )
104105
105- def find_layers (module , layers = [nn .Conv2d , nn .Linear ], name = '' ):
106+
107+ def find_layers (module , layers = [nn .Conv2d , nn .Linear ], name = "" ):
106108 if type (module ) in layers :
107109 return {name : module }
108110 res = {}
109111 for name1 , child in module .named_children ():
110- res .update (find_layers (
111- child , layers = layers , name = name + '.' + name1 if name != '' else name1
112- ))
112+ res .update (
113+ find_layers (
114+ child , layers = layers , name = name + "." + name1 if name != "" else name1
115+ )
116+ )
113117 return res
114118
119+
115120class LlamaLoraInt4Engine (CausalLoraEngine ):
116121 config_name : str = "llama_lora_int4_engine"
117122
118123 def __init__ (self , weights_path : Optional [Union [str , Path ]] = None ):
119- model_name = "decapoda-research/llama-7b-hf"
124+ model_name = "decapoda-research/llama-7b-hf"
120125
121126 if weights_path is None :
122- weights_path = ModelHub ().load ("x/llama_lora_int4" )
127+ weights_path = ModelHub ().load ("x/llama_lora_int4" )
123128
124129 config = LlamaConfig .from_pretrained (model_name )
125130
@@ -129,10 +134,10 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
129134
130135 def noop (* args , ** kwargs ):
131136 pass
132-
133- torch .nn .init .kaiming_uniform_ = noop
134- torch .nn .init .uniform_ = noop
135- torch .nn .init .normal_ = noop
137+
138+ torch .nn .init .kaiming_uniform_ = noop
139+ torch .nn .init .uniform_ = noop
140+ torch .nn .init .normal_ = noop
136141
137142 torch .set_default_dtype (torch .half )
138143 transformers .modeling_utils ._init_weights = False
@@ -143,18 +148,23 @@ def noop(*args, **kwargs):
143148
144149 layers = find_layers (model )
145150
146- for name in [' lm_head' ]:
151+ for name in [" lm_head" ]:
147152 if name in layers :
148153 del layers [name ]
149-
154+
150155 wbits = 4
151156 groupsize = 128
152- warmup_autotune = True
153-
157+ warmup_autotune = True
158+
154159 make_quant (model , layers , wbits , groupsize )
155-
156160
157- model .load_state_dict (torch .load (weights_path / Path ("pytorch_model.bin" )), strict = False )
161+ state_dict = torch .load (
162+ weights_path / Path ("pytorch_model.bin" ), map_location = "cpu"
163+ )
164+ new_state_dict = {}
165+ for key , value in state_dict .items ():
166+ new_state_dict [key [6 :]] = value
167+ model .load_state_dict (new_state_dict , strict = False )
158168
159169 if warmup_autotune :
160170 autotune_warmup (model )
@@ -171,12 +181,12 @@ def noop(*args, **kwargs):
171181 tokenizer .pad_token_id = tokenizer .eos_token_id
172182
173183 super ().__init__ (
174- model = model ,
184+ model = model ,
175185 tokenizer = tokenizer ,
176186 target_modules = [
177187 "q_proj" ,
178188 "v_proj" ,
179- ]
189+ ],
180190 )
181191
182192 torch .nn .init .kaiming_uniform_ = saved_kaiming_uniform_
0 commit comments