Skip to content

Commit 32c46e4

Browse files
committed
Fix for anthropic
1 parent 053e6c3 commit 32c46e4

File tree

3 files changed

+230
-2
lines changed

3 files changed

+230
-2
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
package anthropic
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"strings"
10+
11+
einoclaude "github.com/cloudwego/eino-ext/components/model/claude"
12+
"github.com/cloudwego/eino/components/model"
13+
"github.com/cloudwego/eino/schema"
14+
)
15+
16+
// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling
17+
type CustomChatModel struct {
18+
wrapped *einoclaude.ChatModel
19+
}
20+
21+
// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas
22+
type CustomRoundTripper struct {
23+
wrapped http.RoundTripper
24+
}
25+
26+
// NewCustomChatModel creates a new custom Anthropic chat model
27+
func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*CustomChatModel, error) {
28+
// Create a custom HTTP client that intercepts requests
29+
if config.HTTPClient == nil {
30+
config.HTTPClient = &http.Client{}
31+
}
32+
33+
// Wrap the transport with our custom round tripper
34+
if config.HTTPClient.Transport == nil {
35+
config.HTTPClient.Transport = http.DefaultTransport
36+
}
37+
config.HTTPClient.Transport = &CustomRoundTripper{
38+
wrapped: config.HTTPClient.Transport,
39+
}
40+
41+
// Create the wrapped model
42+
wrapped, err := einoclaude.NewChatModel(ctx, config)
43+
if err != nil {
44+
return nil, err
45+
}
46+
47+
return &CustomChatModel{
48+
wrapped: wrapped,
49+
}, nil
50+
}
51+
52+
// RoundTrip implements http.RoundTripper to intercept and fix requests
53+
func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
54+
// Only process Anthropic API requests
55+
if !strings.Contains(req.URL.Host, "anthropic.com") {
56+
return rt.wrapped.RoundTrip(req)
57+
}
58+
59+
// Read the request body
60+
body, err := io.ReadAll(req.Body)
61+
if err != nil {
62+
return nil, err
63+
}
64+
req.Body = io.NopCloser(bytes.NewReader(body))
65+
66+
// Apply string-based fixes BEFORE JSON parsing for malformed patterns
67+
bodyStr := string(body)
68+
69+
// Replace common malformed patterns - be more specific about context
70+
replacements := []struct {
71+
old string
72+
new string
73+
}{
74+
// Handle input field in tool_use objects
75+
{`"input":,"name"`, `"input":{},"name"`},
76+
{`"input":,"type"`, `"input":{},"type"`},
77+
{`"input":}`, `"input":{}}`},
78+
// Handle arguments field in function calls
79+
{`"arguments":,"name"`, `"arguments":"{}","name"`},
80+
{`"arguments":,"type"`, `"arguments":"{}","type"`},
81+
{`"arguments":}`, `"arguments":"{}"`},
82+
// Fallback patterns (less specific)
83+
{`"input":,`, `"input":{}`},
84+
{`"arguments":,`, `"arguments":"{}"`},
85+
}
86+
87+
for _, r := range replacements {
88+
if strings.Contains(bodyStr, r.old) {
89+
bodyStr = strings.ReplaceAll(bodyStr, r.old, r.new)
90+
}
91+
}
92+
93+
// Parse the JSON request (after string fixes)
94+
var requestData map[string]interface{}
95+
if err := json.Unmarshal([]byte(bodyStr), &requestData); err != nil {
96+
// Return the original request to avoid panic
97+
req.Body = io.NopCloser(bytes.NewReader(body))
98+
req.ContentLength = int64(len(body))
99+
return rt.wrapped.RoundTrip(req)
100+
}
101+
102+
// Fix tool schemas if present
103+
if tools, ok := requestData["tools"].([]interface{}); ok {
104+
for _, tool := range tools {
105+
if toolMap, ok := tool.(map[string]interface{}); ok {
106+
if inputSchema, ok := toolMap["input_schema"].(map[string]interface{}); ok {
107+
// Ensure properties exists and is not null
108+
if properties, exists := inputSchema["properties"]; !exists || properties == nil {
109+
inputSchema["properties"] = map[string]interface{}{}
110+
} else if propertiesMap, ok := properties.(map[string]interface{}); ok {
111+
// Ensure each property has a type
112+
for _, propValue := range propertiesMap {
113+
if propMap, ok := propValue.(map[string]interface{}); ok {
114+
if _, hasType := propMap["type"]; !hasType {
115+
propMap["type"] = "string"
116+
}
117+
}
118+
}
119+
}
120+
}
121+
}
122+
}
123+
}
124+
125+
// Fix tool_use content in messages if present
126+
if messages, ok := requestData["messages"].([]interface{}); ok {
127+
for _, message := range messages {
128+
if msgMap, ok := message.(map[string]interface{}); ok {
129+
if content, ok := msgMap["content"].([]interface{}); ok {
130+
for _, contentItem := range content {
131+
if contentMap, ok := contentItem.(map[string]interface{}); ok {
132+
if contentType, ok := contentMap["type"].(string); ok && contentType == "tool_use" {
133+
// Ensure tool_use input is valid JSON
134+
if input, exists := contentMap["input"]; exists {
135+
// If input is nil or empty, set it to an empty object
136+
if input == nil {
137+
contentMap["input"] = map[string]interface{}{}
138+
} else if inputBytes, ok := input.(json.RawMessage); ok {
139+
if len(inputBytes) == 0 {
140+
contentMap["input"] = map[string]interface{}{}
141+
} else {
142+
// Validate that it's valid JSON
143+
var temp interface{}
144+
if err := json.Unmarshal(inputBytes, &temp); err != nil {
145+
contentMap["input"] = map[string]interface{}{}
146+
}
147+
}
148+
} else if inputStr, ok := input.(string); ok {
149+
// Handle string inputs that might be empty or invalid JSON
150+
if inputStr == "" || inputStr == "{}" {
151+
contentMap["input"] = map[string]interface{}{}
152+
} else {
153+
// Try to parse as JSON
154+
var temp interface{}
155+
if err := json.Unmarshal([]byte(inputStr), &temp); err != nil {
156+
contentMap["input"] = map[string]interface{}{}
157+
}
158+
}
159+
}
160+
} else {
161+
// If input field doesn't exist, add it as empty object
162+
contentMap["input"] = map[string]interface{}{}
163+
}
164+
}
165+
}
166+
}
167+
}
168+
}
169+
}
170+
}
171+
172+
// Marshal the fixed request back to JSON
173+
fixedBody, err := json.Marshal(requestData)
174+
if err != nil {
175+
return nil, err
176+
}
177+
178+
// Use the fixed body from JSON marshaling
179+
finalBodyStr := string(fixedBody)
180+
181+
// Validate the final JSON
182+
var finalCheck interface{}
183+
if err := json.Unmarshal([]byte(finalBodyStr), &finalCheck); err != nil {
184+
return nil, err
185+
}
186+
187+
// Create new request with fixed body
188+
req.Body = io.NopCloser(strings.NewReader(finalBodyStr))
189+
req.ContentLength = int64(len(finalBodyStr))
190+
// Make the actual request
191+
return rt.wrapped.RoundTrip(req)
192+
}
193+
194+
// Generate implements the model.BaseChatModel interface
195+
func (m *CustomChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
196+
return m.wrapped.Generate(ctx, input, opts...)
197+
}
198+
199+
// Stream implements the model.BaseChatModel interface
200+
func (m *CustomChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
201+
return m.wrapped.Stream(ctx, input, opts...)
202+
}
203+
204+
// WithTools implements the model.ToolCallingChatModel interface
205+
func (m *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
206+
wrappedWithTools, err := m.wrapped.WithTools(tools)
207+
if err != nil {
208+
return nil, err
209+
}
210+
211+
return &CustomChatModel{
212+
wrapped: wrappedWithTools.(*einoclaude.ChatModel),
213+
}, nil
214+
}

