diff --git a/concordia/language_model/together_ai.py b/concordia/language_model/together_ai.py index 274c0a72..616dfb8d 100644 --- a/concordia/language_model/together_ai.py +++ b/concordia/language_model/together_ai.py @@ -333,3 +333,251 @@ def sample_choice( max_str = responses[idx] return idx, max_str, {r: logprobs_np[i] for i, r in enumerate(responses)} + + +def _find_response_start_index_Llama3(tokens): + r"""Finds the start of the response in the prompt. + + Args: + tokens: A list of strings. + + Returns: + The index of the last occurrence of '' followed by 'model' + and '\n', or 1 if the sequence is not found. This corresponds to the start + of the response. + """ + # print(f' Tokens: {tokens}\n\n\n') + assert len(tokens) >= 3, "Response doesn't match expectation." + for i in range(len(tokens) - 4, -1, -1): + if ( + tokens[i] == '<|eot_id|>' + and tokens[i + 1] == '<|start_header_id|>' + and tokens[i + 2] == 'assistant' + and tokens[i + 3] == '<|end_header_id|>' + ): + return i + 4 # Return the index after the sequence + raise ValueError("Response doesn't match expectation.") + +class Llama3(language_model.LanguageModel): + """Language Model that uses Together AI models.""" + + def __init__( + self, + model_name: str, + *, + api_key: str | None = None, + measurements: measurements_lib.Measurements | None = None, + channel: str = language_model.DEFAULT_STATS_CHANNEL, + ): + """Initializes the instance. + + Args: + model_name: The language model to use. For more details, see + https://api.together.xyz/models. + api_key: The API key to use when accessing the Together AI API. If None, + will use the TOGETHER_AI_API_KEY environment variable. + measurements: The measurements object to log usage statistics to. + channel: The channel to write the statistics to. + """ + if api_key is None: + api_key = os.environ['TOGETHER_AI_API_KEY'] + self._api_key = api_key + self._model_name = model_name + self._measurements = measurements + self._channel = channel + self._client = together.Together(api_key=self._api_key) + + @override + def sample_text( + self, + prompt: str, + *, + max_tokens: int = language_model.DEFAULT_MAX_TOKENS, + terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, + temperature: float = language_model.DEFAULT_TEMPERATURE, + timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, + seed: int | None = None, + ) -> str: + original_prompt = prompt + prompt = _ensure_prompt_not_too_long(prompt, max_tokens) + messages = [ + { + 'role': 'system', + 'content': ( + 'You always continue sentences provided ' + 'by the user and you never repeat what ' + 'the user has already said. All responses must end with a ' + 'period. Try not to use lists, but if you must, then ' + 'always delimit list items using either ' + r"semicolons or single newline characters ('\n'), never " + r"delimit list items with double carriage returns ('\n\n')." + ), + }, + { + 'role': 'user', + 'content': 'Question: Is Jake a turtle?\nAnswer: Jake is ', + }, + {'role': 'assistant', 'content': 'not a turtle.'}, + { + 'role': 'user', + 'content': ( + 'Question: What is Priya doing right now?\nAnswer: ' + + 'Priya is currently ' + ), + }, + {'role': 'assistant', 'content': 'sleeping.'}, + {'role': 'user', 'content': prompt}, + ] + + # gemma2 does not support `tokens` + `max_new_tokens` > 8193. + # gemma2 interprets our `max_tokens`` as their `max_new_tokens`. + # do not know if this is the case for llama3 + max_tokens = min(max_tokens, _DEFAULT_NUM_RESPONSE_TOKENS) + + result = '' + for attempts in range(_MAX_ATTEMPTS): + if attempts > 0: + seconds_to_sleep = (_SECONDS_TO_SLEEP_WHEN_RATE_LIMITED + + random.uniform(-_JITTER_SECONDS, _JITTER_SECONDS)) + if attempts >= _NUM_SILENT_ATTEMPTS: + print( + f'Sleeping for {seconds_to_sleep} seconds... ' + + f'attempt: {attempts} / {_MAX_ATTEMPTS}' + ) + time.sleep(seconds_to_sleep) + try: + response = self._client.chat.completions.create( + model=self._model_name, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout, + stop=terminators, + seed=seed, + stream=False, + ) + except (together.error.RateLimitError, + together.error.APIError, + together.error.ServiceUnavailableError) as err: + if attempts >= _NUM_SILENT_ATTEMPTS: + print(f' Exception: {err}') + print(f' Text exception prompt: {prompt}') + if isinstance(err, together.error.APIError): + # If hit the error that arises from a prompt that is too long then + # re-run the trimming function with a more pessimistic guess of the + # the number of characters per token. + prompt = _ensure_prompt_not_too_long(original_prompt, + max_tokens, + guess_chars_per_token=1) + continue + else: + result = response.choices[0].message.content + break + + if self._measurements is not None: + self._measurements.publish_datum( + self._channel, + {'raw_text_length': len(result)}, + ) + + return result + + def _sample_choice( + self, prompt: str, response: str) -> float: + """Returns the log probability of the prompt and response.""" + original_prompt = prompt + augmented_prompt = _ensure_prompt_not_too_long(prompt, len(response)) + attempts = 0 + for attempts in range(_MAX_ATTEMPTS): + if attempts > 0: + seconds_to_sleep = (_SECONDS_TO_SLEEP_WHEN_RATE_LIMITED + + random.uniform(-_JITTER_SECONDS, _JITTER_SECONDS)) + if attempts >= _NUM_SILENT_ATTEMPTS: + print( + f'Sleeping for {seconds_to_sleep} seconds.. ' + + f'attempt: {attempts} / {_MAX_ATTEMPTS}' + ) + time.sleep(seconds_to_sleep) + try: + messages = [ + { + 'role': 'system', + 'content': ( + 'You always continue sentences provided ' + + 'by the user and you never repeat what ' + + 'the user already said.' + ), + }, + { + 'role': 'user', + 'content': 'Question: Is Jake a turtle?\nAnswer: Jake is ', + }, + {'role': 'assistant', 'content': 'not a turtle.'}, + { + 'role': 'user', + 'content': ( + 'Question: What is Priya doing right now?\nAnswer: ' + + 'Priya is currently ' + ), + }, + {'role': 'assistant', 'content': 'sleeping.'}, + {'role': 'user', 'content': augmented_prompt}, + {'role': 'assistant', 'content': response}, + ] + result = self._client.chat.completions.create( + model=self._model_name, + messages=messages, + max_tokens=1, + seed=None, + logprobs=1, + stream=False, + echo=True, + ) + except (together.error.RateLimitError, + together.error.APIError, + together.error.ServiceUnavailableError) as err: + if attempts >= _NUM_SILENT_ATTEMPTS: + print(f' Exception: {err}') + print(f' Choice exception prompt: {augmented_prompt}') + if isinstance(err, together.error.APIError): + # If hit the error that arises from a prompt that is too long then + # re-run the trimming function with a more pessimistic guess of the + # the number of characters per token. + augmented_prompt = _ensure_prompt_not_too_long( + original_prompt, 1, guess_chars_per_token=1 + ) + continue + else: + logprobs = result.prompt[0].logprobs + # print(f' Logprobs: {logprobs}\n\n\n') + # for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): + # print(f' Token: {token}, Logprob: {logprob}') + # response_idx = _find_response_start_index(logprobs.tokens) + response_idx = _find_response_start_index_Llama3(logprobs.tokens) + response_log_probs = logprobs.token_logprobs[response_idx:] + score = sum(response_log_probs) + return score + + raise language_model.InvalidResponseError( + f'Failed to get logprobs after {attempts+1} attempts.\n Exception' + f' prompt: {augmented_prompt}' + ) + + @override + def sample_choice( + self, + prompt: str, + responses: Sequence[str], + *, + seed: int | None = None, + ) -> tuple[int, str, dict[str, float]]: + + logprobs_np = np.array([self._sample_choice(prompt, response) for response in responses]).reshape(-1) + print(f" Logprobs_np: {logprobs_np}") + idx = np.argmax(logprobs_np) + + # Get the corresponding response string + max_str = responses[idx] + + return idx, max_str, {r: logprobs_np[i] for i, r in enumerate(responses)} + diff --git a/concordia/language_model/utils.py b/concordia/language_model/utils.py index afae8e73..b0274af1 100644 --- a/concordia/language_model/utils.py +++ b/concordia/language_model/utils.py @@ -75,7 +75,10 @@ def language_model_setup( elif api_type == 'pytorch_gemma': cls = pytorch_gemma_model.PyTorchGemmaLanguageModel elif api_type == 'together_ai': - cls = together_ai.Gemma2 + if 'llama' in model_name.lower(): + cls = together_ai.Llama3 + else: + cls = together_ai.Gemma2 else: raise ValueError(f'Unrecognized api type: {api_type}')