Skip to content

Commit d09a920

Browse files
committed
proper mcp sse handling
1 parent a098c75 commit d09a920

File tree

1 file changed

+54
-30
lines changed

1 file changed

+54
-30
lines changed

pkg/transport/http.go

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,64 @@ func NewHTTP(address string) (*HTTP, error) {
7878
return nil, fmt.Errorf("timeout waiting for SSE response")
7979
}
8080

81-
return &HTTP{
81+
client := &HTTP{
8282
// Use the SSE message address as the base address for the HTTP transport
83-
address: address + messageAddress,
83+
address: address + "/sse" + messageAddress,
8484
nextID: 1,
8585
debug: debug,
8686
eventCh: eventCh,
87-
}, nil
87+
}
88+
89+
// Send initialize request
90+
_, err = client.Execute("initialize", map[string]any{
91+
"clientInfo": map[string]any{
92+
"name": "mcp-client",
93+
"version": "0.1.0",
94+
},
95+
"capabilities": map[string]any{},
96+
"protocolVersion": "2024-11-05",
97+
})
98+
if err != nil {
99+
return nil, fmt.Errorf("error sending initialize request: %w", err)
100+
}
101+
102+
// Send intialized notification
103+
if err := client.send("notifications/initialized", nil); err != nil {
104+
return nil, fmt.Errorf("error sending initialized notification: %w", err)
105+
}
106+
107+
return client, nil
88108
}
89109

90110
// Execute implements the Transport via JSON-RPC over HTTP.
91111
func (t *HTTP) Execute(method string, params any) (map[string]any, error) {
112+
if err := t.send(method, params); err != nil {
113+
return nil, err
114+
}
115+
116+
// After sending the request, we listen the SSE channel for the response
117+
var response Response
118+
select {
119+
case msg := <-t.eventCh:
120+
if unmarshalErr := json.Unmarshal([]byte(msg), &response); unmarshalErr != nil {
121+
return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, msg)
122+
}
123+
case <-time.After(10 * time.Second):
124+
return nil, fmt.Errorf("timeout waiting for SSE response")
125+
}
126+
127+
if response.Error != nil {
128+
return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message)
129+
}
130+
131+
if t.debug {
132+
fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n")
133+
}
134+
135+
return response.Result, nil
136+
}
137+
138+
func (t *HTTP) send(method string, params any) error {
92139
if t.debug {
93140
fmt.Fprintf(os.Stderr, "DEBUG: Connecting to server: %s\n", t.address)
94141
}
@@ -103,7 +150,7 @@ func (t *HTTP) Execute(method string, params any) (map[string]any, error) {
103150

104151
requestJSON, err := json.Marshal(request)
105152
if err != nil {
106-
return nil, fmt.Errorf("error marshaling request: %w", err)
153+
return fmt.Errorf("error marshaling request: %w", err)
107154
}
108155

109156
requestJSON = append(requestJSON, '\n')
@@ -114,7 +161,7 @@ func (t *HTTP) Execute(method string, params any) (map[string]any, error) {
114161

115162
resp, err := http.Post(t.address, "application/json", bytes.NewBuffer(requestJSON))
116163
if err != nil {
117-
return nil, fmt.Errorf("error sending request: %w", err)
164+
return fmt.Errorf("error sending request: %w", err)
118165
}
119166

120167
if t.debug {
@@ -129,35 +176,12 @@ func (t *HTTP) Execute(method string, params any) (map[string]any, error) {
129176

130177
body, err := io.ReadAll(resp.Body)
131178
if err != nil {
132-
return nil, fmt.Errorf("error reading response: %w", err)
179+
return fmt.Errorf("error reading response: %w", err)
133180
}
134181

135182
if t.debug {
136183
fmt.Fprintf(os.Stderr, "DEBUG: Read from server: %s\n", string(body))
137184
}
138185

139-
if len(body) == 0 {
140-
return nil, fmt.Errorf("no response from server")
141-
}
142-
143-
// After sending the request, we listen the SSE channel for the response
144-
var response Response
145-
select {
146-
case msg := <-t.eventCh:
147-
if unmarshalErr := json.Unmarshal([]byte(msg), &response); unmarshalErr != nil {
148-
return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, msg)
149-
}
150-
case <-time.After(10 * time.Second):
151-
return nil, fmt.Errorf("timeout waiting for SSE response")
152-
}
153-
154-
if response.Error != nil {
155-
return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message)
156-
}
157-
158-
if t.debug {
159-
fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n")
160-
}
161-
162-
return response.Result, nil
186+
return nil
163187
}

0 commit comments

Comments
 (0)