This repository was archived by the owner on Jan 24, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 4 files changed +15
-1
lines changed Expand file tree Collapse file tree 4 files changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -18,6 +18,7 @@ def is_true(value):
1818PORT = int (os .getenv ("PORT" , "80" ))
1919
2020# Model-related arguments:
21+ MODEL_PEFT = is_true (os .getenv ("MODEL_PEFT" , "" ))
2122MODEL_REVISION = os .getenv ("MODEL_REVISION" , "" )
2223MODEL_CACHE_DIR = os .getenv ("MODEL_CACHE_DIR" , "models" )
2324MODEL_LOAD_IN_8BIT = is_true (os .getenv ("MODEL_LOAD_IN_8BIT" , "" ))
Original file line number Diff line number Diff line change 2323from . import MODEL_LOAD_IN_4BIT
2424from . import MODEL_4BIT_QUANT_TYPE
2525from . import MODEL_4BIT_DOUBLE_QUANT
26+ from . import MODEL_PEFT
2627from . import MODEL_LOCAL_FILES_ONLY
2728from . import MODEL_TRUST_REMOTE_CODE
2829from . import MODEL_HALF_PRECISION
4445 name_or_path = MODEL ,
4546 revision = MODEL_REVISION ,
4647 cache_dir = MODEL_CACHE_DIR ,
48+ is_peft = MODEL_PEFT ,
4749 load_in_8bit = MODEL_LOAD_IN_8BIT ,
4850 load_in_4bit = MODEL_LOAD_IN_4BIT ,
4951 quant_type = MODEL_4BIT_QUANT_TYPE ,
Original file line number Diff line number Diff line change 1414 TopPLogitsWarper ,
1515 BitsAndBytesConfig
1616)
17+ from peft import (
18+ PeftConfig ,
19+ PeftModel
20+ )
1721
1822from .choice import map_choice
1923from .tokenizer import StreamTokenizer
@@ -310,6 +314,7 @@ def load_model(
310314 name_or_path ,
311315 revision = None ,
312316 cache_dir = None ,
317+ is_peft = False ,
313318 load_in_8bit = False ,
314319 load_in_4bit = False ,
315320 quant_type = "fp4" ,
@@ -327,7 +332,6 @@ def load_model(
327332 kwargs ["revision" ] = revision
328333 if cache_dir :
329334 kwargs ["cache_dir" ] = cache_dir
330- tokenizer = AutoTokenizer .from_pretrained (name_or_path , ** kwargs )
331335
332336 # Set device mapping and quantization options if CUDA is available.
333337 if torch .cuda .is_available ():
@@ -356,6 +360,12 @@ def load_model(
356360 if half_precision or load_in_8bit or load_in_4bit :
357361 kwargs ["torch_dtype" ] = torch .float16
358362
363+ if is_peft :
364+ peft_config = PeftConfig .from_pretrained (name_or_path )
365+ name_or_path = peft_config .base_model_name_or_path
366+
367+ tokenizer = AutoTokenizer .from_pretrained (name_or_path , ** kwargs )
368+
359369 # Support both decoder-only and encoder-decoder models.
360370 try :
361371 model = AutoModelForCausalLM .from_pretrained (name_or_path , ** kwargs )
Original file line number Diff line number Diff line change @@ -12,3 +12,4 @@ safetensors~=0.3.1
1212torch >= 1.12.1
1313transformers [sentencepiece ]~= 4.30.1
1414waitress ~= 2.1.2
15+ peft ~= 0.3.0
You can’t perform that action at this time.
0 commit comments