Skip to content

Commit

Permalink
Work-around netlink buffer limit for setting peers, fixes #619 (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez authored Aug 4, 2024
1 parent 6a32595 commit 87a7c84
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 44 deletions.
36 changes: 18 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 10 additions & 7 deletions src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl Peer {
#[cfg(target_os = "linux")]
impl Peer {
#[must_use]
pub fn from_nlas(nlas: &[WgPeerAttrs]) -> Self {
pub(crate) fn from_nlas(nlas: &[WgPeerAttrs]) -> Self {
let mut peer = Self::default();

for nla in nlas {
Expand Down Expand Up @@ -136,15 +136,15 @@ impl Peer {
}

#[must_use]
pub fn as_nlas(&self, ifname: &str) -> Vec<WgDeviceAttrs> {
pub(crate) fn as_nlas(&self, ifname: &str) -> Vec<WgDeviceAttrs> {
vec![
WgDeviceAttrs::IfName(ifname.into()),
WgDeviceAttrs::Peers(vec![self.as_nlas_peer()]),
]
}

#[must_use]
pub fn as_nlas_peer(&self) -> WgPeer {
pub(crate) fn as_nlas_peer(&self) -> WgPeer {
let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.as_array())];
if let Some(keepalive) = self.persistent_keepalive_interval {
attrs.push(WgPeerAttrs::PersistentKeepalive(keepalive));
Expand Down Expand Up @@ -324,7 +324,7 @@ impl Host {

#[cfg(target_os = "linux")]
impl Host {
pub fn append_nlas(&mut self, nlas: &[WgDeviceAttrs]) {
pub(crate) fn append_nlas(&mut self, nlas: &[WgDeviceAttrs]) {
for nla in nlas {
match nla {
WgDeviceAttrs::PrivateKey(value) => self.private_key = Some(Key::new(*value)),
Expand All @@ -342,7 +342,7 @@ impl Host {
}

#[must_use]
pub fn as_nlas(&self, ifname: &str) -> Vec<WgDeviceAttrs> {
pub(crate) fn as_nlas(&self, ifname: &str) -> Vec<WgDeviceAttrs> {
let mut nlas = vec![
WgDeviceAttrs::IfName(ifname.into()),
WgDeviceAttrs::ListenPort(self.listen_port),
Expand All @@ -354,8 +354,11 @@ impl Host {
nlas.push(WgDeviceAttrs::Fwmark(*fwmark));
}
nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS));
let peers = self.peers.values().map(Peer::as_nlas_peer).collect();
nlas.push(WgDeviceAttrs::Peers(peers));

// IMPORTANT: To avoid buffer overflow, do not add peers here.
// let peers = self.peers.values().map(Peer::as_nlas_peer).collect();
// nlas.push(WgDeviceAttrs::Peers(peers));

nlas
}
}
Expand Down
64 changes: 45 additions & 19 deletions src/netlink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ pub(crate) enum NetlinkError {
UnexpectedPayload,
#[error("Failed to send netlink request")]
SendFailure,
#[error(
"Serialized netlink packet ({0} bytes) larger than maximum size {SOCKET_BUFFER_LENGTH}"
)]
InvalidPacketLength(usize),
#[error("Attribute value not found")]
AttributeNotFound,
#[error("Socket error: {0}")]
Expand Down Expand Up @@ -117,7 +113,7 @@ impl IpAddrMask {

impl IpVersion {
#[must_use]
fn address_family(&self) -> AddressFamily {
fn address_family(self) -> AddressFamily {
match self {
Self::IPv4 => AddressFamily::Inet,
Self::IPv6 => AddressFamily::Inet6,
Expand Down Expand Up @@ -170,19 +166,11 @@ where
{
let mut req = NetlinkMessage::from(message);

if req.buffer_len() > SOCKET_BUFFER_LENGTH {
error!(
"Serialized netlink packet ({} bytes) larger than maximum size {SOCKET_BUFFER_LENGTH}: {req:?}",
req.buffer_len(),
);
return Err(NetlinkError::InvalidPacketLength(req.buffer_len()));
}

req.header.flags = flags;
req.finalize();
let mut buf = [0; SOCKET_BUFFER_LENGTH];
req.serialize(&mut buf);
let len = req.buffer_len();
let mut buf = vec![0u8; len];
req.serialize(&mut buf);

let socket = Socket::new(protocol).map_err(|err| {
error!("Failed to open socket: {err}");
Expand All @@ -193,7 +181,7 @@ where
error!("Failed to connect to socket: {err}");
NetlinkError::SocketError(err.to_string())
})?;
let n_sent = socket.send(&buf[..len], 0).map_err(|err| {
let n_sent = socket.send(&buf, 0).map_err(|err| {
error!("Failed to send to socket: {err}");
NetlinkError::SocketError(err.to_string())
})?;
Expand All @@ -203,13 +191,14 @@ where

let mut responses = Vec::new();
loop {
let n_received = socket.recv(&mut &mut buf[..], 0).map_err(|err| {
let mut recv_buf = [0; SOCKET_BUFFER_LENGTH];
let n_received = socket.recv(&mut &mut recv_buf[..], 0).map_err(|err| {
error!("Failed to receive from socket: {err}");
NetlinkError::SocketError(err.to_string())
})?;
let mut offset = 0;
loop {
let response = NetlinkMessage::<I>::deserialize(&buf[offset..])?;
let response = NetlinkMessage::<I>::deserialize(&recv_buf[offset..])?;
debug!("Read netlink response from socket: {response:?}");
match response.payload {
// We've parsed all parts of the response and can leave the loop.
Expand Down Expand Up @@ -330,7 +319,7 @@ fn flush_addresses(index: u32) -> NetlinkResult<()> {
)?;
}
} else {
debug!("unknown nlmsg response")
debug!("unknown nlmsg response");
}
}

Expand Down Expand Up @@ -411,6 +400,11 @@ pub(crate) fn set_host(ifname: &str, host: &Host) -> NetlinkResult<()> {
nlas: host.as_nlas(ifname),
});
netlink_request_genl(genlmsg, NLM_F_REQUEST | NLM_F_ACK)?;
// Add peers one by one to avoid packet buffer overflow.
for peer in host.peers.values() {
set_peer(ifname, peer)?;
}

Ok(())
}

Expand Down Expand Up @@ -784,4 +778,36 @@ mod tests {

delete_interface(IF_NAME).unwrap();
}

#[ignore]
#[test]
fn docker_peers() {
use x25519_dalek::{EphemeralSecret, PublicKey};

const MAX_PEERS: usize = 1600;

let secret = EphemeralSecret::random();
let key = PublicKey::from(&secret);
// Peer secret key
let key: Key = key.as_ref().try_into().unwrap();
let mut host = Host::new(1234, key);

for _ in 0..MAX_PEERS {
let secret = EphemeralSecret::random();
let key = PublicKey::from(&secret);
let key: Key = key.as_ref().try_into().unwrap();
let peer = Peer::new(key.clone());
host.peers.insert(key, peer);
}

const IF_NAME: &str = "wg0";
create_interface(IF_NAME).unwrap();
set_host(IF_NAME, &host).unwrap();

let host = get_host(IF_NAME).unwrap();
assert_eq!(host.peers.len(), MAX_PEERS);

// With many peers, this takes a long time.
delete_interface(IF_NAME).unwrap();
}
}

0 comments on commit 87a7c84

Please sign in to comment.