提交 5a8f2e47 authored 作者: Serhij S's avatar Serhij S

send blocking/recv blocking in async channel

上级 056dbecd
...@@ -12,7 +12,7 @@ use std::{ ...@@ -12,7 +12,7 @@ use std::{
use crate::{pdeque::Deque, DataDeliveryPolicy, Error, Result}; use crate::{pdeque::Deque, DataDeliveryPolicy, Error, Result};
use object_id::UniqueId; use object_id::UniqueId;
use parking_lot::Mutex; use parking_lot::{Condvar, Mutex};
use pin_project::{pin_project, pinned_drop}; use pin_project::{pin_project, pinned_drop};
type ClientId = usize; type ClientId = usize;
...@@ -46,15 +46,22 @@ struct ChannelInner<T: DataDeliveryPolicy> { ...@@ -46,15 +46,22 @@ struct ChannelInner<T: DataDeliveryPolicy> {
id: UniqueId, id: UniqueId,
pc: Mutex<PolicyChannel<T>>, pc: Mutex<PolicyChannel<T>>,
next_op_id: AtomicUsize, next_op_id: AtomicUsize,
space_available: Arc<Condvar>,
data_available: Arc<Condvar>,
} }
impl<T: DataDeliveryPolicy> Channel<T> { impl<T: DataDeliveryPolicy> Channel<T> {
fn new(capacity: usize, ordering: bool) -> Self { fn new(capacity: usize, ordering: bool) -> Self {
let pc = PolicyChannel::new(capacity, ordering);
let space_available = pc.space_available.clone();
let data_available = pc.data_available.clone();
Self( Self(
ChannelInner { ChannelInner {
id: <_>::default(), id: <_>::default(),
pc: Mutex::new(PolicyChannel::new(capacity, ordering)), pc: Mutex::new(pc),
next_op_id: <_>::default(), next_op_id: <_>::default(),
space_available,
data_available,
} }
.into(), .into(),
) )
...@@ -68,10 +75,12 @@ struct PolicyChannel<T: DataDeliveryPolicy> { ...@@ -68,10 +75,12 @@ struct PolicyChannel<T: DataDeliveryPolicy> {
queue: Deque<T>, queue: Deque<T>,
senders: usize, senders: usize,
receivers: usize, receivers: usize,
send_fut_wakers: VecDeque<(Waker, ClientId)>, send_fut_wakers: VecDeque<Option<(Waker, ClientId)>>,
send_fut_pending: BTreeSet<ClientId>, send_fut_pending: BTreeSet<ClientId>,
recv_fut_wakers: VecDeque<(Waker, ClientId)>, recv_fut_wakers: VecDeque<Option<(Waker, ClientId)>>,
recv_fut_pending: BTreeSet<ClientId>, recv_fut_pending: BTreeSet<ClientId>,
data_available: Arc<Condvar>,
space_available: Arc<Condvar>,
} }
impl<T> PolicyChannel<T> impl<T> PolicyChannel<T>
...@@ -88,6 +97,8 @@ where ...@@ -88,6 +97,8 @@ where
send_fut_pending: <_>::default(), send_fut_pending: <_>::default(),
recv_fut_wakers: <_>::default(), recv_fut_wakers: <_>::default(),
recv_fut_pending: <_>::default(), recv_fut_pending: <_>::default(),
data_available: <_>::default(),
space_available: <_>::default(),
} }
} }
...@@ -95,30 +106,39 @@ where ...@@ -95,30 +106,39 @@ where
#[inline] #[inline]
fn notify_data_sent(&mut self) { fn notify_data_sent(&mut self) {
self.wake_next_recv_fut(); self.wake_next_recv();
} }
#[inline] #[inline]
fn wake_next_send_fut(&mut self) { fn wake_next_send(&mut self) {
if let Some((waker, id)) = self.send_fut_wakers.pop_front() { if let Some(w) = self.send_fut_wakers.pop_front() {
self.send_fut_pending.insert(id); if let Some((waker, id)) = w {
waker.wake(); self.send_fut_pending.insert(id);
waker.wake();
} else {
self.space_available.notify_one();
}
} }
} }
#[inline] #[inline]
fn wake_all_send_futs(&mut self) { fn wake_all_sends(&mut self) {
for (waker, _) in mem::take(&mut self.send_fut_wakers) { for (waker, _) in mem::take(&mut self.send_fut_wakers).into_iter().flatten() {
waker.wake(); waker.wake();
} }
self.space_available.notify_all();
} }
#[inline] #[inline]
fn notify_send_fut_drop(&mut self, id: ClientId) { fn notify_send_fut_drop(&mut self, id: ClientId) {
if let Some(pos) = self.send_fut_wakers.iter().position(|(_, i)| *i == id) { if let Some(pos) = self
.send_fut_wakers
.iter()
.position(|w| w.as_ref().map_or(false, |(_, i)| *i == id))
{
self.send_fut_wakers.remove(pos); self.send_fut_wakers.remove(pos);
} }
if self.send_fut_pending.remove(&id) { if self.send_fut_pending.remove(&id) {
self.wake_next_send_fut(); self.wake_next_send();
} }
} }
...@@ -129,37 +149,52 @@ where ...@@ -129,37 +149,52 @@ where
#[inline] #[inline]
fn append_send_fut_waker(&mut self, waker: Waker, id: ClientId) { fn append_send_fut_waker(&mut self, waker: Waker, id: ClientId) {
self.send_fut_wakers.push_back((waker, id)); self.send_fut_wakers.push_back(Some((waker, id)));
}
#[inline]
fn append_send_sync_waker(&mut self) {
// use condvar
self.send_fut_wakers.push_back(None);
} }
// receivers // receivers
#[inline] #[inline]
fn notify_data_received(&mut self) { fn notify_data_received(&mut self) {
self.wake_next_send_fut(); self.wake_next_send();
} }
#[inline] #[inline]
fn wake_next_recv_fut(&mut self) { fn wake_next_recv(&mut self) {
if let Some((waker, id)) = self.recv_fut_wakers.pop_front() { if let Some(w) = self.recv_fut_wakers.pop_front() {
self.recv_fut_pending.insert(id); if let Some((waker, id)) = w {
waker.wake(); self.recv_fut_pending.insert(id);
waker.wake();
} else {
self.data_available.notify_one();
}
} }
} }
#[inline] #[inline]
fn wake_all_recv_futs(&mut self) { fn wake_all_recvs(&mut self) {
for (waker, _) in mem::take(&mut self.recv_fut_wakers) { for (waker, _) in mem::take(&mut self.recv_fut_wakers).into_iter().flatten() {
waker.wake(); waker.wake();
} }
self.data_available.notify_all();
} }
#[inline] #[inline]
fn notify_recv_fut_drop(&mut self, id: ClientId) { fn notify_recv_fut_drop(&mut self, id: ClientId) {
if let Some(pos) = self.recv_fut_wakers.iter().position(|(_, i)| *i == id) { if let Some(pos) = self
.recv_fut_wakers
.iter()
.position(|w| w.as_ref().map_or(false, |(_, i)| *i == id))
{
self.recv_fut_wakers.remove(pos); self.recv_fut_wakers.remove(pos);
} }
if self.recv_fut_pending.remove(&id) { if self.recv_fut_pending.remove(&id) {
self.wake_next_recv_fut(); self.wake_next_recv();
} }
} }
...@@ -171,7 +206,13 @@ where ...@@ -171,7 +206,13 @@ where
#[inline] #[inline]
fn append_recv_fut_waker(&mut self, waker: Waker, id: ClientId) { fn append_recv_fut_waker(&mut self, waker: Waker, id: ClientId) {
self.recv_fut_wakers.push_back((waker, id)); self.recv_fut_wakers.push_back(Some((waker, id)));
}
#[inline]
fn append_recv_sync_waker(&mut self) {
// use condvar
self.recv_fut_wakers.push_back(None);
} }
} }
...@@ -265,6 +306,27 @@ where ...@@ -265,6 +306,27 @@ where
Err(Error::ChannelFull) Err(Error::ChannelFull)
} }
} }
pub fn send_blocking(&self, mut value: T) -> Result<()> {
let mut pc = self.channel.0.pc.lock();
let pushed = loop {
if pc.receivers == 0 {
return Err(Error::ChannelClosed);
}
let push_result = pc.queue.try_push(value);
let Some(val) = push_result.value else {
break push_result.pushed;
};
value = val;
pc.append_send_sync_waker();
self.channel.0.space_available.wait(&mut pc);
};
pc.wake_next_recv();
if pushed {
Ok(())
} else {
Err(Error::ChannelSkipped)
}
}
#[inline] #[inline]
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.channel.0.pc.lock().queue.len() self.channel.0.pc.lock().queue.len()
...@@ -303,7 +365,7 @@ where ...@@ -303,7 +365,7 @@ where
let mut pc = self.channel.0.pc.lock(); let mut pc = self.channel.0.pc.lock();
pc.senders -= 1; pc.senders -= 1;
if pc.senders == 0 { if pc.senders == 0 {
pc.wake_all_recv_futs(); pc.wake_all_recvs();
} }
} }
} }
...@@ -379,6 +441,19 @@ where ...@@ -379,6 +441,19 @@ where
Err(Error::ChannelEmpty) Err(Error::ChannelEmpty)
} }
} }
pub fn recv_blocking(&self) -> Result<T> {
let mut pc = self.channel.0.pc.lock();
loop {
if let Some(val) = pc.queue.get() {
pc.wake_next_send();
return Ok(val);
} else if pc.senders == 0 {
return Err(Error::ChannelClosed);
}
pc.append_recv_sync_waker();
self.channel.0.data_available.wait(&mut pc);
}
}
#[inline] #[inline]
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.channel.0.pc.lock().queue.len() self.channel.0.pc.lock().queue.len()
...@@ -417,7 +492,7 @@ where ...@@ -417,7 +492,7 @@ where
let mut pc = self.channel.0.pc.lock(); let mut pc = self.channel.0.pc.lock();
pc.receivers -= 1; pc.receivers -= 1;
if pc.receivers == 0 { if pc.receivers == 0 {
pc.wake_all_send_futs(); pc.wake_all_sends();
} }
} }
} }
...@@ -527,6 +602,72 @@ mod test { ...@@ -527,6 +602,72 @@ mod test {
insta::assert_snapshot!(t, @"1"); insta::assert_snapshot!(t, @"1");
} }
#[tokio::test]
async fn test_sync_send_async_recv() {
let (tx, rx) = bounded::<Message>(8);
let tx_t = tx.clone();
tokio::spawn(async move {
for _ in 0..10 {
tx.send(Message::Test(123)).await.unwrap();
if let Err(e) = tx.send(Message::Spam).await {
assert!(e.is_data_skipped(), "{}", e);
}
}
});
tokio::task::spawn_blocking(move || {
for _ in 0..10 {
tx_t.send_blocking(Message::Test(123)).unwrap();
if let Err(e) = tx_t.send_blocking(Message::Spam) {
assert!(e.is_data_skipped(), "{}", e);
}
}
});
thread::sleep(Duration::from_secs(1));
let mut c = 0;
while let Ok(msg) = rx.recv().await {
if let Message::Test(_) = msg {
c += 1
}
}
insta::assert_snapshot!(c, @"20");
}
#[tokio::test]
async fn test_sync_send_sync_recv() {
let (tx, rx) = bounded::<Message>(8);
let tx_t = tx.clone();
tokio::spawn(async move {
for _ in 0..10 {
tx.send(Message::Test(123)).await.unwrap();
if let Err(e) = tx.send(Message::Spam).await {
assert!(e.is_data_skipped(), "{}", e);
}
tx.send(Message::Temperature(123.0)).await.unwrap();
}
});
tokio::task::spawn_blocking(move || {
for _ in 0..10 {
tx_t.send_blocking(Message::Test(123)).unwrap();
if let Err(e) = tx_t.send_blocking(Message::Spam) {
assert!(e.is_data_skipped(), "{}", e);
}
tx_t.send_blocking(Message::Temperature(123.0)).unwrap();
}
});
thread::sleep(Duration::from_secs(1));
let c = tokio::task::spawn_blocking(move || {
let mut c = 0;
while let Ok(msg) = rx.recv_blocking() {
if let Message::Test(_) = msg {
c += 1;
}
}
c
})
.await
.unwrap();
insta::assert_snapshot!(c, @"20");
}
#[tokio::test] #[tokio::test]
async fn test_poisoning() { async fn test_poisoning() {
let n = 5_000; let n = 5_000;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论