Skip to content

Commit 42c18e7

Browse files
committed
reduce size of LLM prompts
* truncate text/binary sample data fields to 1024 characters (or smaller if judged to be needed) * truncate entire tables from schema representation if the representation is very large * for latency improvement, cache sample data and schema representation, passing the dbname in both cases to invalidate the cache if changing the db * add separate progress message when generating sample data The target_size values are chosen somewhat arbitrarily. We could also apply final size limits to the prompt string, though meaning-preserving truncation at that point is harder. Addresses #1348.
1 parent 643e410 commit 42c18e7

File tree

4 files changed

+73
-31
lines changed

4 files changed

+73
-31
lines changed

changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Upcoming (TBD)
22
==============
33

4+
Features
5+
--------
6+
* Limit size of LLM prompts and cache LLM prompt data.
7+
8+
49
Internal
510
--------
611
* Include LLM dependencies in tox configuration.

mycli/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,9 +795,10 @@ def one_iteration(text: str | None = None) -> None:
795795
while special.is_llm_command(text):
796796
start = time()
797797
try:
798+
assert isinstance(self.sqlexecute, SQLExecute)
798799
assert sqlexecute.conn is not None
799800
cur = sqlexecute.conn.cursor()
800-
context, sql, duration = special.handle_llm(text, cur)
801+
context, sql, duration = special.handle_llm(text, cur, sqlexecute.dbname or '')
801802
if context:
802803
click.echo("LLM Response:")
803804
click.echo(context)

mycli/packages/special/llm.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def cli_commands() -> list[str]:
212212
return list(cli.commands.keys())
213213

214214

215-
def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
215+
def handle_llm(text: str, cur: Cursor, dbname: str) -> tuple[str, str | None, float]:
216216
_, verbosity, arg = parse_special_command(text)
217217
if not LLM_IMPORTED:
218218
output = [(None, None, None, NEED_DEPENDENCIES)]
@@ -261,7 +261,7 @@ def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
261261
try:
262262
ensure_mycli_template()
263263
start = time()
264-
context, sql = sql_using_llm(cur=cur, question=arg)
264+
context, sql = sql_using_llm(cur=cur, question=arg, dbname=dbname)
265265
end = time()
266266
if verbosity == Verbosity.SUCCINCT:
267267
context = ""
@@ -275,45 +275,81 @@ def is_llm_command(command: str) -> bool:
275275
return cmd in ("\\llm", "\\ai")
276276

277277

