Skip to content

Commit 3faf46f

Browse files
committed
better token tracking and support for more openai models
1 parent cda80f1 commit 3faf46f

File tree

7 files changed

+62
-254
lines changed

7 files changed

+62
-254
lines changed

cmd/root.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
538538
break
539539
}
540540
}
541-
cli.UpdateUsage(lastUserMessage, response.Content)
541+
cli.UpdateUsageFromResponse(response, lastUserMessage)
542542
}
543543
} else if config.Quiet {
544544
// In quiet mode, only output the final response content to stdout

internal/models/providers.go

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
8787

8888
// validateModelConfig validates configuration parameters against model capabilities
8989
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
90-
// Check if temperature is supported
90+
// Omit temperature if not supported by the model
9191
if config.Temperature != nil && !modelInfo.Temperature {
92-
return fmt.Errorf("model %s does not support temperature parameter", modelInfo.ID)
92+
config.Temperature = nil
9393
}
9494

9595
// Warn about context limits if MaxTokens is set too high
@@ -162,16 +162,35 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
162162
openaiConfig.BaseURL = config.ProviderURL
163163
}
164164

165-
if config.MaxTokens > 0 {
166-
openaiConfig.MaxTokens = &config.MaxTokens
165+
// Check if this is a reasoning model to handle beta limitations
166+
registry := GetGlobalRegistry()
167+
isReasoningModel := false
168+
if modelInfo, err := registry.ValidateModel("openai", modelName); err == nil && modelInfo.Reasoning {
169+
isReasoningModel = true
167170
}
168171

169-
if config.Temperature != nil {
170-
openaiConfig.Temperature = config.Temperature
172+
if config.MaxTokens > 0 {
173+
if isReasoningModel {
174+
// For reasoning models, use MaxCompletionTokens instead of MaxTokens
175+
if openaiConfig.ExtraFields == nil {
176+
openaiConfig.ExtraFields = make(map[string]any)
177+
}
178+
openaiConfig.ExtraFields["max_completion_tokens"] = config.MaxTokens
179+
} else {
180+
// For non-reasoning models, use MaxTokens as usual
181+
openaiConfig.MaxTokens = &config.MaxTokens
182+
}
171183
}
172184

173-
if config.TopP != nil {
174-
openaiConfig.TopP = config.TopP
185+
// For reasoning models, skip temperature and top_p due to beta limitations
186+
if !isReasoningModel {
187+
if config.Temperature != nil {
188+
openaiConfig.Temperature = config.Temperature
189+
}
190+
191+
if config.TopP != nil {
192+
openaiConfig.TopP = config.TopP
193+
}
175194
}
176195

177196
if len(config.StopSequences) > 0 {

internal/tokens/anthropic.go

Lines changed: 1 addition & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1 @@
1-
package tokens
2-
3-
import (
4-
"bytes"
5-
"context"
6-
"encoding/json"
7-
"fmt"
8-
"io"
9-
"net/http"
10-
"strings"
11-
"time"
12-
)
13-
14-
// AnthropicTokenCounter implements token counting for Anthropic models
15-
type AnthropicTokenCounter struct {
16-
apiKey string
17-
httpClient *http.Client
18-
}
19-
20-
// NewAnthropicTokenCounter creates a new Anthropic token counter
21-
func NewAnthropicTokenCounter(apiKey string) *AnthropicTokenCounter {
22-
return &AnthropicTokenCounter{
23-
apiKey: apiKey,
24-
httpClient: &http.Client{
25-
Timeout: 10 * time.Second,
26-
},
27-
}
28-
}
29-
30-
// AnthropicTokenRequest represents the request payload for Anthropic token counting
31-
type AnthropicTokenRequest struct {
32-
Messages []Message `json:"messages"`
33-
Model string `json:"model"`
34-
}
35-
36-
// AnthropicTokenResponse represents the response from Anthropic token counting API
37-
type AnthropicTokenResponse struct {
38-
InputTokens int `json:"input_tokens"`
39-
}
40-
41-
// CountTokens counts tokens using Anthropic's token counting API
42-
func (a *AnthropicTokenCounter) CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error) {
43-
if a.apiKey == "" {
44-
return nil, fmt.Errorf("anthropic API key not provided")
45-
}
46-
47-
// Strip the anthropic: prefix if present
48-
actualModel := model
49-
if strings.HasPrefix(model, "anthropic:") {
50-
actualModel = strings.TrimPrefix(model, "anthropic:")
51-
}
52-
53-
// Prepare request payload
54-
request := AnthropicTokenRequest{
55-
Messages: messages,
56-
Model: actualModel,
57-
}
58-
59-
jsonData, err := json.Marshal(request)
60-
if err != nil {
61-
return nil, fmt.Errorf("failed to marshal request: %w", err)
62-
}
63-
64-
// Create HTTP request
65-
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewReader(jsonData))
66-
if err != nil {
67-
return nil, fmt.Errorf("failed to create request: %w", err)
68-
}
69-
70-
// Set headers
71-
req.Header.Set("Content-Type", "application/json")
72-
req.Header.Set("x-api-key", a.apiKey)
73-
req.Header.Set("anthropic-version", "2023-06-01")
74-
75-
// Make the request
76-
resp, err := a.httpClient.Do(req)
77-
if err != nil {
78-
return nil, fmt.Errorf("failed to make request: %w", err)
79-
}
80-
defer resp.Body.Close()
81-
82-
// Read response body
83-
body, err := io.ReadAll(resp.Body)
84-
if err != nil {
85-
return nil, fmt.Errorf("failed to read response: %w", err)
86-
}
87-
88-
// Check for HTTP errors
89-
if resp.StatusCode != http.StatusOK {
90-
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
91-
}
92-
93-
// Parse response
94-
var tokenResponse AnthropicTokenResponse
95-
if err := json.Unmarshal(body, &tokenResponse); err != nil {
96-
return nil, fmt.Errorf("failed to parse response: %w", err)
97-
}
98-
99-
return &TokenCount{
100-
InputTokens: tokenResponse.InputTokens,
101-
}, nil
102-
}
103-
104-
// SupportsModel returns true if this counter supports the given model
105-
func (a *AnthropicTokenCounter) SupportsModel(model string) bool {
106-
// Support all Anthropic models
107-
return strings.HasPrefix(model, "anthropic:")
108-
}
109-
110-
// ProviderName returns the name of the provider
111-
func (a *AnthropicTokenCounter) ProviderName() string {
112-
return "anthropic"
113-
}
1+
package tokens

internal/tokens/counter.go

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,7 @@
11
package tokens
22

3-
import (
4-
"context"
5-
)
6-
7-
// Message represents a message for token counting
8-
type Message struct {
9-
Role string `json:"role"`
10-
Content string `json:"content"`
11-
}
12-
13-
// TokenCount represents the result of token counting
14-
type TokenCount struct {
15-
InputTokens int `json:"input_tokens"`
16-
}
17-
18-
// TokenCounter interface for provider-specific token counting
19-
type TokenCounter interface {
20-
// CountTokens counts tokens for the given messages and model
21-
CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error)
22-
// SupportsModel returns true if this counter supports the given model
23-
SupportsModel(model string) bool
24-
// ProviderName returns the name of the provider this counter is for
25-
ProviderName() string
26-
}
27-
283
// EstimateTokens provides a rough estimate of tokens in text
29-
// This is a fallback when no provider-specific counter is available
304
func EstimateTokens(text string) int {
315
// Rough approximation: ~4 characters per token for most models
32-
// This is not accurate but gives a reasonable estimate
336
return len(text) / 4
34-
}
35-
36-
// EstimateTokensFromMessages estimates tokens from a slice of messages
37-
func EstimateTokensFromMessages(messages []Message) int {
38-
totalChars := 0
39-
for _, msg := range messages {
40-
totalChars += len(msg.Content)
41-
totalChars += len(msg.Role) + 10 // Add some overhead for role and formatting
42-
}
43-
return EstimateTokens(string(rune(totalChars)))
44-
}
45-
46-
// Registry holds all registered token counters
47-
type Registry struct {
48-
counters map[string]TokenCounter
49-
}
50-
51-
// NewRegistry creates a new token counter registry
52-
func NewRegistry() *Registry {
53-
return &Registry{
54-
counters: make(map[string]TokenCounter),
55-
}
56-
}
57-
58-
// Register adds a token counter to the registry
59-
func (r *Registry) Register(counter TokenCounter) {
60-
r.counters[counter.ProviderName()] = counter
61-
}
62-
63-
// GetCounter returns a token counter for the given provider
64-
func (r *Registry) GetCounter(provider string) (TokenCounter, bool) {
65-
counter, exists := r.counters[provider]
66-
return counter, exists
67-
}
68-
69-
// CountTokens attempts to count tokens using a provider-specific counter,
70-
// falling back to estimation if no counter is available
71-
func (r *Registry) CountTokens(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) {
72-
if counter, exists := r.GetCounter(provider); exists && counter.SupportsModel(model) {
73-
return counter.CountTokens(ctx, messages, model)
74-
}
75-
76-
// Fallback to estimation
77-
estimatedTokens := EstimateTokensFromMessages(messages)
78-
return &TokenCount{
79-
InputTokens: estimatedTokens,
80-
}, nil
81-
}
82-
83-
// Global registry instance
84-
var globalRegistry = NewRegistry()
85-
86-
// GetGlobalRegistry returns the global token counter registry
87-
func GetGlobalRegistry() *Registry {
88-
return globalRegistry
89-
}
90-
91-
// RegisterCounter registers a token counter with the global registry
92-
func RegisterCounter(counter TokenCounter) {
93-
globalRegistry.Register(counter)
94-
}
95-
96-
// CountTokensGlobal counts tokens using the global registry
97-
func CountTokensGlobal(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) {
98-
return globalRegistry.CountTokens(ctx, provider, messages, model)
997
}

internal/tokens/init.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,11 @@
11
package tokens
22

3-
import (
4-
"os"
5-
)
6-
73
// InitializeTokenCounters registers all available token counters
84
func InitializeTokenCounters() {
9-
// Register Anthropic token counter if API key is available
10-
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
11-
RegisterCounter(NewAnthropicTokenCounter(apiKey))
12-
}
5+
// Future provider-specific counters can be registered here
136
}
147

