Skip to content

Commit

Permalink
[RSDK-6857] support multiple connections (viamrobotics#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
npmenard committed Mar 12, 2024
1 parent d57205d commit 8392f84
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 31 deletions.
4 changes: 3 additions & 1 deletion examples/esp32-with-cred/esp32-server-with-cred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ mod esp32 {
(ip, eth)
};

let mut max_connection = 3;
unsafe {
if !g_spiram_ok {
log::info!("spiram not initialized disabling cache feature of the wifi driver");
g_wifi_feature_caps &= !(CONFIG_FEATURE_CACHE_TX_BUF_BIT as u64);
max_connection = 1;
}
}

Expand Down Expand Up @@ -213,7 +215,7 @@ mod esp32 {

let cfg = AppClientConfig::new(nvs_vars.robot_secret, nvs_vars.robot_id, ip, "".to_owned());

serve_web(cfg, tls_cfg, repr, ip, webrtc_certificate);
serve_web(cfg, tls_cfg, repr, ip, webrtc_certificate, max_connection);
}

#[cfg(feature = "qemu")]
Expand Down
4 changes: 3 additions & 1 deletion examples/esp32/esp32-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ mod esp32 {
(ip, eth)
};

let mut max_connection = 3;
unsafe {
if !g_spiram_ok {
log::info!("spiram not initialized disabling cache feature of the wifi driver");
g_wifi_feature_caps &= !(CONFIG_FEATURE_CACHE_TX_BUF_BIT as u64);
max_connection = 1;
}
}
#[allow(clippy::redundant_clone)]
Expand Down Expand Up @@ -108,7 +110,7 @@ mod esp32 {
Esp32TlsServerConfig::new(cert, key.as_ptr(), key.len() as u32)
};

serve_web(cfg, tls_cfg, repr, ip, webrtc_certificate);
serve_web(cfg, tls_cfg, repr, ip, webrtc_certificate, max_connection);
}

#[cfg(feature = "qemu")]
Expand Down
105 changes: 84 additions & 21 deletions micro-rdk/src/common/conn/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub struct ViamServerBuilder<'a, M, C, T, CC = WebRtcNoOp, D = WebRtcNoOp, L = N
exec: Executor<'a>,
app_connector: C,
app_config: AppClientConfig,
max_connections: usize,
}

impl<'a, M, C, T> ViamServerBuilder<'a, M, C, T>
Expand All @@ -105,7 +106,13 @@ where
C: TlsClientConnector,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
pub fn new(mdns: M, exec: Executor<'a>, app_connector: C, app_config: AppClientConfig) -> Self {
pub fn new(
mdns: M,
exec: Executor<'a>,
app_connector: C,
app_config: AppClientConfig,
max_connections: usize,
) -> Self {
Self {
mdns,
http2_listener: NoHttp2 {},
Expand All @@ -115,6 +122,7 @@ where
exec,
app_connector,
app_config,
max_connections,
}
}
}
Expand Down Expand Up @@ -144,6 +152,7 @@ where
webrtc: self.webrtc,
app_connector: self.app_connector,
app_config: self.app_config,
max_connections: self.max_connections,
}
}
pub fn with_webrtc<D2, CC2>(
Expand All @@ -159,6 +168,7 @@ where
exec: self.exec,
app_connector: self.app_connector,
app_config: self.app_config,
max_connections: self.max_connections,
}
}
pub fn build(
Expand Down Expand Up @@ -207,6 +217,7 @@ where
cloned_exec,
self.app_connector,
self.app_config,
self.max_connections,
);

Ok(srv)
Expand Down Expand Up @@ -266,7 +277,7 @@ pub struct ViamServer<'a, C, T, CC, D, L> {
app_connector: C,
app_config: AppClientConfig,
app_client: Option<AppClient<'a>>,
webtrc_conn: Option<Task<Result<(), ServerError>>>,
webrtc_manager: WebRTCConnectionManager,
}
impl<'a, C, T, CC, D, L> ViamServer<'a, C, T, CC, D, L>
where
Expand All @@ -284,6 +295,7 @@ where
exec: Executor<'a>,
app_connector: C,
app_config: AppClientConfig,
max_concurent_connections: usize,
) -> Self {
Self {
http_listener,
Expand All @@ -292,12 +304,11 @@ where
app_connector,
app_config,
app_client: None,
webtrc_conn: None,
webrtc_manager: WebRTCConnectionManager::new(max_concurent_connections),
}
}
pub async fn serve(&mut self, robot: Arc<Mutex<LocalRobot>>) {
let cloned_robot = robot.clone();
let mut current_prio = None;
loop {
let _ = smol::Timer::after(std::time::Duration::from_millis(300)).await;

Expand Down Expand Up @@ -347,32 +358,19 @@ where
async {
let mut api = sig.await?;

let prio = self
.webtrc_conn
.as_ref()
.and_then(|f| (!f.is_finished()).then_some(&current_prio))
.unwrap_or(&None);
let prio = self.webrtc_manager.get_lowest_prio();

let sdp = api
.answer(prio)
.await
.map_err(ServerError::ServerWebRTCError)?;

// When the current priority is lower than the priority of the incoming connection then
// we cancel and close the current webrtc connection (if any)
if let Some(task) = self.webtrc_conn.take() {
if !task.is_finished() {
let _ = task.cancel().await;
}
}

let _ = current_prio.insert(sdp.1);

Ok(IncomingConnection::WebRtcConnection(WebRTCConnection {
webrtc_api: api,
sdp: sdp.0,
server: None,
robot: cloned_robot.clone(),
prio: sdp.1,
}))
},
);
Expand All @@ -384,7 +382,7 @@ where
let connection = match connection {
Ok(c) => c,
Err(ServerError::ServerWebRTCError(_)) => {
// all webrtc errors are arising from failing to connect and doesn't require a tls renegotiation
// all webrtc errors are arising from failing to connect and don't require a tls renegotiation
continue;
}
Err(_) => {
Expand All @@ -400,8 +398,9 @@ where
IncomingConnection::WebRtcConnection(mut c) => match c.open_data_channel().await {
Err(e) => Err(e),
Ok(_) => {
let prio = c.prio;
let t = self.exec.spawn(async move { c.run().await });
let _task = self.webtrc_conn.insert(t);
self.webrtc_manager.insert_new_conn(t, prio).await;
Ok(())
}
},
Expand Down Expand Up @@ -462,6 +461,7 @@ struct WebRTCConnection<C, D, E> {
sdp: Box<WebRtcSdp>,
server: Option<WebRtcGrpcServer<GrpcServer<WebRtcGrpcBody>>>,
robot: Arc<Mutex<LocalRobot>>,
prio: u32,
}
impl<C, D, E> WebRTCConnection<C, D, E>
where
Expand Down Expand Up @@ -561,3 +561,66 @@ where
)))
}
}

