1use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10
11use tokio_util::sync::CancellationToken;
12
13use crate::error::ProxyError;
14use crate::event::{ProxyEvent, SessionId};
15use crate::hook::{NoOpHook, ProxyHook};
16use crate::listener::{AcceptedConn, Listener, ListenerConfig};
17use crate::observer::ProxyObserver;
18use crate::session::{ProxySession, ProxySessionConfig};
19
20pub struct ProxyConfig {
22 pub listener: ListenerConfig,
24 pub session: ProxySessionConfig,
26}
27
28pub struct TransparentProxy {
37 config: ProxyConfig,
38 observer: Arc<dyn ProxyObserver>,
39 hook: Arc<dyn ProxyHook>,
40 cancel: CancellationToken,
41 next_session_id: AtomicU64,
42}
43
44impl TransparentProxy {
45 pub fn new(config: ProxyConfig, observer: Arc<dyn ProxyObserver>) -> Self {
47 Self {
48 config,
49 observer,
50 hook: Arc::new(NoOpHook),
51 cancel: CancellationToken::new(),
52 next_session_id: AtomicU64::new(1),
53 }
54 }
55
56 pub fn with_hook(
58 config: ProxyConfig,
59 observer: Arc<dyn ProxyObserver>,
60 hook: Arc<dyn ProxyHook>,
61 ) -> Self {
62 Self {
63 config,
64 observer,
65 hook,
66 cancel: CancellationToken::new(),
67 next_session_id: AtomicU64::new(1),
68 }
69 }
70
71 pub fn cancel_token(&self) -> CancellationToken {
73 self.cancel.clone()
74 }
75
76 pub async fn run(&self) -> Result<(), ProxyError> {
79 let listener = Listener::bind(ListenerConfig {
80 bind_addr: self.config.listener.bind_addr,
81 cert_chain: self.config.listener.cert_chain.clone(),
82 key_der: self.config.listener.key_der.clone_key(),
83 })?;
84
85 loop {
86 tokio::select! {
87 result = listener.accept() => {
88 self.dispatch(result?);
89 }
90 _ = self.cancel.cancelled() => {
91 listener.close();
92 return Ok(());
93 }
94 }
95 }
96 }
97
98 fn dispatch(&self, accepted: AcceptedConn) {
101 match accepted {
102 AcceptedConn::Quic { conn, alpn } => {
103 let session_id = self.next_session_id();
104 let client_addr = conn.remote_address();
105 self.emit_session_started(session_id, client_addr, "QUIC");
106 let session = self.new_session(session_id, alpn);
107 tokio::spawn(async move {
108 let _ = session.run(conn).await;
109 });
110 }
111 #[cfg(feature = "webtransport")]
112 AcceptedConn::WebTransport(conn) => {
113 let session_id = self.next_session_id();
114 let client_addr = conn.remote_address();
115 self.emit_session_started(session_id, client_addr, "WebTransport");
116 let session = self.new_session(session_id, Vec::new());
119 tokio::spawn(async move {
120 let _ = session.run_webtransport(conn).await;
121 });
122 }
123 }
124 }
125
126 fn next_session_id(&self) -> SessionId {
129 SessionId(self.next_session_id.fetch_add(1, Ordering::Relaxed))
130 }
131
132 fn emit_session_started(
133 &self,
134 session_id: SessionId,
135 client_addr: std::net::SocketAddr,
136 client_transport: &str,
137 ) {
138 if self.observer.wants_events() {
139 self.observer.on_event(&ProxyEvent::SessionStarted {
140 session_id,
141 client_addr,
142 client_transport: client_transport.to_string(),
143 });
144 }
145 }
146
147 fn new_session(&self, session_id: SessionId, client_alpn: Vec<u8>) -> ProxySession {
148 ProxySession::new(
149 session_id,
150 ProxySessionConfig {
151 draft: self.config.session.draft,
152 upstream_transport: self.config.session.upstream_transport.clone(),
153 upstream_addr: self.config.session.upstream_addr.clone(),
154 skip_upstream_cert_verify: self.config.session.skip_upstream_cert_verify,
155 upstream_ca_certs: self.config.session.upstream_ca_certs.clone(),
156 upstream_connect_timeout_secs: self.config.session.upstream_connect_timeout_secs,
157 },
158 client_alpn,
159 Arc::clone(&self.observer),
160 Arc::clone(&self.hook),
161 self.cancel.child_token(),
162 )
163 }
164}