158
// InitializeTokenCountersWithKeys registers token counters with provided API keys
16-
func InitializeTokenCountersWithKeys(anthropicKey string) {
17-
if anthropicKey != "" {
18-
RegisterCounter(NewAnthropicTokenCounter(anthropicKey))
19-
}
9+
func InitializeTokenCountersWithKeys() {
10+
// Future provider-specific counters can be registered here
2011
}

internal/ui/cli.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package ui
22

33
import (
4-
"context"
54
"errors"
65
"fmt"
76
"io"
@@ -12,7 +11,6 @@ import (
1211
"github.com/charmbracelet/huh"
1312
"github.com/charmbracelet/lipgloss"
1413
"github.com/cloudwego/eino/schema"
15-
"github.com/mark3labs/mcphost/internal/tokens"
1614
"golang.org/x/term"
1715
)
1816

@@ -324,10 +322,30 @@ func (c *CLI) UpdateUsage(inputText, outputText string) {
324322
}
325323
}
326324

327-
// UpdateUsageWithMessages updates the usage tracker using custom token counting for messages
328-
func (c *CLI) UpdateUsageWithMessages(ctx context.Context, messages []tokens.Message, outputText string) {
329-
if c.usageTracker != nil {
330-
c.usageTracker.CountAndUpdateUsage(ctx, messages, outputText)
325+
326+
327+
// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata
328+
func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) {
329+
if c.usageTracker == nil {
330+
return
331+
}
332+
333+
// Try to extract token usage from response metadata
334+
if response.ResponseMeta != nil && response.ResponseMeta.Usage != nil {
335+
usage := response.ResponseMeta.Usage
336+
337+
// Use actual token counts from the response
338+
inputTokens := int(usage.PromptTokens)
339+
outputTokens := int(usage.CompletionTokens)
340+
341+
// Handle cache tokens if available (some providers support this)
342+
cacheReadTokens := 0
343+
cacheWriteTokens := 0
344+
345+
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
346+
} else {
347+
// Fallback to estimation if no metadata is available
348+
c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content)
331349
}
332350
}
333351

0 commit comments

Comments
 (0)