diff --git a/src/executors/web_socket_executors.rs b/src/executors/web_socket_executors.rs index 3d71bc2f1..a3b28d7e1 100644 --- a/src/executors/web_socket_executors.rs +++ b/src/executors/web_socket_executors.rs @@ -1,15 +1,39 @@ +use std::borrow::Cow; + use actix::prelude::*; use actix::AsyncContext; use actix_web_actors::ws; use pyo3::prelude::*; +use pyo3::types::PyString; use pyo3_asyncio::TaskLocals; use crate::types::function_info::FunctionInfo; use crate::websockets::WebSocketConnector; +pub enum WsMsgIn<'a> { + String(String), + Bytes(Cow<'a, [u8]>), +} + +impl <'a>Default for WsMsgIn<'a> { + fn default() -> Self { + WsMsgIn::String(Default::default()) + } +} + + +impl <'a>IntoPy for WsMsgIn<'a> { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + WsMsgIn::String(val) => val.into_py(py), + WsMsgIn::Bytes(val) => val.into_py(py), + } + } +} + fn get_function_output<'a>( function: &'a FunctionInfo, - fn_msg: Option, + fn_msg: Option, py: Python<'a>, ws: &WebSocketConnector, ) -> Result<&'a PyAny, PyErr> { @@ -57,9 +81,10 @@ fn get_function_output<'a>( } } + pub fn execute_ws_function( function: &FunctionInfo, - text: Option, + text: Option, task_locals: &TaskLocals, ctx: &mut ws::WebsocketContext, ws: &WebSocketConnector, diff --git a/src/websockets/mod.rs b/src/websockets/mod.rs index 51f51a47c..aa45794c9 100644 --- a/src/websockets/mod.rs +++ b/src/websockets/mod.rs @@ -1,6 +1,6 @@ pub mod registry; -use crate::executors::web_socket_executors::execute_ws_function; +use crate::executors::web_socket_executors::{execute_ws_function, WsMsgIn}; use crate::types::function_info::FunctionInfo; use crate::types::multimap::QueryParams; use registry::{Close, SendMessageToAll, SendText}; @@ -17,6 +17,7 @@ use pyo3_asyncio::TaskLocals; use uuid::Uuid; use registry::{Register, WebSocketRegistry}; +use std::borrow::Cow; use std::collections::HashMap; /// Define HTTP actor @@ -90,13 +91,23 @@ impl StreamHandler> for WebSocketConnecto let function = self.router.get("message").unwrap(); execute_ws_function( function, - Some(text.to_string()), + Some(WsMsgIn::String(text.to_string())), + &self.task_locals, + ctx, + self, + ); + } + Ok(ws::Message::Binary(bin)) => { + debug!("Bin data received"); + let function = self.router.get("message").unwrap(); + execute_ws_function( + function, + Some(WsMsgIn::Bytes(Cow::from(bin.to_vec()))), &self.task_locals, ctx, self, ); } - Ok(ws::Message::Binary(bin)) => ctx.binary(bin), Ok(ws::Message::Close(_close_reason)) => { debug!("Socket was closed"); let function = self.router.get("close").unwrap();