diff --git a/src/bindings/rust/src/lib.rs b/src/bindings/rust/src/lib.rs index adce69543..0e0bd4601 100644 --- a/src/bindings/rust/src/lib.rs +++ b/src/bindings/rust/src/lib.rs @@ -67,7 +67,8 @@ use bindings::{ nixl_capi_query_resp_list_get_params, nixl_capi_prep_xfer_dlist, nixl_capi_release_xfer_dlist_handle, nixl_capi_make_xfer_req, nixl_capi_get_local_partial_md, nixl_capi_send_local_partial_md, nixl_capi_query_xfer_backend, nixl_capi_opt_args_set_ip_addr, - nixl_capi_opt_args_set_port, nixl_capi_get_xfer_telemetry + nixl_capi_opt_args_set_port, nixl_capi_get_xfer_telemetry, + nixl_capi_create_params, nixl_capi_params_add }; // Re-export status codes diff --git a/src/bindings/rust/src/utils/params.rs b/src/bindings/rust/src/utils/params.rs index 9d4d13293..23f4da6ae 100644 --- a/src/bindings/rust/src/utils/params.rs +++ b/src/bindings/rust/src/utils/params.rs @@ -52,9 +52,9 @@ impl<'a> Iterator for ParamIterator<'a> { }; match status { - 0 if !has_next => None, + 0 if key_ptr.is_null() => None, 0 => { - // SAFETY: If status is 0, both pointers are valid null-terminated strings + // SAFETY: If status is 0 and key_ptr is not null, both pointers are valid null-terminated strings let result = unsafe { let key = CStr::from_ptr(key_ptr).to_str().unwrap(); let value = CStr::from_ptr(value_ptr).to_str().unwrap(); @@ -82,6 +82,86 @@ impl Params { Self { inner } } + /// Creates a new empty Params object + pub(crate) fn create() -> Result { + let mut params = ptr::null_mut(); + + let status = unsafe { nixl_capi_create_params(&mut params) }; + + match status { + 0 => { + let inner = unsafe { NonNull::new_unchecked(params) }; + Ok(Self { inner }) + } + -1 => Err(NixlError::InvalidParam), + _ => Err(NixlError::BackendError), + } + } + + /// Creates a new Params object from an iteratable + /// + /// # Example + /// ```ignore + /// use std::collections::HashMap; + /// + /// let map = HashMap::from([ + /// ("access_key", "*********"), + /// ("secret_key", "*********"), + /// ("bucket", "my-bucket"), + /// ]); + /// + /// let params = Params::try_from_iter(map.iter().map(|(k, v)| (*k, *v)))?; + /// ``` + pub fn try_from_iter(iter: I) -> Result + where + I: IntoIterator, + K: AsRef, + V: AsRef, + { + let mut params = Self::create()?; + for (key, value) in iter { + params.set(key.as_ref(), value.as_ref())?; + } + Ok(params) + } + + /// Creates a new Params object by copying from another Params + /// + /// # Example + /// ```ignore + /// let original_params = agent.get_plugin_params("OBJ")?.1; + /// let mut modified_params = original_params.try_clone()?; + /// modified_params.set("bucket", "my-custom-bucket")?; + /// ``` + pub fn try_clone(&self) -> Result { + let mut params = Self::create()?; + + if let Ok(iter) = self.iter() { + for pair in iter { + let pair = pair?; + params.set(pair.key, pair.value)?; + } + } + + Ok(params) + } + + /// Sets a key-value pair in the parameters (overwrites if exists) + pub fn set(&mut self, key: &str, value: &str) -> Result<(), NixlError> { + let c_key = CString::new(key)?; + let c_value = CString::new(value)?; + + let status = unsafe { + nixl_capi_params_add(self.inner.as_ptr(), c_key.as_ptr(), c_value.as_ptr()) + }; + + match status { + 0 => Ok(()), + -1 => Err(NixlError::InvalidParam), + _ => Err(NixlError::BackendError), + } + } + /// Returns true if the parameters are empty pub fn is_empty(&self) -> Result { let mut is_empty = false; diff --git a/src/bindings/rust/stubs.cpp b/src/bindings/rust/stubs.cpp index ab20b3a3e..0be33763b 100644 --- a/src/bindings/rust/stubs.cpp +++ b/src/bindings/rust/stubs.cpp @@ -249,6 +249,16 @@ nixl_capi_opt_args_set_port(nixl_capi_opt_args_t args, uint16_t port) { return nixl_capi_stub_abort(); } +nixl_capi_status_t +nixl_capi_create_params(nixl_capi_params_t *params) { + return nixl_capi_stub_abort(); +} + +nixl_capi_status_t +nixl_capi_params_add(nixl_capi_params_t params, const char *key, const char *value) { + return nixl_capi_stub_abort(); +} + nixl_capi_status_t nixl_capi_params_is_empty(nixl_capi_params_t params, bool* is_empty) { diff --git a/src/bindings/rust/tests/tests.rs b/src/bindings/rust/tests/tests.rs index c8a30bedc..d7128c1ce 100644 --- a/src/bindings/rust/tests/tests.rs +++ b/src/bindings/rust/tests/tests.rs @@ -220,6 +220,64 @@ fn test_params_iteration() { } } +#[test] +fn test_params_try_from_iter() { + use std::collections::HashMap; + + let map = HashMap::from([ + ("key1", "value1"), + ("key2", "value2"), + ("key3", "value3"), + ]); + + let params = Params::try_from_iter(map.iter().map(|(k, v)| (*k, *v))) + .expect("Failed to create params from iterator"); + + assert!(!params.is_empty().unwrap(), "Params should not be empty"); + + let mut found_keys = HashMap::new(); + for param in params.iter().unwrap() { + let param = param.unwrap(); + found_keys.insert(param.key.to_string(), param.value.to_string()); + } + + assert_eq!(found_keys.len(), 3, "Should have 3 key-value pairs"); + assert_eq!(found_keys.get("key1"), Some(&"value1".to_string())); + assert_eq!(found_keys.get("key2"), Some(&"value2".to_string())); + assert_eq!(found_keys.get("key3"), Some(&"value3".to_string())); +} + +#[test] +fn test_params_try_clone() { + let agent = Agent::new("test_agent").expect("Failed to create agent"); + let (_mems, original_params) = agent + .get_plugin_params("UCX") + .expect("Failed to get plugin params"); + + let copied_params = original_params.try_clone() + .expect("Failed to copy params"); + + assert_eq!( + original_params.is_empty().unwrap(), + copied_params.is_empty().unwrap(), + "Copied params should have same empty state" + ); + + let mut original_map = std::collections::HashMap::new(); + for param in original_params.iter().unwrap() { + let param = param.unwrap(); + original_map.insert(param.key.to_string(), param.value.to_string()); + } + + let mut copied_map = std::collections::HashMap::new(); + for param in copied_params.iter().unwrap() { + let param = param.unwrap(); + copied_map.insert(param.key.to_string(), param.value.to_string()); + } + + assert_eq!(original_map, copied_map, "Copied params should match original"); +} + // #[test] // fn test_get_backend_params() { // let agent = Agent::new("test_agent").unwrap(); diff --git a/src/bindings/rust/wrapper.cpp b/src/bindings/rust/wrapper.cpp index fa880b7b6..400e8ff48 100644 --- a/src/bindings/rust/wrapper.cpp +++ b/src/bindings/rust/wrapper.cpp @@ -730,6 +730,37 @@ nixl_capi_opt_args_set_port(nixl_capi_opt_args_t args, uint16_t port) { return NIXL_CAPI_SUCCESS; } +nixl_capi_status_t +nixl_capi_create_params(nixl_capi_params_t *params) { + if (!params) { + return NIXL_CAPI_ERROR_INVALID_PARAM; + } + + try { + auto param_list = new nixl_capi_params_s; + *params = param_list; + return NIXL_CAPI_SUCCESS; + } + catch (...) { + return NIXL_CAPI_ERROR_BACKEND; + } +} + +nixl_capi_status_t +nixl_capi_params_add(nixl_capi_params_t params, const char *key, const char *value) { + if (!params || !key || !value) { + return NIXL_CAPI_ERROR_INVALID_PARAM; + } + + try { + params->params[key] = value; + return NIXL_CAPI_SUCCESS; + } + catch (...) { + return NIXL_CAPI_ERROR_BACKEND; + } +} + nixl_capi_status_t nixl_capi_params_is_empty(nixl_capi_params_t params, bool* is_empty) { diff --git a/src/bindings/rust/wrapper.h b/src/bindings/rust/wrapper.h index c6c848d24..bc40928b7 100644 --- a/src/bindings/rust/wrapper.h +++ b/src/bindings/rust/wrapper.h @@ -196,6 +196,10 @@ nixl_capi_status_t nixl_capi_opt_args_set_port(nixl_capi_opt_args_t args, uint16_t port); // Parameter access functions +nixl_capi_status_t +nixl_capi_create_params(nixl_capi_params_t *params); +nixl_capi_status_t +nixl_capi_params_add(nixl_capi_params_t params, const char *key, const char *value); nixl_capi_status_t nixl_capi_params_is_empty(nixl_capi_params_t params, bool* is_empty); nixl_capi_status_t nixl_capi_params_create_iterator(nixl_capi_params_t params, nixl_capi_param_iter_t* iter); nixl_capi_status_t nixl_capi_params_iterator_next(