Skip to content

Commit ccc8a26

Browse files
teryltTeryl Taylorcrivetimihai
authored
feat: add context sharing across hook type. (#1514)
* feat: add context sharing across hook type. Signed-off-by: Teryl Taylor <[email protected]> * docs: updated docs, added logging and tests. Signed-off-by: Teryl Taylor <[email protected]> * fix: update test assertions for cross-hook context sharing Update unit test assertions to use ANY for plugin_global_context parameter since the HttpAuthMiddleware now correctly creates and stores a GlobalContext in request.state for cross-hook context sharing. Also fix integration test bugs where service constructors were incorrectly passed plugin_manager as a keyword argument (services get the plugin manager from a global singleton, not constructor). Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: Teryl Taylor <[email protected]> Signed-off-by: Mihai Criveti <[email protected]> Co-authored-by: Teryl Taylor <[email protected]> Co-authored-by: Mihai Criveti <[email protected]>
1 parent 4b4eb1c commit ccc8a26

File tree

14 files changed

+790
-59
lines changed

14 files changed

+790
-59
lines changed

docs/docs/using/plugins/index.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ class MyPlugin(Plugin):
863863

864864
### Plugin Context and State
865865

866-
Each hook function has a `context` object of type `PluginContext` which is designed to allow plugins to pass state between one another (across pre/post hook pairs) or for a plugin to pass state information to itself across pre/post hook pairs. The plugin context looks as follows:
866+
Each hook function has a `context` object of type `PluginContext` which is designed to allow plugins to pass state between one another across all hook types in a request, or for a plugin to pass state information to itself across different hooks. The plugin context looks as follows:
867867

868868
```python
869869
class GlobalContext(BaseModel):
@@ -900,10 +900,9 @@ class PluginContext(BaseModel):
900900
metadata: dict[str, Any] = Field(default_factory=dict)
901901
```
902902

903-
As can be seen, the `PluginContext` has both a `state` dictionary and a `global_context` object that also has a `state` dictionary. A single plugin can share state between pre/post hook pairs by using the
904-
the `PluginContext` state dictionary. It can share state with other plugins using the `context.global_context.state` dictionary. Metadata for the specific hook site is passed in through the `metadata` dictionaries in the `context.global_context.metadata`. It is meant to be read-only. The `context.metadata` is plugin specific metadata and can be used to store metadata information such as timing information.
903+
As can be seen, the `PluginContext` has both a `state` dictionary and a `global_context` object that also has a `state` dictionary. A single plugin can share state across all hooks in a request by using the `PluginContext` state dictionary. It can share state with other plugins using the `context.global_context.state` dictionary. Metadata for the specific hook site is passed in through the `metadata` dictionaries in the `context.global_context.metadata`. It is meant to be read-only. The `context.metadata` is plugin specific metadata and can be used to store metadata information such as timing information.
905904

906-
The following shows how plugins can maintain state between pre/post hooks:
905+
The following shows how plugins can maintain state across different hooks:
907906

908907
```python
909908
async def prompt_pre_fetch(self, payload, context):
@@ -926,7 +925,7 @@ async def prompt_post_fetch(self, payload, context):
926925

927926
#### Tool and Gateway Metadata
928927

929-
Currently, the tool pre/post hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows:
928+
Tool hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows:
930929

931930
It can be accessed inside of the tool hooks through:
932931

mcpgateway/auth.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,25 @@ async def get_current_user(
170170
headers = dict(request.headers)
171171

172172
# Get request ID from request state (set by middleware) or generate new one
173-
request_id = None
174-
if request and hasattr(request, "state") and hasattr(request.state, "request_id"):
175-
request_id = request.state.request_id
176-
else:
173+
request_id = getattr(request.state, "request_id", None) if request else None
174+
if not request_id:
177175
request_id = uuid.uuid4().hex
178176

179-
# Create global context
180-
global_context = GlobalContext(
181-
request_id=request_id,
182-
server_id=None,
183-
tenant_id=None,
184-
)
177+
# Get plugin contexts from request state if available
178+
global_context = getattr(request.state, "plugin_global_context", None) if request else None
179+
if not global_context:
180+
# Create global context
181+
global_context = GlobalContext(
182+
request_id=request_id,
183+
server_id=None,
184+
tenant_id=None,
185+
)
186+
187+
context_table = getattr(request.state, "plugin_context_table", None) if request else None
185188

186189
# Invoke custom auth resolution hook
187190
# violations_as_exceptions=True so PluginViolationError is raised for explicit denials
188-
auth_result, _ = await plugin_manager.invoke_hook(
191+
auth_result, context_table_result = await plugin_manager.invoke_hook(
189192
HttpHookType.HTTP_AUTH_RESOLVE_USER,
190193
payload=HttpAuthResolveUserPayload(
191194
credentials=credentials_dict,
@@ -194,7 +197,7 @@ async def get_current_user(
194197
client_port=client_port,
195198
),
196199
global_context=global_context,
197-
local_contexts=None,
200+
local_contexts=context_table,
198201
violations_as_exceptions=True, # Raise PluginViolationError for auth denials
199202
)
200203

@@ -215,12 +218,17 @@ async def get_current_user(
215218
)
216219

217220
# Store auth_method in request.state so it can be accessed by RBAC middleware
218-
if request and hasattr(request, "state") and auth_result.metadata:
221+
if request and auth_result.metadata:
219222
auth_method = auth_result.metadata.get("auth_method")
220223
if auth_method:
221224
request.state.auth_method = auth_method
222225
logger.debug(f"Stored auth_method '{auth_method}' in request.state")
223226

227+
if request and context_table_result:
228+
request.state.plugin_context_table = context_table_result
229+
230+
if request and global_context:
231+
request.state.plugin_global_context = global_context
224232
return user
225233
# If continue_processing=True (no payload), fall through to standard auth
226234

@@ -294,7 +302,7 @@ async def get_current_user(
294302

295303
# Check team level token, if applicable. If public token, then will be defaulted to personal team.
296304
team_id = await get_team_from_token(payload, db)
297-
if request and hasattr(request, "state"):
305+
if request:
298306
request.state.team_id = team_id
299307

300308
except HTTPException:

mcpgateway/main.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2960,9 +2960,21 @@ async def read_resource(resource_id: str, request: Request, db: Session = Depend
29602960
if cached := resource_cache.get(resource_id):
29612961
return cached
29622962

2963+
# Get plugin contexts from request.state for cross-hook sharing
2964+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
2965+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
2966+
29632967
try:
29642968
# Call service with context for plugin support
2965-
content = await resource_service.read_resource(db, resource_id=resource_id, request_id=request_id, user=user, server_id=server_id)
2969+
content = await resource_service.read_resource(
2970+
db,
2971+
resource_id=resource_id,
2972+
request_id=request_id,
2973+
user=user,
2974+
server_id=server_id,
2975+
plugin_context_table=plugin_context_table,
2976+
plugin_global_context=plugin_global_context,
2977+
)
29662978
except (ResourceNotFoundError, ResourceError) as exc:
29672979
# Translate to FastAPI HTTP error
29682980
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
@@ -3289,6 +3301,7 @@ async def create_prompt(
32893301
@prompt_router.post("/{prompt_id}")
32903302
@require_permission("prompts.read")
32913303
async def get_prompt(
3304+
request: Request,
32923305
prompt_id: str,
32933306
args: Dict[str, str] = Body({}),
32943307
db: Session = Depends(get_db),
@@ -3301,6 +3314,7 @@ async def get_prompt(
33013314
33023315
33033316
Args:
3317+
request: FastAPI request object.
33043318
prompt_id: ID of the prompt.
33053319
args: Template arguments.
33063320
db: Database session.
@@ -3314,9 +3328,19 @@ async def get_prompt(
33143328
"""
33153329
logger.debug(f"User: {user} requested prompt: {prompt_id} with args={args}")
33163330

3331+
# Get plugin contexts from request.state for cross-hook sharing
3332+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
3333+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
3334+
33173335
try:
33183336
PromptExecuteArgs(args=args)
3319-
result = await prompt_service.get_prompt(db, prompt_id, args)
3337+
result = await prompt_service.get_prompt(
3338+
db,
3339+
prompt_id,
3340+
args,
3341+
plugin_context_table=plugin_context_table,
3342+
plugin_global_context=plugin_global_context,
3343+
)
33203344
logger.debug(f"Prompt execution successful for '{prompt_id}'")
33213345
except Exception as ex:
33223346
logger.error(f"Could not retrieve prompt {prompt_id}: {ex}")
@@ -3334,6 +3358,7 @@ async def get_prompt(
33343358
@prompt_router.get("/{prompt_id}")
33353359
@require_permission("prompts.read")
33363360
async def get_prompt_no_args(
3361+
request: Request,
33373362
prompt_id: str,
33383363
db: Session = Depends(get_db),
33393364
user=Depends(get_current_user_with_permissions),
@@ -3343,6 +3368,7 @@ async def get_prompt_no_args(
33433368
This endpoint is for convenience when no arguments are needed.
33443369
33453370
Args:
3371+
request: FastAPI request object.
33463372
prompt_id: The ID of the prompt to retrieve
33473373
db: Database session
33483374
user: Authenticated user
@@ -3354,7 +3380,18 @@ async def get_prompt_no_args(
33543380
Exception: Re-raised from prompt service.
33553381
"""
33563382
logger.debug(f"User: {user} requested prompt: {prompt_id} with no arguments")
3357-
return await prompt_service.get_prompt(db, prompt_id, {})
3383+
3384+
# Get plugin contexts from request.state for cross-hook sharing
3385+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
3386+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
3387+
3388+
return await prompt_service.get_prompt(
3389+
db,
3390+
prompt_id,
3391+
{},
3392+
plugin_context_table=plugin_context_table,
3393+
plugin_global_context=plugin_global_context,
3394+
)
33583395

33593396

33603397
@prompt_router.put("/{prompt_id}", response_model=PromptRead)
@@ -3921,8 +3958,18 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
39213958
raise JSONRPCError(-32602, "Missing resource URI in parameters", params)
39223959
# Get user email for OAuth token selection
39233960
user_email = get_user_email(user)
3961+
# Get plugin contexts from request.state for cross-hook sharing
3962+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
3963+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
39243964
try:
3925-
result = await resource_service.read_resource(db, uri, request_id=request_id, user=user_email)
3965+
result = await resource_service.read_resource(
3966+
db,
3967+
resource_uri=uri,
3968+
request_id=request_id,
3969+
user=user_email,
3970+
plugin_context_table=plugin_context_table,
3971+
plugin_global_context=plugin_global_context,
3972+
)
39263973
if hasattr(result, "model_dump"):
39273974
result = {"contents": [result.model_dump(by_alias=True, exclude_none=True)]}
39283975
else:
@@ -3966,7 +4013,16 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
39664013
arguments = params.get("arguments", {})
39674014
if not name:
39684015
raise JSONRPCError(-32602, "Missing prompt name in parameters", params)
3969-
result = await prompt_service.get_prompt(db, name, arguments)
4016+
# Get plugin contexts from request.state for cross-hook sharing
4017+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
4018+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
4019+
result = await prompt_service.get_prompt(
4020+
db,
4021+
name,
4022+
arguments,
4023+
plugin_context_table=plugin_context_table,
4024+
plugin_global_context=plugin_global_context,
4025+
)
39704026
if hasattr(result, "model_dump"):
39714027
result = result.model_dump(by_alias=True, exclude_none=True)
39724028
elif method == "ping":
@@ -3981,8 +4037,19 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
39814037
raise JSONRPCError(-32602, "Missing tool name in parameters", params)
39824038
# Get user email for OAuth token selection
39834039
user_email = get_user_email(user)
4040+
# Get plugin contexts from request.state for cross-hook sharing
4041+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
4042+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
39844043
try:
3985-
result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments, request_headers=headers, app_user_email=user_email)
4044+
result = await tool_service.invoke_tool(
4045+
db=db,
4046+
name=name,
4047+
arguments=arguments,
4048+
request_headers=headers,
4049+
app_user_email=user_email,
4050+
plugin_context_table=plugin_context_table,
4051+
plugin_global_context=plugin_global_context,
4052+
)
39864053
if hasattr(result, "model_dump"):
39874054
result = result.model_dump(by_alias=True, exclude_none=True)
39884055
except ValueError:

mcpgateway/middleware/http_auth_middleware.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ async def dispatch(self, request: Request, call_next):
9595
violations_as_exceptions=False, # Don't block on pre-request violations
9696
)
9797

98+
if context_table:
99+
request.state.plugin_context_table = context_table
100+
101+
if global_context:
102+
request.state.plugin_global_context = global_context
103+
98104
# Apply modified headers if plugin returned them
99105
if pre_result.modified_payload:
100106
# Modify request headers by updating request.scope["headers"]

mcpgateway/middleware/rbac.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ async def protected_route(user = Depends(get_current_user_with_permissions)):
150150
request_id = getattr(request.state, "request_id", None)
151151
team_id = getattr(request.state, "team_id", None)
152152

153+
# Read plugin context data from request.state for cross-hook context sharing
154+
# (set by HttpAuthMiddleware for passing contexts between different hook types)
155+
plugin_context_table = getattr(request.state, "plugin_context_table", None)
156+
plugin_global_context = getattr(request.state, "plugin_global_context", None)
157+
153158
# Add request context for permission auditing
154159
return {
155160
"email": user.email,
@@ -161,6 +166,8 @@ async def protected_route(user = Depends(get_current_user_with_permissions)):
161166
"auth_method": auth_method, # Include auth_method from plugin
162167
"request_id": request_id, # Include request_id from middleware
163168
"team_id": team_id, # Include team_id from token
169+
"plugin_context_table": plugin_context_table, # Plugin contexts for cross-hook sharing
170+
"plugin_global_context": plugin_global_context, # Global context for consistency
164171
}
165172
except Exception as e:
166173
logger.error(f"Authentication failed: {type(e).__name__}: {e}")
@@ -256,18 +263,24 @@ async def wrapper(*args, **kwargs):
256263

257264
plugin_manager = get_plugin_manager()
258265
if plugin_manager:
259-
# Get request_id from user_context (passed from get_current_user_with_permissions)
260-
# Generate a fallback if not present
261-
request_id = user_context.get("request_id") or uuid.uuid4().hex
262-
263-
# Create global context for plugin invocation
264-
global_context = GlobalContext(
265-
request_id=request_id,
266-
server_id=None,
267-
tenant_id=None,
268-
)
266+
# Get plugin contexts from user_context (stored in request.state by HttpAuthMiddleware)
267+
# These enable cross-hook context sharing between HTTP_PRE_REQUEST and HTTP_AUTH_CHECK_PERMISSION
268+
plugin_context_table = user_context.get("plugin_context_table")
269+
plugin_global_context = user_context.get("plugin_global_context")
270+
271+
# Reuse existing global context from middleware if available for consistency
272+
# Otherwise create a new one (fallback for cases where middleware didn't run)
273+
if plugin_global_context:
274+
global_context = plugin_global_context
275+
else:
276+
request_id = user_context.get("request_id") or uuid.uuid4().hex
277+
global_context = GlobalContext(
278+
request_id=request_id,
279+
server_id=None,
280+
tenant_id=None,
281+
)
269282

270-
# Invoke permission check hook
283+
# Invoke permission check hook, passing plugin contexts from HTTP_PRE_REQUEST hook
271284
result, _ = await plugin_manager.invoke_hook(
272285
HttpHookType.HTTP_AUTH_CHECK_PERMISSION,
273286
payload=HttpAuthCheckPermissionPayload(
@@ -281,6 +294,7 @@ async def wrapper(*args, **kwargs):
281294
user_agent=user_context.get("user_agent"),
282295
),
283296
global_context=global_context,
297+
local_contexts=plugin_context_table, # Pass context table for cross-hook state
284298
)
285299

286300
# If a plugin made a decision, respect it

mcpgateway/plugins/framework/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
PluginCondition,
5555
PluginConfig,
5656
PluginContext,
57+
PluginContextTable,
5758
PluginErrorModel,
5859
PluginMode,
5960
PluginPayload,
@@ -120,6 +121,7 @@ def get_plugin_manager() -> Optional[PluginManager]:
120121
"PluginCondition",
121122
"PluginConfig",
122123
"PluginContext",
124+
"PluginContextTable",
123125
"PluginError",
124126
"PluginErrorModel",
125127
"PluginLoader",

0 commit comments

Comments
 (0)