diff --git a/Cargo.lock b/Cargo.lock index 4703dd345..a4949b4fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -803,6 +803,7 @@ dependencies = [ "metrics", "opentelemetry", "parking_lot", + "probabilistic-set", "quinn", "quinn-plaintext", "quinn-proto", @@ -1016,6 +1017,7 @@ dependencies = [ "once_cell", "opentelemetry", "parking_lot", + "probabilistic-set", "rand 0.9.2", "rangemap", "rcgen", @@ -3368,6 +3370,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "probabilistic-set" +version = "0.1.0" +dependencies = [ + "rand 0.9.2", + "serde", + "speedy", + "uuid", +] + [[package]] name = "proc-macro-error" version = "1.0.4" diff --git a/crates/corro-agent/Cargo.toml b/crates/corro-agent/Cargo.toml index 6ebd4b829..012cbcca0 100644 --- a/crates/corro-agent/Cargo.toml +++ b/crates/corro-agent/Cargo.toml @@ -28,6 +28,7 @@ itertools = { workspace = true } metrics = { workspace = true } opentelemetry = { workspace = true } parking_lot = { workspace = true } +probabilistic-set = { path = "../probabilistic-set" } quinn = { workspace = true } quinn-proto = { workspace = true } quinn-plaintext = { workspace = true } diff --git a/crates/corro-agent/src/agent/handlers.rs b/crates/corro-agent/src/agent/handlers.rs index 439eb704c..9180b8be5 100644 --- a/crates/corro-agent/src/agent/handlers.rs +++ b/crates/corro-agent/src/agent/handlers.rs @@ -24,7 +24,7 @@ use corro_types::{ actor::{Actor, ActorId}, agent::{Agent, Bookie, SplitPool}, base::CrsqlSeq, - broadcast::{BroadcastInput, BroadcastV1, ChangeSource, ChangeV1, FocaInput}, + broadcast::{BroadcastInput, BroadcastV1, BroadcastV2, ChangeSource, ChangeV1, FocaInput}, channel::CorroReceiver, members::MemberAddedResult, sync::generate_sync, @@ -690,7 +690,7 @@ pub async fn handle_changes( continue; } - let src_str: &'static str = src.into(); + let src_str: &'static str = (&src).into(); let recv_lag = change.ts().and_then(|ts| { let mut our_ts = Timestamp::from(agent.clock().new_timestamp()); if ts > our_ts { @@ -705,7 +705,10 @@ pub async fn handle_changes( Some((our_ts.0 - ts.0).to_duration()) }); - if matches!(src, ChangeSource::Broadcast) { + if matches!( + src, + ChangeSource::Broadcast | ChangeSource::BroadcastV2(_, _) + ) { counter!("corro.broadcast.recv.count", "kind" => "change").increment(1); } @@ -766,20 +769,32 @@ pub async fn handle_changes( } } - assert_sometimes!( - matches!(src, ChangeSource::Sync), - "Corrosion receives changes through sync" - ); - if matches!(src, ChangeSource::Broadcast) && !change.is_empty() { - assert_sometimes!(true, "Corrosion rebroadcasts changes"); - if let Err(_e) = - agent - .tx_bcast() - .try_send(BroadcastInput::Rebroadcast(BroadcastV1::Change( + if !change.is_empty() { + let bcast = match src.clone() { + ChangeSource::Broadcast => { + assert_sometimes!(true, "Corrosion rebroadcasts changes"); + Some(BroadcastInput::Rebroadcast(BroadcastV1::Change( change.clone(), ))) - { - debug!("broadcasts are full or done!"); + } + ChangeSource::BroadcastV2(set, num_broadcasts) => { + assert_sometimes!(true, "Corrosion rebroadcasts changes"); + Some(BroadcastInput::RebroadcastV2(BroadcastV2 { + change: BroadcastV1::Change(change.clone()), + set, + num_broadcasts, + })) + } + ChangeSource::Sync => { + assert_sometimes!(true, "Corrosion receives changes through sync"); + None + } + }; + + if let Some(bcast) = bcast { + if let Err(_e) = agent.tx_bcast().try_send(bcast) { + debug!("broadcasts are full or done!"); + } } } diff --git a/crates/corro-agent/src/agent/uni.rs b/crates/corro-agent/src/agent/uni.rs index 369925a4f..1b8cb7142 100644 --- a/crates/corro-agent/src/agent/uni.rs +++ b/crates/corro-agent/src/agent/uni.rs @@ -1,6 +1,6 @@ use corro_types::{ actor::ClusterId, - broadcast::{BroadcastV1, ChangeSource, ChangeV1, UniPayload, UniPayloadV1}, + broadcast::{BroadcastV1, BroadcastV2, ChangeSource, ChangeV1, UniPayload, UniPayloadV1}, channel::CorroSender, }; use metrics::counter; @@ -66,16 +66,38 @@ pub fn spawn_unipayload_handler( match payload { UniPayload::V1 { - data: - UniPayloadV1::Broadcast(BroadcastV1::Change( - change, - )), + data: payload_data, cluster_id: payload_cluster_id, } => { if cluster_id != payload_cluster_id { continue; } - changes.push((change, ChangeSource::Broadcast)); + + match payload_data { + UniPayloadV1::Broadcast( + BroadcastV1::Change(change), + ) => { + changes.push(( + change, + ChangeSource::Broadcast, + )); + } + UniPayloadV1::BroadcastV2( + BroadcastV2 { + change: BroadcastV1::Change(change), + set, + num_broadcasts, + }, + ) => { + changes.push(( + change, + ChangeSource::BroadcastV2( + set, + num_broadcasts, + ), + )); + } + } } } } diff --git a/crates/corro-agent/src/broadcast/mod.rs b/crates/corro-agent/src/broadcast/mod.rs index 9d88b7601..e690b70ae 100644 --- a/crates/corro-agent/src/broadcast/mod.rs +++ b/crates/corro-agent/src/broadcast/mod.rs @@ -20,6 +20,7 @@ use futures::{ use governor::{Quota, RateLimiter}; use metrics::{counter, gauge}; use parking_lot::RwLock; +use probabilistic_set::ProbSet; use rand::{rngs::StdRng, seq::IteratorRandom, SeedableRng}; use rusqlite::params; use spawn::spawn_counted; @@ -38,7 +39,10 @@ use tripwire::Tripwire; use corro_types::{ actor::{Actor, ActorId}, agent::Agent, - broadcast::{BroadcastInput, DispatchRuntime, FocaCmd, FocaInput, UniPayload, UniPayloadV1}, + broadcast::{ + BroadcastInput, BroadcastV1, BroadcastV2, ChangeV1, DispatchRuntime, FocaCmd, FocaInput, + UniPayload, UniPayloadV1, + }, channel::{bounded, CorroReceiver, CorroSender}, }; @@ -426,8 +430,6 @@ async fn handle_broadcasts( let mut bcast_buf = BytesMut::new(); let mut local_bcast_buf = BytesMut::new(); - let mut single_bcast_buf = BytesMut::new(); - let mut metrics_interval = interval(Duration::from_secs(10)); let mut rng = StdRng::from_os_rng(); @@ -435,12 +437,16 @@ async fn handle_broadcasts( let mut idle_pendings = FuturesUnordered:: + Send + 'static>>>::new(); + let mut broadcast_v2_pendings = + FuturesUnordered:: + Send + 'static>>>::new(); + let mut bcast_interval = interval(opts.interval); enum Branch { Broadcast(BroadcastInput), BroadcastDeadline, WokePendingBroadcast(PendingBroadcast), + WokePendingBroadcastV2(BroadcastV2), Tripped, Metrics, } @@ -452,10 +458,13 @@ async fn handle_broadcasts( let max_queue_len = agent.config().perf.processing_queue_len; const MAX_INFLIGHT_BROADCAST: usize = 500; let mut to_broadcast = VecDeque::new(); + let mut to_broadcast_v2 = VecDeque::new(); let mut to_local_broadcast = VecDeque::new(); let mut log_count = 0; + let cfg_max_transmissions = agent.config().perf.max_broadcast_transmissions; let mut limited_log_count = 0; + let mut prev_set_size = 0; let bytes_per_sec: BroadcastRateLimiter = RateLimiter::direct(Quota::per_second(unsafe { NonZeroU32::new_unchecked(10 * 1024 * 1024) @@ -491,6 +500,14 @@ async fn handle_broadcasts( continue; } }, + maybe_woke = broadcast_v2_pendings.next(), if !broadcast_v2_pendings.is_terminated() => match maybe_woke { + Some(woke) => Branch::WokePendingBroadcastV2(woke), + None => { + trace!("broadcast v2 pendings returned None"); + // I guess? + continue; + } + }, _ = &mut tripwire, if !tripped => { tripped = true; @@ -518,62 +535,59 @@ async fn handle_broadcasts( } } Branch::Broadcast(input) => { - trace!("handling Branch::Broadcast"); - - let (bcast, is_local) = match input { - BroadcastInput::Rebroadcast(bcast) => (bcast, false), - BroadcastInput::AddBroadcast(bcast) => (bcast, true), - }; - trace!("adding broadcast: {bcast:?}, local? {is_local}"); - - if let Err(e) = (UniPayload::V1 { - data: UniPayloadV1::Broadcast(bcast.clone()), - cluster_id: agent.cluster_id(), - }) - .write_to_stream((&mut ser_buf).writer()) - { - error!("could not encode UniPayload::V1 Broadcast: {e}"); - ser_buf.clear(); - continue; - } - trace!("ser buf len: {}", ser_buf.len()); - - if is_local { - if let Err(e) = - bcast_codec.encode(ser_buf.split().freeze(), &mut single_bcast_buf) - { - error!("could not encode local broadcast: {e}"); - single_bcast_buf.clear(); - continue; + trace!("handling Branch::Broadcast {input:?}"); + match input { + BroadcastInput::AddBroadcast(bcast) => { + let BroadcastV1::Change(change) = bcast; + to_local_broadcast.push_front(change); } - - let payload = single_bcast_buf.split().freeze(); - - local_bcast_buf.extend_from_slice(&payload); - - to_local_broadcast.push_front(payload); - - if local_bcast_buf.len() >= broadcast_cutoff { - to_broadcast.push_front(PendingBroadcast::new_local( - local_bcast_buf.split().freeze(), - )); - } - } else { - if let Err(e) = bcast_codec.encode(ser_buf.split().freeze(), &mut bcast_buf) { - error!("could not encode broadcast: {e}"); - bcast_buf.clear(); - continue; + BroadcastInput::AddBroadcastV2(bcast) => { + // todo: rebroadcast immediates too?? + let BroadcastV2 { + change: BroadcastV1::Change(change), + .. + } = bcast; + to_local_broadcast.push_front(change); } + // for old rebroadcast, treat as normal + BroadcastInput::Rebroadcast(bcast) => { + let payload = UniPayload::V1 { + data: UniPayloadV1::Broadcast(bcast.clone()), + cluster_id: agent.cluster_id(), + }; + if let Err(e) = payload.write_to_stream((&mut ser_buf).writer()) { + error!("could not encode UniPayload::V1 Broadcast: {e}"); + ser_buf.clear(); + continue; + } + if let Err(e) = bcast_codec.encode(ser_buf.split().freeze(), &mut bcast_buf) + { + error!("could not encode broadcast: {e}"); + bcast_buf.clear(); + continue; + } - if bcast_buf.len() >= broadcast_cutoff { - to_broadcast.push_front(PendingBroadcast::new(bcast_buf.split().freeze())); + if bcast_buf.len() >= broadcast_cutoff { + to_broadcast + .push_front(PendingBroadcast::new(bcast_buf.split().freeze())); + } } - } + BroadcastInput::RebroadcastV2(bcast) => { + let max_transmissions = config.read().max_transmissions.get(); + if bcast.num_broadcasts < max_transmissions { + to_broadcast_v2.push_front(bcast); + } + } + }; } Branch::WokePendingBroadcast(pending) => { trace!("handling Branch::WokePendingBroadcast"); to_broadcast.push_front(pending); } + Branch::WokePendingBroadcastV2(pending) => { + trace!("handling Branch::WokePendingBroadcastV2"); + to_broadcast_v2.push_front(pending); + } Branch::Metrics => { trace!("handling Branch::Metrics"); gauge!("corro.broadcast.pending.count").set(idle_pendings.len() as f64); @@ -581,31 +595,91 @@ async fn handle_broadcasts( gauge!("corro.broadcast.buffer.capacity").set(bcast_buf.capacity() as f64); gauge!("corro.broadcast.serialization.buffer.capacity") .set(ser_buf.capacity() as f64); + gauge!("corro.broadcast.prob_set.size").set(prev_set_size as f64); } } let prev_rate_limited = rate_limited; + let (members_count, ring0_count) = { + let members = agent.members().read(); + let members_count = members.states.len(); + let ring0_count = members.ring0(agent.cluster_id()).count(); + (members_count, ring0_count) + }; + + let (choose_count, dynamic_count, max_transmissions) = { + let config = config.read(); + let gossip_max_txns = config.max_transmissions.get(); + let max_transmissions = cmp::min( + gossip_max_txns, + cfg_max_transmissions.unwrap_or(gossip_max_txns), + ); + let dynamic_count = (members_count - ring0_count) / (max_transmissions as usize * 10); + let count = cmp::max(config.num_indirect_probes.get(), dynamic_count); + + if prev_rate_limited { + // we've been rate limited on the last loop, try sending to less nodes... + ( + cmp::min(count, dynamic_count / 2), + dynamic_count / 2, + max_transmissions / 2, + ) + } else { + (count, dynamic_count, max_transmissions) + } + }; + // start with local broadcasts, they're higher priority - let mut ring0 = HashSet::new(); + let mut ring0: HashSet = HashSet::new(); + + let prob_set = { + // setting the size of the set to the number of times it'll broadcasted + number of members + // it will be broadcasted to with some padding + let dynamic_size = (dynamic_count + 1) * (max_transmissions + 1) as usize; + // clamp size to 500 + let size = cmp::min(500, dynamic_size); + + prev_set_size = size; + let mut set = ProbSet::new(size, 4); + let members = agent.members().read(); + let members_ring0 = members.ring0(agent.cluster_id()); + for (actor_id, addr) in members_ring0 { + set.insert(actor_id.to_u128()); + ring0.insert(addr); + } + set + }; + + debug!("sending local broadcasts to ring0 nodes: {:?}", ring0); while !to_local_broadcast.is_empty() && join_set.len() < MAX_INFLIGHT_BROADCAST { // UNWRAP: we just checked that it wasn't empty - let payload = to_local_broadcast.pop_front().unwrap(); + let change = to_local_broadcast.pop_front().unwrap(); - let members = agent.members().read(); - let mut spawn_count = 0; - let mut ring0_count = 0; - for addr in members.ring0(agent.cluster_id()) { - if join_set.len() >= MAX_INFLIGHT_BROADCAST { - debug!( - "breaking, max inflight broadcast reached: {}", - MAX_INFLIGHT_BROADCAST - ); - break; + let bcast_change = BroadcastV2 { + change: BroadcastV1::Change(change.clone()), + set: prob_set.clone(), + num_broadcasts: 1, + }; + + let uni_payload = UniPayload::V1 { + data: UniPayloadV1::BroadcastV2(bcast_change.clone()), + cluster_id: agent.cluster_id(), + }; + + let payload = { + match encode_framed(&mut bcast_codec, &mut ser_buf, &mut bcast_buf, &uni_payload) { + Ok(payload) => payload, + Err(e) => { + error!("could not encode UniPayload::V1 BroadcastV2: {e}"); + continue; + } } - ring0_count += 1; - ring0.insert(addr); + }; + let mut spawn_count = 0; + for addr in ring0.clone() { + debug!("sending broadcast to ring0 node: {addr}"); match try_transmit_broadcast( &bytes_per_sec, payload.clone(), @@ -640,39 +714,19 @@ async fn handle_broadcasts( } // couldn't send it anywhere! - if rate_limited && spawn_count == 0 && ring0_count > 0 { + if rate_limited && spawn_count == 0 && !ring0.is_empty() { // push it back in front since this got nowhere and it's still the // freshest item we have in the queue - to_local_broadcast.push_front(payload); + to_local_broadcast.push_front(change); break; + } else { + // TODO: test whether we still want to re-queue + to_broadcast_v2.push_front(bcast_change); } - counter!("corro.broadcast.spawn", "type" => "local").increment(spawn_count); } - if !rate_limited && !to_broadcast.is_empty() && join_set.len() < MAX_INFLIGHT_BROADCAST { - let (members_count, ring0_count) = { - let members = agent.members().read(); - let members_count = members.states.len(); - let ring0_count = members.ring0(agent.cluster_id()).count(); - (members_count, ring0_count) - }; - - let (choose_count, max_transmissions) = { - let config = config.read(); - let max_transmissions = config.max_transmissions.get(); - let dynamic_count = - (members_count - ring0_count) / (max_transmissions as usize * 10); - let count = cmp::max(config.num_indirect_probes.get(), dynamic_count); - - if prev_rate_limited { - // we've been rate limited on the last loop, try sending to less nodes... - (cmp::min(count, dynamic_count / 2), max_transmissions / 2) - } else { - (count, max_transmissions) - } - }; - + if !rate_limited { debug!( "choosing {} broadcasts, ring0 count: {}, MAX_INFLIGHT_BROADCAST: {}", choose_count, ring0_count, MAX_INFLIGHT_BROADCAST @@ -738,6 +792,7 @@ async fn handle_broadcasts( "broadcasts rate limited", &mut limited_log_count, ); + rate_limited = true; break; } } @@ -778,8 +833,114 @@ async fn handle_broadcasts( } } - if drop_oldest_broadcast(&mut to_broadcast, &mut to_local_broadcast, max_queue_len) - .is_some() + if !rate_limited { + let member_states = { + let members = agent.members().read(); + members.states.clone() + }; + + while !to_broadcast_v2.is_empty() && join_set.len() < MAX_INFLIGHT_BROADCAST { + let mut bcast_v2 = to_broadcast_v2.pop_front().unwrap(); + let count = cmp::min( + choose_count, + MAX_INFLIGHT_BROADCAST.saturating_sub(join_set.len()), + ); + let num_broadcasts = bcast_v2.num_broadcasts; + // let prev_set = bcast_change.set.clone(); + let broadcast_to = member_states + .iter() + .filter_map(|(member_id, state)| { + // don't broadcast to ourselves... or a member that's already in the set + if *member_id == actor_id + || state.cluster_id != agent.cluster_id() + || bcast_v2.set.contains(member_id.to_u128()) + { + None + } else { + Some((member_id, state.addr)) + } + }) + .choose_multiple(&mut rng, count); + + bcast_v2.num_broadcasts += 1; + for (member_id, _) in broadcast_to.clone() { + bcast_v2.set.insert(member_id.to_u128()); + } + + let uni_payload = UniPayload::V1 { + data: UniPayloadV1::BroadcastV2(bcast_v2.clone()), + cluster_id: agent.cluster_id(), + }; + + let payload = match encode_framed( + &mut bcast_codec, + &mut ser_buf, + &mut bcast_buf, + &uni_payload, + ) { + Ok(payload) => payload, + Err(e) => { + error!("could not encode UniPayload::V1 BroadcastV2: {e}"); + continue; + } + }; + + let mut spawn_count = 0; + for (_, addr) in broadcast_to { + match try_transmit_broadcast( + &bytes_per_sec, + payload.clone(), + transport.clone(), + addr, + ) { + Err(e) => { + warn!("could not spawn broadcast transmission: {e}"); + match e { + TransmitError::TooBig(_) + | TransmitError::InsufficientCapacity(_) => { + // not sure this would ever happen + continue; + } + TransmitError::QuotaExceeded(_) => { + // exceeded our quota, stop trying to send this through + counter!("corro.broadcast.rate_limited").increment(1); + log_at_pow_10( + "broadcasts rate limited", + &mut limited_log_count, + ); + rate_limited = true; + break; + } + } + } + Ok(fut) => { + join_set.spawn(fut); + spawn_count += 1; + } + } + } + + let sleep_ms_base = if prev_rate_limited { 500 } else { 100 }; + + if spawn_count > 0 && num_broadcasts < max_transmissions { + broadcast_v2_pendings.push(Box::pin(async move { + tokio::time::sleep(Duration::from_millis( + sleep_ms_base * num_broadcasts as u64, + )) + .await; + bcast_v2 + })); + } + } + } + + if drop_oldest_broadcast( + &mut to_broadcast, + &mut to_broadcast_v2, + &mut to_local_broadcast, + max_queue_len, + ) + .is_some() { log_at_pow_10("dropped old change from broadcast queue", &mut log_count); counter!("corro.broadcast.dropped").increment(1); @@ -789,23 +950,63 @@ async fn handle_broadcasts( info!("broadcasts are done"); } -// Drop the oldest, most sent item or the oldest local item +#[derive(Debug, thiserror::Error)] +pub enum EncodeFramedError { + #[error(transparent)] + IO(#[from] std::io::Error), + #[error(transparent)] + Speedy(#[from] speedy::Error), +} + +fn encode_framed( + codec: &mut LengthDelimitedCodec, + ser_buf: &mut BytesMut, + frame_buf: &mut BytesMut, + payload: &UniPayload, +) -> Result { + if let Err(e) = payload.write_to_stream((&mut *ser_buf).writer()) { + ser_buf.clear(); + return Err(e.into()); + } + if let Err(e) = codec.encode(ser_buf.split().freeze(), frame_buf) { + frame_buf.clear(); + return Err(e.into()); + } + Ok(frame_buf.split().freeze()) +} + +// Drop the oldest, most sent item or the olde st local item fn drop_oldest_broadcast( queue: &mut VecDeque, - local_queue: &mut VecDeque, + v2_queue: &mut VecDeque, + local_queue: &mut VecDeque, max: usize, -) -> Option { - if queue.len() + local_queue.len() > max { +) -> Option { + if queue.len() + v2_queue.len() + local_queue.len() > max { // start by dropping from global queue let max_sent: Option<(_, _)> = queue .iter() .enumerate() .max_by_key(|(_, val)| val.send_count); - return if let Some((i, _)) = max_sent { - queue.remove(i) - } else { - local_queue.pop_back().map(PendingBroadcast::new_local) - }; + + if let Some((i, _)) = max_sent { + let removed = queue.remove(i); + return removed.map(|v| v.send_count); + } + + let max_sent: Option<(_, _)> = v2_queue + .iter() + .enumerate() + .max_by_key(|(_, val)| val.num_broadcasts); + + if let Some((i, _)) = max_sent { + let removed = v2_queue.remove(i); + return removed.map(|v| v.num_broadcasts); + } + + if local_queue.pop_back().is_some() { + return Some(0); + } } None @@ -1056,38 +1257,49 @@ mod tests { fn test_behaviour_when_queue_is_full() -> eyre::Result<()> { let max = 4; let mut queue = VecDeque::new(); + let mut v2_queue = VecDeque::new(); let mut local_queue = VecDeque::new(); - assert!(drop_oldest_broadcast(&mut queue, &mut local_queue, max).is_none()); + assert!(drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).is_none()); - queue.push_front(build_broadcast(1, 0)); + queue.push_front(build_broadcast(1, 1)); queue.push_front(build_broadcast(2, 3)); - queue.push_front(build_broadcast(3, 1)); - queue.push_front(build_broadcast(4, 1)); - queue.push_front(build_broadcast(5, 2)); - queue.push_front(build_broadcast(6, 1)); queue.push_front(build_broadcast(7, 3)); - queue.push_front(build_broadcast(8, 0)); - + queue.push_front(build_broadcast(5, 2)); + v2_queue.push_front(build_v2_uni_payload(1)); + v2_queue.push_front(build_v2_uni_payload(2)); + v2_queue.push_front(build_v2_uni_payload(3)); + v2_queue.push_front(build_v2_uni_payload(4)); + v2_queue.push_front(build_v2_uni_payload(4)); + v2_queue.push_front(build_v2_uni_payload(7)); // drop oldest item with highest send count - let dropped = drop_oldest_broadcast(&mut queue, &mut local_queue, max).unwrap(); - assert_eq!(dropped.send_count, 3); - assert_eq!(2_i64.to_be_bytes(), dropped.payload.as_ref()); + let dropped = + drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).unwrap(); + assert_eq!(dropped, 3); + + let dropped = + drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).unwrap(); + assert_eq!(dropped, 3); + + let dropped = + drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).unwrap(); + assert_eq!(dropped, 2); - let dropped = drop_oldest_broadcast(&mut queue, &mut local_queue, max).unwrap(); - assert_eq!(dropped.send_count, 3); - assert_eq!(7_i64.to_be_bytes(), dropped.payload.as_ref()); + let dropped = + drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).unwrap(); + assert_eq!(dropped, 1); - let dropped = drop_oldest_broadcast(&mut queue, &mut local_queue, max).unwrap(); - assert_eq!(dropped.send_count, 2); - assert_eq!(5_i64.to_be_bytes(), dropped.payload.as_ref()); + // we drop from v2_queue next + let dropped = + drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).unwrap(); + assert_eq!(dropped, 7); - let dropped = drop_oldest_broadcast(&mut queue, &mut local_queue, max).unwrap(); - assert_eq!(dropped.send_count, 1); - assert_eq!(3_i64.to_be_bytes(), dropped.payload.as_ref()); + let dropped = + drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).unwrap(); + assert_eq!(dropped, 4); // queue is still at max now, no item gets dropped - assert!(drop_oldest_broadcast(&mut queue, &mut local_queue, max).is_none()); + assert!(drop_oldest_broadcast(&mut queue, &mut v2_queue, &mut local_queue, max).is_none()); Ok(()) } @@ -1101,6 +1313,23 @@ mod tests { } } + fn build_v2_uni_payload(send_count: u8) -> BroadcastV2 { + BroadcastV2 { + change: BroadcastV1::Change(ChangeV1 { + actor_id: ActorId(Uuid::new_v4()), + changeset: Changeset::Full { + version: CrsqlDbVersion(0), + changes: vec![], + seqs: dbsr!(0, 0), + last_seq: CrsqlSeq(0), + ts: Default::default(), + }, + }), + set: ProbSet::new(10, 4), + num_broadcasts: send_count, + } + } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_broadcast_order() -> eyre::Result<()> { let _ = tracing_subscriber::fmt() diff --git a/crates/corro-types/Cargo.toml b/crates/corro-types/Cargo.toml index e6875cc5c..e88662bb8 100644 --- a/crates/corro-types/Cargo.toml +++ b/crates/corro-types/Cargo.toml @@ -19,6 +19,7 @@ config = { workspace = true } consul-client = { version = "0.1.0-alpha.0", path = "../consul-client" } corro-api-types = { version = "0.1.0-alpha.1", path = "../corro-api-types" } corro-base-types = { version = "0.1.0-alpha.1", path = "../corro-base-types" } +probabilistic-set = { version = "0.1.0-alpha.1", path = "../probabilistic-set" } deadpool = { workspace = true } enquote = { workspace = true } fallible-iterator = { workspace = true } diff --git a/crates/corro-types/src/actor.rs b/crates/corro-types/src/actor.rs index a9412112d..d450938b8 100644 --- a/crates/corro-types/src/actor.rs +++ b/crates/corro-types/src/actor.rs @@ -37,6 +37,10 @@ impl ActorId { pub fn from_bytes(bytes: [u8; 16]) -> Self { Self(Uuid::from_bytes(bytes)) } + + pub fn to_u128(&self) -> u128 { + self.0.as_u128() + } } impl TryFrom for uhlc::ID { diff --git a/crates/corro-types/src/broadcast.rs b/crates/corro-types/src/broadcast.rs index 62fc1d164..512b31f86 100644 --- a/crates/corro-types/src/broadcast.rs +++ b/crates/corro-types/src/broadcast.rs @@ -10,6 +10,7 @@ use corro_base_types::{CrsqlDbVersionRange, CrsqlSeqRange}; use foca::{Identity, Member, Notification, Runtime, Timer}; use indexmap::{map::Entry, IndexMap}; use metrics::counter; +use probabilistic_set::ProbSet; use rusqlite::{ types::{FromSql, FromSqlError}, ToSql, @@ -48,6 +49,7 @@ pub enum UniPayload { #[derive(Debug, Clone, Readable, Writable)] pub enum UniPayloadV1 { Broadcast(BroadcastV1), + BroadcastV2(BroadcastV2), } #[derive(Debug, Clone, Readable, Writable)] @@ -94,6 +96,13 @@ pub enum BroadcastV1 { Change(ChangeV1), } +#[derive(Clone, Debug, Readable, Writable)] +pub struct BroadcastV2 { + pub change: BroadcastV1, + pub set: ProbSet, + pub num_broadcasts: u8, +} + #[derive(Debug, Clone, PartialEq, Readable, Writable)] pub struct ColumnChange { pub cid: ColumnName, @@ -103,11 +112,12 @@ pub struct ColumnChange { pub cl: i64, } -#[derive(Debug, Clone, Copy, strum::IntoStaticStr)] +#[derive(Debug, Clone, strum::IntoStaticStr)] #[strum(serialize_all = "snake_case")] pub enum ChangeSource { Broadcast, Sync, + BroadcastV2(ProbSet, u8), } #[derive(Debug, Clone, PartialEq, Readable, Writable)] @@ -586,6 +596,8 @@ pub enum BroadcastDecodeError { pub enum BroadcastInput { Rebroadcast(BroadcastV1), AddBroadcast(BroadcastV1), + AddBroadcastV2(BroadcastV2), + RebroadcastV2(BroadcastV2), } pub struct DispatchRuntime { diff --git a/crates/corro-types/src/config.rs b/crates/corro-types/src/config.rs index ccea8dc55..b505a642f 100755 --- a/crates/corro-types/src/config.rs +++ b/crates/corro-types/src/config.rs @@ -231,6 +231,7 @@ pub struct PerfConfig { pub min_sync_backoff: u32, #[serde(default = "default_max_sync_backoff")] pub max_sync_backoff: u32, + pub max_broadcast_transmissions: Option, } impl Default for PerfConfig { @@ -252,6 +253,7 @@ impl Default for PerfConfig { sql_tx_timeout: default_sql_tx_timeout(), min_sync_backoff: default_min_sync_backoff(), max_sync_backoff: default_max_sync_backoff(), + max_broadcast_transmissions: None, } } } diff --git a/crates/corro-types/src/members.rs b/crates/corro-types/src/members.rs index 2cf2f74b7..fe93d0fa4 100644 --- a/crates/corro-types/src/members.rs +++ b/crates/corro-types/src/members.rs @@ -170,10 +170,10 @@ impl Members { /// Get member addresses where the ring index is `0` (meaning a /// very small RTT) - pub fn ring0(&self, cluster_id: ClusterId) -> impl Iterator + '_ { - self.states.values().filter_map(move |v| { + pub fn ring0(&self, cluster_id: ClusterId) -> impl Iterator + '_ { + self.states.iter().filter_map(move |(id, v)| { v.ring - .and_then(|ring| (v.cluster_id == cluster_id && ring == 0).then_some(v.addr)) + .and_then(|ring| (v.cluster_id == cluster_id && ring == 0).then_some((*id, v.addr))) }) } } diff --git a/crates/probabilistic-set/Cargo.toml b/crates/probabilistic-set/Cargo.toml new file mode 100644 index 000000000..ce5f69999 --- /dev/null +++ b/crates/probabilistic-set/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "probabilistic-set" +version = "0.1.0" +edition = "2024" + +[dependencies] +rand = { workspace = true } +serde = { workspace = true, features = ["derive"] } +speedy = { workspace = true } + +[dev-dependencies] +uuid = { version = "1.0", features = ["v4"] } + +[lints] +workspace = true diff --git a/crates/probabilistic-set/src/lib.rs b/crates/probabilistic-set/src/lib.rs new file mode 100644 index 000000000..36afead4e --- /dev/null +++ b/crates/probabilistic-set/src/lib.rs @@ -0,0 +1,203 @@ +use rand::Rng; +use serde::{Deserialize, Serialize}; +use speedy::{Readable, Writable}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +#[derive(Serialize, Debug, Deserialize, Clone, Readable, Writable)] +pub struct ProbSet { + // todo: consider making bits and array? + bits: Vec, + size_bits: usize, + seed: u64, // Store the random seed +} + +impl ProbSet { + pub fn new(expected_items: usize, bits_per_item: usize) -> Self { + if expected_items == 0 || bits_per_item == 0 { + panic!("expected_items and bits_per_item must be greater than 0"); + } + + let size_bits = expected_items * bits_per_item; + let seed = rand::thread_rng().r#gen(); // Generate random seed + + ProbSet { + bits: vec![0; size_bits.div_ceil(8)], + size_bits, + seed, + } + } + + // Create with a specific seed (useful for testing or controlled randomness) + pub fn with_seed(expected_items: usize, bits_per_item: usize, seed: u64) -> Self { + let size_bits: usize = expected_items * bits_per_item; + ProbSet { + bits: vec![0; size_bits.div_ceil(8)], + size_bits, + seed, + } + } + + pub fn insert(&mut self, item: u128) { + let idx = self.hash(item) % self.size_bits; + let byte_idx = idx / 8; + let bit_idx = idx % 8; + self.bits[byte_idx] |= 1 << bit_idx; + } + + pub fn contains(&self, item: u128) -> bool { + let idx = self.hash(item) % self.size_bits; + let byte_idx = idx / 8; + let bit_idx = idx % 8; + self.bits[byte_idx] & (1 << bit_idx) != 0 + } + + fn hash(&self, item: u128) -> usize { + let mut hasher = DefaultHasher::new(); + // Hash both the seed and the item + self.seed.hash(&mut hasher); + item.hash(&mut hasher); + hasher.finish() as usize + } + + // Get current seed (useful for debugging) + pub fn seed(&self) -> u64 { + self.seed + } + + // Change the seed (and clear the bits since the hash function effectively changes) + pub fn reseed(&mut self) { + self.seed = rand::thread_rng().r#gen(); + self.bits.fill(0); + } + + pub fn size_bytes(&self) -> usize { + // Include size of seed in total + self.bits.len() + std::mem::size_of::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + #[test] + fn test_uuid_collision_rates() { + // Test different configurations with UUIDs + let configs = vec![(1000, 2), (1000, 4)]; + + for (expected_items, bits_per_item) in configs { + let collision_rate = measure_uuid_collision_rate(expected_items, bits_per_item, 1000); + println!( + "Config: {} items, {} bits/item -> Collision rate: {:.2}%", + expected_items, + bits_per_item, + collision_rate * 100.0 + ); + } + } + + #[test] + fn test_uuid_false_positive_rate() { + let mut set = ProbSet::new(1000, 4); + let mut inserted_uuids = Vec::new(); + + for _ in 0..500 { + let uuid = Uuid::new_v4(); + let uuid_u128 = uuid.as_u128(); + set.insert(uuid_u128); + inserted_uuids.push(uuid_u128); + } + + for &uuid in &inserted_uuids { + assert!(set.contains(uuid), "False negative detected!"); + } + + // Test false positive rate with 10,000 random UUIDs + let test_size = 10000; + let mut false_positives = 0; + + for _ in 0..test_size { + let test_uuid = Uuid::new_v4().as_u128(); + + if inserted_uuids.contains(&test_uuid) { + continue; + } + + if set.contains(test_uuid) { + false_positives += 1; + } + } + + let false_positive_rate = false_positives as f64 / test_size as f64; + println!("False positive rate: {:.2}%", false_positive_rate * 100.0); + + assert!( + false_positive_rate < 0.20, + "False positive rate too high: {:.2}%", + false_positive_rate * 100.0 + ); + } + + #[test] + fn test_capacity_vs_collision_rate() { + let bits_per_item = 16; + let max_capacity = 1000; + + // Test collision rates at different capacity utilizations + for utilization in [0.25, 0.5, 0.75, 1.0, 1.25, 1.5] { + let num_items = (max_capacity as f64 * utilization) as usize; + let collision_rate = + measure_uuid_collision_rate(max_capacity, bits_per_item, num_items); + + println!( + "Utilization: {:.0}% -> Collision rate: {:.2}%", + utilization * 100.0, + collision_rate * 100.0 + ); + } + } + + #[test] + fn test_deterministic_behavior() { + // Test that same seed produces same results + let seed = 12345; + let mut set1 = ProbSet::with_seed(1000, 16, seed); + let mut set2 = ProbSet::with_seed(1000, 16, seed); + + let test_uuids: Vec = (0..100).map(|_| Uuid::new_v4().as_u128()).collect(); + + for &uuid in &test_uuids { + set1.insert(uuid); + set2.insert(uuid); + } + + for _ in 0..1000 { + let test_uuid = Uuid::new_v4().as_u128(); + assert_eq!( + set1.contains(test_uuid), + set2.contains(test_uuid), + "Sets with same seed should behave identically" + ); + } + } + // Helper function to measure collision rate + fn measure_uuid_collision_rate( + expected_items: usize, + bits_per_item: usize, + num_test_items: usize, + ) -> f64 { + let set = ProbSet::new(expected_items, bits_per_item); + let mut unique_positions = std::collections::HashMap::new(); + + for _ in 0..num_test_items { + let uuid = Uuid::new_v4().as_u128(); + let position = set.hash(uuid) % set.size_bits; + *unique_positions.entry(position).or_insert(0) += 1; + } + + let collisions = num_test_items - unique_positions.len(); + collisions as f64 / num_test_items as f64 + } +}