internal/models/providers.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/cloudwego/eino-ext/components/model/ollama"
1616
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
1717
"github.com/cloudwego/eino/components/model"
18+
"github.com/mark3labs/mcphost/internal/models/anthropic"
1819
"github.com/mark3labs/mcphost/internal/models/openai"
1920
"github.com/mark3labs/mcphost/internal/ui/progress"
2021
"github.com/ollama/ollama/api"
@@ -273,7 +274,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
273274
claudeConfig.StopSequences = config.StopSequences
274275
}
275276

276-
return einoclaude.NewChatModel(ctx, claudeConfig)
277+
return anthropic.NewCustomChatModel(ctx, claudeConfig)
277278
}
278279

279280
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {

internal/tools/mcp.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,19 @@ func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) {
162162

163163
// InvokableRun executes the tool by mapping back to the original name and server
164164
func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
165+
// Handle empty or invalid JSON arguments
166+
var arguments any
167+
if argumentsInJSON == "" || argumentsInJSON == "{}" {
168+
arguments = nil
169+
} else {
170+
// Validate that argumentsInJSON is valid JSON before using it
171+
var temp any
172+
if err := json.Unmarshal([]byte(argumentsInJSON), &temp); err != nil {
173+
return "", fmt.Errorf("invalid JSON arguments: %w", err)
174+
}
175+
arguments = json.RawMessage(argumentsInJSON)
176+
}
177+
165178
result, err := t.mapping.client.CallTool(ctx, mcp.CallToolRequest{
166179
Request: mcp.Request{
167180
Method: "tools/call",
@@ -172,7 +185,7 @@ func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string,
172185
Meta *mcp.Meta `json:"_meta,omitempty"`
173186
}{
174187
Name: t.mapping.originalName, // Use original name, not prefixed
175-
Arguments: json.RawMessage(argumentsInJSON),
188+
Arguments: arguments,
176189
},
177190
})
178191
if err != nil {

0 commit comments

Comments
 (0)