diff --git a/rosrust/src/api/raii.rs b/rosrust/src/api/raii.rs index 1c2f0d1..abe6263 100644 --- a/rosrust/src/api/raii.rs +++ b/rosrust/src/api/raii.rs @@ -8,7 +8,7 @@ use crate::rosxmlrpc::Response; use crate::tcpros::{Message, PublisherStream, ServicePair, ServiceResult}; use crate::{RawMessageDescription, SubscriptionHandler}; use log::error; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicBool, AtomicUsize}; use std::sync::Arc; #[derive(Clone)] @@ -144,13 +144,15 @@ impl Subscriber { T: Message, H: SubscriptionHandler, { - let id = slave.add_subscription::(name, queue_size, handler)?; + let unsub_signal = Arc::new(AtomicBool::new(false)); + let id = slave.add_subscription::(name, queue_size, handler, unsub_signal.clone())?; let info = Arc::new(InteractorRaii::new(SubscriberInfo { master, slave, name: name.into(), id, + unsub_signal, })); let publishers = info @@ -194,10 +196,13 @@ struct SubscriberInfo { slave: Arc, name: String, id: usize, + unsub_signal: Arc, } impl Interactor for SubscriberInfo { fn unregister(&mut self) -> Response<()> { + self.unsub_signal.store(true, std::sync::atomic::Ordering::Relaxed); + self.slave.remove_subscription(&self.name, self.id); self.master.unregister_subscriber(&self.name).map(|_| ()) } diff --git a/rosrust/src/api/slave/mod.rs b/rosrust/src/api/slave/mod.rs index d99c3ea..982c97d 100644 --- a/rosrust/src/api/slave/mod.rs +++ b/rosrust/src/api/slave/mod.rs @@ -14,6 +14,7 @@ use error_chain::bail; use log::error; use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; pub struct Slave { @@ -156,13 +157,14 @@ impl Slave { topic: &str, queue_size: usize, handler: H, + unsub_signal: Arc, ) -> Result where T: Message, H: SubscriptionHandler, { self.subscriptions - .add(&self.name, topic, queue_size, handler) + .add(&self.name, topic, queue_size, handler, unsub_signal) } #[inline] diff --git a/rosrust/src/api/slave/subscriptions.rs b/rosrust/src/api/slave/subscriptions.rs index 5253213..27e47a4 100644 --- a/rosrust/src/api/slave/subscriptions.rs +++ b/rosrust/src/api/slave/subscriptions.rs @@ -7,6 +7,7 @@ use log::error; use std::collections::{BTreeSet, HashMap}; use std::iter::FromIterator; use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, Ordering}; #[derive(Clone, Default)] pub struct SubscriptionsTracker { @@ -51,7 +52,7 @@ impl SubscriptionsTracker { .collect() } - pub fn add(&self, name: &str, topic: &str, queue_size: usize, handler: H) -> Result + pub fn add(&self, name: &str, topic: &str, queue_size: usize, handler: H, unsub_signal: Arc) -> Result where T: Message, H: SubscriptionHandler, @@ -67,6 +68,7 @@ impl SubscriptionsTracker { msg_definition, msg_type.clone(), md5sum.clone(), + unsub_signal, ) }); let connection_topic = connection.get_topic(); diff --git a/rosrust/src/tcpros/subscriber.rs b/rosrust/src/tcpros/subscriber.rs index 5f3009e..e3f38dd 100644 --- a/rosrust/src/tcpros/subscriber.rs +++ b/rosrust/src/tcpros/subscriber.rs @@ -11,6 +11,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::sync::Arc; use std::thread; +use std::sync::atomic::{AtomicBool, Ordering}; enum DataStreamConnectionChange { Connect( @@ -37,6 +38,7 @@ impl SubscriberRosConnection { msg_definition: String, msg_type: String, md5sum: String, + unsub_signal: Arc, ) -> SubscriberRosConnection { let subscriber_connection_queue_size = 8; let (data_stream_tx, data_stream_rx) = bounded(subscriber_connection_queue_size); @@ -56,6 +58,7 @@ impl SubscriberRosConnection { &msg_definition, &md5sum, &msg_type, + unsub_signal, ) } }); @@ -203,6 +206,7 @@ fn join_connections( msg_definition: &str, md5sum: &str, msg_type: &str, + unsub_signal: Arc, ) { type Sub = (LossySender, Sender>); let mut subs: BTreeMap = BTreeMap::new(); @@ -255,6 +259,7 @@ fn join_connections( msg_definition, md5sum, msg_type, + unsub_signal.clone(), ) .chain_err(|| ErrorKind::TopicConnectionFail(topic.into())); match result { @@ -290,8 +295,11 @@ fn join_connection( msg_definition: &str, md5sum: &str, msg_type: &str, + unsub_signal: Arc, ) -> Result> { let mut stream = TcpStream::connect(publisher)?; + stream.set_read_timeout(Some(std::time::Duration::from_secs(10)))?; + let headers = exchange_headers::<_>( &mut stream, caller_id, @@ -302,15 +310,29 @@ fn join_connection( )?; let pub_caller_id = headers.get("callerid").cloned(); let target = data_stream.clone(); + thread::spawn(move || { let pub_caller_id = Arc::new(pub_caller_id.unwrap_or_default()); - while let Ok(buffer) = package_to_vector(&mut stream) { - if let Err(TrySendError::Disconnected(_)) = - target.try_send(MessageInfo::new(Arc::clone(&pub_caller_id), buffer)) - { - // Data receiver has been destroyed after - // Subscriber destructor's kill signal - break; + loop { + match package_to_vector(&mut stream) { + Ok(buffer) => { + if let Err(TrySendError::Disconnected(_)) = target.try_send(MessageInfo::new(Arc::clone(&pub_caller_id), buffer)) + { + // Data receiver has been destroyed after + // Subscriber destructor's kill signal + break; + } + } + + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + if unsub_signal.load(Ordering::Relaxed) { + // SubscriberInfo has been dropped, so break out of here to close the + // socket and exit the thread + break; + } + } + + Err(_) => break, } } });