|
| 1 | +package anthropic |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
| 7 | + "io" |
| 8 | + "net/http" |
| 9 | + "strings" |
| 10 | + |
| 11 | + einoclaude "github.com/cloudwego/eino-ext/components/model/claude" |
| 12 | + "github.com/cloudwego/eino/components/model" |
| 13 | + "github.com/cloudwego/eino/schema" |
| 14 | +) |
| 15 | + |
| 16 | +// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling |
| 17 | +type CustomChatModel struct { |
| 18 | + wrapped *einoclaude.ChatModel |
| 19 | +} |
| 20 | + |
| 21 | +// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas |
| 22 | +type CustomRoundTripper struct { |
| 23 | + wrapped http.RoundTripper |
| 24 | +} |
| 25 | + |
| 26 | +// NewCustomChatModel creates a new custom Anthropic chat model |
| 27 | +func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*CustomChatModel, error) { |
| 28 | + // Create a custom HTTP client that intercepts requests |
| 29 | + if config.HTTPClient == nil { |
| 30 | + config.HTTPClient = &http.Client{} |
| 31 | + } |
| 32 | + |
| 33 | + // Wrap the transport with our custom round tripper |
| 34 | + if config.HTTPClient.Transport == nil { |
| 35 | + config.HTTPClient.Transport = http.DefaultTransport |
| 36 | + } |
| 37 | + config.HTTPClient.Transport = &CustomRoundTripper{ |
| 38 | + wrapped: config.HTTPClient.Transport, |
| 39 | + } |
| 40 | + |
| 41 | + // Create the wrapped model |
| 42 | + wrapped, err := einoclaude.NewChatModel(ctx, config) |
| 43 | + if err != nil { |
| 44 | + return nil, err |
| 45 | + } |
| 46 | + |
| 47 | + return &CustomChatModel{ |
| 48 | + wrapped: wrapped, |
| 49 | + }, nil |
| 50 | +} |
| 51 | + |
| 52 | +// RoundTrip implements http.RoundTripper to intercept and fix requests |
| 53 | +func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
| 54 | + // Only process Anthropic API requests |
| 55 | + if !strings.Contains(req.URL.Host, "anthropic.com") { |
| 56 | + return rt.wrapped.RoundTrip(req) |
| 57 | + } |
| 58 | + |
| 59 | + // Read the request body |
| 60 | + body, err := io.ReadAll(req.Body) |
| 61 | + if err != nil { |
| 62 | + return nil, err |
| 63 | + } |
| 64 | + req.Body = io.NopCloser(bytes.NewReader(body)) |
| 65 | + |
| 66 | + // Apply string-based fixes BEFORE JSON parsing for malformed patterns |
| 67 | + bodyStr := string(body) |
| 68 | + |
| 69 | + // Replace common malformed patterns - be more specific about context |
| 70 | + replacements := []struct { |
| 71 | + old string |
| 72 | + new string |
| 73 | + }{ |
| 74 | + // Handle input field in tool_use objects |
| 75 | + {`"input":,"name"`, `"input":{},"name"`}, |
| 76 | + {`"input":,"type"`, `"input":{},"type"`}, |
| 77 | + {`"input":}`, `"input":{}}`}, |
| 78 | + // Handle arguments field in function calls |
| 79 | + {`"arguments":,"name"`, `"arguments":"{}","name"`}, |
| 80 | + {`"arguments":,"type"`, `"arguments":"{}","type"`}, |
| 81 | + {`"arguments":}`, `"arguments":"{}"`}, |
| 82 | + // Fallback patterns (less specific) |
| 83 | + {`"input":,`, `"input":{}`}, |
| 84 | + {`"arguments":,`, `"arguments":"{}"`}, |
| 85 | + } |
| 86 | + |
| 87 | + for _, r := range replacements { |
| 88 | + if strings.Contains(bodyStr, r.old) { |
| 89 | + bodyStr = strings.ReplaceAll(bodyStr, r.old, r.new) |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + // Parse the JSON request (after string fixes) |
| 94 | + var requestData map[string]interface{} |
| 95 | + if err := json.Unmarshal([]byte(bodyStr), &requestData); err != nil { |
| 96 | + // Return the original request to avoid panic |
| 97 | + req.Body = io.NopCloser(bytes.NewReader(body)) |
| 98 | + req.ContentLength = int64(len(body)) |
| 99 | + return rt.wrapped.RoundTrip(req) |
| 100 | + } |
| 101 | + |
| 102 | + // Fix tool schemas if present |
| 103 | + if tools, ok := requestData["tools"].([]interface{}); ok { |
| 104 | + for _, tool := range tools { |
| 105 | + if toolMap, ok := tool.(map[string]interface{}); ok { |
| 106 | + if inputSchema, ok := toolMap["input_schema"].(map[string]interface{}); ok { |
| 107 | + // Ensure properties exists and is not null |
| 108 | + if properties, exists := inputSchema["properties"]; !exists || properties == nil { |
| 109 | + inputSchema["properties"] = map[string]interface{}{} |
| 110 | + } else if propertiesMap, ok := properties.(map[string]interface{}); ok { |
| 111 | + // Ensure each property has a type |
| 112 | + for _, propValue := range propertiesMap { |
| 113 | + if propMap, ok := propValue.(map[string]interface{}); ok { |
| 114 | + if _, hasType := propMap["type"]; !hasType { |
| 115 | + propMap["type"] = "string" |
| 116 | + } |
| 117 | + } |
| 118 | + } |
| 119 | + } |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + // Fix tool_use content in messages if present |
| 126 | + if messages, ok := requestData["messages"].([]interface{}); ok { |
| 127 | + for _, message := range messages { |
| 128 | + if msgMap, ok := message.(map[string]interface{}); ok { |
| 129 | + if content, ok := msgMap["content"].([]interface{}); ok { |
| 130 | + for _, contentItem := range content { |
| 131 | + if contentMap, ok := contentItem.(map[string]interface{}); ok { |
| 132 | + if contentType, ok := contentMap["type"].(string); ok && contentType == "tool_use" { |
| 133 | + // Ensure tool_use input is valid JSON |
| 134 | + if input, exists := contentMap["input"]; exists { |
| 135 | + // If input is nil or empty, set it to an empty object |
| 136 | + if input == nil { |
| 137 | + contentMap["input"] = map[string]interface{}{} |
| 138 | + } else if inputBytes, ok := input.(json.RawMessage); ok { |
| 139 | + if len(inputBytes) == 0 { |
| 140 | + contentMap["input"] = map[string]interface{}{} |
| 141 | + } else { |
| 142 | + // Validate that it's valid JSON |
| 143 | + var temp interface{} |
| 144 | + if err := json.Unmarshal(inputBytes, &temp); err != nil { |
| 145 | + contentMap["input"] = map[string]interface{}{} |
| 146 | + } |
| 147 | + } |
| 148 | + } else if inputStr, ok := input.(string); ok { |
| 149 | + // Handle string inputs that might be empty or invalid JSON |
| 150 | + if inputStr == "" || inputStr == "{}" { |
| 151 | + contentMap["input"] = map[string]interface{}{} |
| 152 | + } else { |
| 153 | + // Try to parse as JSON |
| 154 | + var temp interface{} |
| 155 | + if err := json.Unmarshal([]byte(inputStr), &temp); err != nil { |
| 156 | + contentMap["input"] = map[string]interface{}{} |
| 157 | + } |
| 158 | + } |
| 159 | + } |
| 160 | + } else { |
| 161 | + // If input field doesn't exist, add it as empty object |
| 162 | + contentMap["input"] = map[string]interface{}{} |
| 163 | + } |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | + } |
| 168 | + } |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + // Marshal the fixed request back to JSON |
| 173 | + fixedBody, err := json.Marshal(requestData) |
| 174 | + if err != nil { |
| 175 | + return nil, err |
| 176 | + } |
| 177 | + |
| 178 | + // Use the fixed body from JSON marshaling |
| 179 | + finalBodyStr := string(fixedBody) |
| 180 | + |
| 181 | + // Validate the final JSON |
| 182 | + var finalCheck interface{} |
| 183 | + if err := json.Unmarshal([]byte(finalBodyStr), &finalCheck); err != nil { |
| 184 | + return nil, err |
| 185 | + } |
| 186 | + |
| 187 | + // Create new request with fixed body |
| 188 | + req.Body = io.NopCloser(strings.NewReader(finalBodyStr)) |
| 189 | + req.ContentLength = int64(len(finalBodyStr)) |
| 190 | + // Make the actual request |
| 191 | + return rt.wrapped.RoundTrip(req) |
| 192 | +} |
| 193 | + |
| 194 | +// Generate implements the model.BaseChatModel interface |
| 195 | +func (m *CustomChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { |
| 196 | + return m.wrapped.Generate(ctx, input, opts...) |
| 197 | +} |
| 198 | + |
| 199 | +// Stream implements the model.BaseChatModel interface |
| 200 | +func (m *CustomChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { |
| 201 | + return m.wrapped.Stream(ctx, input, opts...) |
| 202 | +} |
| 203 | + |
| 204 | +// WithTools implements the model.ToolCallingChatModel interface |
| 205 | +func (m *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { |
| 206 | + wrappedWithTools, err := m.wrapped.WithTools(tools) |
| 207 | + if err != nil { |
| 208 | + return nil, err |
| 209 | + } |
| 210 | + |
| 211 | + return &CustomChatModel{ |
| 212 | + wrapped: wrappedWithTools.(*einoclaude.ChatModel), |
| 213 | + }, nil |
| 214 | +} |
0 commit comments