Skip to content

Commit 1e6029d

Browse files
committed
The api should not start an oauth callback server
In api mode we are not the ones who manage the callback server, up to the caller to do it for us and tell us when the auth is granted Signed-off-by: Djordje Lukic <[email protected]>
1 parent facbc8d commit 1e6029d

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

pkg/oauth/manager.go

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type manager struct {
2020
serverMutex sync.Mutex
2121
redirectURI string
2222
port int
23+
managedServer bool
2324
}
2425

2526
// NewManager creates a new OAuth manager with optional port configuration
@@ -29,6 +30,7 @@ func NewManager(emitAuthRequired func(serverURL, serverType, status string), opt
2930
resumeAuthorizeOauthFlow: make(chan bool),
3031
resumeOauthCodeReceived: make(chan string),
3132
port: 8083,
33+
managedServer: true,
3234
}
3335

3436
// Apply options
@@ -61,6 +63,12 @@ func WithRedirectURI(uri string) ManagerOption {
6163
}
6264
}
6365

66+
func WithManagedServer(managed bool) ManagerOption {
67+
return func(m *manager) {
68+
m.managedServer = managed
69+
}
70+
}
71+
6472
// HandleAuthorizationFlow handles a single OAuth authorization flow
6573
func (m *manager) HandleAuthorizationFlow(ctx context.Context, sessionID string, oauthErr *AuthorizationRequiredError) error {
6674
m.emitAuthRequired(oauthErr.ServerURL, oauthErr.ServerType, "pending")
@@ -176,37 +184,41 @@ func (m *manager) performOAuthAuthorization(ctx context.Context, sessionID strin
176184
slog.Warn("Failed to start callback server, falling back to manual input", "error", err)
177185
}
178186

179-
// Check if we have a callback server running (either global or our own)
180-
if callbackServer := m.getCallbackServer(); callbackServer != nil {
181-
slog.Debug("Using callback server for OAuth authorization")
182-
// Wait for callback from the browser
183-
callbackCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
184-
defer cancel()
187+
if m.managedServer {
188+
// Check if we have a callback server running (either global or our own)
189+
if callbackServer := m.getCallbackServer(); callbackServer != nil {
190+
slog.Debug("Using callback server for OAuth authorization")
191+
// Wait for callback from the browser
192+
callbackCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
193+
defer cancel()
194+
195+
result, err := callbackServer.WaitForCallback(callbackCtx)
196+
if err != nil {
197+
if err == context.DeadlineExceeded {
198+
return fmt.Errorf("OAuth authorization timed out after 5 minutes")
199+
}
200+
return fmt.Errorf("failed to wait for OAuth callback: %w", err)
201+
}
185202

186-
result, err := callbackServer.WaitForCallback(callbackCtx)
187-
if err != nil {
188-
if err == context.DeadlineExceeded {
189-
return fmt.Errorf("OAuth authorization timed out after 5 minutes")
203+
if result.Error != "" {
204+
return fmt.Errorf("OAuth authorization error: %s", result.Error)
190205
}
191-
return fmt.Errorf("failed to wait for OAuth callback: %w", err)
192-
}
193206

194-
if result.Error != "" {
195-
return fmt.Errorf("OAuth authorization error: %s", result.Error)
196-
}
207+
if result.Code == "" {
208+
return fmt.Errorf("no authorization code received from OAuth callback")
209+
}
197210

198-
if result.Code == "" {
199-
return fmt.Errorf("no authorization code received from OAuth callback")
200-
}
211+
// Verify state parameter matches
212+
receivedState := result.State
213+
if receivedState != state {
214+
slog.Warn("OAuth state mismatch", "expected", state, "received", receivedState)
215+
}
201216

202-
// Verify state parameter matches
203-
receivedState := result.State
204-
if receivedState != state {
205-
slog.Warn("OAuth state mismatch", "expected", state, "received", receivedState)
217+
code = result.Code
218+
slog.Debug("Received OAuth code via callback server", "code_present", code != "")
219+
} else {
220+
return fmt.Errorf("no callback server available for OAuth authorization")
206221
}
207-
208-
code = result.Code
209-
slog.Debug("Received OAuth code via callback server", "code_present", code != "")
210222
} else {
211223
// Fallback to manual input
212224
slog.Debug("No callback server available, waiting for manual input")

pkg/runtime/runtime.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ type runtime struct {
6767
tracer trace.Tracer
6868
modelsStore modelStore
6969
sessionCompaction bool
70+
managedOAuth bool
7071
}
7172

7273
type Opt func(*runtime)
@@ -77,6 +78,12 @@ func WithCurrentAgent(agentName string) Opt {
7778
}
7879
}
7980

81+
func WithManagedOAuth(managed bool) Opt {
82+
return func(r *runtime) {
83+
r.managedOAuth = managed
84+
}
85+
}
86+
8087
// WithTracer sets a custom OpenTelemetry tracer; if not provided, tracing is disabled (no-op).
8188
func WithTracer(t trace.Tracer) Opt {
8289
return func(r *runtime) {
@@ -110,6 +117,7 @@ func New(agents *team.Team, opts ...Opt) (Runtime, error) {
110117
resumeChan: make(chan ResumeType),
111118
modelsStore: modelsStore,
112119
sessionCompaction: true,
120+
managedOAuth: true,
113121
}
114122

115123
for _, opt := range opts {
@@ -139,7 +147,7 @@ func (r *runtime) handleOAuthAuthorizationFlow(ctx context.Context, sess *sessio
139147
emitAuthRequired := func(serverURL, serverType, status string) {
140148
events <- AuthorizationRequired(serverURL, serverType, status, r.currentAgent)
141149
}
142-
r.oauthManager = oauth.NewManager(emitAuthRequired)
150+
r.oauthManager = oauth.NewManager(emitAuthRequired, oauth.WithManagedServer(r.managedOAuth))
143151
defer func() {
144152
if cleanupErr := r.oauthManager.Cleanup(ctx); cleanupErr != nil {
145153
slog.Error("Failed to cleanup OAuth manager", "error", cleanupErr)

pkg/server/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ func (s *Server) runAgent(c echo.Context) error {
882882
if !exists {
883883
var opts []runtime.Opt = []runtime.Opt{
884884
runtime.WithCurrentAgent(currentAgent),
885+
runtime.WithManagedOAuth(false),
885886
}
886887
rt, err = runtime.New(t, opts...)
887888
if err != nil {

0 commit comments

Comments
 (0)