Skip to content

Commit bed2d12

Browse files
authored
[feat!] Advanced stop logic (#184)
[breaking] Remove `PrtChatStop` for the general `PrtStop` command.
1 parent e527fe9 commit bed2d12

File tree

9 files changed

+497
-15
lines changed

9 files changed

+497
-15
lines changed

README.md

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ Additional useful commands are implemented through hooks (see below).
204204
| `PrtChatFinder` | Fuzzy search chat files using fzf |
205205
| `PrtChatDelete` | Delete the current chat file |
206206
| `PrtChatRespond` | Trigger chat respond (in chat file) |
207-
| `PrtStop` | Interrupt ongoing respond |
207+
| `PrtStop` | Interrupt any ongoing Parrot generation (works everywhere) |
208208
| `PrtProvider <provider>` | Switch the provider (empty arg triggers fzf) |
209209
| `PrtModel <model>` | Switch the interactive command model (empty arg triggers fzf). Note: Chat model must be changed from within the chat buffer. |
210210
| `PrtStatus` | Prints current provider and model selection |
@@ -403,7 +403,7 @@ This plugin provides the following default key mappings:
403403
|--------------|-------------------------------------------------------------|
404404
| `<C-g>c` | Opens a new chat via `PrtChatNew` |
405405
| `<C-g><C-g>` | Trigger the API to generate a response via `PrtChatRespond` |
406-
| `<C-g>s` | Stop the current text generation via `PrtStop` |
406+
| `<C-g>s` | Stop any ongoing Parrot generation via `PrtStop` |
407407
| `<C-g>d` | Delete the current chat file via `PrtChatDelete` |
408408

409409
### Provider Configuration Examples
@@ -942,6 +942,52 @@ or have suggestions for improving provider support.
942942
}
943943
```
944944

945+
## Cancellation
946+
947+
You can stop any ongoing Parrot generation at any time using multiple methods:
948+
949+
### Methods
950+
951+
1. **Keybinding**: `<C-g>s` (configurable via `chat_shortcut_stop`)
952+
2. **Command**: `:PrtStop` (works everywhere)
953+
954+
### Behavior
955+
956+
When you cancel a generation:
957+
958+
- **Immediate Termination**: The API request is stopped immediately
959+
- **Preserves Generated Text**: The text generated so far remains in the buffer
960+
- **Visual Feedback**: You receive a notification confirming the cancellation
961+
- **Preview Mode**: If cancelled during streaming, the preview won't be shown
962+
- **Multiple Jobs**: If multiple generations are running, all are stopped
963+
964+
### Autocommand Event
965+
966+
A `User PrtCancelled` event is fired when generation is cancelled, allowing you to create custom hooks:
967+
968+
```lua
969+
vim.api.nvim_create_autocmd("User", {
970+
pattern = "PrtCancelled",
971+
callback = function()
972+
-- Your custom logic here
973+
print("Parrot generation was cancelled")
974+
end,
975+
})
976+
```
977+
978+
### Advanced Usage
979+
980+
For buffer-specific cancellation in custom code:
981+
982+
```lua
983+
-- Stop only jobs for current buffer
984+
local chat_handler = require("parrot").chat_handler
985+
chat_handler:stop({ buffer = vim.api.nvim_get_current_buf() })
986+
987+
-- Stop without notification
988+
chat_handler:stop({ notify = false })
989+
```
990+
945991
## Bonus
946992

947993
Access parrot.nvim directly from your terminal:

lua/parrot/chat_handler.lua

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,19 +377,81 @@ function ChatHandler:Cmd(params)
377377
end
378378

379379
--- Stops all ongoing processes by killing associated jobs.
380-
---@param signal number | nil Signal to send to the processes.
381-
function ChatHandler:stop(signal)
380+
---@param options table|number|nil Options table or signal number for backwards compatibility
381+
--- - options.signal: Signal to send to processes (default: 15)
382+
--- - options.buffer: Buffer number to stop jobs for (nil = stop all)
383+
--- - options.notify: Show notification after stopping (default: true)
384+
function ChatHandler:stop(options)
385+
-- Backwards compatibility: if options is a number, treat it as signal
386+
if type(options) == "number" then
387+
options = { signal = options }
388+
end
389+
options = options or {}
390+
local signal = options.signal or 15
391+
local target_buf = options.buffer
392+
local show_notification = options.notify ~= false
393+
382394
if self.pool:is_empty() then
395+
if show_notification then
396+
logger.warning("No active Parrot processes to stop")
397+
end
383398
return
384399
end
385400

386-
for _, process_info in self.pool:ipairs() do
387-
if process_info.job.handle ~= nil and not process_info.job.handle:is_closing() then
388-
vim.uv.kill(process_info.job.pid, signal or 15)
401+
local stopped_count = 0
402+
local cancelled_queries = {}
403+
404+
-- Collect jobs to stop (either for specific buffer or all)
405+
for i = #self.pool._processes, 1, -1 do
406+
local process_info = self.pool._processes[i]
407+
local should_stop = target_buf == nil or process_info.buf == target_buf
408+
409+
if should_stop then
410+
-- Mark associated query as cancelled
411+
if process_info.qid then
412+
self.queries:mark_cancelled(process_info.qid, "user")
413+
table.insert(cancelled_queries, process_info.qid)
414+
end
415+
416+
-- Kill the job
417+
if process_info.job.handle ~= nil and not process_info.job.handle:is_closing() then
418+
vim.uv.kill(process_info.job.pid, signal)
419+
stopped_count = stopped_count + 1
420+
end
421+
422+
-- Remove from pool
423+
table.remove(self.pool._processes, i)
389424
end
390425
end
391426

392-
self.pool = Pool:new()
427+
-- Clean up extmarks only (preserve generated text)
428+
vim.schedule(function()
429+
for _, qid in ipairs(cancelled_queries) do
430+
local qt = self.queries:get(qid)
431+
if qt then
432+
-- Clear namespace/extmarks only
433+
if qt.ns_id and qt.buf and vim.api.nvim_buf_is_valid(qt.buf) then
434+
pcall(vim.api.nvim_buf_clear_namespace, qt.buf, qt.ns_id, 0, -1)
435+
end
436+
end
437+
end
438+
439+
-- Show notification
440+
if show_notification then
441+
if stopped_count > 0 then
442+
local msg = stopped_count == 1 and "Stopped 1 process" or string.format("Stopped %d processes", stopped_count)
443+
if target_buf then
444+
msg = msg .. " in current buffer"
445+
end
446+
logger.info(msg)
447+
else
448+
logger.warning("No active processes found" .. (target_buf and " in current buffer" or ""))
449+
end
450+
end
451+
452+
-- Fire autocmd event for user hooks
453+
vim.cmd("doautocmd User PrtCancelled")
454+
end)
393455
end
394456

395457
--- Context command
@@ -1889,9 +1951,21 @@ function ChatHandler:query(buf, provider, payload, handler, on_exit)
18891951
end,
18901952
})
18911953
job:start()
1892-
self.pool:add(job, buf)
1954+
1955+
-- Determine target type for better tracking
1956+
local target_type = "unknown"
1957+
if buf and vim.api.nvim_buf_is_valid(buf) then
1958+
local file_name = vim.api.nvim_buf_get_name(buf)
1959+
local utils_module = require("parrot.utils")
1960+
if utils_module.is_chat(buf, file_name, self.options.chat_dir) then
1961+
target_type = "chat"
1962+
end
1963+
end
1964+
1965+
self.pool:add(job, buf, qid, target_type)
18931966
logger.debug("ChatHandler:query pool updated", {
18941967
pool = self.pool,
1968+
qid = qid,
18951969
})
18961970
end
18971971

lua/parrot/config.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ function M.setup(opts)
371371

372372
M.cmd = {
373373
ChatFinder = "chat_finder",
374-
ChatStop = "stop",
374+
Stop = "stop",
375375
ChatNew = "chat_new",
376376
ChatToggle = "chat_toggle",
377377
ChatPaste = "chat_paste",

lua/parrot/pool.lua

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,16 @@ end
1010
--- Adds a process to the pool.
1111
--- @param job table # A plenary job.
1212
--- @param buf number|nil # The buffer number (optional)
13-
function Pool:add(job, buf)
14-
table.insert(self._processes, { job = job, buf = buf })
13+
--- @param qid string|nil # The query ID (optional)
14+
--- @param target_type string|nil # The target type (optional) - "chat", "rewrite", "popup", etc.
15+
function Pool:add(job, buf, qid, target_type)
16+
table.insert(self._processes, {
17+
job = job,
18+
buf = buf,
19+
qid = qid,
20+
target_type = target_type,
21+
timestamp = os.time(),
22+
})
1523
end
1624

1725
--- Checks if there is no other process running for the given buffer.
@@ -51,4 +59,61 @@ function Pool:ipairs()
5159
return ipairs(self._processes)
5260
end
5361

62+
--- Gets processes for a specific buffer.
63+
--- @param buf number # The buffer number
64+
--- @return table # List of process info tables for the buffer
65+
function Pool:get_for_buffer(buf)
66+
local result = {}
67+
for _, process_info in ipairs(self._processes) do
68+
if process_info.buf == buf then
69+
table.insert(result, process_info)
70+
end
71+
end
72+
return result
73+
end
74+
75+
--- Gets the most recent active job (by timestamp).
76+
--- @return table|nil # The most recent process info, or nil if pool is empty
77+
function Pool:get_active_job()
78+
if self:is_empty() then
79+
return nil
80+
end
81+
82+
local most_recent = nil
83+
for _, process_info in ipairs(self._processes) do
84+
if most_recent == nil or process_info.timestamp > most_recent.timestamp then
85+
most_recent = process_info
86+
end
87+
end
88+
return most_recent
89+
end
90+
91+
--- Stops jobs for a specific buffer.
92+
--- @param buf number # The buffer number
93+
--- @param signal number|nil # Signal to send (default 15)
94+
--- @return number # Number of jobs stopped
95+
function Pool:stop_buffer(buf, signal)
96+
signal = signal or 15
97+
local stopped_count = 0
98+
99+
for i = #self._processes, 1, -1 do
100+
local process_info = self._processes[i]
101+
if process_info.buf == buf then
102+
if process_info.job.handle ~= nil and not process_info.job.handle:is_closing() then
103+
vim.uv.kill(process_info.job.pid, signal)
104+
stopped_count = stopped_count + 1
105+
end
106+
table.remove(self._processes, i)
107+
end
108+
end
109+
110+
return stopped_count
111+
end
112+
113+
--- Checks if there are any active jobs in the pool.
114+
--- @return boolean # True if there are active jobs, false otherwise
115+
function Pool:has_active_jobs()
116+
return not self:is_empty()
117+
end
118+
54119
return Pool

lua/parrot/preview_response_handler.lua

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ function PreviewResponseHandler:handle_chunk(qid, chunk)
9999
return
100100
end
101101

102+
-- Check if query was cancelled
103+
if qt.cancelled then
104+
-- For preview, we don't show the preview if cancelled
105+
logger.debug("PreviewResponseHandler: Query cancelled, aborting preview")
106+
return
107+
end
108+
102109
if chunk and chunk ~= "" then
103110
self.response = self.response .. chunk
104111
qt.response = self.response
@@ -269,7 +276,14 @@ end
269276
--- Creates a completion handler that shows the preview
270277
---@return function
271278
function PreviewResponseHandler:create_completion_handler()
272-
return vim.schedule_wrap(function(_)
279+
return vim.schedule_wrap(function(qid)
280+
-- Check if cancelled before showing preview
281+
local qt = self.queries:get(qid)
282+
if qt and qt.cancelled then
283+
logger.debug("PreviewResponseHandler: Completion cancelled, not showing preview")
284+
return
285+
end
286+
273287
-- Clean up the response (remove code fences, etc.)
274288
self.response = self
275289
.response

lua/parrot/queries.lua

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ end
1313
--- @param qid number # Query ID.
1414
--- @param data table # Query data.
1515
function Queries:add(qid, data)
16+
-- Initialize cancellation state
17+
data.cancelled = false
18+
data.cancellation_reason = nil
19+
data.cancellation_time = nil
1620
self._queries[qid] = data
1721
end
1822

@@ -38,6 +42,39 @@ function Queries:get(qid)
3842
return self._queries[qid]
3943
end
4044

45+
--- Gets queries for a specific buffer.
46+
--- @param buf number # The buffer number
47+
--- @return table # List of query IDs for the buffer
48+
function Queries:get_for_buffer(buf)
49+
local result = {}
50+
for qid, query_data in pairs(self._queries) do
51+
if query_data.buf == buf then
52+
table.insert(result, qid)
53+
end
54+
end
55+
return result
56+
end
57+
58+
--- Marks a query as cancelled.
59+
--- @param qid string # Query ID.
60+
--- @param reason string|nil # Cancellation reason (optional)
61+
function Queries:mark_cancelled(qid, reason)
62+
local query = self._queries[qid]
63+
if query then
64+
query.cancelled = true
65+
query.cancellation_reason = reason or "user"
66+
query.cancellation_time = os.time()
67+
end
68+
end
69+
70+
--- Checks if a query was cancelled.
71+
--- @param qid string # Query ID.
72+
--- @return boolean # True if cancelled, false otherwise.
73+
function Queries:is_cancelled(qid)
74+
local query = self._queries[qid]
75+
return query and query.cancelled or false
76+
end
77+
4178
--- Cleans up old queries from the collection based on the specified criteria.
4279
--- @param N number # Number of queries to keep.
4380
--- @param age number # Age of queries to keep in seconds.

lua/parrot/response_handler.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ function ResponseHandler:handle_chunk(qid, chunk)
6060
return
6161
end
6262

63+
-- Check if query was cancelled - just stop processing, preserve existing text
64+
if qt.cancelled then
65+
return
66+
end
67+
6368
if not self.skip_first_undojoin then
6469
utils.undojoin(self.buffer)
6570
end

tests/parrot/pool_spec.lua

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ describe("Pool", function()
1818
local job = { pid = 1 }
1919
local buf = 10
2020
pool:add(job, buf)
21-
assert.are.same({ { job = job, buf = buf } }, pool._processes)
21+
assert.equals(1, #pool._processes)
22+
assert.are.same(job, pool._processes[1].job)
23+
assert.equals(buf, pool._processes[1].buf)
24+
assert.is_not_nil(pool._processes[1].timestamp)
2225
end)
2326
end)
2427
end)
@@ -57,7 +60,9 @@ describe("Pool", function()
5760
pool:add(job1, 10)
5861
pool:add(job2, 20)
5962
pool:remove(1)
60-
assert.are.same({ { job = job2, buf = 20 } }, pool._processes)
63+
assert.equals(1, #pool._processes)
64+
assert.are.same(job2, pool._processes[1].job)
65+
assert.equals(20, pool._processes[1].buf)
6166
end)
6267
end)
6368
end)

0 commit comments

Comments
 (0)