use std::net::Ipv4Addr;
use std::{net::SocketAddr, str::from_utf8, sync::Arc, sync::Once};
use tokio::sync::{mpsc, oneshot};
use tracing_subscriber::EnvFilter;
use crate::{
config::Config,
filters::{prelude::*, FilterRegistry},
net::endpoint::metadata::Value,
net::endpoint::{Endpoint, EndpointAddress},
net::DualStackEpollSocket as DualStackLocalSocket,
pool::BufferPool,
ShutdownKind, ShutdownRx, ShutdownTx,
};
static LOG_ONCE: Once = Once::new();
pub fn enable_log(filter: impl Into<EnvFilter>) {
LOG_ONCE.call_once(|| {
tracing_subscriber::fmt()
.pretty()
.with_env_filter(filter)
.init()
});
}
#[derive(Copy, Clone)]
pub enum AddressType {
Random,
Ipv4,
Ipv6,
}
impl From<AddressType> for SocketAddr {
fn from(value: AddressType) -> Self {
match value {
AddressType::Random => {
if rand::random() {
(std::net::Ipv6Addr::LOCALHOST, 0).into()
} else {
(std::net::Ipv4Addr::LOCALHOST, 0).into()
}
}
AddressType::Ipv4 => (std::net::Ipv4Addr::LOCALHOST, 0).into(),
AddressType::Ipv6 => (std::net::Ipv6Addr::LOCALHOST, 0).into(),
}
}
}
pub async fn available_addr(address_type: AddressType) -> SocketAddr {
let socket = create_socket().await;
let addr = get_address(address_type, &socket);
tracing::debug!(addr = ?addr, "test_util::available_addr");
addr
}
fn get_address(address_type: AddressType, socket: &DualStackLocalSocket) -> SocketAddr {
let addr = match address_type {
AddressType::Random => {
match rand::random() {
true => socket.local_ipv6_addr().unwrap(),
false => socket.local_ipv4_addr().unwrap(),
}
}
AddressType::Ipv4 => socket.local_ipv4_addr().unwrap(),
AddressType::Ipv6 => socket.local_ipv6_addr().unwrap(),
};
tracing::debug!(addr = ?addr, "test_util::get_address");
addr
}
pub struct TestFilter;
#[async_trait::async_trait]
impl Filter for TestFilter {
async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
ctx.metadata
.entry("downstream".into())
.and_modify(|e| e.as_mut_string().unwrap().push_str(":receive"))
.or_insert_with(|| Value::String("receive".into()));
ctx.contents
.extend_from_slice(format!(":odr:{}", ctx.source).as_bytes());
Ok(())
}
async fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
ctx.metadata
.entry("upstream".into())
.and_modify(|e| e.as_mut_string().unwrap().push_str(":receive"))
.or_insert_with(|| Value::String("receive".to_string()));
ctx.contents
.extend_from_slice(format!(":our:{}:{}", ctx.source, ctx.dest).as_bytes());
Ok(())
}
}
impl StaticFilter for TestFilter {
const NAME: &'static str = "TestFilter";
type Configuration = ();
type BinaryConfiguration = ();
fn try_from_config(_: Option<Self::Configuration>) -> Result<Self, CreationError> {
Ok(Self)
}
}
#[derive(Default)]
pub struct TestHelper {
shutdown_ch: Option<(ShutdownTx, ShutdownRx)>,
server_shutdown_tx: Vec<Option<ShutdownTx>>,
}
pub struct OpenSocketRecvPacket {
pub socket: Arc<DualStackLocalSocket>,
pub packet_rx: oneshot::Receiver<String>,
}
impl Drop for TestHelper {
fn drop(&mut self) {
for shutdown_tx in self.server_shutdown_tx.iter_mut().flat_map(|tx| tx.take()) {
shutdown_tx
.send(ShutdownKind::Testing)
.map_err(|error| {
tracing::warn!(
%error,
"Failed to send server shutdown over channel"
)
})
.ok();
}
if let Some((shutdown_tx, _)) = self.shutdown_ch.take() {
shutdown_tx.send(ShutdownKind::Testing).unwrap();
}
}
}
impl TestHelper {
pub async fn open_socket_and_recv_single_packet(&self) -> OpenSocketRecvPacket {
let socket = Arc::new(create_socket().await);
let (packet_tx, packet_rx) = oneshot::channel::<String>();
let socket_recv = socket.clone();
tokio::spawn(async move {
let mut buf = vec![0; 1024];
let (size, _) = socket_recv.recv_from(&mut buf).await.unwrap();
packet_tx
.send(from_utf8(&buf[..size]).unwrap().to_string())
.unwrap();
});
OpenSocketRecvPacket { socket, packet_rx }
}
pub async fn open_socket_and_recv_multiple_packets(
&mut self,
) -> (mpsc::Receiver<String>, Arc<DualStackLocalSocket>) {
let socket = Arc::new(create_socket().await);
let packet_rx = self.recv_multiple_packets(&socket).await;
(packet_rx, socket)
}
pub async fn open_ipv4_socket_and_recv_multiple_packets(
&mut self,
) -> (mpsc::Receiver<String>, Arc<DualStackLocalSocket>) {
let socket = Arc::new(
DualStackLocalSocket::new_with_address((Ipv4Addr::LOCALHOST, 0).into()).unwrap(),
);
let packet_rx = self.recv_multiple_packets(&socket).await;
(packet_rx, socket)
}
async fn recv_multiple_packets(
&mut self,
socket: &Arc<DualStackLocalSocket>,
) -> mpsc::Receiver<String> {
let (packet_tx, packet_rx) = mpsc::channel::<String>(10);
let mut shutdown_rx = self.get_shutdown_subscriber().await;
let socket_recv = socket.clone();
tokio::spawn(async move {
let mut buf = vec![0; 1024];
loop {
tokio::select! {
received = socket_recv.recv_from(&mut buf) => {
let (size, _) = received.unwrap();
let str = from_utf8(&buf[..size]).unwrap().to_string();
match packet_tx.send(str).await {
Ok(_) => {}
Err(error) => {
tracing::warn!(target: "recv_multiple_packets", %error, "recv_chan dropped");
return;
}
};
},
_ = shutdown_rx.changed() => {
return;
}
}
}
});
packet_rx
}
pub async fn run_echo_server(&mut self, address_type: AddressType) -> EndpointAddress {
self.run_echo_server_with_tap(address_type, |_, _, _| {})
.await
}
pub async fn run_echo_server_with_tap<F>(
&mut self,
address_type: AddressType,
tap: F,
) -> EndpointAddress
where
F: Fn(SocketAddr, &[u8], SocketAddr) + Send + 'static,
{
let socket = create_socket().await;
let mut addr = get_address(address_type, &socket);
crate::test::map_addr_to_localhost(&mut addr);
let mut shutdown = self.get_shutdown_subscriber().await;
let local_addr = addr;
tokio::spawn(async move {
loop {
let mut buf = vec![0; 1024];
tokio::select! {
recvd = socket.recv_from(&mut buf) => {
let (size, sender) = recvd.unwrap();
let packet = &buf[..size];
tracing::trace!(%sender, %size, "echo server received and returning packet");
tap(sender, packet, local_addr);
socket.send_to(packet, sender).await.unwrap();
},
_ = shutdown.changed() => {
return;
}
}
}
});
addr.into()
}
pub async fn run_server(
&mut self,
config: Arc<Config>,
server: Option<crate::components::proxy::Proxy>,
with_admin: Option<Option<SocketAddr>>,
) -> u16 {
let (shutdown_tx, shutdown_rx) = crate::make_shutdown_channel(crate::ShutdownKind::Testing);
self.server_shutdown_tx.push(Some(shutdown_tx));
let mode = crate::components::admin::Admin::Proxy(<_>::default());
if let Some(address) = with_admin {
mode.server(config.clone(), address);
}
let server = server.unwrap_or_else(|| {
let qcmp = crate::net::raw_socket_with_reuse(0).unwrap();
let phoenix = crate::net::TcpListener::bind(None).unwrap();
crate::components::proxy::Proxy {
num_workers: std::num::NonZeroUsize::new(1).unwrap(),
mmdb: None,
management_servers: Vec::new(),
to: Vec::new(),
socket: crate::net::raw_socket_with_reuse(0).unwrap(),
qcmp,
phoenix,
}
});
let (prox_tx, prox_rx) = tokio::sync::oneshot::channel();
let port = crate::net::socket_port(&server.socket);
tokio::spawn(async move {
server
.run(
crate::components::RunArgs {
config,
ready: Default::default(),
shutdown_rx,
},
Some(prox_tx),
)
.await
.unwrap();
});
prox_rx.await.unwrap();
port
}
async fn get_shutdown_subscriber(&mut self) -> ShutdownRx {
match self.shutdown_ch {
Some((_, ref rx)) => rx.clone(),
None => {
let ch = crate::make_shutdown_channel(crate::ShutdownKind::Testing);
let recv = ch.1.clone();
self.shutdown_ch = Some(ch);
recv
}
}
}
}
pub static BUFFER_POOL: once_cell::sync::Lazy<Arc<BufferPool>> =
once_cell::sync::Lazy::new(|| Arc::new(BufferPool::default()));
#[inline]
pub fn alloc_buffer(data: impl AsRef<[u8]>) -> crate::pool::PoolBuffer {
BUFFER_POOL.clone().alloc_slice(data.as_ref())
}
#[cfg(test)]
pub async fn assert_filter_read_no_change<F>(filter: &F)
where
F: Filter,
{
let endpoints = std::sync::Arc::new(crate::net::cluster::ClusterMap::default());
endpoints.insert_default(std::collections::BTreeSet::from(["127.0.0.1:80"
.parse::<Endpoint>()
.unwrap()]));
let source = "127.0.0.1:90".parse().unwrap();
let contents = b"hello";
let mut context = ReadContext::new(endpoints.clone(), source, alloc_buffer(contents));
filter.read(&mut context).await.unwrap();
assert!(context.destinations.is_empty());
assert_eq!(endpoints, context.endpoints);
assert_eq!(contents, &*context.contents);
}
pub async fn assert_write_no_change<F>(filter: &F)
where
F: Filter,
{
let endpoint = "127.0.0.1:90".parse::<Endpoint>().unwrap();
let contents = b"hello";
let mut context = WriteContext::new(
endpoint.address,
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(contents),
);
filter.write(&mut context).await.unwrap();
assert_eq!(contents, &*context.contents);
}
pub async fn map_to_localhost(address: &mut EndpointAddress) {
let mut socket_addr = address.to_socket_addr().await.unwrap();
map_addr_to_localhost(&mut socket_addr);
*address = socket_addr.into();
}
pub fn map_addr_to_localhost(address: &mut SocketAddr) {
match address {
SocketAddr::V4(addr) => addr.set_ip(std::net::Ipv4Addr::LOCALHOST),
SocketAddr::V6(addr) => addr.set_ip(std::net::Ipv6Addr::LOCALHOST),
}
}
pub async fn create_socket() -> DualStackLocalSocket {
DualStackLocalSocket::new(0).unwrap()
}
fn test_proxy_id() -> String {
"test-proxy-id".to_owned()
}
#[derive(serde::Serialize, serde::Deserialize, Default)]
pub struct TestConfig {
#[serde(default)]
pub clusters: crate::net::ClusterMap,
#[serde(default)]
pub filters: crate::filters::FilterChain,
#[serde(default = "test_proxy_id")]
pub id: String,
#[serde(default)]
pub version: crate::config::Version,
}
impl TestConfig {
pub fn with_dummy_endpoint() -> Self {
let config = Self::default();
config.clusters.insert(
None,
[Endpoint::new((std::net::Ipv4Addr::LOCALHOST, 8080).into())].into(),
);
config
}
#[track_caller]
pub fn new() -> Self {
load_test_filters();
Self {
filters: crate::filters::FilterChain::try_create(std::iter::once(
crate::config::Filter {
name: "TestFilter".into(),
label: None,
config: None,
},
))
.unwrap(),
..Default::default()
}
}
pub fn write_to_file(&self, path: impl AsRef<std::path::Path>) {
std::fs::write(
path,
serde_yaml::to_string(self).expect("failed to serialize TestConfig"),
)
.expect("failed to write TestConfig to path");
}
}
pub fn ep(id: u8) -> Endpoint {
Endpoint {
address: ([127, 0, 0, id], 8080).into(),
..<_>::default()
}
}
pub fn load_test_filters() {
FilterRegistry::register([TestFilter::factory()]);
}
#[macro_export]
macro_rules! __func_name {
() => {{
fn f() {}
fn type_name_of<T>(_: T) -> &'static str {
std::any::type_name::<T>()
}
let name = type_name_of(f);
&name[..name.len() - 3]
}};
}
#[macro_export]
macro_rules! temp_file {
($prefix:expr) => {{
let name = $crate::__func_name!();
let name = name.strip_suffix("::{{closure}}").unwrap_or(name);
let mut name = name.replace("::", ".");
name.push('-');
name.push_str($prefix);
name.push('-');
tempfile::NamedTempFile::with_prefix(name).unwrap()
}};
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tokio::time::timeout;
use crate::test::{AddressType, TestHelper};
#[tokio::test]
async fn test_echo_server() {
let mut t = TestHelper::default();
let echo_addr = t.run_echo_server(AddressType::Random).await;
let endpoint = t.open_socket_and_recv_single_packet().await;
let msg = "hello";
endpoint
.socket
.send_to(msg.as_bytes(), &echo_addr.to_socket_addr().await.unwrap())
.await
.unwrap();
assert_eq!(
msg,
timeout(Duration::from_secs(5), endpoint.packet_rx)
.await
.expect("should not timeout")
.unwrap()
);
}
}