提交 90604f54 authored 作者: Serhij S's avatar Serhij S

pchannel_async

上级 0b20c539
......@@ -20,3 +20,6 @@ parking_lot = "0.12.1"
serde = { version = "1.0.197", features = ["derive", "rc"] }
sysinfo = "0.30.6"
thiserror = "1.0.57"
[dev-dependencies]
tokio = { version = "1.36.0", features = ["rt", "macros", "time"] }
......@@ -7,8 +7,10 @@ use thread_rt::{RTParams, Scheduling};
pub mod buf;
/// In-process data communication pub/sub hub
pub mod hub;
/// Policy-based channels
/// Policy-based channels, synchronous edition
pub mod pchannel;
/// Policy-based channels, asynchronous edition
pub mod pchannel_async;
/// Policy-based data storages
pub mod pdeque;
/// Task supervisor to manage real-time threads
......
use std::{
collections::{BTreeSet, VecDeque},
future::Future,
mem,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use crate::{pdeque::Deque, DataDeliveryPolicy, Error, Result};
use object_id::UniqueId;
use parking_lot::Mutex;
type ClientId = usize;
struct Channel<T: DataDeliveryPolicy>(Arc<ChannelInner<T>>);
impl<T: DataDeliveryPolicy> Channel<T> {
fn id(&self) -> usize {
self.0.id.as_usize()
}
}
impl<T: DataDeliveryPolicy> Eq for Channel<T> {}
impl<T: DataDeliveryPolicy> PartialEq for Channel<T> {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
}
}
impl<T> Clone for Channel<T>
where
T: DataDeliveryPolicy,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
struct ChannelInner<T: DataDeliveryPolicy> {
id: UniqueId,
pc: Mutex<PolicyChannel<T>>,
}
impl<T: DataDeliveryPolicy> Channel<T> {
fn new(capacity: usize, ordering: bool) -> Self {
Self(
ChannelInner {
id: <_>::default(),
pc: Mutex::new(PolicyChannel::new(capacity, ordering)),
}
.into(),
)
}
}
struct PolicyChannel<T: DataDeliveryPolicy> {
queue: Deque<T>,
senders: usize,
receivers: usize,
send_fut_wakers: VecDeque<(Waker, ClientId)>,
send_fut_pending: BTreeSet<ClientId>,
recv_fut_wakers: VecDeque<(Waker, ClientId)>,
recv_fut_pending: BTreeSet<ClientId>,
}
impl<T> PolicyChannel<T>
where
T: DataDeliveryPolicy,
{
fn new(capacity: usize, ordering: bool) -> Self {
assert!(capacity > 0, "channel capacity MUST be > 0");
Self {
queue: Deque::bounded(capacity).set_ordering(ordering),
senders: 1,
receivers: 1,
send_fut_wakers: <_>::default(),
send_fut_pending: <_>::default(),
recv_fut_wakers: <_>::default(),
recv_fut_pending: <_>::default(),
}
}
// senders
#[inline]
fn notify_data_sent(&mut self) {
self.wake_next_recv_fut();
}
#[inline]
fn wake_next_send_fut(&mut self) {
if let Some((waker, id)) = self.send_fut_wakers.pop_front() {
self.send_fut_pending.insert(id);
waker.wake();
}
}
#[inline]
fn wake_all_send_futs(&mut self) {
for (waker, _) in mem::take(&mut self.send_fut_wakers) {
waker.wake();
}
}
#[inline]
fn notify_send_fut_drop(&mut self, id: ClientId) {
if let Some(pos) = self.send_fut_wakers.iter().position(|(_, i)| *i == id) {
self.send_fut_wakers.remove(pos);
}
if self.send_fut_pending.remove(&id) {
self.wake_next_send_fut();
}
}
#[inline]
fn confirm_send_fut_waked(&mut self, id: ClientId) {
self.send_fut_pending.remove(&id);
}
#[inline]
fn append_send_fut_waker(&mut self, waker: Waker, id: ClientId) {
self.send_fut_wakers.push_back((waker, id));
}
// receivers
#[inline]
fn notify_data_received(&mut self) {
self.wake_next_send_fut();
}
#[inline]
fn wake_next_recv_fut(&mut self) {
if let Some((waker, id)) = self.recv_fut_wakers.pop_front() {
self.recv_fut_pending.insert(id);
waker.wake();
}
}
#[inline]
fn wake_all_recv_futs(&mut self) {
for (waker, _) in mem::take(&mut self.recv_fut_wakers) {
waker.wake();
}
}
#[inline]
fn notify_recv_fut_drop(&mut self, id: ClientId) {
if let Some(pos) = self.recv_fut_wakers.iter().position(|(_, i)| *i == id) {
self.recv_fut_wakers.remove(pos);
}
if self.recv_fut_pending.remove(&id) {
self.wake_next_recv_fut();
}
}
#[inline]
fn confirm_recv_fut_waked(&mut self, id: ClientId) {
// the resource is taken, remove from pending
self.recv_fut_pending.remove(&id);
}
#[inline]
fn append_recv_fut_waker(&mut self, waker: Waker, id: ClientId) {
self.recv_fut_wakers.push_back((waker, id));
}
}
struct Send<'a, T: DataDeliveryPolicy> {
id: UniqueId,
channel: &'a Channel<T>,
queued: bool,
value: Option<T>,
}
impl<'a, T: DataDeliveryPolicy> Drop for Send<'a, T> {
fn drop(&mut self) {
self.channel
.0
.pc
.lock()
.notify_send_fut_drop(self.id.as_usize());
}
}
impl<'a, T> Future for Send<'a, T>
where
T: DataDeliveryPolicy,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut pc = self.channel.0.pc.lock();
if self.queued {
pc.confirm_send_fut_waked(self.id.as_usize());
}
if pc.receivers == 0 {
return Poll::Ready(Err(Error::ChannelClosed));
}
let this = unsafe { self.as_mut().get_unchecked_mut() };
if pc.send_fut_wakers.is_empty() || this.queued {
let push_result = pc.queue.try_push(this.value.take().unwrap());
if let Some(val) = push_result.value {
this.value = Some(val);
} else {
pc.notify_data_sent();
return Poll::Ready(if push_result.pushed {
Ok(())
} else {
Err(Error::ChannelSkipped)
});
}
}
this.queued = true;
pc.append_send_fut_waker(cx.waker().clone(), self.id.as_usize());
Poll::Pending
}
}
#[derive(Eq, PartialEq)]
pub struct Sender<T>
where
T: DataDeliveryPolicy,
{
channel: Channel<T>,
}
impl<T> Sender<T>
where
T: DataDeliveryPolicy,
{
#[inline]
pub fn send(&self, value: T) -> impl Future<Output = Result<()>> + '_ {
Send {
id: <_>::default(),
channel: &self.channel,
queued: false,
value: Some(value),
}
}
pub fn try_send(&self, value: T) -> Result<()> {
let mut pc = self.channel.0.pc.lock();
if pc.receivers == 0 {
return Err(Error::ChannelClosed);
}
let push_result = pc.queue.try_push(value);
if push_result.value.is_none() {
pc.notify_data_sent();
if push_result.pushed {
Ok(())
} else {
Err(Error::ChannelSkipped)
}
} else {
Err(Error::ChannelFull)
}
}
#[inline]
pub fn len(&self) -> usize {
self.channel.0.pc.lock().queue.len()
}
#[inline]
pub fn is_full(&self) -> bool {
self.channel.0.pc.lock().queue.is_full()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.channel.0.pc.lock().queue.is_empty()
}
#[inline]
pub fn is_alive(&self) -> bool {
self.channel.0.pc.lock().receivers > 0
}
}
impl<T> Clone for Sender<T>
where
T: DataDeliveryPolicy,
{
fn clone(&self) -> Self {
self.channel.0.pc.lock().senders += 1;
Self {
channel: self.channel.clone(),
}
}
}
impl<T> Drop for Sender<T>
where
T: DataDeliveryPolicy,
{
fn drop(&mut self) {
let mut pc = self.channel.0.pc.lock();
pc.senders -= 1;
if pc.senders == 0 {
pc.wake_all_recv_futs();
}
}
}
struct Recv<'a, T: DataDeliveryPolicy> {
id: UniqueId,
channel: &'a Channel<T>,
queued: bool,
}
impl<'a, T: DataDeliveryPolicy> Drop for Recv<'a, T> {
fn drop(&mut self) {
self.channel
.0
.pc
.lock()
.notify_recv_fut_drop(self.id.as_usize());
}
}
impl<'a, T> Future for Recv<'a, T>
where
T: DataDeliveryPolicy,
{
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut pc = self.channel.0.pc.lock();
if self.queued {
pc.confirm_recv_fut_waked(self.id.as_usize());
}
if pc.recv_fut_wakers.is_empty() || self.queued {
if let Some(val) = pc.queue.get() {
pc.notify_data_received();
return Poll::Ready(Ok(val));
} else if pc.senders == 0 {
return Poll::Ready(Err(Error::ChannelClosed));
}
}
self.queued = true;
pc.append_recv_fut_waker(cx.waker().clone(), self.id.as_usize());
Poll::Pending
}
}
#[derive(Eq, PartialEq)]
pub struct Receiver<T>
where
T: DataDeliveryPolicy,
{
channel: Channel<T>,
}
impl<T> Receiver<T>
where
T: DataDeliveryPolicy,
{
#[inline]
pub fn recv(&self) -> impl Future<Output = Result<T>> + '_ {
Recv {
id: <_>::default(),
channel: &self.channel,
queued: false,
}
}
pub fn try_recv(&self) -> Result<T> {
let mut pc = self.channel.0.pc.lock();
if let Some(val) = pc.queue.get() {
pc.notify_data_received();
Ok(val)
} else if pc.senders == 0 {
Err(Error::ChannelClosed)
} else {
Err(Error::ChannelEmpty)
}
}
#[inline]
pub fn len(&self) -> usize {
self.channel.0.pc.lock().queue.len()
}
#[inline]
pub fn is_full(&self) -> bool {
self.channel.0.pc.lock().queue.is_full()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.channel.0.pc.lock().queue.is_empty()
}
#[inline]
pub fn is_alive(&self) -> bool {
self.channel.0.pc.lock().senders > 0
}
}
impl<T> Clone for Receiver<T>
where
T: DataDeliveryPolicy,
{
fn clone(&self) -> Self {
self.channel.0.pc.lock().receivers += 1;
Self {
channel: self.channel.clone(),
}
}
}
impl<T> Drop for Receiver<T>
where
T: DataDeliveryPolicy,
{
fn drop(&mut self) {
let mut pc = self.channel.0.pc.lock();
pc.receivers -= 1;
if pc.receivers == 0 {
pc.wake_all_send_futs();
}
}
}
fn make_channel<T: DataDeliveryPolicy>(ch: Channel<T>) -> (Sender<T>, Receiver<T>) {
let tx = Sender {
channel: ch.clone(),
};
let rx = Receiver { channel: ch };
(tx, rx)
}
/// Creates a bounded async channel which respects [`DataDeliveryPolicy`] rules with no message
/// priority ordering
///
/// # Panics
///
/// Will panic if the capacity is zero
pub fn bounded<T: DataDeliveryPolicy>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let ch = Channel::new(capacity, false);
make_channel(ch)
}
/// Creates a bounded async channel which respects [`DataDeliveryPolicy`] rules and has got message
/// priority ordering turned on
///
/// # Panics
///
/// Will panic if the capacity is zero
pub fn ordered<T: DataDeliveryPolicy>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let ch = Channel::new(capacity, true);
make_channel(ch)
}
#[cfg(test)]
mod test {
use std::{thread, time::Duration};
use crate::{DataDeliveryPolicy, DeliveryPolicy};
use super::bounded;
#[derive(Debug)]
enum Message {
Test(usize),
Temperature(f64),
Spam,
}
impl DataDeliveryPolicy for Message {
fn delivery_policy(&self) -> DeliveryPolicy {
match self {
Message::Test(_) => DeliveryPolicy::Always,
Message::Temperature(_) => DeliveryPolicy::Single,
Message::Spam => DeliveryPolicy::Optional,
}
}
}
#[tokio::test]
async fn test_delivery_policy_optional() {
let (tx, rx) = bounded::<Message>(1);
tokio::spawn(async move {
for _ in 0..10 {
tx.send(Message::Test(123)).await.unwrap();
if let Err(e) = tx.send(Message::Spam).await {
assert!(e.is_skipped(), "{}", e);
}
tx.send(Message::Temperature(123.0)).await.unwrap();
}
});
thread::sleep(Duration::from_secs(1));
while let Ok(msg) = rx.recv().await {
thread::sleep(Duration::from_millis(10));
if matches!(msg, Message::Spam) {
panic!("delivery policy not respected ({:?})", msg);
}
}
}
#[tokio::test]
async fn test_delivery_policy_single() {
let (tx, rx) = bounded::<Message>(512);
tokio::spawn(async move {
for _ in 0..10 {
tx.send(Message::Test(123)).await.unwrap();
if let Err(e) = tx.send(Message::Spam).await {
assert!(e.is_skipped(), "{}", e);
}
tx.send(Message::Temperature(123.0)).await.unwrap();
}
});
thread::sleep(Duration::from_secs(1));
let mut c = 0;
let mut t = 0;
while let Ok(msg) = rx.recv().await {
match msg {
Message::Test(_) => c += 1,
Message::Temperature(_) => t += 1,
Message::Spam => {}
}
}
assert_eq!(c, 10);
assert_eq!(t, 1);
}
#[tokio::test]
async fn test_poisoning() {
let n = 20_000;
for _ in 0..n {
let (tx, rx) = bounded::<Message>(512);
let rx_t = tokio::spawn(async move { while rx.recv().await.is_ok() {} });
tokio::spawn(async move {
let _t = tx;
});
tokio::time::timeout(Duration::from_millis(100), rx_t)
.await
.unwrap()
.unwrap();
}
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论