hai/src/quic/server.rs

48 lines
1.9 KiB
Rust
Raw Normal View History

2025-02-22 17:09:33 +01:00
use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc;
use quinn::{Endpoint, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use rustls::pki_types::pem::PemObject;
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
/// and port.
///
/// ## Returns
///
/// - a stream of incoming QUIC connections
/// - server certificate serialized into DER format
pub fn make_server_endpoint(
bind_addr: SocketAddr,
2025-02-22 17:10:55 +01:00
cert_file: Option<String>
2025-02-22 17:09:33 +01:00
) -> Result<(Endpoint, CertificateDer<'static>), Box<dyn Error + Send + Sync + 'static>> {
2025-02-22 17:10:55 +01:00
let (server_config, server_cert) = configure_server(cert_file)?;
2025-02-22 17:09:33 +01:00
let endpoint = Endpoint::server(server_config, bind_addr)?;
Ok((endpoint, server_cert))
}
/// Returns default server configuration along with its certificate.
fn configure_server(
cert_file: Option<String>
) -> Result<(ServerConfig, CertificateDer<'static>), Box<dyn Error + Send + Sync + 'static>> {
let cert_closure = |cert_file: Option<String>| {
return if (cert_file.is_some()) {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_der = CertificateDer::from(cert.cert);
(cert_der, PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()))
} else {
2025-02-22 17:10:55 +01:00
let mut certs: Vec<_> = CertificateDer::pem_file_iter(cert_file.clone().unwrap())
2025-02-22 17:09:33 +01:00
.unwrap()
.collect();
let cert = certs.pop().unwrap().unwrap();
(cert, PrivatePkcs8KeyDer::from_pem_file(cert_file.unwrap()).unwrap())
}
};
let (cert_der, priv_key) = cert_closure(cert_file);
let mut server_config =
ServerConfig::with_single_cert(vec![cert_der.clone()], priv_key.into())?;
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
transport_config.max_concurrent_uni_streams(0_u8.into());
Ok((server_config, cert_der))
}