diff --git a/.changelog/1764888150.md b/.changelog/1764888150.md new file mode 100644 index 00000000000..8a50cc11a4f --- /dev/null +++ b/.changelog/1764888150.md @@ -0,0 +1,14 @@ +--- +applies_to: +- aws-sdk-rust +- client +authors: +- rcoh +- ysaito1001 +references: +- smithy-rs#4429 +breaking: false +new_feature: false +bug_fix: true +--- +Fix bug where initial-request messages in event stream operations are not signed. diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index ff7d4e66587..2d2e8c708b2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -64,6 +64,10 @@ private fun eventStreamWithInitialRequest( return null } + val smithyHttp = RuntimeType.smithyHttp(codegenContext.runtimeConfig) + val eventOrInitial = smithyHttp.resolve("event_stream::EventOrInitial") + val eventOrInitialMarshaller = smithyHttp.resolve("event_stream::EventOrInitialMarshaller") + return writable { rustTemplate( """ @@ -71,21 +75,32 @@ private fun eventStreamWithInitialRequest( use #{futures_util}::StreamExt; let body = #{parser}(&input)?; let initial_message = #{initial_message}(body); - let mut buffer = #{Vec}::new(); - #{write_message_to}(&initial_message, &mut buffer)?; - let initial_message_stream = futures_util::stream::iter(vec![Ok(buffer.into())]); - let adapter = #{message_stream_adaptor:W}; - initial_message_stream.chain(adapter) + + // Wrap the marshaller to handle both initial and regular messages + let wrapped_marshaller = #{EventOrInitialMarshaller}::new(marshaller); + + // Create stream with initial message + let initial_stream = #{futures_util}::stream::once(async move { + #{Ok}(#{EventOrInitial}::InitialMessage(initial_message)) + }); + + // Extract inner stream and map events + let event_stream = ${params.outerName}.${params.memberName}.into_inner() + .map(|result| result.map(#{EventOrInitial}::Event)); + + // Chain streams and convert to EventStreamSender + let combined = initial_stream.chain(event_stream); + #{EventStreamSender}::from(combined) + .into_body_stream(wrapped_marshaller, error_marshaller, signer) } """, *preludeScope, "futures_util" to CargoDependency.FuturesUtil.toType(), "initial_message" to params.eventStreamMarshallerGenerator.renderInitialRequestGenerator(params.payloadContentType), - "message_stream_adaptor" to messageStreamAdaptor(params.outerName, params.memberName), "parser" to parser, - "write_message_to" to - RuntimeType.smithyEventStream(codegenContext.runtimeConfig) - .resolve("frame::write_message_to"), + "EventOrInitial" to eventOrInitial, + "EventOrInitialMarshaller" to eventOrInitialMarshaller, + "EventStreamSender" to smithyHttp.resolve("event_stream::EventStreamSender"), ) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt index 8450fbb30fe..1dbe4f3950a 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt @@ -178,6 +178,134 @@ class ClientEventStreamMarshallerGeneratorTest { } } } + + @ParameterizedTest + @ArgumentsSource(RpcEventStreamTestCasesProvider::class) + fun signedInitialMessageTest(rpcEventStreamTestCase: RpcEventStreamTestCase) { + val testCase = rpcEventStreamTestCase.inner + // Filter out tests that do not have initial message + if (rpcEventStreamTestCase.nonEventStreamMember != NonEventStreamMemberInOutput.NONE) { + clientIntegrationTest( + testCase.model, + IntegrationTestParams(service = "test#TestService"), + ) { codegenContext, rustCrate -> + rustCrate.testModule { + tokioTest("initial_message_and_event_are_signed") { + rustTemplate( + """ + use crate::types::*; + use aws_smithy_eventstream::frame::{DeferredSignerSender, SignMessage, SignMessageError}; + use aws_smithy_http::event_stream::EventStreamSender; + use aws_smithy_runtime_api::box_error::BoxError; + use aws_smithy_runtime_api::client::interceptors::Intercept; + use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut; + use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; + use aws_smithy_types::config_bag::ConfigBag; + use aws_smithy_types::event_stream::Message; + + const FAKE_SIGNATURE: &str = "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890"; + + ##[derive(Debug)] + struct TestSigner; + + impl SignMessage for TestSigner { + fn sign(&mut self, message: Message) -> Result { + let mut new_payload = message.payload().to_vec(); + new_payload.extend_from_slice(FAKE_SIGNATURE.as_bytes()); + + let mut signed_msg = Message::new(bytes::Bytes::from(new_payload)); + for header in message.headers() { + signed_msg = signed_msg.add_header(header.clone()); + } + Ok(signed_msg) + } + + fn sign_empty(&mut self) -> Option> { + None + } + } + + ##[derive(Debug)] + struct TestSignerInterceptor; + + impl Intercept for TestSignerInterceptor { + fn name(&self) -> &'static str { + "TestSignerInterceptor" + } + + fn modify_before_signing( + &self, + _context: &mut BeforeTransmitInterceptorContextMut<'_>, + _runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + if let Some(signer_sender) = cfg.load::() { + signer_sender + .send(Box::new(TestSigner)) + .expect("failed to send test signer"); + } + Ok(()) + } + } + + let (http_client, rx) = #{capture_request}(None); + let conf = crate::Config::builder() + .endpoint_url("http://localhost:1234") + .http_client(http_client.clone()) + .interceptor(TestSignerInterceptor) + .behavior_version_latest() + .build(); + let client = crate::Client::from_conf(conf); + + let event = TestStream::MessageWithString( + MessageWithString::builder().data("hello, world!").build(), + ); + let stream = ::futures_util::stream::iter(vec![Ok(event)]); + let _ = client + .test_stream_op() + .test_string("this is test") + .value(EventStreamSender::from(stream)) + .send() + .await + .unwrap(); + + let mut request = rx.expect_request(); + + let mut body = ::aws_smithy_types::body::SdkBody::taken(); + std::mem::swap(&mut body, request.body_mut()); + + let unmarshaller = crate::event_stream_serde::TestStreamUnmarshaller::new(); + let mut event_receiver = crate::event_receiver::EventReceiver::new( + ::aws_smithy_http::event_stream::Receiver::new(unmarshaller, body), + ); + + // Check initial message has signature + let initial_msg = event_receiver + .try_recv_initial_request() + .await + .unwrap() + .expect("should receive initial-request"); + assert!(initial_msg.payload().ends_with(FAKE_SIGNATURE.as_bytes())); + + // Check event payload has signature + if let Some(event) = event_receiver.recv().await.unwrap() { + match event { + TestStream::MessageWithString(message_with_string) => { + assert!(message_with_string.data().unwrap().ends_with(FAKE_SIGNATURE)); + } + otherwise => panic!("matched on unexpected variant {otherwise:?}"), + } + } else { + panic!("should receive at least one frame"); + } + """, + "capture_request" to RuntimeType.captureRequest(codegenContext.runtimeConfig), + ) + } + } + } + } + } } class TestCasesProvider : ArgumentsProvider { diff --git a/rust-runtime/aws-smithy-http/src/event_stream.rs b/rust-runtime/aws-smithy-http/src/event_stream.rs index b74b85f474e..612e137bcbd 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream.rs @@ -14,7 +14,10 @@ mod sender; pub type BoxError = Box; #[doc(inline)] -pub use sender::{EventStreamSender, MessageStreamAdapter, MessageStreamError}; +pub use sender::{ + EventOrInitial, EventOrInitialMarshaller, EventStreamSender, MessageStreamAdapter, + MessageStreamError, +}; #[doc(inline)] pub use receiver::{InitialMessageType, Receiver, ReceiverError}; diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index 7749abedada..81133754bbf 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -7,6 +7,7 @@ use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessa use aws_smithy_eventstream::message_size_hint::MessageSizeHint; use aws_smithy_runtime_api::client::result::SdkError; use aws_smithy_types::error::ErrorMetadata; +use aws_smithy_types::event_stream::Message; use bytes::Bytes; use futures_core::Stream; use std::error::Error as StdError; @@ -17,6 +18,17 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tracing::trace; +/// Wrapper for event stream items that may include an initial-request message. +/// This is used internally to allow initial messages to flow through the signing pipeline. +#[doc(hidden)] +#[derive(Debug)] +pub enum EventOrInitial { + /// A regular event that needs marshalling and signing + Event(T), + /// An initial-request message that's already marshalled, just needs signing + InitialMessage(Message), +} + /// Input type for Event Streams. pub struct EventStreamSender { input_stream: Pin> + Send + Sync>>, @@ -47,6 +59,12 @@ impl EventStreamSender { ) -> MessageStreamAdapter { MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream) } + + /// Extract the inner stream. This is used internally for composing streams. + #[doc(hidden)] + pub fn into_inner(self) -> Pin> + Send + Sync>> { + self.input_stream + } } impl From for EventStreamSender @@ -200,6 +218,38 @@ impl Stream for MessageStreamAdapter { + inner: M, +} + +impl EventOrInitialMarshaller { + #[doc(hidden)] + pub fn new(inner: M) -> Self { + Self { inner } + } +} + +impl MarshallMessage for EventOrInitialMarshaller +where + M: MarshallMessage, +{ + type Input = EventOrInitial; + + fn marshall( + &self, + input: Self::Input, + ) -> Result { + match input { + EventOrInitial::Event(event) => self.inner.marshall(event), + EventOrInitial::InitialMessage(message) => Ok(message), + } + } +} + #[cfg(test)] mod tests { use super::MarshallMessage;