diff --git a/pkgs/dart_mcp/CHANGELOG.md b/pkgs/dart_mcp/CHANGELOG.md index d76e0a41..011464e9 100644 --- a/pkgs/dart_mcp/CHANGELOG.md +++ b/pkgs/dart_mcp/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.2.3-wip + +- Added error checking to required fields of all `Request` subclasses so that + they will throw helpful errors when accessed and not set. + ## 0.2.2 - Refactor `ClientImplementation` and `ServerImplementation` to the shared diff --git a/pkgs/dart_mcp/lib/src/api/completions.dart b/pkgs/dart_mcp/lib/src/api/completions.dart index a8768b29..c5072f93 100644 --- a/pkgs/dart_mcp/lib/src/api/completions.dart +++ b/pkgs/dart_mcp/lib/src/api/completions.dart @@ -27,11 +27,22 @@ extension type CompleteRequest.fromMap(Map _value) /// /// In the case of a [ResourceReference], it must refer to a /// [ResourceTemplate]. - Reference get ref => _value['ref'] as Reference; + Reference get ref { + final ref = _value['ref'] as Reference?; + if (ref == null) { + throw ArgumentError('Missing ref field in $CompleteRequest.'); + } + return ref; + } /// The argument's information. - CompletionArgument get argument => - (_value['argument'] as Map).cast() as CompletionArgument; + CompletionArgument get argument { + final argument = _value['argument'] as CompletionArgument?; + if (argument == null) { + throw ArgumentError('Missing argument field in $CompleteRequest.'); + } + return argument; + } } /// The server's response to a completion/complete request diff --git a/pkgs/dart_mcp/lib/src/api/initialization.dart b/pkgs/dart_mcp/lib/src/api/initialization.dart index ec6acba9..80da37c4 100644 --- a/pkgs/dart_mcp/lib/src/api/initialization.dart +++ b/pkgs/dart_mcp/lib/src/api/initialization.dart @@ -30,10 +30,21 @@ extension type InitializeRequest._fromMap(Map _value) ProtocolVersion? get protocolVersion => ProtocolVersion.tryParse(_value['protocolVersion'] as String); - ClientCapabilities get capabilities => - _value['capabilities'] as ClientCapabilities; + ClientCapabilities get capabilities { + final capabilities = _value['capabilities'] as ClientCapabilities?; + if (capabilities == null) { + throw ArgumentError('Missing capabilities field in $InitializeRequest.'); + } + return capabilities; + } - Implementation get clientInfo => _value['clientInfo'] as Implementation; + Implementation get clientInfo { + final clientInfo = _value['clientInfo'] as Implementation?; + if (clientInfo == null) { + throw ArgumentError('Missing clientInfo field in $InitializeRequest.'); + } + return clientInfo; + } } /// After receiving an initialize request from the client, the server sends diff --git a/pkgs/dart_mcp/lib/src/api/logging.dart b/pkgs/dart_mcp/lib/src/api/logging.dart index ccee6105..596478fc 100644 --- a/pkgs/dart_mcp/lib/src/api/logging.dart +++ b/pkgs/dart_mcp/lib/src/api/logging.dart @@ -26,8 +26,18 @@ extension type SetLevelRequest.fromMap(Map _value) /// /// The server should send all logs at this level and higher (i.e., more /// severe) to the client as notifications/message. - LoggingLevel get level => - LoggingLevel.values.firstWhere((level) => level.name == _value['level']); + LoggingLevel get level { + final levelName = _value['level']; + final foundLevel = LoggingLevel.values.firstWhereOrNull( + (level) => level.name == levelName, + ); + if (foundLevel == null) { + throw ArgumentError( + "Invalid level field in $SetLevelRequest: didn't find level $levelName", + ); + } + return foundLevel; + } } /// Notification of a log message passed from server to client. diff --git a/pkgs/dart_mcp/lib/src/api/prompts.dart b/pkgs/dart_mcp/lib/src/api/prompts.dart index 49ab537e..a658725b 100644 --- a/pkgs/dart_mcp/lib/src/api/prompts.dart +++ b/pkgs/dart_mcp/lib/src/api/prompts.dart @@ -49,7 +49,13 @@ extension type GetPromptRequest.fromMap(Map _value) }); /// The name of the prompt or prompt template. - String get name => _value['name'] as String; + String get name { + final name = _value['name'] as String?; + if (name == null) { + throw ArgumentError('Missing name field in $GetPromptRequest.'); + } + return name; + } /// Arguments to use for templating the prompt. Map? get arguments => diff --git a/pkgs/dart_mcp/lib/src/api/resources.dart b/pkgs/dart_mcp/lib/src/api/resources.dart index 0a4baa87..5b6da075 100644 --- a/pkgs/dart_mcp/lib/src/api/resources.dart +++ b/pkgs/dart_mcp/lib/src/api/resources.dart @@ -80,7 +80,13 @@ extension type ReadResourceRequest.fromMap(Map _value) /// The URI of the resource to read. The URI can use any protocol; it is /// up to the server how to interpret it. - String get uri => _value['uri'] as String; + String get uri { + final uri = _value['uri'] as String?; + if (uri == null) { + throw ArgumentError('Missing uri field in $ReadResourceRequest.'); + } + return uri; + } } /// The server's response to a resources/read request from the client. @@ -128,7 +134,13 @@ extension type SubscribeRequest.fromMap(Map _value) /// The URI of the resource to subscribe to. The URI can use any protocol; /// it is up to the server how to interpret it. - String get uri => _value['uri'] as String; + String get uri { + final uri = _value['uri'] as String?; + if (uri == null) { + throw ArgumentError('Missing uri field in $SubscribeRequest.'); + } + return uri; + } } /// Sent from the client to request cancellation of resources/updated @@ -146,7 +158,13 @@ extension type UnsubscribeRequest.fromMap(Map _value) UnsubscribeRequest.fromMap({'uri': uri, if (meta != null) '_meta': meta}); /// The URI of the resource to unsubscribe from. - String get uri => _value['uri'] as String; + String get uri { + final uri = _value['uri'] as String?; + if (uri == null) { + throw ArgumentError('Missing uri field in $UnsubscribeRequest.'); + } + return uri; + } } /// A notification from the server to the client, informing it that a resource diff --git a/pkgs/dart_mcp/lib/src/api/roots.dart b/pkgs/dart_mcp/lib/src/api/roots.dart index c8221eff..f7194a36 100644 --- a/pkgs/dart_mcp/lib/src/api/roots.dart +++ b/pkgs/dart_mcp/lib/src/api/roots.dart @@ -33,7 +33,13 @@ extension type ListRootsResult.fromMap(Map _value) if (meta != null) '_meta': meta, }); - List get roots => (_value['roots'] as List).cast(); + List get roots { + final roots = _value['roots'] as List?; + if (roots == null) { + throw ArgumentError('Missing roots field in $ListRootsResult.'); + } + return roots.cast(); + } } /// Represents a root directory or file that the server can operate on. diff --git a/pkgs/dart_mcp/lib/src/api/sampling.dart b/pkgs/dart_mcp/lib/src/api/sampling.dart index 311795de..8cbe985b 100644 --- a/pkgs/dart_mcp/lib/src/api/sampling.dart +++ b/pkgs/dart_mcp/lib/src/api/sampling.dart @@ -39,8 +39,13 @@ extension type CreateMessageRequest.fromMap(Map _value) }); /// The messages to send to the LLM. - List get messages => - (_value['messages'] as List).cast(); + List get messages { + final messages = _value['messages'] as List?; + if (messages == null) { + throw ArgumentError('Missing messages field in $CreateMessageRequest.'); + } + return messages.cast(); + } /// The server's preferences for which model to select. /// @@ -69,7 +74,13 @@ extension type CreateMessageRequest.fromMap(Map _value) /// The maximum number of tokens to sample, as requested by the server. /// /// The client MAY choose to sample fewer tokens than requested. - int get maxTokens => _value['maxTokens'] as int; + int get maxTokens { + final maxTokens = _value['maxTokens'] as int?; + if (maxTokens == null) { + throw ArgumentError('Missing maxTokens field in $CreateMessageRequest.'); + } + return maxTokens; + } /// Note: This has no documentation in the specification or schema. List? get stopSequences => diff --git a/pkgs/dart_mcp/lib/src/api/tools.dart b/pkgs/dart_mcp/lib/src/api/tools.dart index d67b012c..6b90833b 100644 --- a/pkgs/dart_mcp/lib/src/api/tools.dart +++ b/pkgs/dart_mcp/lib/src/api/tools.dart @@ -80,7 +80,13 @@ extension type CallToolRequest._fromMap(Map _value) }); /// The name of the method to invoke. - String get name => _value['name'] as String; + String get name { + final name = _value['name'] as String?; + if (name == null) { + throw ArgumentError('Missing name field in $CallToolRequest'); + } + return name; + } /// The arguments to pass to the method. Map? get arguments => diff --git a/pkgs/dart_mcp/lib/src/shared.dart b/pkgs/dart_mcp/lib/src/shared.dart index fa6a5b52..3b348ec9 100644 --- a/pkgs/dart_mcp/lib/src/shared.dart +++ b/pkgs/dart_mcp/lib/src/shared.dart @@ -79,10 +79,15 @@ base class MCPBase { void registerRequestHandler( String name, FutureOr Function(T) impl, - ) => _peer.registerMethod( - name, - (Parameters p) => impl((p.value as Map?)?.cast() as T), - ); + ) => _peer.registerMethod(name, (Parameters p) { + if (p.value != null && p.value is! Map) { + throw ArgumentError( + 'Request to $name must be a Map or null. Instead, got ' + '${p.value.runtimeType}', + ); + } + return impl((p.value as Map?)?.cast() as T); + }); /// Registers a notification handler named [name] on this server. void registerNotificationHandler( diff --git a/pkgs/dart_mcp/pubspec.yaml b/pkgs/dart_mcp/pubspec.yaml index 0a09b1c8..3a98683f 100644 --- a/pkgs/dart_mcp/pubspec.yaml +++ b/pkgs/dart_mcp/pubspec.yaml @@ -1,5 +1,5 @@ name: dart_mcp -version: 0.2.2 +version: 0.2.3-wip description: A package for making MCP servers and clients. repository: https://github.com/dart-lang/ai/tree/main/pkgs/dart_mcp issue_tracker: https://github.com/dart-lang/ai/issues?q=is%3Aissue+is%3Aopen+label%3Apackage%3Adart_mcp @@ -10,7 +10,7 @@ environment: dependencies: async: ^2.13.0 collection: ^1.19.1 - json_rpc_2: '>=3.0.3 <5.0.0' + json_rpc_2: ">=3.0.3 <5.0.0" meta: ^1.16.0 stream_channel: ^2.1.4 stream_transform: ^2.1.1 diff --git a/pkgs/dart_mcp/test/api/api_test.dart b/pkgs/dart_mcp/test/api/api_test.dart index 8867f98e..faddb6b6 100644 --- a/pkgs/dart_mcp/test/api/api_test.dart +++ b/pkgs/dart_mcp/test/api/api_test.dart @@ -58,4 +58,51 @@ void main() { false, ); }); + + group('API object validation', () { + test('throws when required fields are missing', () { + final empty = {}; + + // Initialization + expect( + () => (empty as InitializeRequest).capabilities, + throwsArgumentError, + ); + expect( + () => (empty as InitializeRequest).clientInfo, + throwsArgumentError, + ); + + // Tools + expect(() => (empty as CallToolRequest).name, throwsArgumentError); + + // Resources + expect(() => (empty as ReadResourceRequest).uri, throwsArgumentError); + expect(() => (empty as SubscribeRequest).uri, throwsArgumentError); + expect(() => (empty as UnsubscribeRequest).uri, throwsArgumentError); + + // Roots + expect(() => (empty as ListRootsResult).roots, throwsArgumentError); + + // Prompts + expect(() => (empty as GetPromptRequest).name, throwsArgumentError); + + // Completions + expect(() => (empty as CompleteRequest).ref, throwsArgumentError); + expect(() => (empty as CompleteRequest).argument, throwsArgumentError); + + // Logging + expect(() => (empty as SetLevelRequest).level, throwsArgumentError); + + // Sampling + expect( + () => (empty as CreateMessageRequest).messages, + throwsArgumentError, + ); + expect( + () => (empty as CreateMessageRequest).maxTokens, + throwsArgumentError, + ); + }); + }); }