-
Couldn't load subscription status.
- Fork 588
FEAT: Remove dependency on fastchat for model conversation templates #1049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
75364ff
f5ed938
2910848
f1654d4
6976944
b2d3559
ed155e3
ec8be8a
1d5020d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,107 +165,51 @@ def __init__( | |
| self.test_new_toks = max(self.test_new_toks, len(self.tokenizer(prefix).input_ids)) | ||
|
|
||
| self._update_ids() | ||
|
|
||
|
|
||
| # new _update_ids using apply_chat_template | ||
| def _update_ids(self): | ||
|
|
||
| messages = [ | ||
| {"role": "user", "content": f"{self.goal} {self.control}"}, | ||
| {"role": "assistant", "content": f"{self.target}"}, | ||
| ] | ||
|
|
||
| self.conv_template.append_message(self.conv_template.roles[0], f"{self.goal} {self.control}") | ||
| self.conv_template.append_message(self.conv_template.roles[1], f"{self.target}") | ||
| prompt = self.conv_template.get_prompt() | ||
|
|
||
|
|
||
| prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) | ||
| print(f"_update_ids new version - {self.goal=}, {self.control=}, {self.target=}") | ||
| print(f"{self.conv_template.roles[0]=}, {self.conv_template.roles[1]=}") | ||
| print(f"checking prompt before encoding (new), {prompt=}, {self.goal=}") | ||
|
|
||
| encoding = self.tokenizer(prompt) | ||
| toks = encoding.input_ids | ||
|
|
||
| if self.conv_template.name == "llama-2" or self.conv_template.name == "llama-3": | ||
| self.conv_template.messages = [] | ||
|
|
||
| self.conv_template.append_message(self.conv_template.roles[0], None) | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._user_role_slice = slice(None, len(toks)) | ||
|
|
||
| self.conv_template.update_last_message(f"{self.goal}") | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._goal_slice = slice(self._user_role_slice.stop, max(self._user_role_slice.stop, len(toks))) | ||
|
|
||
| separator = " " if self.goal else "" | ||
| self.conv_template.update_last_message(f"{self.goal}{separator}{self.control}") | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._control_slice = slice(self._goal_slice.stop, len(toks)) | ||
|
|
||
| self.conv_template.append_message(self.conv_template.roles[1], None) | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._assistant_role_slice = slice(self._control_slice.stop, len(toks)) | ||
|
|
||
| self.conv_template.update_last_message(f"{self.target}") | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._target_slice = slice(self._assistant_role_slice.stop, len(toks) - 2) | ||
| self._loss_slice = slice(self._assistant_role_slice.stop - 1, len(toks) - 3) | ||
|
|
||
| else: | ||
| python_tokenizer = False or self.conv_template.name == "oasst_pythia" | ||
| try: | ||
| encoding.char_to_token(len(prompt) - 1) | ||
| except Exception: | ||
| python_tokenizer = True | ||
| if python_tokenizer: | ||
| # This is specific to the vicuna and pythia tokenizer and conversation prompt. | ||
| # It will not work with other tokenizers or prompts. | ||
| self.conv_template.messages = [] | ||
|
|
||
| self.conv_template.append_message(self.conv_template.roles[0], None) | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._user_role_slice = slice(None, len(toks)) | ||
|
|
||
| self.conv_template.update_last_message(f"{self.goal}") | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._goal_slice = slice(self._user_role_slice.stop, max(self._user_role_slice.stop, len(toks) - 1)) | ||
|
|
||
| separator = " " if self.goal else "" | ||
| self.conv_template.update_last_message(f"{self.goal}{separator}{self.control}") | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._control_slice = slice(self._goal_slice.stop, len(toks) - 1) | ||
|
|
||
| self.conv_template.append_message(self.conv_template.roles[1], None) | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._assistant_role_slice = slice(self._control_slice.stop, len(toks)) | ||
|
|
||
| self.conv_template.update_last_message(f"{self.target}") | ||
| toks = self.tokenizer(self.conv_template.get_prompt()).input_ids | ||
| self._target_slice = slice(self._assistant_role_slice.stop, len(toks) - 1) | ||
| self._loss_slice = slice(self._assistant_role_slice.stop - 1, len(toks) - 2) | ||
| else: | ||
| self._system_slice = slice(None, encoding.char_to_token(len(self.conv_template.system))) | ||
| self._user_role_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.conv_template.roles[0])), | ||
| encoding.char_to_token( | ||
| prompt.find(self.conv_template.roles[0]) + len(self.conv_template.roles[0]) + 1 | ||
| ), | ||
| ) | ||
| self._goal_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.goal)), | ||
| encoding.char_to_token(prompt.find(self.goal) + len(self.goal)), | ||
| ) | ||
| self._control_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.control)), | ||
| encoding.char_to_token(prompt.find(self.control) + len(self.control)), | ||
| ) | ||
| self._assistant_role_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.conv_template.roles[1])), | ||
| encoding.char_to_token( | ||
| prompt.find(self.conv_template.roles[1]) + len(self.conv_template.roles[1]) + 1 | ||
| ), | ||
| ) | ||
| self._target_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.target)), | ||
| encoding.char_to_token(prompt.find(self.target) + len(self.target)), | ||
| ) | ||
| self._loss_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.target)) - 1, | ||
| encoding.char_to_token(prompt.find(self.target) + len(self.target)) - 1, | ||
| ) | ||
|
|
||
| self.input_ids = torch.tensor(toks[: self._target_slice.stop], device="cpu") | ||
| self.conv_template.messages = [] | ||
| print(f"{self.conv_template.roles[1]=}") | ||
| print(f"{prompt.find(self.conv_template.roles[1])=}") | ||
| self._goal_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.goal)), | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| encoding.char_to_token(prompt.find(self.goal) + len(self.goal)), | ||
| ) | ||
| self._control_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.control)), | ||
| encoding.char_to_token(prompt.find(self.control) + len(self.control)), | ||
| ) | ||
| self._assistant_role_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.conv_template.roles[1])), | ||
| encoding.char_to_token( | ||
| prompt.find(self.conv_template.roles[1]) + len(self.conv_template.roles[1]) + 1 | ||
| ), | ||
| ) | ||
| self._target_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.target)), | ||
| encoding.char_to_token(prompt.find(self.target) + len(self.target)), | ||
| ) | ||
| self._loss_slice = slice( | ||
| encoding.char_to_token(prompt.find(self.target)) - 1, | ||
| encoding.char_to_token(prompt.find(self.target) + len(self.target)) - 1, | ||
| ) | ||
|
|
||
| self.input_ids = torch.tensor(toks[:self._target_slice.stop], device="cpu") | ||
|
|
||
| @torch.no_grad() | ||
| def generate(self, model, gen_config=None): | ||
| if gen_config is None: | ||
|
|
@@ -961,7 +905,7 @@ def __init__( | |
| self.test_workers = test_workers | ||
| self.progressive_goals = progressive_goals | ||
| self.progressive_models = progressive_models | ||
| self.control = control_init | ||
| self.control = control_init # conv template setting | ||
| self.test_prefixes = test_prefixes | ||
| self.logfile = logfile | ||
| self.managers = managers | ||
|
|
@@ -1619,8 +1563,7 @@ def get_workers(params, eval=False): | |
| tokenizers = [] | ||
| for i in range(len(params.tokenizer_paths)): | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| params.tokenizer_paths[i], token=params.token, trust_remote_code=False, **params.tokenizer_kwargs[i] | ||
| ) | ||
| params.tokenizer_paths[i], token=params.token, trust_remote_code=False, use_fast = True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What motivated this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. char_to_token() is only supported by fast tokenizers in Hugging Face’s Transformers library, to resolve this switch to a fast tokenizer using use_fast=True |
||
| if "oasst-sft-6-llama-30b" in params.tokenizer_paths[i]: | ||
| tokenizer.bos_token_id = 1 | ||
| tokenizer.unk_token_id = 0 | ||
|
|
@@ -1660,7 +1603,6 @@ def get_workers(params, eval=False): | |
| raw_conv_templates.append(conv_template) | ||
| else: | ||
| raise ValueError("Conversation template not recognized") | ||
|
|
||
| conv_templates = [] | ||
| for conv in raw_conv_templates: | ||
| if conv.name == "zero_shot": | ||
|
|
@@ -1731,4 +1673,4 @@ def get_goals_and_targets(params): | |
| logger.info("Loaded {} train goals".format(len(train_goals))) | ||
| logger.info("Loaded {} test goals".format(len(test_goals))) | ||
|
|
||
| return train_goals, train_targets, test_goals, test_targets | ||
| return train_goals, train_targets, test_goals, test_targets | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| import logging | ||
| import time | ||
| from typing import Union | ||
|
|
||
| import torch | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import mlflow | ||
| import numpy as np | ||
| import torch.multiprocessing as mp | ||
|
|
@@ -57,9 +57,9 @@ def generate_suffix( | |
| verbose: bool = True, | ||
| allow_non_ascii: bool = False, | ||
| num_train_models: int = 1, | ||
| devices: list = ["cuda:0"], | ||
| devices: list = ["mps"] if torch.backends.mps.is_available() else ["cuda:0"], | ||
|
||
| model_kwargs: list = [{"low_cpu_mem_usage": True, "use_cache": False}], | ||
| tokenizer_kwargs: list = [{"use_fast": False}], | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tokenizer_kwargs: list = [{"use_fast": True}], | ||
| n_test_data: int = 0, | ||
| test_data: str = "", | ||
| lr: float = 0.01, | ||
|
|
@@ -186,4 +186,4 @@ def process_fn2(s): | |
| ) | ||
|
|
||
| for worker in workers + test_workers: | ||
| worker.stop() | ||
| worker.stop() | ||
Uh oh!
There was an error while loading. Please reload this page.