|
| 1 | +import json |
| 2 | +import os |
| 3 | +import re |
| 4 | +import requests |
| 5 | +import sqlite3 |
| 6 | +from datetime import datetime |
| 7 | +from jinja2 import Template |
| 8 | +from naomi import paths |
| 9 | +from naomi import profile |
| 10 | +from naomi import visualizations |
| 11 | +from typing import List, Sequence |
| 12 | + |
| 13 | + |
| 14 | +LLM_STOP_SEQUENCE = "<|eot_id|>" # End of sentence token for Meta-Llama-3 |
| 15 | +TEMPLATES = { |
| 16 | + "LLAMA3": { |
| 17 | + 'template': "".join([ |
| 18 | + "{{ bos_token }}", |
| 19 | + "{% for message in messages %}", |
| 20 | + " {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}", |
| 21 | + "{% endfor %}", |
| 22 | + "{% if add_generation_prompt %}", |
| 23 | + " {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", |
| 24 | + "{% endif %}" |
| 25 | + ]), |
| 26 | + 'eot_markers': ['<|eot_id|>'] |
| 27 | + }, |
| 28 | + "ALPACA": { |
| 29 | + 'template': "".join([""" |
| 30 | +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} |
| 31 | +
|
| 32 | +{% for message in messages %} |
| 33 | +{% if message['role'] == 'user' %} |
| 34 | +### Instruction: |
| 35 | +{{ message['content']|trim -}} |
| 36 | +{% elif message['role'] == 'assistant' %} |
| 37 | +### Response: |
| 38 | +{{ message['content']|trim -}} |
| 39 | +{% else %} |
| 40 | +### Input: |
| 41 | +{{ message['content']|trim -}} |
| 42 | +{% endif %} |
| 43 | +{% endfor %} |
| 44 | +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} |
| 45 | +### Response: |
| 46 | +{% endif %} |
| 47 | + """]), |
| 48 | + 'eot_markers': ['###'] |
| 49 | + }, |
| 50 | + "CHATML": { |
| 51 | + 'template': "".join([ |
| 52 | + "{{ bos_token }}", |
| 53 | + "{% for message in messages %}", |
| 54 | + " {{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}", |
| 55 | + "{% endfor %}", |
| 56 | + "{% if add_generation_prompt %}", |
| 57 | + " {{ '<|im_start|>assistant\n' }}", |
| 58 | + "{% endif %}" |
| 59 | + ]), |
| 60 | + 'eot_markers': ['<|im_end|>', '\n'] |
| 61 | + }, |
| 62 | + "ZEPHYR": { |
| 63 | + 'template': "\n".join([ |
| 64 | + "{% for message in messages %}", |
| 65 | + "{{ '<|' + message['role'] + '|>' }}", |
| 66 | + "{{ message['content'] + '</s>' }}", |
| 67 | + "{% endfor %}", |
| 68 | + "{% if add_generation_prompt %}", |
| 69 | + "{{ '<|assistant|>' }}", |
| 70 | + "{% endif %}" |
| 71 | + ]), |
| 72 | + 'eot_markers': ['</s>', '\n'] |
| 73 | + } |
| 74 | +} |
| 75 | +DEFAULT_PERSONALITY_PREPROMPT = [ |
| 76 | + { |
| 77 | + "role": "system", |
| 78 | + "content": "The following is a friendly conversation between a human and an AI named {keywords}. {keyword} is capable of and willing to express feelings and opinions. If {keyword} does not know the answer to a question, she truthfully says she does not know. Responses should be limited to one or two sentences and be as concise as possible. Use the Context information above the Question for context. Do not restate the question or separate your response into context and question parts. The context is not part of the user's question, and is being provided to help you respond. Please do not mention the term “context” in your answer or say things like 'Based on the information you've provided...'. Do not provide extra explanation to your answer. You have the ability to shut yourself down. To do so, please include the token [shutdown] in your response.", |
| 79 | + }, |
| 80 | +] |
| 81 | + |
| 82 | + |
| 83 | +class llama_client(object): |
| 84 | + @property |
| 85 | + def messages(self) -> Sequence[dict[str, str]]: |
| 86 | + return self._messages |
| 87 | + |
| 88 | + def __init__( |
| 89 | + self, |
| 90 | + mic, |
| 91 | + completion_url: str, |
| 92 | + api_key: str | None = None, |
| 93 | + template: str = "LLAMA3", |
| 94 | + personality_preprompt: Sequence[dict[str, str]] = DEFAULT_PERSONALITY_PREPROMPT |
| 95 | + ): |
| 96 | + self.mic = mic |
| 97 | + self.completion_url = completion_url |
| 98 | + self.prompt_headers = {'Authorization': api_key or "Bearer your_api_key_here"} |
| 99 | + self._messages = personality_preprompt |
| 100 | + # Add context from previous conversations to the _messages array |
| 101 | + conversationlog_path = paths.sub("conversationlog") |
| 102 | + if not os.path.exists(conversationlog_path): |
| 103 | + # Create the conversationlog folder |
| 104 | + os.makedirs(conversationlog_path) |
| 105 | + # Make sure the conversation log exists |
| 106 | + # The format of the conversation log will be json in the form: |
| 107 | + # { |
| 108 | + # "name": "System", |
| 109 | + # "is_user":false, |
| 110 | + # "is_system":true, |
| 111 | + # "send_date":"December 11, 2024 7:37pm", |
| 112 | + # "mes":"message" |
| 113 | + # } |
| 114 | + # { |
| 115 | + # "name":"User", |
| 116 | + # "is_user":true, |
| 117 | + # "is_system":false, |
| 118 | + # "send_date":"December 11, 2024 7:38pm", |
| 119 | + # "mes":"message" |
| 120 | + # } |
| 121 | + # { |
| 122 | + # "extra":{ |
| 123 | + # "api":"llamacpp", |
| 124 | + # "model":"llama-2-7b-function-calling.Q3_K_M.gguf" |
| 125 | + # }, |
| 126 | + # "name":"Assistant", |
| 127 | + # "is_user":false, |
| 128 | + # "is_system":false, |
| 129 | + # "send_date":"December 11, 2024 7:39pm", |
| 130 | + # "mes":"message", |
| 131 | + # "gen_started":"2024-12-12T00:38:53.656Z", |
| 132 | + # "gen_finished":"2024-12-12T00:39:04.331Z" |
| 133 | + # } |
| 134 | + # This is pretty close to the log format that sillytavern uses. |
| 135 | + # If I can make it so conversations can be passed back and forth |
| 136 | + # between SillyTavern and Naomi, I will. |
| 137 | + self._conversationlog = os.path.join(conversationlog_path, 'conversationlog.db') |
| 138 | + # Create the conversationlog table |
| 139 | + conn = sqlite3.connect(self._conversationlog) |
| 140 | + c = conn.cursor() |
| 141 | + c.execute(" ".join([ |
| 142 | + "create table if not exists conversationlog(", |
| 143 | + " datetime,", |
| 144 | + " role,", |
| 145 | + " content", |
| 146 | + ")" |
| 147 | + ])) |
| 148 | + conn.commit() |
| 149 | + # Read in the last 10 active records |
| 150 | + c.execute(" ".join([ |
| 151 | + "select", |
| 152 | + " *", |
| 153 | + "from conversationlog", |
| 154 | + "order by rowid" |
| 155 | + ])) |
| 156 | + result = c.fetchall() |
| 157 | + print(result) |
| 158 | + for row in result: |
| 159 | + self._messages.append({'role': row[1], 'content': row[2]}) |
| 160 | + conn.close() |
| 161 | + self.template = Template(TEMPLATES[template]['template']) |
| 162 | + self.eot_markers = TEMPLATES[template]['eot_markers'] |
| 163 | + self.emoji_filter = re.compile("[" |
| 164 | + U"\U0001F600-\U0001F64F" # emoticons |
| 165 | + U"\U0001F300-\U0001F5FF" # symbols & pictographs |
| 166 | + U"\U0001F680-\U0001F6FF" # transport & map symbols |
| 167 | + U"\U0001F1E0-\U0001F1FF" # flags (iOS) |
| 168 | + U"\U00002702-\U000027B0" |
| 169 | + U"\U000024C2-\U0001F251" |
| 170 | + U"\U0001F900-\U0001F9FF" # symbols and pictographs, extended |
| 171 | + U"\U0001F000-\U0001F0FF" # flags |
| 172 | + U"\U0001F180-\U0001F1FF" # flags |
| 173 | + "]+", re.UNICODE) |
| 174 | + |
| 175 | + def append_message(self, role, content): |
| 176 | + """Append a message to both internal _messages list and the log file""" |
| 177 | + self.messages.append({'role': role, 'content': content}) |
| 178 | + conn = sqlite3.connect(self._conversationlog) |
| 179 | + c = conn.cursor() |
| 180 | + c.execute( |
| 181 | + " ".join([ |
| 182 | + "insert into conversationlog(", |
| 183 | + " datetime,", |
| 184 | + " role,", |
| 185 | + " content", |
| 186 | + ")values(?,?,?)" |
| 187 | + ]), |
| 188 | + ( |
| 189 | + datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| 190 | + role, |
| 191 | + content |
| 192 | + ) |
| 193 | + ) |
| 194 | + conn.commit() |
| 195 | + conn.close() |
| 196 | + |
| 197 | + def process_query(self, query, context): |
| 198 | + # self.messages.append({'role': 'system', 'content': context}) |
| 199 | + if len(context) > 0: |
| 200 | + self.append_message('user', f"Context:\n{context}\n\nQuestion:\n{query}") |
| 201 | + else: |
| 202 | + self.append_message('user', f"Question:\n{query}") |
| 203 | + now = datetime.now() |
| 204 | + keywords = profile.get(['keyword'], ['NAOMI']) |
| 205 | + if isinstance(keywords, str): |
| 206 | + keywords = [keywords] |
| 207 | + keyword = keywords[0] |
| 208 | + keywords = " or ".join(keywords) |
| 209 | + # print(self.messages) |
| 210 | + prompt = self.template.render( |
| 211 | + messages=[{"role": message['role'], 'content': message['content'].format(t=now, keyword=keyword, keywords=keywords)} for message in self.messages], |
| 212 | + bos_token="<|begin_of_text|>", |
| 213 | + eos_token="<|end_of_text|>", |
| 214 | + add_generation_prompt=True |
| 215 | + ) |
| 216 | + print(prompt) |
| 217 | + data = { |
| 218 | + "stream": True, |
| 219 | + "prompt": prompt |
| 220 | + } |
| 221 | + sentences = [] |
| 222 | + try: |
| 223 | + with requests.post( |
| 224 | + self.completion_url, |
| 225 | + headers=self.prompt_headers, |
| 226 | + json=data, |
| 227 | + stream=True |
| 228 | + ) as response: |
| 229 | + sentence = [] |
| 230 | + tokens = "" |
| 231 | + for line in response.iter_lines(): |
| 232 | + # print(f"Line: {line}") |
| 233 | + if line: |
| 234 | + line = self._clean_raw_bytes(line) |
| 235 | + # print(f"Line: {line}") |
| 236 | + next_token = self._process_line(line) |
| 237 | + if next_token: |
| 238 | + tokens += f"\x1b[36m*{next_token}* \x1b[0m" |
| 239 | + sentence.append(next_token) |
| 240 | + if next_token in [ |
| 241 | + ".", |
| 242 | + "!", |
| 243 | + "?", |
| 244 | + "?!", |
| 245 | + "\n", |
| 246 | + "\n\n" |
| 247 | + ]: |
| 248 | + visualizations.run_visualization( |
| 249 | + "output", |
| 250 | + tokens |
| 251 | + ) |
| 252 | + sentence = self._process_sentence(sentence) |
| 253 | + if not re.match("^\s*$", sentence): |
| 254 | + sentences.append(sentence) |
| 255 | + self.mic.say(self.emoji_filter.sub(r'', sentence).strip()) |
| 256 | + tokens = '' |
| 257 | + sentence = [] |
| 258 | + if next_token.strip() in self.eot_markers: |
| 259 | + break |
| 260 | + if sentence: |
| 261 | + visualizations.run_visualization( |
| 262 | + "output", |
| 263 | + tokens |
| 264 | + ) |
| 265 | + sentence = self._process_sentence(sentence) |
| 266 | + if not re.match("^\s*$", sentence): |
| 267 | + self.mic.say(self.emoji_filter.sub(r'', sentence).strip()) |
| 268 | + sentences.append(sentence) |
| 269 | + except requests.exceptions.ConnectionError: |
| 270 | + print(f"Error connecting to {self.completion_url}") |
| 271 | + self.mic.say(context) |
| 272 | + sentences = [context] |
| 273 | + finally: |
| 274 | + self.append_message('assistant', " ".join(sentences)) |
| 275 | + |
| 276 | + def _clean_raw_bytes(self, line): |
| 277 | + line = line.decode("utf-8") |
| 278 | + if line: |
| 279 | + line = line.removeprefix("data: ") |
| 280 | + # print(f"Line: {line}") |
| 281 | + if line == '[DONE]': |
| 282 | + line = '{"choices": [{"text": "' + self.eot_markers[0] + '"}]}' |
| 283 | + # print(f"Line: {line}") |
| 284 | + line = json.loads(line) |
| 285 | + return line |
| 286 | + |
| 287 | + def _process_line(self, line): |
| 288 | + token = self.eot_markers[0] |
| 289 | + if 'error' in line: |
| 290 | + print(line['error']) |
| 291 | + else: |
| 292 | + if not (('stop' in line and line['stop']) or ('choices' in line and 'finish_reason' in line['choices'][0] and line['choices'][0]['finish_reason'] == 'stop')): |
| 293 | + token = line['choices'][0]['text'] |
| 294 | + return token |
| 295 | + |
| 296 | + def _process_sentence(self, current_sentence: List[str]): |
| 297 | + sentence = "".join(current_sentence) |
| 298 | + sentence = re.sub(r"\<\|im_end\|\>.*$", "", sentence) |
| 299 | + sentence = re.sub(r"\*.*?\*|\(.*?\)|\<\|.*?\|\>", "", sentence) |
| 300 | + # sentence = sentence.replace("\n\n", ", ").replace("\n", ", ").replace(" ", " ").strip() |
| 301 | + return sentence |
0 commit comments