diff --git a/cmd/root/acp.go b/cmd/root/acp.go new file mode 100644 index 00000000..7d472767 --- /dev/null +++ b/cmd/root/acp.go @@ -0,0 +1,50 @@ +package root + +import ( + "log/slog" + "os" + + acpsdk "github.com/coder/acp-go-sdk" + "github.com/spf13/cobra" + + "github.com/docker/cagent/pkg/acp" + "github.com/docker/cagent/pkg/telemetry" +) + +// NewACPCmd creates a new acp command +func NewACPCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "acp ", + Short: "Start an ACP (Agent Client Protocol) server", + Long: `Start an ACP server that exposes the agent via the Agent Client Protocol`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + telemetry.TrackCommand("acp", args) + return runACP(cmd, args) + }, + } + + addGatewayFlags(cmd) + addRuntimeConfigFlags(cmd) + + return cmd +} + +func runACP(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + agentFilename := args[0] + + slog.Debug("Starting ACP server", "agent_file", agentFilename, "debug_mode", debugMode) + + acpAgent := acp.NewAgent(agentFilename, runConfig) + conn := acpsdk.NewAgentSideConnection(acpAgent, os.Stdout, os.Stdin) + conn.SetLogger(slog.Default()) + acpAgent.SetAgentConnection(conn) + defer acpAgent.Stop(ctx) + + slog.Debug("acp started, waiting for conn") + + <-conn.Done() + + return nil +} diff --git a/cmd/root/root.go b/cmd/root/root.go index f4d211f4..e47330f8 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -114,6 +114,7 @@ func NewRootCmd() *cobra.Command { cmd.AddCommand(NewExecCmd()) cmd.AddCommand(NewNewCmd()) cmd.AddCommand(NewAPICmd()) + cmd.AddCommand(NewACPCmd()) cmd.AddCommand(NewEvalCmd()) cmd.AddCommand(NewPushCmd()) cmd.AddCommand(NewPullCmd()) diff --git a/docs/USAGE.md b/docs/USAGE.md index f57b8b0e..4cb07335 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -57,6 +57,9 @@ $ cagent exec config.yaml --yolo # Run the agent once and auto-accept a $ cagent api config.yaml $ cagent api config.yaml --listen :8080 +# ACP Server (Agent Client Protocol via stdio) +$ cagent acp config.yaml # Start ACP server on stdio + # Other commands $ cagent new # Initialize new project $ cagent new --model openai/gpt-5-mini --max-tokens 32000 # Override max tokens during generation diff --git a/go.mod b/go.mod index c4db41e7..7ed61817 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250930175933-4cafc092c5e7 github.com/charmbracelet/glamour/v2 v2.0.0-20250811143442-a27abb32f018 github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250917201909-41ff0bf215ea + github.com/coder/acp-go-sdk v0.4.9 github.com/dop251/goja v0.0.0-20251008123653-cf18d89f3cf6 github.com/fatih/color v1.18.0 github.com/goccy/go-yaml v1.18.0 diff --git a/go.sum b/go.sum index cd4c9361..dc6d2e6e 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2 github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k= github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY= github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/coder/acp-go-sdk v0.4.9 h1:F4sKT2up4sMqNYt6yt2L9g4MaE09VPgt3eRqDFnoY5k= +github.com/coder/acp-go-sdk v0.4.9/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko= github.com/containerd/stargz-snapshotter/estargz v0.17.0 h1:+TyQIsR/zSFI1Rm31EQBwpAA1ovYgIKHy7kctL3sLcE= github.com/containerd/stargz-snapshotter/estargz v0.17.0/go.mod h1:s06tWAiJcXQo9/8AReBCIo/QxcXFZ2n4qfsRnpl71SM= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= diff --git a/pkg/acp/agent.go b/pkg/acp/agent.go new file mode 100644 index 00000000..e3e25a3c --- /dev/null +++ b/pkg/acp/agent.go @@ -0,0 +1,422 @@ +package acp + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "sync" + + "github.com/coder/acp-go-sdk" + "github.com/google/uuid" + + "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/team" + "github.com/docker/cagent/pkg/teamloader" + "github.com/docker/cagent/pkg/tools" +) + +// Agent implements the ACP Agent interface for cagent +type Agent struct { + conn *acp.AgentSideConnection + team *team.Team + agentFilename string + runtimeConfig config.RuntimeConfig + sessions map[string]*Session + mu sync.Mutex +} + +var _ acp.Agent = (*Agent)(nil) + +// Session represents an ACP session +type Session struct { + id string + sess *session.Session + rt runtime.Runtime + cancel context.CancelFunc +} + +// NewAgent creates a new ACP agent +func NewAgent(agentFilename string, runtimeConfig config.RuntimeConfig) *Agent { + agent := &Agent{ + agentFilename: agentFilename, + runtimeConfig: runtimeConfig, + sessions: make(map[string]*Session), + } + return agent +} + +// Stop stops the agent and its toolsets +func (a *Agent) Stop(ctx context.Context) { + a.mu.Lock() + defer a.mu.Unlock() + if a.team != nil { + if err := a.team.StopToolSets(ctx); err != nil { + slog.Error("Failed to stop tool sets", "error", err) + } + } +} + +// SetConnection sets the ACP connection +func (a *Agent) SetAgentConnection(conn *acp.AgentSideConnection) { + a.conn = conn +} + +// Initialize implements [acp.Agent] +func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) { + slog.Debug("ACP Initialize called", "client_version", params.ProtocolVersion) + + a.mu.Lock() + defer a.mu.Unlock() + slog.Debug("Loading teams", "agent_file", a.agentFilename) + t, err := teamloader.Load(ctx, a.agentFilename, a.runtimeConfig, teamloader.WithToolsetRegistry(createToolsetRegistry(a))) + if err != nil { + return acp.InitializeResponse{}, fmt.Errorf("failed to load teams: %w", err) + } + a.team = t + slog.Debug("Teams loaded successfully", "team_id", t.ID, "agent_count", t.Size()) + + return acp.InitializeResponse{ + ProtocolVersion: acp.ProtocolVersionNumber, + AgentCapabilities: acp.AgentCapabilities{ + LoadSession: false, + PromptCapabilities: acp.PromptCapabilities{ + EmbeddedContext: true, + }, + }, + }, nil +} + +// NewSession implements [acp.Agent] +func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { + sid := uuid.New().String() + slog.Debug("ACP NewSession called", "session_id", sid) + + rt, err := runtime.New(a.team, runtime.WithCurrentAgent("root")) + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("failed to create runtime: %w", err) + } + + a.mu.Lock() + a.sessions[sid] = &Session{ + id: sid, + sess: session.New(session.WithTitle("ACP Session " + sid)), + rt: rt, + } + a.mu.Unlock() + + return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil +} + +// Authenticate implements [acp.Agent] +func (a *Agent) Authenticate(ctx context.Context, params acp.AuthenticateRequest) (acp.AuthenticateResponse, error) { + slog.Debug("ACP Authenticate called") + return acp.AuthenticateResponse{}, nil +} + +// LoadSession implements [acp.Agent] (optional, not supported) +func (a *Agent) LoadSession(ctx context.Context, params acp.LoadSessionRequest) (acp.LoadSessionResponse, error) { + slog.Debug("ACP LoadSession called (not supported)") + return acp.LoadSessionResponse{}, fmt.Errorf("load session not supported") +} + +// Cancel implements [acp.Agent] +func (a *Agent) Cancel(ctx context.Context, params acp.CancelNotification) error { + sid := string(params.SessionId) + slog.Debug("ACP Cancel called", "session_id", sid) + + a.mu.Lock() + acpSess, ok := a.sessions[sid] + a.mu.Unlock() + + if ok && acpSess != nil && acpSess.cancel != nil { + acpSess.cancel() + } + + return nil +} + +// Prompt implements [acp.Agent] +func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { + sid := string(params.SessionId) + slog.Debug("ACP Prompt called", "session_id", sid) + + a.mu.Lock() + acpSess, ok := a.sessions[sid] + a.mu.Unlock() + + if !ok { + return acp.PromptResponse{}, fmt.Errorf("session %s not found", sid) + } + + // Cancel any previous turn + a.mu.Lock() + if acpSess.cancel != nil { + prev := acpSess.cancel + a.mu.Unlock() + prev() + } else { + a.mu.Unlock() + } + + // Create a new context for this turn + turnCtx, cancel := context.WithCancel(context.Background()) + a.mu.Lock() + acpSess.cancel = cancel + a.mu.Unlock() + + // Add the user message to the session + var userContent string + for _, content := range params.Prompt { + if content.Text != nil { + userContent += content.Text.Text + } + if content.ResourceLink != nil { + slog.Debug("resource link", "link", content.ResourceLink) + } + if content.Resource != nil { + slog.Debug("embedded context", "context", content.Resource) + slog.Debug(content.Resource.Resource.TextResourceContents.Text) + } + } + + if userContent != "" { + acpSess.sess.AddMessage(session.UserMessage(a.agentFilename, userContent)) + } + + // Run the agent and stream updates + if err := a.runAgent(turnCtx, acpSess); err != nil { + if turnCtx.Err() != nil { + return acp.PromptResponse{StopReason: acp.StopReasonCancelled}, nil + } + return acp.PromptResponse{}, err + } + + a.mu.Lock() + acpSess.cancel = nil + a.mu.Unlock() + + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil +} + +// SetSessionMode implements acp.Agent (optional) +func (a *Agent) SetSessionMode(ctx context.Context, params acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) { + // We don't implement session modes, cagent agents have only one mode (for now? ;) ). + return acp.SetSessionModeResponse{}, nil +} + +// runAgent runs a single agent loop and streams updates to the ACP client +func (a *Agent) runAgent(ctx context.Context, acpSess *Session) error { + slog.Debug("Running agent turn", "session_id", acpSess.id) + + ctx = withSessionID(ctx, acpSess.id) + + eventsChan := acpSess.rt.RunStream(ctx, acpSess.sess) + + for event := range eventsChan { + if ctx.Err() != nil { + return ctx.Err() + } + + switch e := event.(type) { + case *runtime.AgentChoiceEvent: + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: acp.SessionId(acpSess.id), + Update: acp.UpdateAgentMessageText(e.Content), + }); err != nil { + return err + } + + case *runtime.ToolCallConfirmationEvent: + if err := a.handleToolCallConfirmation(ctx, acpSess, e); err != nil { + return err + } + + case *runtime.ToolCallEvent: + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: acp.SessionId(acpSess.id), + Update: buildToolCallStart(e.ToolCall, e.ToolDefinition), + }); err != nil { + return err + } + + case *runtime.ToolCallResponseEvent: + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: acp.SessionId(acpSess.id), + Update: buildToolCallComplete(e.ToolCall, e.Response), + }); err != nil { + return err + } + + case *runtime.ErrorEvent: + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: acp.SessionId(acpSess.id), + Update: acp.UpdateAgentMessageText(fmt.Sprintf("\n\nError: %s\n", e.Error)), + }); err != nil { + return err + } + + case *runtime.MaxIterationsReachedEvent: + if err := a.handleMaxIterationsReached(ctx, acpSess, e); err != nil { + return err + } + } + } + + return nil +} + +// handleToolCallConfirmation handles tool call permission requests +func (a *Agent) handleToolCallConfirmation(ctx context.Context, acpSess *Session, e *runtime.ToolCallConfirmationEvent) error { + toolCallUpdate := buildToolCallUpdate(e.ToolCall, e.ToolDefinition, acp.ToolCallStatusPending) + + permResp, err := a.conn.RequestPermission(ctx, acp.RequestPermissionRequest{ + SessionId: acp.SessionId(acpSess.id), + ToolCall: toolCallUpdate, + Options: []acp.PermissionOption{ + { + Kind: acp.PermissionOptionKindAllowOnce, + Name: "Allow this action", + OptionId: acp.PermissionOptionId("allow"), + }, + { + Kind: acp.PermissionOptionKindAllowAlways, + Name: "Allow and remember my choice", + OptionId: acp.PermissionOptionId("allow-always"), + }, + { + Kind: acp.PermissionOptionKindRejectOnce, + Name: "Skip this action", + OptionId: acp.PermissionOptionId("reject"), + }, + }, + }) + if err != nil { + return err + } + + // Handle permission outcome + if permResp.Outcome.Cancelled != nil { + acpSess.rt.Resume(ctx, string(runtime.ResumeTypeReject)) + return nil + } + + if permResp.Outcome.Selected == nil { + return fmt.Errorf("unexpected permission outcome") + } + + switch string(permResp.Outcome.Selected.OptionId) { + case "allow": + acpSess.rt.Resume(ctx, string(runtime.ResumeTypeApprove)) + case "allow-always": + acpSess.rt.Resume(ctx, string(runtime.ResumeTypeApproveSession)) + case "reject": + acpSess.rt.Resume(ctx, string(runtime.ResumeTypeReject)) + default: + return fmt.Errorf("unexpected permission option: %s", permResp.Outcome.Selected.OptionId) + } + + return nil +} + +// handleMaxIterationsReached handles max iterations events +func (a *Agent) handleMaxIterationsReached(ctx context.Context, acpSess *Session, e *runtime.MaxIterationsReachedEvent) error { + permResp, err := a.conn.RequestPermission(ctx, acp.RequestPermissionRequest{ + SessionId: acp.SessionId(acpSess.id), + ToolCall: acp.ToolCallUpdate{ + ToolCallId: acp.ToolCallId("max_iterations"), + Title: acp.Ptr(fmt.Sprintf("Maximum iterations (%d) reached", e.MaxIterations)), + Kind: acp.Ptr(acp.ToolKindExecute), + Status: acp.Ptr(acp.ToolCallStatusPending), + }, + Options: []acp.PermissionOption{ + { + Kind: acp.PermissionOptionKindAllowOnce, + Name: "Continue", + OptionId: acp.PermissionOptionId("continue"), + }, + { + Kind: acp.PermissionOptionKindRejectOnce, + Name: "Stop", + OptionId: acp.PermissionOptionId("stop"), + }, + }, + }) + if err != nil { + return err + } + + if permResp.Outcome.Cancelled != nil || permResp.Outcome.Selected == nil || + string(permResp.Outcome.Selected.OptionId) == "stop" { + acpSess.rt.Resume(ctx, string(runtime.ResumeTypeReject)) + } else { + acpSess.rt.Resume(ctx, string(runtime.ResumeTypeApprove)) + } + + return nil +} + +// buildToolCallStart creates a tool call start update +func buildToolCallStart(toolCall tools.ToolCall, tool tools.Tool) acp.SessionUpdate { + kind := acp.ToolKindExecute + title := tool.Annotations.Title + if title == "" { + title = toolCall.Function.Name + } + + // Determine tool kind from tool annotations + if tool.Annotations.ReadOnlyHint { + kind = acp.ToolKindRead + } + + return acp.StartToolCall( + acp.ToolCallId(toolCall.ID), + title, + acp.WithStartKind(kind), + acp.WithStartStatus(acp.ToolCallStatusPending), + acp.WithStartRawInput(parseToolCallArguments(toolCall.Function.Arguments)), + ) +} + +// buildToolCallComplete creates a tool call completion update +func buildToolCallComplete(toolCall tools.ToolCall, output string) acp.SessionUpdate { + return acp.UpdateToolCall( + acp.ToolCallId(toolCall.ID), + acp.WithUpdateStatus(acp.ToolCallStatusCompleted), + acp.WithUpdateContent([]acp.ToolCallContent{acp.ToolContent(acp.TextBlock(output))}), + acp.WithUpdateRawOutput(map[string]any{"content": output}), + ) +} + +// buildToolCallUpdate creates a tool call update for permission requests +func buildToolCallUpdate(toolCall tools.ToolCall, tool tools.Tool, status acp.ToolCallStatus) acp.ToolCallUpdate { + kind := acp.ToolKindExecute + title := tool.Annotations.Title + if title == "" { + title = toolCall.Function.Name + } + + if tool.Annotations.ReadOnlyHint { + kind = acp.ToolKindRead + } + + return acp.ToolCallUpdate{ + ToolCallId: acp.ToolCallId(toolCall.ID), + Title: acp.Ptr(title), + Kind: acp.Ptr(kind), + Status: acp.Ptr(status), + RawInput: parseToolCallArguments(toolCall.Function.Arguments), + } +} + +// parseToolCallArguments parses JSON tool call arguments into a map +func parseToolCallArguments(argsJSON string) map[string]any { + var args map[string]any + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + slog.Warn("Failed to parse tool call arguments", "error", err) + return map[string]any{"raw": argsJSON} + } + return args +} diff --git a/pkg/acp/filesystem.go b/pkg/acp/filesystem.go new file mode 100644 index 00000000..33edefbc --- /dev/null +++ b/pkg/acp/filesystem.go @@ -0,0 +1,154 @@ +package acp + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "strings" + + "github.com/coder/acp-go-sdk" + + "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/tools/builtin" +) + +type contextKey string + +const sessionIDKey contextKey = "acp_session_id" + +// withSessionID adds the session ID to the context +func withSessionID(ctx context.Context, sessionID string) context.Context { + return context.WithValue(ctx, sessionIDKey, sessionID) +} + +// getSessionID retrieves the session ID from the context +func getSessionID(ctx context.Context) (string, bool) { + sid, ok := ctx.Value(sessionIDKey).(string) + return sid, ok +} + +// FilesystemToolset wraps a standard FilesystemTool and overrides read_file, write_file, +// and edit_file to use the ACP connection for file operations +type FilesystemToolset struct { + *builtin.FilesystemTool + agent *Agent + workindgDir string +} + +var _ tools.ToolSet = (*FilesystemToolset)(nil) + +// NewFilesystemToolset creates a new ACP-specific filesystem toolset +func NewFilesystemToolset(agent *Agent, workingDir string, opts ...builtin.FileSystemOpt) *FilesystemToolset { + return &FilesystemToolset{ + FilesystemTool: builtin.NewFilesystemTool([]string{workingDir}, opts...), + agent: agent, + workindgDir: workingDir, + } +} + +// Tools returns the tool definitions with ACP-specific overrides +func (t *FilesystemToolset) Tools(ctx context.Context) ([]tools.Tool, error) { + baseTools, err := t.FilesystemTool.Tools(ctx) + if err != nil { + return nil, err + } + + for i := range baseTools { + switch baseTools[i].Name { + case "read_file": + baseTools[i].Handler = t.handleReadFile + case "write_file": + baseTools[i].Handler = t.handleWriteFile + case "edit_file": + baseTools[i].Handler = t.handleEditFile + } + } + + return baseTools, nil +} + +func (t *FilesystemToolset) handleReadFile(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + var args builtin.ReadFileArgs + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + return nil, fmt.Errorf("failed to parse arguments: %w", err) + } + + sessionID, ok := getSessionID(ctx) + if !ok { + return &tools.ToolCallResult{Output: "Error: session ID not found in context"}, nil + } + + resp, err := t.agent.conn.ReadTextFile(ctx, acp.ReadTextFileRequest{ + SessionId: acp.SessionId(sessionID), + Path: filepath.Join(t.workindgDir, args.Path), + }) + if err != nil { + return &tools.ToolCallResult{Output: fmt.Sprintf("Error reading file: %s", err)}, nil + } + + return &tools.ToolCallResult{Output: resp.Content}, nil +} + +func (t *FilesystemToolset) handleWriteFile(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + var args builtin.WriteFileArgs + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + return nil, fmt.Errorf("failed to parse arguments: %w", err) + } + + sessionID, ok := getSessionID(ctx) + if !ok { + return &tools.ToolCallResult{Output: "Error: session ID not found in context"}, nil + } + + _, err := t.agent.conn.WriteTextFile(ctx, acp.WriteTextFileRequest{ + SessionId: acp.SessionId(sessionID), + Path: args.Path, + Content: args.Content, + }) + if err != nil { + return &tools.ToolCallResult{Output: fmt.Sprintf("Error writing file: %s", err)}, nil + } + + return &tools.ToolCallResult{Output: "File written successfully"}, nil +} + +func (t *FilesystemToolset) handleEditFile(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + var args builtin.EditFileArgs + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + return nil, fmt.Errorf("failed to parse arguments: %w", err) + } + + sessionID, ok := getSessionID(ctx) + if !ok { + return &tools.ToolCallResult{Output: "Error: session ID not found in context"}, nil + } + + resp, err := t.agent.conn.ReadTextFile(ctx, acp.ReadTextFileRequest{ + SessionId: acp.SessionId(sessionID), + Path: filepath.Join(t.workindgDir, args.Path), + }) + if err != nil { + return &tools.ToolCallResult{Output: fmt.Sprintf("Error reading file: %s", err)}, nil + } + + modifiedContent := resp.Content + + for i, edit := range args.Edits { + if !strings.Contains(modifiedContent, edit.OldText) { + return &tools.ToolCallResult{Output: fmt.Sprintf("Edit %d failed: old text not found", i+1)}, nil + } + modifiedContent = strings.Replace(modifiedContent, edit.OldText, edit.NewText, 1) + } + + _, err = t.agent.conn.WriteTextFile(ctx, acp.WriteTextFileRequest{ + SessionId: acp.SessionId(sessionID), + Path: filepath.Join(t.workindgDir, args.Path), + Content: modifiedContent, + }) + if err != nil { + return &tools.ToolCallResult{Output: fmt.Sprintf("Error writing file: %s", err)}, nil + } + + return &tools.ToolCallResult{Output: "File edited successfully"}, nil +} diff --git a/pkg/acp/registry.go b/pkg/acp/registry.go new file mode 100644 index 00000000..c4d3651a --- /dev/null +++ b/pkg/acp/registry.go @@ -0,0 +1,32 @@ +package acp + +import ( + "context" + "os" + + "github.com/docker/cagent/pkg/config" + latest "github.com/docker/cagent/pkg/config/v2" + "github.com/docker/cagent/pkg/environment" + "github.com/docker/cagent/pkg/teamloader" + "github.com/docker/cagent/pkg/tools" +) + +// createToolsetRegistry creates a custom toolset registry with ACP-specific filesystem toolset +func createToolsetRegistry(agent *Agent) *teamloader.ToolsetRegistry { + registry := teamloader.NewDefaultToolsetRegistry() + + registry.Register("filesystem", func(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + wd := runtimeConfig.WorkingDir + if wd == "" { + var err error + wd, err = os.Getwd() + if err != nil { + return nil, err + } + } + + return NewFilesystemToolset(agent, wd), nil + }) + + return registry +} diff --git a/pkg/session/session.go b/pkg/session/session.go index 6a2d2122..a16a8e5f 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -229,6 +229,12 @@ func WithWorkingDir(workingDir string) Opt { } } +func WithTitle(title string) Opt { + return func(s *Session) { + s.Title = title + } +} + // New creates a new agent session func New(opts ...Opt) *Session { sessionID := uuid.New().String() diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 81e5cea2..ab73aaf1 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -25,6 +25,191 @@ import ( "github.com/docker/cagent/pkg/tools/mcp" ) +// ToolsetCreator is a function that creates a toolset based on the provided configuration +type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) + +// ToolsetRegistry manages the registration of toolset creators by type +type ToolsetRegistry struct { + creators map[string]ToolsetCreator +} + +// NewToolsetRegistry creates a new empty toolset registry +func NewToolsetRegistry() *ToolsetRegistry { + return &ToolsetRegistry{ + creators: make(map[string]ToolsetCreator), + } +} + +// Register adds a new toolset creator for the given type +func (r *ToolsetRegistry) Register(toolsetType string, creator ToolsetCreator) { + r.creators[toolsetType] = creator +} + +// Get retrieves a toolset creator for the given type +func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) { + creator, ok := r.creators[toolsetType] + return creator, ok +} + +// CreateTool creates a toolset using the registered creator for the given type +func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + creator, ok := r.Get(toolset.Type) + if !ok { + return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type) + } + return creator(ctx, toolset, parentDir, envProvider, runtimeConfig) +} + +func NewDefaultToolsetRegistry() *ToolsetRegistry { + r := NewToolsetRegistry() + // Register all built-in toolset creators + r.Register("todo", createTodoTool) + r.Register("memory", createMemoryTool) + r.Register("think", createThinkTool) + r.Register("shell", createShellTool) + r.Register("script", createScriptTool) + r.Register("filesystem", createFilesystemTool) + r.Register("fetch", createFetchTool) + r.Register("mcp", createMCPTool) + return r +} + +func createTodoTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + if toolset.Shared { + return builtin.NewSharedTodoTool(), nil + } + return builtin.NewTodoTool(), nil +} + +func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + var memoryPath string + if filepath.IsAbs(toolset.Path) { + memoryPath = "" + } else if wd, err := os.Getwd(); err == nil { + memoryPath = wd + } else { + memoryPath = parentDir + } + + validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath) + if err != nil { + return nil, fmt.Errorf("invalid memory database path: %w", err) + } + if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil { + return nil, fmt.Errorf("failed to create memory database directory: %w", err) + } + + db, err := sqlite.NewMemoryDatabase(validatedMemoryPath) + if err != nil { + return nil, fmt.Errorf("failed to create memory database: %w", err) + } + + return builtin.NewMemoryTool(db), nil +} + +func createThinkTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + return builtin.NewThinkTool(), nil +} + +func createShellTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + env = append(env, os.Environ()...) + return builtin.NewShellTool(env), nil +} + +func createScriptTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + if len(toolset.Shell) == 0 { + return nil, fmt.Errorf("shell is required for script toolset") + } + + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + env = append(env, os.Environ()...) + return builtin.NewScriptShellTool(toolset.Shell, env), nil +} + +func createFilesystemTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + wd := runtimeConfig.WorkingDir + if wd == "" { + var err error + wd, err = os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get working directory: %w", err) + } + } + + var opts []builtin.FileSystemOpt + if len(toolset.PostEdit) > 0 { + postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit)) + for i, pe := range toolset.PostEdit { + postEditConfigs[i] = builtin.PostEditConfig{ + Path: pe.Path, + Cmd: pe.Cmd, + } + } + opts = append(opts, builtin.WithPostEditCommands(postEditConfigs)) + } + + return builtin.NewFilesystemTool([]string{wd}, opts...), nil +} + +func createFetchTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + var opts []builtin.FetchToolOption + if toolset.Timeout > 0 { + timeout := time.Duration(toolset.Timeout) * time.Second + opts = append(opts, builtin.WithTimeout(timeout)) + } + return builtin.NewFetchTool(opts...), nil +} + +func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { + // MCP tool has three different modes: ref, command, and remote + if toolset.Ref != "" { + mcpServerName := gateway.ParseServerRef(toolset.Ref) + serverSpec, err := gateway.ServerSpec(ctx, mcpServerName) + if err != nil { + return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err) + } + + // TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it. + if serverSpec.Type == "remote" { + return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.RedirectURI), nil + } + + return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider) + } + + if toolset.Command != "" { + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + env = append(env, os.Environ()...) + return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil + } + + if toolset.Remote.URL != "" { + headers := map[string]string{} + for k, v := range toolset.Remote.Headers { + expanded, err := environment.Expand(ctx, v, envProvider) + if err != nil { + return nil, fmt.Errorf("failed to expand header '%s': %w", k, err) + } + + headers[k] = expanded + } + + return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.RedirectURI), nil + } + + return nil, fmt.Errorf("mcp toolset requires either ref, command, or remote configuration") +} + // LoadTeams loads all agent teams from the given directory or file path func LoadTeams(ctx context.Context, agentsPathOrDirectory string, runtimeConfig config.RuntimeConfig) (map[string]*team.Team, error) { teams := make(map[string]*team.Team) @@ -95,7 +280,8 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme } type loadOptions struct { - modelOverrides []string + modelOverrides []string + toolsetRegistry *ToolsetRegistry } type Opt func(*loadOptions) error @@ -107,13 +293,24 @@ func WithModelOverrides(overrides []string) Opt { } } +// WithToolsetRegistry allows using a custom toolset registry instead of the default +func WithToolsetRegistry(registry *ToolsetRegistry) Opt { + return func(opts *loadOptions) error { + opts.toolsetRegistry = registry + return nil + } +} + func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opts ...Opt) (*team.Team, error) { - var loadOptions loadOptions + var loadOpts loadOptions + loadOpts.toolsetRegistry = NewDefaultToolsetRegistry() + for _, o := range opts { - if err := o(&loadOptions); err != nil { + if err := o(&loadOpts); err != nil { return nil, err } } + fileName := filepath.Base(p) parentDir := filepath.Dir(p) @@ -141,7 +338,7 @@ func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opt } // Apply model overrides from CLI flags before checking required env vars - if err := config.ApplyModelOverrides(cfg, loadOptions.modelOverrides); err != nil { + if err := config.ApplyModelOverrides(cfg, loadOpts.modelOverrides); err != nil { return nil, err } @@ -174,7 +371,7 @@ func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opt opts = append(opts, agent.WithModel(model)) } - agentTools, err := getToolsForAgent(ctx, &agentConfig, parentDir, env, runtimeConfig) + agentTools, err := getToolsForAgent(ctx, &agentConfig, parentDir, env, runtimeConfig, loadOpts.toolsetRegistry) if err != nil { return nil, fmt.Errorf("failed to get tools: %w", err) } @@ -239,13 +436,13 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC } // getToolsForAgent returns the tool definitions for an agent based on its configuration -func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) ([]tools.ToolSet, error) { +func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig, registry *ToolsetRegistry) ([]tools.ToolSet, error) { var t []tools.ToolSet for i := range a.Toolsets { toolset := a.Toolsets[i] - tool, err := createTool(ctx, toolset, parentDir, envProvider, runtimeConfig) + tool, err := registry.CreateTool(ctx, toolset, parentDir, envProvider, runtimeConfig) if err != nil { return nil, err } @@ -267,122 +464,3 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri codemode.Wrap(t...), }, nil } - -func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) { - env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) - if err != nil { - return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) - } - env = append(env, os.Environ()...) - - switch { - case toolset.Type == "todo": - if toolset.Shared { - return builtin.NewSharedTodoTool(), nil - } - return builtin.NewTodoTool(), nil - - case toolset.Type == "memory": - var memoryPath string - if filepath.IsAbs(toolset.Path) { - memoryPath = "" - } else if wd, err := os.Getwd(); err == nil { - memoryPath = wd - } else { - memoryPath = parentDir - } - - validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath) - if err != nil { - return nil, fmt.Errorf("invalid memory database path: %w", err) - } - if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil { - return nil, fmt.Errorf("failed to create memory database directory: %w", err) - } - - db, err := sqlite.NewMemoryDatabase(validatedMemoryPath) - if err != nil { - return nil, fmt.Errorf("failed to create memory database: %w", err) - } - - return builtin.NewMemoryTool(db), nil - - case toolset.Type == "think": - return builtin.NewThinkTool(), nil - - case toolset.Type == "shell": - return builtin.NewShellTool(env), nil - - case toolset.Type == "script": - if len(toolset.Shell) == 0 { - return nil, fmt.Errorf("shell is required for script toolset") - } - - return builtin.NewScriptShellTool(toolset.Shell, env), nil - - case toolset.Type == "filesystem": - wd := runtimeConfig.WorkingDir - if wd == "" { - var err error - wd, err = os.Getwd() - if err != nil { - return nil, fmt.Errorf("failed to get working directory: %w", err) - } - } - - var opts []builtin.FileSystemOpt - if len(toolset.PostEdit) > 0 { - postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit)) - for i, pe := range toolset.PostEdit { - postEditConfigs[i] = builtin.PostEditConfig{ - Path: pe.Path, - Cmd: pe.Cmd, - } - } - opts = append(opts, builtin.WithPostEditCommands(postEditConfigs)) - } - - return builtin.NewFilesystemTool([]string{wd}, opts...), nil - - case toolset.Type == "fetch": - var opts []builtin.FetchToolOption - if toolset.Timeout > 0 { - timeout := time.Duration(toolset.Timeout) * time.Second - opts = append(opts, builtin.WithTimeout(timeout)) - } - return builtin.NewFetchTool(opts...), nil - - case toolset.Type == "mcp" && toolset.Ref != "": - mcpServerName := gateway.ParseServerRef(toolset.Ref) - serverSpec, err := gateway.ServerSpec(ctx, mcpServerName) - if err != nil { - return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err) - } - - // TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it. - if serverSpec.Type == "remote" { - return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.RedirectURI), nil - } - - return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider) - - case toolset.Type == "mcp" && toolset.Command != "": - return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil - - case toolset.Type == "mcp" && toolset.Remote.URL != "": - headers := map[string]string{} - for k, v := range toolset.Remote.Headers { - expanded, err := environment.Expand(ctx, v, envProvider) - if err != nil { - return nil, fmt.Errorf("failed to expand header '%s': %w", k, err) - } - - headers[k] = expanded - } - - return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.RedirectURI), nil - - default: - return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type) - } -}