提交 2017434a authored 作者: Serhij S's avatar Serhij S

advanced socket opts

上级 2d486bd7
use parking_lot::MutexGuard; use parking_lot::MutexGuard;
use std::sync::Arc; use std::{
io::{Read, Write},
sync::Arc,
time::Duration,
};
use crate::Result;
pub mod serial; // Serial communications pub mod serial; // Serial communications
pub mod tcp; // TCP communications pub mod tcp; // TCP communications
...@@ -18,11 +24,11 @@ impl Client { ...@@ -18,11 +24,11 @@ impl Client {
self.0.reconnect(); self.0.reconnect();
} }
/// Write data to the client /// Write data to the client
pub fn write(&self, buf: &[u8]) -> Result<(), std::io::Error> { pub fn write(&self, buf: &[u8]) -> Result<()> {
self.0.write(buf) self.0.write(buf).map_err(Into::into)
} }
/// Read data from the client /// Read data from the client
pub fn read_exact(&self, buf: &mut [u8]) -> Result<(), std::io::Error> { pub fn read_exact(&self, buf: &mut [u8]) -> Result<()> {
self.0.read_exact(buf) self.0.read_exact(buf)
} }
/// Get the protocol of the client /// Get the protocol of the client
...@@ -36,11 +42,53 @@ pub enum Protocol { ...@@ -36,11 +42,53 @@ pub enum Protocol {
Serial, Serial,
} }
pub trait Stream: Read + Write + Send {}
trait Communicator { trait Communicator {
fn lock(&self) -> MutexGuard<()>; fn lock(&self) -> MutexGuard<()>;
fn reconnect(&self); fn reconnect(&self);
fn write(&self, buf: &[u8]) -> Result<(), std::io::Error>; fn write(&self, buf: &[u8]) -> Result<()>;
fn read_exact(&self, buf: &mut [u8]) -> Result<(), std::io::Error>; fn read_exact(&self, buf: &mut [u8]) -> Result<()>;
fn protocol(&self) -> Protocol; fn protocol(&self) -> Protocol;
fn session_id(&self) -> usize; fn session_id(&self) -> usize;
} }
/// Connection Options
pub struct ConnectionOptions {
with_reader: bool,
chat: Option<Box<ChatFn>>,
timeout: Duration,
}
pub type ChatFn = dyn Fn(&mut dyn Stream) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>
+ Send
+ Sync;
impl ConnectionOptions {
pub fn new(timeout: Duration) -> Self {
Self {
with_reader: false,
chat: None,
timeout,
}
}
/// Enable the reader channel. The reader channel allows the client to receive a clone of the
/// stream reader when the connection is established. This is useful for implementing custom
/// protocols that require reading from the stream.
pub fn with_reader(mut self) -> Self {
self.with_reader = true;
self
}
/// Set the chat function. The chat function is called after the connection is established. The
/// chat function can be used to implement custom protocols that require additional setup.
pub fn chat<F>(mut self, chat: F) -> Self
where
F: Fn(&mut dyn Stream) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>
+ Send
+ Sync
+ 'static,
{
self.chat = Some(Box::new(chat));
self
}
}
...@@ -154,7 +154,7 @@ impl Communicator for Serial { ...@@ -154,7 +154,7 @@ impl Communicator for Serial {
port.last_frame.take(); port.last_frame.take();
self.session_id.fetch_add(1, Ordering::Relaxed); self.session_id.fetch_add(1, Ordering::Relaxed);
} }
fn write(&self, buf: &[u8]) -> std::result::Result<(), std::io::Error> { fn write(&self, buf: &[u8]) -> Result<()> {
let mut port = self let mut port = self
.get_port() .get_port()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
...@@ -178,7 +178,7 @@ impl Communicator for Serial { ...@@ -178,7 +178,7 @@ impl Communicator for Serial {
} }
result.map_err(Into::into) result.map_err(Into::into)
} }
fn read_exact(&self, buf: &mut [u8]) -> std::result::Result<(), std::io::Error> { fn read_exact(&self, buf: &mut [u8]) -> Result<()> {
let mut port = self let mut port = self
.get_port() .get_port()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
......
use crate::Error; use crate::pchannel;
use crate::{DataDeliveryPolicy, Error, Result};
use super::{Client, Communicator, Protocol}; use super::{ChatFn, Client, Communicator, ConnectionOptions, Protocol, Stream};
use core::fmt; use core::fmt;
use parking_lot::{Mutex, MutexGuard}; use parking_lot::{Mutex, MutexGuard};
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::TcpStream; use std::net::{self, TcpStream};
use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
const READER_CHANNEL_CAPACITY: usize = 1024;
/// Create a new TCP client. The client will attempt to connect to the given address at the time of /// Create a new TCP client. The client will attempt to connect to the given address at the time of
/// the first request. The client will automatically reconnect if the connection is lost. /// the first request. The client will automatically reconnect if the connection is lost.
pub fn connect<A: ToSocketAddrs + fmt::Debug>(addr: A, timeout: Duration) -> Result<Client, Error> { pub fn connect<A: ToSocketAddrs + fmt::Debug>(addr: A, timeout: Duration) -> Result<Client> {
Ok(Client(Tcp::create(addr, timeout)?)) Ok(Client(
Tcp::create(addr, ConnectionOptions::new(timeout))?.0,
))
}
/// Create a new TCP client with options. The client will attempt to connect to the given address
/// at the time of the first request. The client will automatically reconnect if the connection is
/// lost.
pub fn connect_with_options<A: ToSocketAddrs + fmt::Debug>(
addr: A,
options: ConnectionOptions,
) -> Result<(Client, pchannel::Receiver<TcpReader>)> {
let (tcp, rx) = Tcp::create(addr, options)?;
Ok((Client(tcp), rx.unwrap()))
}
impl Stream for TcpStream {}
pub struct TcpReader {
reader: Option<Box<dyn Read + Send + 'static>>,
}
impl TcpReader {
pub fn take(&mut self) -> Option<Box<dyn Read + Send + 'static>> {
self.reader.take()
}
} }
impl DataDeliveryPolicy for TcpReader {}
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub struct Tcp { pub struct Tcp {
addr: SocketAddr, addr: SocketAddr,
...@@ -23,17 +53,20 @@ pub struct Tcp { ...@@ -23,17 +53,20 @@ pub struct Tcp {
timeout: Duration, timeout: Duration,
busy: Mutex<()>, busy: Mutex<()>,
session_id: AtomicUsize, session_id: AtomicUsize,
reader_tx: Option<pchannel::Sender<TcpReader>>,
chat: Option<Box<ChatFn>>,
} }
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub type TcpClient = Arc<Tcp>; pub type TcpClient = Arc<Tcp>;
macro_rules! handle_tcp_stream_error { macro_rules! handle_tcp_stream_error {
($stream: expr, $err: expr, $any: expr) => {{ ($sess: expr, $stream: expr, $err: expr, $any: expr) => {{
if $any || $err.kind() == std::io::ErrorKind::TimedOut { if $any || $err.kind() == std::io::ErrorKind::TimedOut {
$stream.take(); $sess.fetch_add(1, Ordering::Relaxed);
$stream.take().map(|s| s.shutdown(net::Shutdown::Both));
} }
$err $err.into()
}}; }};
} }
...@@ -45,24 +78,26 @@ impl Communicator for Tcp { ...@@ -45,24 +78,26 @@ impl Communicator for Tcp {
self.session_id.load(Ordering::Relaxed) self.session_id.load(Ordering::Relaxed)
} }
fn reconnect(&self) { fn reconnect(&self) {
self.stream.lock().take(); self.stream.lock().take().map(|s| {
self.session_id.fetch_add(1, Ordering::Relaxed); self.session_id.fetch_add(1, Ordering::Relaxed);
s.shutdown(net::Shutdown::Both)
});
} }
fn write(&self, buf: &[u8]) -> Result<(), std::io::Error> { fn write(&self, buf: &[u8]) -> Result<()> {
let mut stream = self.get_stream()?; let mut stream = self.get_stream()?;
stream stream
.as_mut() .as_mut()
.unwrap() .unwrap()
.write_all(buf) .write_all(buf)
.map_err(|e| handle_tcp_stream_error!(stream, e, true)) .map_err(|e| handle_tcp_stream_error!(self.session_id, stream, e, true))
} }
fn read_exact(&self, buf: &mut [u8]) -> Result<(), std::io::Error> { fn read_exact(&self, buf: &mut [u8]) -> Result<()> {
let mut stream = self.get_stream()?; let mut stream = self.get_stream()?;
stream stream
.as_mut() .as_mut()
.unwrap() .unwrap()
.read_exact(buf) .read_exact(buf)
.map_err(|e| handle_tcp_stream_error!(stream, e, false)) .map_err(|e| handle_tcp_stream_error!(self.session_id, stream, e, false))
} }
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
Protocol::Tcp Protocol::Tcp
...@@ -72,29 +107,54 @@ impl Communicator for Tcp { ...@@ -72,29 +107,54 @@ impl Communicator for Tcp {
impl Tcp { impl Tcp {
fn create<A: ToSocketAddrs + fmt::Debug>( fn create<A: ToSocketAddrs + fmt::Debug>(
addr: A, addr: A,
timeout: Duration, options: ConnectionOptions,
) -> Result<TcpClient, Error> { ) -> Result<(TcpClient, Option<pchannel::Receiver<TcpReader>>)> {
Ok(Self { let (tx, rx) = if options.with_reader {
let (tx, rx) = pchannel::bounded(READER_CHANNEL_CAPACITY);
(Some(tx), Some(rx))
} else {
(None, None)
};
let client = Self {
addr: addr addr: addr
.to_socket_addrs()? .to_socket_addrs()?
.next() .next()
.ok_or_else(|| Error::invalid_data(format!("Invalid address: {:?}", addr)))?, .ok_or_else(|| Error::invalid_data(format!("Invalid address: {:?}", addr)))?,
stream: <_>::default(), stream: <_>::default(),
busy: <_>::default(), busy: <_>::default(),
timeout, timeout: options.timeout,
session_id: <_>::default(), session_id: <_>::default(),
} reader_tx: tx,
.into()) chat: options.chat,
};
Ok((client.into(), rx))
} }
fn get_stream(&self) -> Result<MutexGuard<Option<TcpStream>>, std::io::Error> { fn get_stream(&self) -> Result<MutexGuard<Option<TcpStream>>> {
let mut lock = self.stream.lock(); let mut lock = self.stream.lock();
if lock.as_mut().is_none() { if lock.as_mut().is_none() {
let stream = TcpStream::connect_timeout(&self.addr, self.timeout)?; let mut stream = TcpStream::connect_timeout(&self.addr, self.timeout)?;
stream.set_read_timeout(Some(self.timeout))?; stream.set_read_timeout(Some(self.timeout))?;
stream.set_write_timeout(Some(self.timeout))?; stream.set_write_timeout(Some(self.timeout))?;
stream.set_nodelay(true)?; stream.set_nodelay(true)?;
if let Some(ref chat) = self.chat {
chat(&mut stream).map_err(Error::io)?;
}
if let Some(ref tx) = self.reader_tx {
tx.send(TcpReader {
reader: Some(Box::new(stream.try_clone()?)),
})?;
}
lock.replace(stream); lock.replace(stream);
} }
Ok(lock) Ok(lock)
} }
} }
impl Drop for Tcp {
fn drop(&mut self) {
self.stream
.lock()
.take()
.map(|s| s.shutdown(net::Shutdown::Both));
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论