#[derive(Default)]
struct WebRTCTask {
task: Option<Task<Result<(), ServerError>>>,
prio: Option<u32>,
}

impl WebRTCTask {
fn replace(&mut self, task: Task<Result<(), ServerError>>, prio: u32) {
let _ = self.task.replace(task);
let _ = self.prio.replace(prio);
}
fn is_finished(&self) -> bool {
if let Some(task) = self.task.as_ref() {
return task.is_finished();
}
true
}
async fn cancel(&mut self) -> Option<ServerError> {
if let Some(task) = self.task.take() {
return task.cancel().await?.err();
}
None
}
fn get_prio(&self) -> u32 {
if !self.is_finished() {
return *self.prio.as_ref().unwrap_or(&0);
}
0
}
}

struct WebRTCConnectionManager {
connections: Vec<WebRTCTask>,
}

impl WebRTCConnectionManager {
fn new(size: usize) -> Self {
let mut connections = Vec::with_capacity(size);
connections.resize_with(size, Default::default);
Self { connections }
}
// return the lowest priority of active webrtc tasks or 0
fn get_lowest_prio(&self) -> u32 {
self.connections
.iter()
.min_by(|a, b| a.get_prio().cmp(&b.get_prio()))
.map_or(0, |c| c.get_prio())
}
// function will never fail and the lowest priority will always be replaced
async fn insert_new_conn(&mut self, task: Task<Result<(), ServerError>>, prio: u32) {
if let Some(slot) = self
.connections
.iter_mut()
.min_by(|a, b| a.get_prio().cmp(&b.get_prio()))
{
if let Some(last_error) = slot.cancel().await {
log::info!("last_error {:?}", last_error);
}
slot.replace(task, prio);
}
}
}
4 changes: 1 addition & 3 deletions micro-rdk/src/common/webrtc/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ where

pub async fn answer(
&mut self,
current_prio: &Option<u32>,
current_prio: u32,
) -> Result<(Box<WebRtcSdp>, u32), WebRtcError> {
let offer = self
.signaling
Expand All @@ -502,8 +502,6 @@ where
.map_or(Ok(u32::MAX), |a| a.parse::<u32>())
.unwrap_or(u32::MAX);

let current_prio = current_prio.unwrap_or(0);

// TODO use is_some_then when rust min version reach 1.70
if current_prio >= caller_prio {
return Err(WebRtcError::CurrentConnectionHigherPrority());
Expand Down
17 changes: 13 additions & 4 deletions micro-rdk/src/esp32/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub async fn serve_web_inner(
_ip: Ipv4Addr,
webrtc_certificate: WebRtcCertificate,
exec: Esp32Executor<'_>,
max_webrtc_connection: usize,
) {
let (mut srv, robot) = {
let mut client_connector = Esp32Tls::new_client();
Expand Down Expand Up @@ -131,10 +132,16 @@ pub async fn serve_web_inner(

(
Box::new(
ViamServerBuilder::new(mdns, cloned_exec, client_connector, app_config)
.with_webrtc(webrtc)
.build(&cfg_response)
.unwrap(),
ViamServerBuilder::new(
mdns,
cloned_exec,
client_connector,
app_config,
max_webrtc_connection,
)
.with_webrtc(webrtc)
.build(&cfg_response)
.unwrap(),
),
robot,
)
Expand All @@ -155,6 +162,7 @@ pub fn serve_web(
repr: RobotRepresentation,
_ip: Ipv4Addr,
webrtc_certificate: WebRtcCertificate,
max_webrtc_connection: usize,
) {
// set the TWDT to expire after 5 minutes
crate::esp32::esp_idf_svc::sys::esp!(unsafe {
Expand All @@ -180,6 +188,7 @@ pub fn serve_web(
_ip,
webrtc_certificate,
exec,
max_webrtc_connection,
)));
futures_lite::pin!(fut);

Expand Down
2 changes: 1 addition & 1 deletion micro-rdk/src/native/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub async fn serve_web_inner(
exec.clone(),
));

let mut srv = ViamServerBuilder::new(mdns, cloned_exec, client_connector, app_config)
let mut srv = ViamServerBuilder::new(mdns, cloned_exec, client_connector, app_config, 3)
.with_http2(tls_listener, 12346)
.with_webrtc(webrtc)
.build(&cfg_response)
Expand Down

0 comments on commit 8392f84

Please sign in to comment.