Skip to content

Commit 3375277

Browse files
authored
Merge pull request #564 from rumpl/fix-shell
Fix shell process cleanup
2 parents 5b2bcd6 + 3fca38f commit 3375277

File tree

8 files changed

+132
-105
lines changed

8 files changed

+132
-105
lines changed

.gitattributes

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Auto detect text files and normalize to LF
2+
* text=auto eol=lf
3+
4+
# Go source files
5+
*.go text eol=lf
6+
7+
# Shell scripts
8+
*.sh text eol=lf
9+
10+
# Windows specific files
11+
*.bat text eol=crlf
12+
*.cmd text eol=crlf
13+
*.ps1 text eol=crlf
14+
15+
# Binary files
16+
*.exe binary
17+
*.dll binary
18+
*.so binary
19+
*.dylib binary
20+
*.png binary
21+
*.jpg binary
22+
*.jpeg binary
23+
*.gif binary
24+
*.ico binary
25+
*.pdf binary
26+
*.zip binary
27+
*.tar binary
28+
*.gz binary

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ require (
3535
go.opentelemetry.io/otel/sdk v1.38.0
3636
go.opentelemetry.io/otel/trace v1.38.0
3737
golang.org/x/oauth2 v0.32.0
38+
golang.org/x/sys v0.37.0
3839
golang.org/x/term v0.36.0
3940
google.golang.org/genai v1.31.0
4041
modernc.org/sqlite v1.39.1
@@ -116,7 +117,6 @@ require (
116117
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect
117118
golang.org/x/net v0.46.0 // indirect
118119
golang.org/x/sync v0.17.0 // indirect
119-
golang.org/x/sys v0.37.0 // indirect
120120
golang.org/x/text v0.30.0 // indirect
121121
golang.org/x/time v0.14.0 // indirect
122122
google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1 // indirect

pkg/runtime/remote_runtime.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,6 @@ func (r *RemoteRuntime) CurrentAgent() *agent.Agent {
6565
return agent.New(r.currentAgent, fmt.Sprintf("Remote agent: %s", r.currentAgent))
6666
}
6767

68-
// StopPendingProcesses stops all pending tool operations for the remote runtime
69-
func (r *RemoteRuntime) StopPendingProcesses(ctx context.Context) error {
70-
// For remote runtime, stop the team's toolsets
71-
// This will kill any spawned processes from shell tools
72-
return r.team.StopToolSets(ctx)
73-
}
74-
7568
// RunStream starts the agent's interaction loop and returns a channel of events
7669
func (r *RemoteRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event {
7770
slog.Debug("Starting remote runtime stream", "agent", r.currentAgent, "session_id", r.sessionID)

pkg/runtime/runtime.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ type ElicitationRequestHandler func(ctx context.Context, message string, schema
6565
type Runtime interface {
6666
// CurrentAgent returns the currently active agent
6767
CurrentAgent() *agent.Agent
68-
// StopPendingProcesses stops all pending tool operations (e.g., running shell commands)
69-
StopPendingProcesses(ctx context.Context) error
7068
// RunStream starts the agent's interaction loop and returns a channel of events
7169
RunStream(ctx context.Context, sess *session.Session) <-chan Event
7270
// Run starts the agent's interaction loop and returns the final messages
@@ -180,10 +178,6 @@ func (r *runtime) CurrentAgent() *agent.Agent {
180178
return current
181179
}
182180

183-
func (r *runtime) StopPendingProcesses(ctx context.Context) error {
184-
return r.team.StopToolSets(ctx)
185-
}
186-
187181
// registerDefaultTools registers the default tool handlers
188182
func (r *runtime) registerDefaultTools() {
189183
slog.Debug("Registering default tools")

pkg/server/server.go

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ func New(sessionStore session.Store, runConfig config.RuntimeConfig, teams map[s
126126
group.POST("/sessions/:id/resume", s.resumeSession)
127127
// Create a new session and run an agent loop
128128
group.POST("/sessions", s.createSession)
129-
// Stop a running session
130-
group.POST("/sessions/:id/stop", s.stopSession)
131129
// Delete a session
132130
group.DELETE("/sessions/:id", s.deleteSession)
133131

@@ -984,36 +982,6 @@ func (s *Server) resumeSession(c echo.Context) error {
984982
return c.JSON(http.StatusOK, map[string]string{"message": "session resumed"})
985983
}
986984

987-
func (s *Server) stopSession(c echo.Context) error {
988-
sessionID := c.Param("id")
989-
990-
// Get the runtime for this session to access its team
991-
rt, rtExists := s.runtimes[sessionID]
992-
993-
// Cancel the runtime context if it's still running
994-
s.cancelsMu.Lock()
995-
if cancel, exists := s.runtimeCancels[sessionID]; exists {
996-
slog.Info("Stopping session execution", "session_id", sessionID)
997-
cancel()
998-
delete(s.runtimeCancels, sessionID)
999-
s.cancelsMu.Unlock()
1000-
1001-
// Stop all pending tool operations (including killing shell-spawned processes)
1002-
if rtExists {
1003-
if err := rt.StopPendingProcesses(c.Request().Context()); err != nil {
1004-
slog.Error("Failed to stop pending tools for session", "session_id", sessionID, "error", err)
1005-
// Don't return error here, as we still want to report success for stopping the session
1006-
}
1007-
}
1008-
1009-
return c.JSON(http.StatusOK, map[string]string{"message": "session stopped successfully"})
1010-
}
1011-
s.cancelsMu.Unlock()
1012-
1013-
slog.Debug("No active runtime found for session", "session_id", sessionID)
1014-
return c.JSON(http.StatusNotFound, map[string]string{"error": "no active session found"})
1015-
}
1016-
1017985
func (s *Server) deleteSession(c echo.Context) error {
1018986
sessionID := c.Param("id")
1019987

pkg/tools/builtin/cmd_unix.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@ import (
77
"syscall"
88
)
99

10+
type processGroup struct {
11+
// Unix doesn't need to store handles, process group is managed by kernel
12+
}
13+
1014
func platformSpecificSysProcAttr() *syscall.SysProcAttr {
1115
return &syscall.SysProcAttr{
1216
Setpgid: true,
1317
}
1418
}
1519

16-
func kill(proc *os.Process) error {
20+
func createProcessGroup(proc *os.Process) (*processGroup, error) {
21+
return &processGroup{}, nil
22+
}
23+
24+
func kill(proc *os.Process, pg *processGroup) error {
1725
return syscall.Kill(-proc.Pid, syscall.SIGTERM)
1826
}

pkg/tools/builtin/cmd_windows.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,70 @@ package builtin
33
import (
44
"os"
55
"syscall"
6+
"unsafe"
7+
8+
"golang.org/x/sys/windows"
69
)
710

11+
type processGroup struct {
12+
jobHandle windows.Handle
13+
processHandle windows.Handle
14+
}
15+
816
func platformSpecificSysProcAttr() *syscall.SysProcAttr {
917
return nil
1018
}
1119

12-
func kill(proc *os.Process) error {
20+
func createProcessGroup(proc *os.Process) (*processGroup, error) {
21+
job, err := windows.CreateJobObject(nil, nil)
22+
if err != nil {
23+
return nil, err
24+
}
25+
26+
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
27+
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
28+
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
29+
},
30+
}
31+
if _, err := windows.SetInformationJobObject(
32+
job,
33+
windows.JobObjectExtendedLimitInformation,
34+
uintptr(unsafe.Pointer(&info)),
35+
uint32(unsafe.Sizeof(info))); err != nil {
36+
_ = windows.CloseHandle(job)
37+
return nil, err
38+
}
39+
40+
handle, err := windows.OpenProcess(windows.PROCESS_SET_QUOTA|windows.PROCESS_TERMINATE, false, uint32(proc.Pid))
41+
if err != nil {
42+
_ = windows.CloseHandle(job)
43+
return nil, err
44+
}
45+
46+
if err := windows.AssignProcessToJobObject(job, handle); err != nil {
47+
_ = windows.CloseHandle(handle)
48+
_ = windows.CloseHandle(job)
49+
return nil, err
50+
}
51+
52+
return &processGroup{
53+
jobHandle: job,
54+
processHandle: handle,
55+
}, nil
56+
}
57+
58+
func kill(proc *os.Process, pg *processGroup) error {
59+
if pg != nil {
60+
// Close handles to trigger JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
61+
// which will terminate all processes in the job
62+
if pg.processHandle != 0 {
63+
_ = windows.CloseHandle(pg.processHandle)
64+
}
65+
if pg.jobHandle != 0 {
66+
_ = windows.CloseHandle(pg.jobHandle)
67+
}
68+
}
69+
70+
// Also call Kill on the process as a fallback
1371
return proc.Kill()
1472
}

pkg/tools/builtin/shell.go

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"os"
99
"os/exec"
1010
"runtime"
11-
"sync"
11+
"strings"
1212

1313
"github.com/docker/cagent/pkg/tools"
1414
)
@@ -25,8 +25,6 @@ type shellHandler struct {
2525
shell string
2626
shellArgsPrefix []string
2727
env []string
28-
mu sync.Mutex
29-
processes []*os.Process
3028
}
3129

3230
type RunShellArgs struct {
@@ -40,74 +38,67 @@ func (h *shellHandler) RunShell(ctx context.Context, toolCall tools.ToolCall) (*
4038
return nil, fmt.Errorf("invalid arguments: %w", err)
4139
}
4240

43-
cmd := exec.CommandContext(ctx, h.shell, append(h.shellArgsPrefix, params.Cmd)...)
41+
cmd := exec.Command(h.shell, append(h.shellArgsPrefix, params.Cmd)...)
4442
cmd.Env = h.env
4543
if params.Cwd != "" {
4644
cmd.Dir = params.Cwd
4745
} else {
48-
// Use the current working directory; avoid PWD on Windows (may be MSYS-style like /c/...)
4946
if wd, err := os.Getwd(); err == nil {
5047
cmd.Dir = wd
5148
}
5249
}
5350

54-
// Set up process group for proper cleanup
55-
// On Unix: create new process group so we can kill the entire tree
5651
cmd.SysProcAttr = platformSpecificSysProcAttr()
5752

58-
// Note: On Windows, we would set CreationFlags, but that requires
59-
// platform-specific code in a _windows.go file
60-
61-
// Capture output using buffers
62-
var outBuf, errBuf bytes.Buffer
53+
var outBuf bytes.Buffer
6354
cmd.Stdout = &outBuf
64-
cmd.Stderr = &errBuf
55+
cmd.Stderr = &outBuf
6556

66-
// Start the command so we can track it
6757
if err := cmd.Start(); err != nil {
6858
return &tools.ToolCallResult{
6959
Output: fmt.Sprintf("Error starting command: %s", err),
7060
}, nil
7161
}
7262

73-
// Track the process for cleanup
74-
h.mu.Lock()
75-
h.processes = append(h.processes, cmd.Process)
76-
h.mu.Unlock()
77-
78-
// Remove from tracking once complete
79-
defer func() {
80-
h.mu.Lock()
81-
for i, p := range h.processes {
82-
if p != nil && p.Pid == cmd.Process.Pid {
83-
h.processes = append(h.processes[:i], h.processes[i+1:]...)
84-
break
85-
}
86-
}
87-
h.mu.Unlock()
88-
}()
89-
90-
// Wait for the command to complete and get the result
91-
err := cmd.Wait()
92-
93-
// Combine stdout and stderr
94-
output := outBuf.String() + errBuf.String()
95-
63+
pg, err := createProcessGroup(cmd.Process)
9664
if err != nil {
9765
return &tools.ToolCallResult{
98-
Output: fmt.Sprintf("Error executing command: %s\nOutput: %s", err, output),
66+
Output: fmt.Sprintf("Error creating process group: %s", err),
9967
}, nil
10068
}
10169

102-
if output == "" {
70+
done := make(chan error, 1)
71+
go func() {
72+
done <- cmd.Wait()
73+
}()
74+
75+
select {
76+
case <-ctx.Done():
77+
if cmd.Process != nil {
78+
_ = kill(cmd.Process, pg)
79+
}
10380
return &tools.ToolCallResult{
104-
Output: "<no output>",
81+
Output: "Command cancelled",
10582
}, nil
106-
}
83+
case err := <-done:
84+
output := outBuf.String()
10785

108-
return &tools.ToolCallResult{
109-
Output: output,
110-
}, nil
86+
if err != nil {
87+
return &tools.ToolCallResult{
88+
Output: fmt.Sprintf("Error executing command: %s\nOutput: %s", err, output),
89+
}, nil
90+
}
91+
92+
if strings.TrimSpace(output) == "" {
93+
return &tools.ToolCallResult{
94+
Output: "<no output>",
95+
}, nil
96+
}
97+
98+
return &tools.ToolCallResult{
99+
Output: fmt.Sprintf("Output: %s", output),
100+
}, nil
101+
}
111102
}
112103

113104
func NewShellTool(env []string) *ShellTool {
@@ -236,18 +227,5 @@ func (t *ShellTool) Start(context.Context) error {
236227
}
237228

238229
func (t *ShellTool) Stop(context.Context) error {
239-
t.handler.mu.Lock()
240-
defer t.handler.mu.Unlock()
241-
242-
// Kill all tracked processes
243-
for _, proc := range t.handler.processes {
244-
if proc != nil {
245-
_ = kill(proc)
246-
}
247-
}
248-
249-
// Clear the processes list
250-
t.handler.processes = nil
251-
252230
return nil
253231
}

0 commit comments

Comments
 (0)