Skip to content

Commit 7efe4d5

Browse files
Merge pull request #441 from aaronchantrill/LLM
Allow Naomi to use a local LLM
2 parents 67a42c3 + 6ee080d commit 7efe4d5

File tree

21 files changed

+8291
-125
lines changed

21 files changed

+8291
-125
lines changed

naomi/application.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,32 @@ def settings(self):
696696
"default": True
697697
}
698698
),
699+
(
700+
("LLM", "enabled"), {
701+
"type": "boolean",
702+
"title": _("Should I use a large language model?"),
703+
"description": _("With this option enabled, I will attempt to run my output through a large language model before speaking. This can make me much more conversational, but can interfere with my functioning and make me take longer to respond."),
704+
"default": False
705+
}
706+
),
707+
(
708+
("LLM", "completion_url"), {
709+
"title": _("LLM endpoint URL"),
710+
"description": _("The url used for requesting text completion from your LLM"),
711+
"default": "http://localhost:8080/v1/completions",
712+
"active": lambda: profile.get(['LLM', 'enabled'], False)
713+
}
714+
),
715+
(
716+
("LLM", "template"), {
717+
"title": _("LLM Template type"),
718+
"type": "listbox",
719+
"description": _("The jinja2 template to use for formatting prompts for your LLM model."),
720+
"options": ['LLAMA3', 'LLAMA2', 'CHATML'],
721+
"default": "LLAMA3",
722+
"active": lambda: profile.get(['LLM', 'enabled'], False)
723+
}
724+
),
699725
(
700726
("email", "address"), {
701727
"type": "encrypted",

naomi/brain.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ def add_plugin(self, plugin):
2222
self._plugins.append(plugin)
2323
# print("Checking {} for intents".format(plugin._plugin_info.name))
2424
if hasattr(plugin, "intents"):
25-
self._intentparser.add_intents(plugin.intents())
25+
# pdb.set_trace()
26+
# Make sure every intent has an "allow_llm" property.
27+
# If not, then initialize it to True
28+
intents = plugin.intents()
29+
for intent in intents:
30+
if 'allow_llm' not in intents[intent]:
31+
intents[intent]['allow_llm'] = True
32+
self._intentparser.add_intents(intents)
2633

2734
def train(self):
2835
self._intentparser.train()

naomi/llama_client.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
import json
2+
import os
3+
import re
4+
import requests
5+
import sqlite3
6+
from datetime import datetime
7+
from jinja2 import Template
8+
from naomi import paths
9+
from naomi import profile
10+
from naomi import visualizations
11+
from typing import List, Sequence
12+
13+
14+
LLM_STOP_SEQUENCE = "<|eot_id|>" # End of sentence token for Meta-Llama-3
15+
TEMPLATES = {
16+
"LLAMA3": {
17+
'template': "".join([
18+
"{{ bos_token }}",
19+
"{% for message in messages %}",
20+
" {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}",
21+
"{% endfor %}",
22+
"{% if add_generation_prompt %}",
23+
" {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
24+
"{% endif %}"
25+
]),
26+
'eot_markers': ['<|eot_id|>']
27+
},
28+
"ALPACA": {
29+
'template': "".join(["""
30+
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
31+
32+
{% for message in messages %}
33+
{% if message['role'] == 'user' %}
34+
### Instruction:
35+
{{ message['content']|trim -}}
36+
{% elif message['role'] == 'assistant' %}
37+
### Response:
38+
{{ message['content']|trim -}}
39+
{% else %}
40+
### Input:
41+
{{ message['content']|trim -}}
42+
{% endif %}
43+
{% endfor %}
44+
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
45+
### Response:
46+
{% endif %}
47+
"""]),
48+
'eot_markers': ['###']
49+
},
50+
"CHATML": {
51+
'template': "".join([
52+
"{{ bos_token }}",
53+
"{% for message in messages %}",
54+
" {{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}",
55+
"{% endfor %}",
56+
"{% if add_generation_prompt %}",
57+
" {{ '<|im_start|>assistant\n' }}",
58+
"{% endif %}"
59+
]),
60+
'eot_markers': ['<|im_end|>', '\n']
61+
},
62+
"ZEPHYR": {
63+
'template': "\n".join([
64+
"{% for message in messages %}",
65+
"{{ '<|' + message['role'] + '|>' }}",
66+
"{{ message['content'] + '</s>' }}",
67+
"{% endfor %}",
68+
"{% if add_generation_prompt %}",
69+
"{{ '<|assistant|>' }}",
70+
"{% endif %}"
71+
]),
72+
'eot_markers': ['</s>', '\n']
73+
}
74+
}
75+
DEFAULT_PERSONALITY_PREPROMPT = [
76+
{
77+
"role": "system",
78+
"content": "The following is a friendly conversation between a human and an AI named {keywords}. {keyword} is capable of and willing to express feelings and opinions. If {keyword} does not know the answer to a question, she truthfully says she does not know. Responses should be limited to one or two sentences and be as concise as possible. Use the Context information above the Question for context. Do not restate the question or separate your response into context and question parts. The context is not part of the user's question, and is being provided to help you respond. Please do not mention the term “context” in your answer or say things like 'Based on the information you've provided...'. Do not provide extra explanation to your answer. You have the ability to shut yourself down. To do so, please include the token [shutdown] in your response.",
79+
},
80+
]
81+
82+
83+
class llama_client(object):
84+
@property
85+
def messages(self) -> Sequence[dict[str, str]]:
86+
return self._messages
87+
88+
def __init__(
89+
self,
90+
mic,
91+
completion_url: str,
92+
api_key: str | None = None,
93+
template: str = "LLAMA3",
94+
personality_preprompt: Sequence[dict[str, str]] = DEFAULT_PERSONALITY_PREPROMPT
95+
):
96+
self.mic = mic
97+
self.completion_url = completion_url
98+
self.prompt_headers = {'Authorization': api_key or "Bearer your_api_key_here"}
99+
self._messages = personality_preprompt
100+
# Add context from previous conversations to the _messages array
101+
conversationlog_path = paths.sub("conversationlog")
102+
if not os.path.exists(conversationlog_path):
103+
# Create the conversationlog folder
104+
os.makedirs(conversationlog_path)
105+
# Make sure the conversation log exists
106+
# The format of the conversation log will be json in the form:
107+
# {
108+
# "name": "System",
109+
# "is_user":false,
110+
# "is_system":true,
111+
# "send_date":"December 11, 2024 7:37pm",
112+
# "mes":"message"
113+
# }
114+
# {
115+
# "name":"User",
116+
# "is_user":true,
117+
# "is_system":false,
118+
# "send_date":"December 11, 2024 7:38pm",
119+
# "mes":"message"
120+
# }
121+
# {
122+
# "extra":{
123+
# "api":"llamacpp",
124+
# "model":"llama-2-7b-function-calling.Q3_K_M.gguf"
125+
# },
126+
# "name":"Assistant",
127+
# "is_user":false,
128+
# "is_system":false,
129+
# "send_date":"December 11, 2024 7:39pm",
130+
# "mes":"message",
131+
# "gen_started":"2024-12-12T00:38:53.656Z",
132+
# "gen_finished":"2024-12-12T00:39:04.331Z"
133+
# }
134+
# This is pretty close to the log format that sillytavern uses.
135+
# If I can make it so conversations can be passed back and forth
136+
# between SillyTavern and Naomi, I will.
137+
self._conversationlog = os.path.join(conversationlog_path, 'conversationlog.db')
138+
# Create the conversationlog table
139+
conn = sqlite3.connect(self._conversationlog)
140+
c = conn.cursor()
141+
c.execute(" ".join([
142+
"create table if not exists conversationlog(",
143+
" datetime,",
144+
" role,",
145+
" content",
146+
")"
147+
]))
148+
conn.commit()
149+
# Read in the last 10 active records
150+
c.execute(" ".join([
151+
"select",
152+
" *",
153+
"from conversationlog",
154+
"order by rowid"
155+
]))
156+
result = c.fetchall()
157+
print(result)
158+
for row in result:
159+
self._messages.append({'role': row[1], 'content': row[2]})
160+
conn.close()
161+
self.template = Template(TEMPLATES[template]['template'])
162+
self.eot_markers = TEMPLATES[template]['eot_markers']
163+
self.emoji_filter = re.compile("["
164+
U"\U0001F600-\U0001F64F" # emoticons
165+
U"\U0001F300-\U0001F5FF" # symbols & pictographs
166+
U"\U0001F680-\U0001F6FF" # transport & map symbols
167+
U"\U0001F1E0-\U0001F1FF" # flags (iOS)
168+
U"\U00002702-\U000027B0"
169+
U"\U000024C2-\U0001F251"
170+
U"\U0001F900-\U0001F9FF" # symbols and pictographs, extended
171+
U"\U0001F000-\U0001F0FF" # flags
172+
U"\U0001F180-\U0001F1FF" # flags
173+
"]+", re.UNICODE)
174+
175+
def append_message(self, role, content):
176+
"""Append a message to both internal _messages list and the log file"""
177+
self.messages.append({'role': role, 'content': content})
178+
conn = sqlite3.connect(self._conversationlog)
179+
c = conn.cursor()
180+
c.execute(
181+
" ".join([
182+
"insert into conversationlog(",
183+
" datetime,",
184+
" role,",
185+
" content",
186+
")values(?,?,?)"
187+
]),
188+
(
189+
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
190+
role,
191+
content
192+
)
193+
)
194+
conn.commit()
195+
conn.close()
196+
197+
def process_query(self, query, context):
198+
# self.messages.append({'role': 'system', 'content': context})
199+
if len(context) > 0:
200+
self.append_message('user', f"Context:\n{context}\n\nQuestion:\n{query}")
201+
else:
202+
self.append_message('user', f"Question:\n{query}")
203+
now = datetime.now()
204+
keywords = profile.get(['keyword'], ['NAOMI'])
205+
if isinstance(keywords, str):
206+
keywords = [keywords]
207+
keyword = keywords[0]
208+
keywords = " or ".join(keywords)
209+
# print(self.messages)
210+
prompt = self.template.render(
211+
messages=[{"role": message['role'], 'content': message['content'].format(t=now, keyword=keyword, keywords=keywords)} for message in self.messages],
212+
bos_token="<|begin_of_text|>",
213+
eos_token="<|end_of_text|>",
214+
add_generation_prompt=True
215+
)
216+
print(prompt)
217+
data = {
218+
"stream": True,
219+
"prompt": prompt
220+
}
221+
sentences = []
222+
try:
223+
with requests.post(
224+
self.completion_url,
225+
headers=self.prompt_headers,
226+
json=data,
227+
stream=True
228+
) as response:
229+
sentence = []
230+
tokens = ""
231+
for line in response.iter_lines():
232+
# print(f"Line: {line}")
233+
if line:
234+
line = self._clean_raw_bytes(line)
235+
# print(f"Line: {line}")
236+
next_token = self._process_line(line)
237+
if next_token:
238+
tokens += f"\x1b[36m*{next_token}* \x1b[0m"
239+
sentence.append(next_token)
240+
if next_token in [
241+
".",
242+
"!",
243+
"?",
244+
"?!",
245+
"\n",
246+
"\n\n"
247+
]:
248+
visualizations.run_visualization(
249+
"output",
250+
tokens
251+
)
252+
sentence = self._process_sentence(sentence)
253+
if not re.match("^\s*$", sentence):
254+
sentences.append(sentence)
255+
self.mic.say(self.emoji_filter.sub(r'', sentence).strip())
256+
tokens = ''
257+
sentence = []
258+
if next_token.strip() in self.eot_markers:
259+
break
260+
if sentence:
261+
visualizations.run_visualization(
262+
"output",
263+
tokens
264+
)
265+
sentence = self._process_sentence(sentence)
266+
if not re.match("^\s*$", sentence):
267+
self.mic.say(self.emoji_filter.sub(r'', sentence).strip())
268+
sentences.append(sentence)
269+
except requests.exceptions.ConnectionError:
270+
print(f"Error connecting to {self.completion_url}")
271+
self.mic.say(context)
272+
sentences = [context]
273+
finally:
274+
self.append_message('assistant', " ".join(sentences))
275+
276+
def _clean_raw_bytes(self, line):
277+
line = line.decode("utf-8")
278+
if line:
279+
line = line.removeprefix("data: ")
280+
# print(f"Line: {line}")
281+
if line == '[DONE]':
282+
line = '{"choices": [{"text": "' + self.eot_markers[0] + '"}]}'
283+
# print(f"Line: {line}")
284+
line = json.loads(line)
285+
return line
286+
287+
def _process_line(self, line):
288+
token = self.eot_markers[0]
289+
if 'error' in line:
290+
print(line['error'])
291+
else:
292+
if not (('stop' in line and line['stop']) or ('choices' in line and 'finish_reason' in line['choices'][0] and line['choices'][0]['finish_reason'] == 'stop')):
293+
token = line['choices'][0]['text']
294+
return token
295+
296+
def _process_sentence(self, current_sentence: List[str]):
297+
sentence = "".join(current_sentence)
298+
sentence = re.sub(r"\<\|im_end\|\>.*$", "", sentence)
299+
sentence = re.sub(r"\*.*?\*|\(.*?\)|\<\|.*?\|\>", "", sentence)
300+
# sentence = sentence.replace("\n\n", ", ").replace("\n", ", ").replace(" ", " ").strip()
301+
return sentence

0 commit comments

Comments
 (0)