riot_wrappers/socket_embedded_nal_tcp.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
//! An implementation of the [embedded_nal] (Network Abstradtion Layer) TCP traits based on RIOT
//! sockets
//!
//! This is vastly distinct from [the UDP version](crate::socket_embedded_nal) as it requires
//! vastly different workarounds (and because it was implemented when embedded-nal had already
//! switched over to &mut stack).
//!
//! ## Warning
//!
//! The implementation of TcpExactStack is highly naïve, and may panic already with well-behaved
//! peers, let alone an adversarial one.
use core::convert::TryInto;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::pin::Pin;
use crate::error::{NegativeErrorExt, NumericError};
use embedded_nal::{SocketAddr, TcpClientStack, TcpFullStack};
/// A view on the RIOT socket stack that is prepared for a single listening socket that can accept
/// QUEUELEN connections simultaneously.
///
/// Note that unless CONFIG_GNRC_TCP_RCV_BUFFERS is overridden, QUEUELEN is limited to 1, anything
/// more makes it fail at setup time.
///
/// To use it as an implementation of TcpFullStack, it needs to be pinned, eg. by
/// `pin_utils::pin_mut!(stack)`, and later passed as mutable reference to the pinned item.
///
/// Note that while it would be perfectly feasible to count the number of open connection and allow
/// this to be dropped when all connections are closed, this will only be implemented once there is
/// any case that needs it (as most RIOT servers are up indefinitely).
pub struct ListenStack<const QUEUELEN: usize> {
// This should be type state, but embedded-nal does not allow that.
stage: ListenStage,
listener: riot_sys::sock_tcp_queue_t,
connections: [riot_sys::sock_tcp_t; QUEUELEN],
// because by passing listener and connections to the socket API, we promise not to move them
// any more
_unpin: core::marker::PhantomPinned,
}
#[derive(PartialEq)]
enum ListenStage {
/// No socket was populated
New,
/// The listener was handed out
Bound,
}
impl<const QUEUELEN: usize> ListenStack<QUEUELEN> {
// unsafe: We never promise not to move *that* one.
pin_utils::unsafe_unpinned!(stage: ListenStage);
}
impl<const QUEUELEN: usize> Default for ListenStack<QUEUELEN> {
fn default() -> Self {
ListenStack {
stage: ListenStage::New,
// As this is usually one-time cost, doing the additional code dance to make this
// uninit isn't worth it right now.
listener: Default::default(),
connections: [Default::default(); QUEUELEN],
_unpin: Default::default(),
}
}
}
impl<const QUEUELEN: usize> Drop for ListenStack<QUEUELEN> {
fn drop(&mut self) {
unimplemented!("Sorry, I didn't keep track of which connections are all on")
}
}
/// Socket for a given pool.
///
/// The lifetime is used as branding to ensure sockets are always used with their respective
/// stacks.
#[derive(Debug)]
pub struct Socket<'a> {
socket: SocketImpl,
_phantom: PhantomData<&'a ()>, // I'd even say &'a ListenStack if that didn't take a queuelen
}
#[derive(Debug, PartialEq)]
enum SocketImpl {
// By the time socket() is called, we don't know yet what it'll be
Unspecified,
Listener,
// Assuming no more than 256 connections. Could really be u7 to allow the others to take
// niches.
Connection(u8),
// No Closed state is defined as the close operation consumes the Rust wrapper around the
// socket anyway.
}
impl<'a, const QUEUELEN: usize> TcpClientStack for Pin<&'a mut ListenStack<QUEUELEN>> {
type TcpSocket = Socket<'a>;
type Error = NumericError;
fn socket(&mut self) -> Result<Self::TcpSocket, Self::Error> {
// Not knowing what it will be, we can't check anything yet
Ok(Socket {
socket: SocketImpl::Unspecified,
_phantom: PhantomData,
})
}
fn connect(
&mut self,
_sock: &mut Self::TcpSocket,
_addr: SocketAddr,
) -> Result<(), nb::Error<Self::Error>> {
panic!("A ListenStack can not connect out.")
}
fn is_connected(&mut self, sock: &Self::TcpSocket) -> Result<bool, Self::Error> {
// FIXME: Check whether that's what is meant (or whether more checks should be done through
// RIOT)
Ok(match sock.socket {
SocketImpl::Connection(_n) => true,
_ => false,
})
}
fn send(
&mut self,
sock: &mut Self::TcpSocket,
buf: &[u8],
) -> Result<usize, nb::Error<Self::Error>> {
let index = match sock.socket {
SocketImpl::Connection(n) => usize::from(n),
_ => panic!("Send on unconnected socket"),
};
unsafe {
riot_sys::sock_tcp_write(
&mut self.as_mut().get_unchecked_mut().connections[index],
buf.as_ptr() as *const _,
buf.len().try_into().unwrap_or(u32::MAX),
)
}
.negative_to_error()
.map_err(|e| e.again_is_wouldblock())
.map(|n| n as _)
}
fn receive(
&mut self,
sock: &mut Self::TcpSocket,
buf: &mut [u8],
) -> Result<usize, nb::Error<Self::Error>> {
let index = match sock.socket {
SocketImpl::Connection(n) => usize::from(n),
_ => panic!("Receive on unconnected socket"),
};
unsafe {
riot_sys::sock_tcp_read(
&mut self.as_mut().get_unchecked_mut().connections[index],
buf.as_ptr() as *mut _,
buf.len().try_into().unwrap_or(u32::MAX),
0,
)
}
.negative_to_error()
.map_err(|e| e.again_is_wouldblock())
.map(|n| n as _)
}
fn close(&mut self, sock: Self::TcpSocket) -> Result<(), Self::Error> {
let index = match sock.socket {
SocketImpl::Connection(n) => usize::from(n),
_ => panic!("Receive on unconnected socket"),
};
unsafe {
riot_sys::sock_tcp_disconnect(&mut self.as_mut().get_unchecked_mut().connections[index])
};
Ok(())
}
}
impl<'a, const QUEUELEN: usize> TcpFullStack for Pin<&'a mut ListenStack<QUEUELEN>> {
fn bind(&mut self, sock: &mut Self::TcpSocket, port: u16) -> Result<(), Self::Error> {
assert!(
self.stage == ListenStage::New,
"Stack already has its listening socket bound"
);
*self.as_mut().stage() = ListenStage::Bound;
assert!(
sock.socket == SocketImpl::Unspecified,
"Attempted to bind running socket"
);
sock.socket = SocketImpl::Listener;
// Reusing UdpEp because TcpEp is probably (FIXME) all the same.
let local = crate::socket::UdpEp::ipv6_any().with_port(port);
unsafe {
riot_sys::sock_tcp_listen(
&mut self.as_mut().get_unchecked_mut().listener,
local.as_ref(),
self.as_mut().get_unchecked_mut().connections.as_mut_ptr(),
self.connections
.len()
.try_into()
.expect("Size exceeds expressible size"),
0,
)
}
.negative_to_error()?;
Ok(())
}
fn listen(&mut self, _sock: &mut Self::TcpSocket) -> Result<(), Self::Error> {
// Done already in bind
Ok(())
}
fn accept(
&mut self,
// This is and can actually be ignored, because our stack object only serves a single
// listening socket.
_sock: &mut Self::TcpSocket,
) -> Result<(Self::TcpSocket, SocketAddr), nb::Error<Self::Error>> {
let mut sockptr = core::ptr::null_mut();
unsafe {
riot_sys::sock_tcp_accept(
&mut self.as_mut().get_unchecked_mut().listener,
&mut sockptr,
0, // return immediately / nonblocking
)
}
.negative_to_error()
.map_err(|e| e.again_is_wouldblock())?;
// unsafe: That's what sock_tcp_accept implicitly (FIXME) promises
let index = unsafe { sockptr.offset_from(self.connections.as_ptr()) };
let remote = unsafe {
let mut remote: MaybeUninit<riot_sys::sock_tcp_ep_t> = MaybeUninit::uninit();
riot_sys::sock_tcp_get_remote(sockptr, remote.as_mut_ptr());
remote.assume_init()
};
Ok((
Socket {
socket: SocketImpl::Connection(index.try_into().expect("Excessive pool")),
_phantom: PhantomData,
},
crate::socket::UdpEp(remote).into(),
))
}
}
impl<'a, const QUEUELEN: usize> embedded_nal_tcpextensions::TcpExactStack
for Pin<&'a mut ListenStack<QUEUELEN>>
{
const RECVBUFLEN: usize = 1152;
const SENDBUFLEN: usize = 1152;
fn receive_exact(
&mut self,
sock: &mut Self::TcpSocket,
buf: &mut [u8],
) -> Result<(), nb::Error<Self::Error>> {
let ret = self.receive(sock, buf)?;
if ret == 0 {
// Could mean timeout *or* connection closed, but with timeout 0 it's always
// connection closed.
//
// FIXME is returning an error right here?
return Err(nb::Error::Other(NumericError::from_constant(
riot_sys::ENOTCONN as _,
)));
}
assert!(
ret == buf.len(),
"Well that's a bad TcpExactStack, only got {} of {}",
ret,
buf.len()
);
Ok(())
}
fn send_all(
&mut self,
sock: &mut Self::TcpSocket,
buf: &[u8],
) -> Result<(), nb::Error<Self::Error>> {
assert!(
self.send(sock, buf)? == buf.len(),
"Well that's a bad TcpExactStack"
);
Ok(())
}
}