Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .changelog/1764888150.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
applies_to:
- aws-sdk-rust
- client
authors:
- rcoh
- ysaito1001
references:
- smithy-rs#4429
breaking: false
new_feature: false
bug_fix: true
---
Fix bug where initial-request messages in event stream operations are not signed.
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,43 @@ private fun eventStreamWithInitialRequest(
return null
}

val smithyHttp = RuntimeType.smithyHttp(codegenContext.runtimeConfig)
val eventOrInitial = smithyHttp.resolve("event_stream::EventOrInitial")
val eventOrInitialMarshaller = smithyHttp.resolve("event_stream::EventOrInitialMarshaller")

return writable {
rustTemplate(
"""
{
use #{futures_util}::StreamExt;
let body = #{parser}(&input)?;
let initial_message = #{initial_message}(body);
let mut buffer = #{Vec}::new();
#{write_message_to}(&initial_message, &mut buffer)?;
let initial_message_stream = futures_util::stream::iter(vec![Ok(buffer.into())]);
let adapter = #{message_stream_adaptor:W};
initial_message_stream.chain(adapter)

// Wrap the marshaller to handle both initial and regular messages
let wrapped_marshaller = #{EventOrInitialMarshaller}::new(marshaller);

// Create stream with initial message
let initial_stream = #{futures_util}::stream::once(async move {
#{Ok}(#{EventOrInitial}::InitialMessage(initial_message))
});

// Extract inner stream and map events
let event_stream = ${params.outerName}.${params.memberName}.into_inner()
.map(|result| result.map(#{EventOrInitial}::Event));

// Chain streams and convert to EventStreamSender
let combined = initial_stream.chain(event_stream);
#{EventStreamSender}::from(combined)
.into_body_stream(wrapped_marshaller, error_marshaller, signer)
}
""",
*preludeScope,
"futures_util" to CargoDependency.FuturesUtil.toType(),
"initial_message" to params.eventStreamMarshallerGenerator.renderInitialRequestGenerator(params.payloadContentType),
"message_stream_adaptor" to messageStreamAdaptor(params.outerName, params.memberName),
"parser" to parser,
"write_message_to" to
RuntimeType.smithyEventStream(codegenContext.runtimeConfig)
.resolve("frame::write_message_to"),
"EventOrInitial" to eventOrInitial,
"EventOrInitialMarshaller" to eventOrInitialMarshaller,
"EventStreamSender" to smithyHttp.resolve("event_stream::EventStreamSender"),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,134 @@ class ClientEventStreamMarshallerGeneratorTest {
}
}
}

@ParameterizedTest
@ArgumentsSource(RpcEventStreamTestCasesProvider::class)
fun signedInitialMessageTest(rpcEventStreamTestCase: RpcEventStreamTestCase) {
val testCase = rpcEventStreamTestCase.inner
// Filter out tests that do not have initial message
if (rpcEventStreamTestCase.nonEventStreamMember != NonEventStreamMemberInOutput.NONE) {
clientIntegrationTest(
testCase.model,
IntegrationTestParams(service = "test#TestService"),
) { codegenContext, rustCrate ->
rustCrate.testModule {
tokioTest("initial_message_and_event_are_signed") {
rustTemplate(
"""
use crate::types::*;
use aws_smithy_eventstream::frame::{DeferredSignerSender, SignMessage, SignMessageError};
use aws_smithy_http::event_stream::EventStreamSender;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::Intercept;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::event_stream::Message;

const FAKE_SIGNATURE: &str = "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890";

##[derive(Debug)]
struct TestSigner;

impl SignMessage for TestSigner {
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
let mut new_payload = message.payload().to_vec();
new_payload.extend_from_slice(FAKE_SIGNATURE.as_bytes());

let mut signed_msg = Message::new(bytes::Bytes::from(new_payload));
for header in message.headers() {
signed_msg = signed_msg.add_header(header.clone());
}
Ok(signed_msg)
}

fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
None
}
}

##[derive(Debug)]
struct TestSignerInterceptor;

impl Intercept for TestSignerInterceptor {
fn name(&self) -> &'static str {
"TestSignerInterceptor"
}

fn modify_before_signing(
&self,
_context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if let Some(signer_sender) = cfg.load::<DeferredSignerSender>() {
signer_sender
.send(Box::new(TestSigner))
.expect("failed to send test signer");
}
Ok(())
}
}

let (http_client, rx) = #{capture_request}(None);
let conf = crate::Config::builder()
.endpoint_url("http://localhost:1234")
.http_client(http_client.clone())
.interceptor(TestSignerInterceptor)
.behavior_version_latest()
.build();
let client = crate::Client::from_conf(conf);

let event = TestStream::MessageWithString(
MessageWithString::builder().data("hello, world!").build(),
);
let stream = ::futures_util::stream::iter(vec![Ok(event)]);
let _ = client
.test_stream_op()
.test_string("this is test")
.value(EventStreamSender::from(stream))
.send()
.await
.unwrap();

let mut request = rx.expect_request();

let mut body = ::aws_smithy_types::body::SdkBody::taken();
std::mem::swap(&mut body, request.body_mut());

let unmarshaller = crate::event_stream_serde::TestStreamUnmarshaller::new();
let mut event_receiver = crate::event_receiver::EventReceiver::new(
::aws_smithy_http::event_stream::Receiver::new(unmarshaller, body),
);

// Check initial message has signature
let initial_msg = event_receiver
.try_recv_initial_request()
.await
.unwrap()
.expect("should receive initial-request");
assert!(initial_msg.payload().ends_with(FAKE_SIGNATURE.as_bytes()));

// Check event payload has signature
if let Some(event) = event_receiver.recv().await.unwrap() {
match event {
TestStream::MessageWithString(message_with_string) => {
assert!(message_with_string.data().unwrap().ends_with(FAKE_SIGNATURE));
}
otherwise => panic!("matched on unexpected variant {otherwise:?}"),
}
} else {
panic!("should receive at least one frame");
}
""",
"capture_request" to RuntimeType.captureRequest(codegenContext.runtimeConfig),
)
}
}
}
}
}
}

