use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional}; use tokio::net::TcpStream; use tokio::sync::RwLock; use tracing::{error, info, warn}; #[derive(Debug, Clone)] pub struct TcpProxyManager { connections: Arc>>, #[allow(dead_code)] last_cleanup: Arc>, } #[derive(Debug, Clone)] pub struct TcpConnection { pub target: String, pub created_at: Instant, pub request_count: u64, pub bytes_transferred: u64, } #[derive(Debug, Clone)] pub enum ProxyProtocol { Tcp, WebSocket, AutoDetect, } #[derive(Debug, Clone)] pub struct ProxyError { pub message: String, } impl std::fmt::Display for ProxyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.message) } } impl std::error::Error for ProxyError {} impl TcpProxyManager { pub fn new() -> Self { Self { connections: Arc::new(RwLock::new(HashMap::new())), last_cleanup: Arc::new(tokio::sync::Mutex::new(Instant::now())), } } pub async fn handle_tcp_proxy( &self, mut client_stream: TcpStream, target: &str, protocol: ProxyProtocol, ) -> Result<(), Box> { let connection_id = format!( "{}:{}->{}", client_stream.local_addr()?.ip(), client_stream.local_addr()?.port(), target ); info!( "Starting TCP proxy connection {} to {} with protocol {:?}", connection_id, target, protocol ); let actual_protocol = if matches!(protocol, ProxyProtocol::AutoDetect) { self.detect_protocol(&mut client_stream).await? } else { protocol }; match actual_protocol { ProxyProtocol::Tcp => { self.handle_raw_tcp(&mut client_stream, target, &connection_id) .await? } ProxyProtocol::WebSocket => { self.handle_websocket(&mut client_stream, target, &connection_id) .await? } ProxyProtocol::AutoDetect => { warn!("Auto-detect should have been resolved to a specific protocol"); self.handle_raw_tcp(&mut client_stream, target, &connection_id) .await? } } self.update_connection_stats(&connection_id, target).await; info!("TCP proxy connection {} completed", connection_id); Ok(()) } async fn detect_protocol( &self, client_stream: &mut TcpStream, ) -> Result> { client_stream.set_nodelay(true)?; let mut peek_buf = [0u8; 1024]; match client_stream.peek(&mut peek_buf).await { Ok(0) => return Ok(ProxyProtocol::Tcp), Ok(n) => { let header = String::from_utf8_lossy(&peek_buf[..n]); if header.contains("Upgrade: websocket") || header.contains("upgrade: websocket") || header.contains("UPGRADE: websocket") { info!("Detected WebSocket protocol from handshake"); return Ok(ProxyProtocol::WebSocket); } Ok(ProxyProtocol::Tcp) } Err(e) => { warn!("Failed to peek at client stream: {}", e); Ok(ProxyProtocol::Tcp) } } } async fn handle_raw_tcp( &self, client_stream: &mut TcpStream, target: &str, connection_id: &str, ) -> Result<(), Box> { info!("Establishing raw TCP connection to: {}", target); let mut server_stream = TcpStream::connect(target).await.map_err(|e| { error!("Failed to connect to target {}: {}", target, e); ProxyError { message: format!("Failed to connect to target {}: {}", target, e), } })?; info!( "Established TCP connection {} -> {}", connection_id, server_stream.peer_addr()? ); client_stream.set_nodelay(true)?; server_stream.set_nodelay(true)?; let result = copy_bidirectional(client_stream, &mut server_stream).await; match result { Ok((client_bytes, server_bytes)) => { info!( "TCP proxy {} transferred {} bytes (client->server) and {} bytes (server->client)", connection_id, client_bytes, server_bytes ); Ok(()) } Err(e) => { error!("TCP proxy {} failed: {}", connection_id, e); Err(Box::new(ProxyError { message: format!("TCP proxy failed: {}", e), })) } } } async fn handle_websocket( &self, client_stream: &mut TcpStream, target: &str, connection_id: &str, ) -> Result<(), Box> { info!("Establishing WebSocket connection to: {}", target); let mut server_stream = TcpStream::connect(target).await.map_err(|e| { error!("Failed to connect to WebSocket target {}: {}", target, e); ProxyError { message: format!("Failed to connect to WebSocket target {}: {}", target, e), } })?; client_stream.set_nodelay(true)?; server_stream.set_nodelay(true)?; if let Err(e) = self .forward_websocket_handshake(client_stream, &mut server_stream) .await { error!( "WebSocket handshake failed for connection {}: {}", connection_id, e ); return Err(Box::new(e)); } info!( "WebSocket handshake completed for connection {} -> {}", connection_id, server_stream.peer_addr()? ); let result = copy_bidirectional(client_stream, &mut server_stream).await; match result { Ok((client_bytes, server_bytes)) => { info!( "WebSocket proxy {} transferred {} bytes (client->server) and {} bytes (server->client)", connection_id, client_bytes, server_bytes ); Ok(()) } Err(e) => { error!("WebSocket proxy {} failed: {}", connection_id, e); Err(Box::new(ProxyError { message: format!("WebSocket proxy failed: {}", e), })) } } } async fn forward_websocket_handshake( &self, client_stream: &mut TcpStream, server_stream: &mut TcpStream, ) -> Result<(), ProxyError> { let mut handshake = Vec::new(); let mut buf = [0u8; 1]; let mut header_end_found = false; while !header_end_found { match client_stream.read(&mut buf).await { Ok(0) => { return Err(ProxyError { message: "Client closed connection before handshake completed".to_string(), }); } Ok(n) => { handshake.extend_from_slice(&buf[..n]); if handshake.len() >= 4 && handshake[handshake.len() - 4..] == [b'\r', b'\n', b'\r', b'\n'] { header_end_found = true; } } Err(e) => { return Err(ProxyError { message: format!("Error reading handshake: {}", e), }); } } } server_stream .write_all(&handshake) .await .map_err(|e| ProxyError { message: format!("Failed to write handshake to server: {}", e), })?; let mut response_buf = [0u8; 1024]; let mut response = Vec::new(); let mut response_end_found = false; while !response_end_found { match server_stream.read(&mut response_buf).await { Ok(0) => { return Err(ProxyError { message: "Server closed connection before handshake completed".to_string(), }); } Ok(n) => { response.extend_from_slice(&response_buf[..n]); if response.len() >= 4 && response[response.len() - 4..] == [b'\r', b'\n', b'\r', b'\n'] { response_end_found = true; } } Err(e) => { return Err(ProxyError { message: format!("Error reading handshake response: {}", e), }); } } } client_stream .write_all(&response) .await .map_err(|e| ProxyError { message: format!("Failed to write handshake response to client: {}", e), })?; info!("WebSocket handshake forwarded successfully"); Ok(()) } async fn update_connection_stats(&self, connection_id: &str, target: &str) { let mut connections = self.connections.write().await; let conn = connections .entry(connection_id.to_string()) .or_insert_with(|| TcpConnection { target: target.to_string(), created_at: Instant::now(), request_count: 0, bytes_transferred: 0, }); conn.request_count += 1; } pub async fn cleanup_expired(&self, max_age: Duration) { let mut connections = self.connections.write().await; connections.retain(|_, conn| conn.created_at.elapsed() < max_age); let now = Instant::now(); let mut last_cleanup = self.last_cleanup.lock().await; if now.duration_since(*last_cleanup) > Duration::from_secs(60) { info!( "Cleaned up expired connections (total: {})", connections.len() ); *last_cleanup = now; } } pub async fn get_stats(&self) -> HashMap { self.connections.read().await.clone() } } impl Default for TcpProxyManager { fn default() -> Self { Self::new() } }