278-
def sql_using_llm(
279-
cur: Cursor | None,
280-
question: str | None = None,
281-
) -> tuple[str, str | None]:
282-
if cur is None:
283-
raise RuntimeError("Connect to a database and try again.")
284-
schema_query = """
285-
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')')
278+
def truncate_list_elements(row: list) -> list:
279+
target_size = 100000
280+
width = 1024
281+
while width >= 0:
282+
truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row]
283+
if sum(sys.getsizeof(x) for x in truncated_row) <= target_size:
284+
break
285+
width -= 100
286+
return truncated_row
287+
288+
289+
def truncate_table_lines(table: list[str]) -> list[str]:
290+
target_size = 100000
291+
truncated_table = []
292+
running_sum = 0
293+
while table and running_sum <= target_size:
294+
line = table.pop(0)
295+
running_sum += sys.getsizeof(line)
296+
truncated_table.append(line)
297+
return truncated_table
298+
299+
300+
@functools.cache
301+
def get_schema(cur: Cursor, dbname: str) -> str:
302+
click.echo("Preparing schema information to feed the LLM")
303+
schema_query = f"""
304+
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS schema
286305
FROM information_schema.columns
287-
WHERE table_schema = DATABASE()
306+
WHERE table_schema = '{dbname}'
288307
GROUP BY table_name
289308
ORDER BY table_name
290309
"""
291-
tables_query = "SHOW TABLES"
292-
sample_row_query = "SELECT * FROM `{table}` LIMIT 1"
293-
click.echo("Preparing schema information to feed the llm")
294310
cur.execute(schema_query)
295-
db_schema = "\n".join([row[0] for (row,) in cur.fetchall()])
311+
db_schema = [row[0] for (row,) in cur.fetchall()]
312+
return '\n'.join(truncate_table_lines(db_schema))
313+
314+
315+
@functools.cache
316+
def get_sample_data(cur: Cursor, dbname: str) -> dict[str, Any]:
317+
click.echo("Preparing sample data to feed the LLM")
318+
tables_query = "SHOW TABLES"
319+
sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1"
296320
cur.execute(tables_query)
297321
sample_data = {}
298322
for (table_name,) in cur.fetchall():
299323
try:
300-
cur.execute(sample_row_query.format(table=table_name))
324+
cur.execute(sample_row_query.format(dbname=dbname, table=table_name))
301325
except Exception:
302326
continue
303327
cols = [desc[0] for desc in cur.description]
304328
row = cur.fetchone()
305329
if row is None:
306330
continue
307-
sample_data[table_name] = list(zip(cols, row))
331+
sample_data[table_name] = list(zip(cols, truncate_list_elements(list(row))))
332+
return sample_data
333+
334+
335+
def sql_using_llm(
336+
cur: Cursor | None,
337+
question: str | None,
338+
dbname: str = '',
339+
) -> tuple[str, str | None]:
340+
if cur is None:
341+
raise RuntimeError("Connect to a database and try again.")
342+
if dbname == '':
343+
raise RuntimeError("Choose a schema and try again.")
308344
args = [
309345
"--template",
310346
LLM_TEMPLATE_NAME,
311347
"--param",
312348
"db_schema",
313-
db_schema,
349+
get_schema(cur, dbname),
314350
"--param",
315351
"sample_data",
316-
sample_data,
352+
get_sample_data(cur, dbname),
317353
"--param",
318354
"question",
319355
question,

test/test_llm_special.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_llm_command_without_args(mock_llm, executor):
2626
assert mock_llm is not None
2727
test_text = r"\llm"
2828
with pytest.raises(FinishIteration) as exc_info:
29-
handle_llm(test_text, executor)
29+
handle_llm(test_text, executor, 'mysql')
3030
# Should return usage message when no args provided
3131
assert exc_info.value.args[0] == [(None, None, None, USAGE)]
3232

@@ -38,7 +38,7 @@ def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor):
3838
mock_run_cmd.return_value = (0, "Hello, no SQL today.")
3939
test_text = r"\llm -c 'Something?'"
4040
with pytest.raises(FinishIteration) as exc_info:
41-
handle_llm(test_text, executor)
41+
handle_llm(test_text, executor, 'mysql')
4242
# Expect raw output when no SQL fence found
4343
assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")]
4444

@@ -51,7 +51,7 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor
5151
fenced = f"Here you go:\n```sql\n{sql_text}\n```"
5252
mock_run_cmd.return_value = (0, fenced)
5353
test_text = r"\llm -c 'Rewrite SQL'"
54-
result, sql, duration = handle_llm(test_text, executor)
54+
result, sql, duration = handle_llm(test_text, executor, 'mysql')
5555
# Without verbose, result is empty, sql extracted
5656
assert sql == sql_text
5757
assert result == ""
@@ -64,7 +64,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
6464
# 'models' is a known subcommand
6565
test_text = r"\llm models"
6666
with pytest.raises(FinishIteration) as exc_info:
67-
handle_llm(test_text, executor)
67+
handle_llm(test_text, executor, 'mysql')
6868
mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False)
6969
assert exc_info.value.args[0] is None
7070

@@ -74,7 +74,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
7474
def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
7575
test_text = r"\llm --help"
7676
with pytest.raises(FinishIteration) as exc_info:
77-
handle_llm(test_text, executor)
77+
handle_llm(test_text, executor, 'mysql')
7878
mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False)
7979
assert exc_info.value.args[0] is None
8080

@@ -84,7 +84,7 @@ def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
8484
def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor):
8585
test_text = r"\llm install openai"
8686
with pytest.raises(FinishIteration) as exc_info:
87-
handle_llm(test_text, executor)
87+
handle_llm(test_text, executor, 'mysql')
8888
mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True)
8989
assert exc_info.value.args[0] is None
9090

@@ -98,7 +98,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_
9898
"""
9999
mock_sql_using_llm.return_value = ("CTX", "SELECT 1;")
100100
test_text = r"\llm prompt 'Test?'"
101-
context, sql, duration = handle_llm(test_text, executor)
101+
context, sql, duration = handle_llm(test_text, executor, 'mysql')
102102
mock_ensure_template.assert_called_once()
103103
mock_sql_using_llm.assert_called()
104104
assert context == "CTX"
@@ -115,7 +115,7 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ
115115
"""
116116
mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;")
117117
test_text = r"\llm 'Top 10?'"
118-
context, sql, duration = handle_llm(test_text, executor)
118+
context, sql, duration = handle_llm(test_text, executor, 'mysql')
119119
mock_ensure_template.assert_called_once()
120120
mock_sql_using_llm.assert_called()
121121
assert context == "CTX2"
@@ -132,7 +132,7 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template,
132132
"""
133133
mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;")
134134
test_text = r"\llm- 'Succinct?'"
135-
context, sql, duration = handle_llm(test_text, executor)
135+
context, sql, duration = handle_llm(test_text, executor, 'mysql')
136136
assert context == ""
137137
assert sql == "SELECT 42;"
138138
assert isinstance(duration, float)
@@ -181,7 +181,7 @@ def fetchone(self):
181181
sql_text = "SELECT 1, 'abc';"
182182
fenced = f"Note\n```sql\n{sql_text}\n```"
183183
mock_run_cmd.return_value = (0, fenced)
184-
result, sql = sql_using_llm(dummy_cur, question="dummy")
184+
result, sql = sql_using_llm(dummy_cur, question="dummy", dbname='mysql')
185185
assert result == fenced
186186
assert sql == sql_text
187187

@@ -194,5 +194,5 @@ def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch):
194194

195195
monkeypatch.setattr(llm_module, "llm", object())
196196
with pytest.raises(FinishIteration) as exc_info:
197-
handle_llm(prefix, executor)
197+
handle_llm(prefix, executor, 'mysql')
198198
assert exc_info.value.args[0] == [(None, None, None, USAGE)]

0 commit comments

Comments
 (0)