class TestCasesProvider : ArgumentsProvider {
Expand Down
5 changes: 4 additions & 1 deletion rust-runtime/aws-smithy-http/src/event_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ mod sender;
pub type BoxError = Box<dyn StdError + Send + Sync + 'static>;

#[doc(inline)]
pub use sender::{EventStreamSender, MessageStreamAdapter, MessageStreamError};
pub use sender::{
EventOrInitial, EventOrInitialMarshaller, EventStreamSender, MessageStreamAdapter,
MessageStreamError,
};

#[doc(inline)]
pub use receiver::{InitialMessageType, Receiver, ReceiverError};
50 changes: 50 additions & 0 deletions rust-runtime/aws-smithy-http/src/event_stream/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessa
use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
use aws_smithy_runtime_api::client::result::SdkError;
use aws_smithy_types::error::ErrorMetadata;
use aws_smithy_types::event_stream::Message;
use bytes::Bytes;
use futures_core::Stream;
use std::error::Error as StdError;
Expand All @@ -17,6 +18,17 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tracing::trace;

/// Wrapper for event stream items that may include an initial-request message.
/// This is used internally to allow initial messages to flow through the signing pipeline.
#[doc(hidden)]
#[derive(Debug)]
pub enum EventOrInitial<T> {
/// A regular event that needs marshalling and signing
Event(T),
/// An initial-request message that's already marshalled, just needs signing
InitialMessage(Message),
}

/// Input type for Event Streams.
pub struct EventStreamSender<T, E> {
input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
Expand Down Expand Up @@ -47,6 +59,12 @@ impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
) -> MessageStreamAdapter<T, E> {
MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
}

/// Extract the inner stream. This is used internally for composing streams.
#[doc(hidden)]
pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>> {
self.input_stream
}
}

impl<T, E, S> From<S> for EventStreamSender<T, E>
Expand Down Expand Up @@ -200,6 +218,38 @@ impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T,
}
}

/// Marshaller wrapper that handles both regular events and initial messages.
/// This is used internally to support initial-request messages in event streams.
#[doc(hidden)]
#[derive(Debug)]
pub struct EventOrInitialMarshaller<M> {
inner: M,
}

impl<M> EventOrInitialMarshaller<M> {
#[doc(hidden)]
pub fn new(inner: M) -> Self {
Self { inner }
}
}

impl<M, T> MarshallMessage for EventOrInitialMarshaller<M>
where
M: MarshallMessage<Input = T>,
{
type Input = EventOrInitial<T>;

fn marshall(
&self,
input: Self::Input,
) -> Result<Message, aws_smithy_eventstream::error::Error> {
match input {
EventOrInitial::Event(event) => self.inner.marshall(event),
EventOrInitial::InitialMessage(message) => Ok(message),
}
}
}

#[cfg(test)]
mod tests {
use super::MarshallMessage;
Expand Down
Loading