diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index f774d36..de433e2 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -1,6 +1,7 @@ package transport import ( + "bufio" "bytes" "encoding/json" "fmt" @@ -31,12 +32,27 @@ func NewStdio(command []string) *Stdio { // Execute implements the Transport interface by spawning a subprocess // and communicating with it via JSON-RPC over stdin/stdout. func (t *Stdio) Execute(method string, params any) (map[string]any, error) { - if len(t.command) == 0 { - return nil, fmt.Errorf("no command specified for stdio transport") + stdin, stdout, cmd, stderrBuf, err := t.setupCommand() + if err != nil { + return nil, err } if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Executing command: %v\n", t.command) + fmt.Fprintf(os.Stderr, "DEBUG: Starting initialization\n") + } + + if initErr := t.initialize(stdin, stdout); initErr != nil { + if t.debug { + fmt.Fprintf(os.Stderr, "DEBUG: Initialization failed: %v\n", initErr) + if stderrBuf.Len() > 0 { + fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", stderrBuf.String()) + } + } + return nil, initErr + } + + if t.debug { + fmt.Fprintf(os.Stderr, "DEBUG: Initialization successful, sending method request\n") } request := Request{ @@ -47,77 +63,155 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) { } t.nextID++ - requestJSON, err := json.Marshal(request) + if sendErr := t.sendRequest(stdin, request); sendErr != nil { + return nil, sendErr + } + _ = stdin.Close() + + response, err := t.readResponse(stdout) if err != nil { - return nil, fmt.Errorf("error marshaling request: %w", err) + return nil, err } - requestJSON = append(requestJSON, '\n') + waitErr := cmd.Wait() if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Sending request: %s\n", string(requestJSON)) + fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr) + if stderrBuf.Len() > 0 { + fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", stderrBuf.String()) + } } - cmd := exec.Command(t.command[0], t.command[1:]...) // #nosec G204 + if waitErr != nil && stderrBuf.Len() > 0 { + return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String()) + } + + return response.Result, nil +} - stdin, stdinErr := cmd.StdinPipe() - if stdinErr != nil { - return nil, fmt.Errorf("error getting stdin pipe: %w", stdinErr) +// setupCommand prepares and starts the command, returning the stdin/stdout pipes and any error. +func (t *Stdio) setupCommand() (stdin io.WriteCloser, stdout io.ReadCloser, cmd *exec.Cmd, stderrBuf *bytes.Buffer, err error) { + if len(t.command) == 0 { + return nil, nil, nil, nil, fmt.Errorf("no command specified for stdio transport") } - stdout, stdoutErr := cmd.StdoutPipe() - if stdoutErr != nil { - return nil, fmt.Errorf("error getting stdout pipe: %w", stdoutErr) + if t.debug { + fmt.Fprintf(os.Stderr, "DEBUG: Executing command: %v\n", t.command) } - var stderrBuf bytes.Buffer - cmd.Stderr = &stderrBuf + cmd = exec.Command(t.command[0], t.command[1:]...) // #nosec G204 + + stdin, err = cmd.StdinPipe() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("error getting stdin pipe: %w", err) + } - if startErr := cmd.Start(); startErr != nil { - return nil, fmt.Errorf("error starting command: %w", startErr) + stdout, err = cmd.StdoutPipe() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("error getting stdout pipe: %w", err) } - if _, writeErr := stdin.Write(requestJSON); writeErr != nil { - return nil, fmt.Errorf("error writing to stdin: %w", writeErr) + stderrBuf = &bytes.Buffer{} + cmd.Stderr = stderrBuf + + if err = cmd.Start(); err != nil { + return nil, nil, nil, nil, fmt.Errorf("error starting command: %w", err) + } + + return stdin, stdout, cmd, stderrBuf, nil +} + +// initialize sends the initialization request and waits for response and then sends the initialized +// notification. +func (t *Stdio) initialize(stdin io.WriteCloser, stdout io.ReadCloser) error { + initRequest := Request{ + JSONRPC: "2.0", + Method: "initialize", + ID: t.nextID, + Params: map[string]any{ + "clientInfo": map[string]any{ + "name": "f/mcptools", + "version": "beta", + }, + "protocolVersion": protocolVersion, + "capabilities": map[string]any{}, + }, + } + t.nextID++ + + if err := t.sendRequest(stdin, initRequest); err != nil { + return fmt.Errorf("init request failed: %w", err) } - _ = stdin.Close() + + _, err := t.readResponse(stdout) + if err != nil { + return fmt.Errorf("init response failed: %w", err) + } + + initNotification := Request{ + JSONRPC: "2.0", + Method: "notifications/initialized", + } + + if sendErr := t.sendRequest(stdin, initNotification); sendErr != nil { + return fmt.Errorf("init notification failed: %w", sendErr) + } + + return nil +} + +// sendRequest sends a JSON-RPC request and returns the marshaled request. +func (t *Stdio) sendRequest(stdin io.WriteCloser, request Request) error { + requestJSON, err := json.Marshal(request) + if err != nil { + return fmt.Errorf("error marshaling request: %w", err) + } + requestJSON = append(requestJSON, '\n') if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Wrote request to stdin\n") + fmt.Fprintf(os.Stderr, "DEBUG: Preparing to send request: %s\n", string(requestJSON)) } - var respBytes bytes.Buffer - if _, copyErr := io.Copy(&respBytes, stdout); copyErr != nil { - return nil, fmt.Errorf("error reading from stdout: %w", copyErr) + writer := bufio.NewWriter(stdin) + n, err := writer.Write(requestJSON) + if err != nil { + return fmt.Errorf("error writing bytes to stdin: %w", err) } if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s\n", respBytes.String()) + fmt.Fprintf(os.Stderr, "DEBUG: Wrote %d bytes\n", n) } - waitErr := cmd.Wait() + if flushErr := writer.Flush(); flushErr != nil { + return fmt.Errorf("error flushing bytes to stdin: %w", flushErr) + } if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr) - if stderrBuf.Len() > 0 { - fmt.Fprintf(os.Stderr, "DEBUG: stderr output: %s\n", stderrBuf.String()) - } + fmt.Fprintf(os.Stderr, "DEBUG: Successfully flushed bytes\n") } - if waitErr != nil && stderrBuf.Len() > 0 { - return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String()) + return nil +} + +// readResponse reads and parses a JSON-RPC response. +func (t *Stdio) readResponse(stdout io.ReadCloser) (*Response, error) { + reader := bufio.NewReader(stdout) + line, err := reader.ReadBytes('\n') + if err != nil { + return nil, fmt.Errorf("error reading from stdout: %w", err) } - if respBytes.Len() == 0 { - if stderrBuf.Len() > 0 { - return nil, fmt.Errorf("no response from command, stderr: %s", stderrBuf.String()) - } + if t.debug { + fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s", string(line)) + } + + if len(line) == 0 { return nil, fmt.Errorf("no response from command") } var response Response - if unmarshalErr := json.Unmarshal(respBytes.Bytes(), &response); unmarshalErr != nil { - return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, respBytes.String()) + if unmarshalErr := json.Unmarshal(line, &response); unmarshalErr != nil { + return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, string(line)) } if response.Error != nil { @@ -128,5 +222,5 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) { fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n") } - return response.Result, nil + return &response, nil } diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 6341aab..b27b7a8 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -6,6 +6,10 @@ import ( "io" ) +const ( + protocolVersion = "2024-11-05" +) + // Transport defines the interface for communicating with MCP servers. // Implementations should handle the specifics of communication protocols. type Transport interface { @@ -17,7 +21,7 @@ type Request struct { Params any `json:"params,omitempty"` JSONRPC string `json:"jsonrpc"` Method string `json:"method"` - ID int `json:"id"` + ID int `json:"id,omitempty"` } // Response represents a JSON-RPC 2.0 response.