336 lines
11 KiB
Rust
336 lines
11 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt, copy_bidirectional};
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::RwLock;
|
|
use tracing::{error, info, warn};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct TcpProxyManager {
|
|
connections: Arc<RwLock<HashMap<String, TcpConnection>>>,
|
|
#[allow(dead_code)]
|
|
last_cleanup: Arc<tokio::sync::Mutex<Instant>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct TcpConnection {
|
|
pub target: String,
|
|
pub created_at: Instant,
|
|
pub request_count: u64,
|
|
pub bytes_transferred: u64,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum ProxyProtocol {
|
|
Tcp,
|
|
WebSocket,
|
|
AutoDetect,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ProxyError {
|
|
pub message: String,
|
|
}
|
|
|
|
impl std::fmt::Display for ProxyError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ProxyError {}
|
|
|
|
impl TcpProxyManager {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
connections: Arc::new(RwLock::new(HashMap::new())),
|
|
last_cleanup: Arc::new(tokio::sync::Mutex::new(Instant::now())),
|
|
}
|
|
}
|
|
|
|
pub async fn handle_tcp_proxy(
|
|
&self,
|
|
mut client_stream: TcpStream,
|
|
target: &str,
|
|
protocol: ProxyProtocol,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let connection_id = format!(
|
|
"{}:{}->{}",
|
|
client_stream.local_addr()?.ip(),
|
|
client_stream.local_addr()?.port(),
|
|
target
|
|
);
|
|
|
|
info!(
|
|
"Starting TCP proxy connection {} to {} with protocol {:?}",
|
|
connection_id, target, protocol
|
|
);
|
|
|
|
let actual_protocol = if matches!(protocol, ProxyProtocol::AutoDetect) {
|
|
self.detect_protocol(&mut client_stream).await?
|
|
} else {
|
|
protocol
|
|
};
|
|
|
|
match actual_protocol {
|
|
ProxyProtocol::Tcp => {
|
|
self.handle_raw_tcp(&mut client_stream, target, &connection_id)
|
|
.await?
|
|
}
|
|
ProxyProtocol::WebSocket => {
|
|
self.handle_websocket(&mut client_stream, target, &connection_id)
|
|
.await?
|
|
}
|
|
ProxyProtocol::AutoDetect => {
|
|
warn!("Auto-detect should have been resolved to a specific protocol");
|
|
self.handle_raw_tcp(&mut client_stream, target, &connection_id)
|
|
.await?
|
|
}
|
|
}
|
|
|
|
self.update_connection_stats(&connection_id, target).await;
|
|
info!("TCP proxy connection {} completed", connection_id);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn detect_protocol(
|
|
&self,
|
|
client_stream: &mut TcpStream,
|
|
) -> Result<ProxyProtocol, Box<dyn std::error::Error + Send + Sync>> {
|
|
client_stream.set_nodelay(true)?;
|
|
let mut peek_buf = [0u8; 1024];
|
|
|
|
match client_stream.peek(&mut peek_buf).await {
|
|
Ok(0) => return Ok(ProxyProtocol::Tcp),
|
|
Ok(n) => {
|
|
let header = String::from_utf8_lossy(&peek_buf[..n]);
|
|
if header.contains("Upgrade: websocket")
|
|
|| header.contains("upgrade: websocket")
|
|
|| header.contains("UPGRADE: websocket")
|
|
{
|
|
info!("Detected WebSocket protocol from handshake");
|
|
return Ok(ProxyProtocol::WebSocket);
|
|
}
|
|
Ok(ProxyProtocol::Tcp)
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to peek at client stream: {}", e);
|
|
Ok(ProxyProtocol::Tcp)
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_raw_tcp(
|
|
&self,
|
|
client_stream: &mut TcpStream,
|
|
target: &str,
|
|
connection_id: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
info!("Establishing raw TCP connection to: {}", target);
|
|
let mut server_stream = TcpStream::connect(target).await.map_err(|e| {
|
|
error!("Failed to connect to target {}: {}", target, e);
|
|
ProxyError {
|
|
message: format!("Failed to connect to target {}: {}", target, e),
|
|
}
|
|
})?;
|
|
|
|
info!(
|
|
"Established TCP connection {} -> {}",
|
|
connection_id,
|
|
server_stream.peer_addr()?
|
|
);
|
|
|
|
client_stream.set_nodelay(true)?;
|
|
server_stream.set_nodelay(true)?;
|
|
|
|
let result = copy_bidirectional(client_stream, &mut server_stream).await;
|
|
|
|
match result {
|
|
Ok((client_bytes, server_bytes)) => {
|
|
info!(
|
|
"TCP proxy {} transferred {} bytes (client->server) and {} bytes (server->client)",
|
|
connection_id, client_bytes, server_bytes
|
|
);
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
error!("TCP proxy {} failed: {}", connection_id, e);
|
|
Err(Box::new(ProxyError {
|
|
message: format!("TCP proxy failed: {}", e),
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_websocket(
|
|
&self,
|
|
client_stream: &mut TcpStream,
|
|
target: &str,
|
|
connection_id: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
info!("Establishing WebSocket connection to: {}", target);
|
|
|
|
let mut server_stream = TcpStream::connect(target).await.map_err(|e| {
|
|
error!("Failed to connect to WebSocket target {}: {}", target, e);
|
|
ProxyError {
|
|
message: format!("Failed to connect to WebSocket target {}: {}", target, e),
|
|
}
|
|
})?;
|
|
|
|
client_stream.set_nodelay(true)?;
|
|
server_stream.set_nodelay(true)?;
|
|
|
|
if let Err(e) = self
|
|
.forward_websocket_handshake(client_stream, &mut server_stream)
|
|
.await
|
|
{
|
|
error!(
|
|
"WebSocket handshake failed for connection {}: {}",
|
|
connection_id, e
|
|
);
|
|
return Err(Box::new(e));
|
|
}
|
|
|
|
info!(
|
|
"WebSocket handshake completed for connection {} -> {}",
|
|
connection_id,
|
|
server_stream.peer_addr()?
|
|
);
|
|
|
|
let result = copy_bidirectional(client_stream, &mut server_stream).await;
|
|
|
|
match result {
|
|
Ok((client_bytes, server_bytes)) => {
|
|
info!(
|
|
"WebSocket proxy {} transferred {} bytes (client->server) and {} bytes (server->client)",
|
|
connection_id, client_bytes, server_bytes
|
|
);
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
error!("WebSocket proxy {} failed: {}", connection_id, e);
|
|
Err(Box::new(ProxyError {
|
|
message: format!("WebSocket proxy failed: {}", e),
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn forward_websocket_handshake(
|
|
&self,
|
|
client_stream: &mut TcpStream,
|
|
server_stream: &mut TcpStream,
|
|
) -> Result<(), ProxyError> {
|
|
let mut handshake = Vec::new();
|
|
let mut buf = [0u8; 1];
|
|
let mut header_end_found = false;
|
|
|
|
while !header_end_found {
|
|
match client_stream.read(&mut buf).await {
|
|
Ok(0) => {
|
|
return Err(ProxyError {
|
|
message: "Client closed connection before handshake completed".to_string(),
|
|
});
|
|
}
|
|
Ok(n) => {
|
|
handshake.extend_from_slice(&buf[..n]);
|
|
if handshake.len() >= 4
|
|
&& handshake[handshake.len() - 4..] == [b'\r', b'\n', b'\r', b'\n']
|
|
{
|
|
header_end_found = true;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
return Err(ProxyError {
|
|
message: format!("Error reading handshake: {}", e),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
server_stream
|
|
.write_all(&handshake)
|
|
.await
|
|
.map_err(|e| ProxyError {
|
|
message: format!("Failed to write handshake to server: {}", e),
|
|
})?;
|
|
|
|
let mut response_buf = [0u8; 1024];
|
|
let mut response = Vec::new();
|
|
let mut response_end_found = false;
|
|
|
|
while !response_end_found {
|
|
match server_stream.read(&mut response_buf).await {
|
|
Ok(0) => {
|
|
return Err(ProxyError {
|
|
message: "Server closed connection before handshake completed".to_string(),
|
|
});
|
|
}
|
|
Ok(n) => {
|
|
response.extend_from_slice(&response_buf[..n]);
|
|
if response.len() >= 4
|
|
&& response[response.len() - 4..] == [b'\r', b'\n', b'\r', b'\n']
|
|
{
|
|
response_end_found = true;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
return Err(ProxyError {
|
|
message: format!("Error reading handshake response: {}", e),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
client_stream
|
|
.write_all(&response)
|
|
.await
|
|
.map_err(|e| ProxyError {
|
|
message: format!("Failed to write handshake response to client: {}", e),
|
|
})?;
|
|
|
|
info!("WebSocket handshake forwarded successfully");
|
|
Ok(())
|
|
}
|
|
|
|
async fn update_connection_stats(&self, connection_id: &str, target: &str) {
|
|
let mut connections = self.connections.write().await;
|
|
let conn = connections
|
|
.entry(connection_id.to_string())
|
|
.or_insert_with(|| TcpConnection {
|
|
target: target.to_string(),
|
|
created_at: Instant::now(),
|
|
request_count: 0,
|
|
bytes_transferred: 0,
|
|
});
|
|
conn.request_count += 1;
|
|
}
|
|
|
|
pub async fn cleanup_expired(&self, max_age: Duration) {
|
|
let mut connections = self.connections.write().await;
|
|
connections.retain(|_, conn| conn.created_at.elapsed() < max_age);
|
|
|
|
let now = Instant::now();
|
|
let mut last_cleanup = self.last_cleanup.lock().await;
|
|
if now.duration_since(*last_cleanup) > Duration::from_secs(60) {
|
|
info!(
|
|
"Cleaned up expired connections (total: {})",
|
|
connections.len()
|
|
);
|
|
*last_cleanup = now;
|
|
}
|
|
}
|
|
|
|
pub async fn get_stats(&self) -> HashMap<String, TcpConnection> {
|
|
self.connections.read().await.clone()
|
|
}
|
|
}
|
|
|
|
impl Default for TcpProxyManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|