Skip to main content

moqtap_client/draft11/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft11::endpoint::{Endpoint, EndpointError};
6use crate::draft11::event::{ClientEvent, Direction, FetchObject, StreamKind, SubgroupObject};
7use crate::draft11::observer::ConnectionObserver;
8use crate::transport::quic::QuicTransport;
9use crate::transport::{RecvStream, SendStream, Transport, TransportError};
10use moqtap_codec::dispatch::{
11    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
12};
13use moqtap_codec::draft11::data_stream::{FetchObjectHeader, ObjectHeader};
14use moqtap_codec::draft11::message::ControlMessage;
15use moqtap_codec::error::CodecError;
16use moqtap_codec::types::*;
17use moqtap_codec::varint::VarInt;
18use moqtap_codec::version::DraftVersion;
19
20/// MoQT ALPN identifier (used by raw QUIC transport).
21pub const MOQT_ALPN: &[u8] = b"moq-00";
22
23/// Errors from the draft-11 connection layer.
24#[derive(Debug, thiserror::Error)]
25pub enum ConnectionError {
26    /// Endpoint state machine error.
27    #[error("endpoint error: {0}")]
28    Endpoint(#[from] EndpointError),
29    /// Wire codec error.
30    #[error("codec error: {0}")]
31    Codec(#[from] CodecError),
32    /// Transport-level error.
33    #[error("transport error: {0}")]
34    Transport(#[from] TransportError),
35    /// Variable-length integer decoding error.
36    #[error("varint error: {0}")]
37    VarInt(#[from] moqtap_codec::varint::VarIntError),
38    /// Control stream was not opened.
39    #[error("control stream not open")]
40    NoControlStream,
41    /// Stream ended before a complete message was read.
42    #[error("unexpected end of stream")]
43    UnexpectedEnd,
44    /// Stream was finished by the peer.
45    #[error("stream finished")]
46    StreamFinished,
47    /// Invalid server address string.
48    #[error("invalid server address: {0}")]
49    InvalidAddress(String),
50    /// TLS configuration error.
51    #[error("TLS config error: {0}")]
52    TlsConfig(String),
53}
54
55/// Transport type for the connection.
56#[derive(Debug, Clone)]
57pub enum TransportType {
58    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
59    Quic,
60    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
61    WebTransport {
62        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
63        url: String,
64    },
65}
66
67/// Configuration for a draft-11 MoQT client connection.
68pub struct ClientConfig {
69    /// Additional draft versions to offer in CLIENT_SETUP (draft-11 is always
70    /// offered first).
71    pub additional_versions: Vec<DraftVersion>,
72    /// The transport type (QUIC or WebTransport).
73    pub transport: TransportType,
74    /// Whether to skip TLS certificate verification (for testing).
75    pub skip_cert_verification: bool,
76    /// Custom CA certificates to trust (DER-encoded).
77    pub ca_certs: Vec<Vec<u8>>,
78    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
79    pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
80}
81
82impl ClientConfig {
83    /// Returns the MoQT version varints for the CLIENT_SETUP message.
84    /// Draft-11 first, then any additional versions.
85    pub fn supported_versions(&self) -> Vec<VarInt> {
86        let mut versions = vec![DraftVersion::Draft11.version_varint()];
87        for v in &self.additional_versions {
88            let varint = v.version_varint();
89            if !versions.contains(&varint) {
90                versions.push(varint);
91            }
92        }
93        versions
94    }
95
96    /// Returns the ALPN protocol identifiers for the transport.
97    pub fn alpn(&self) -> Vec<Vec<u8>> {
98        match &self.transport {
99            TransportType::Quic => vec![DraftVersion::Draft11.quic_alpn().to_vec()],
100            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
101        }
102    }
103}
104
105/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
106pub struct FramedSendStream {
107    inner: SendStream,
108}
109
110impl FramedSendStream {
111    /// Create a new framed send stream.
112    pub fn new(inner: SendStream) -> Self {
113        Self { inner }
114    }
115
116    /// Get the transport-level stream ID.
117    pub fn stream_id(&self) -> u64 {
118        self.inner.stream_id()
119    }
120
121    /// Write a control message to the stream with type+length framing.
122    /// Returns the raw bytes that were written (for event capture).
123    pub async fn write_control(
124        &mut self,
125        msg: &AnyControlMessage,
126    ) -> Result<Vec<u8>, ConnectionError> {
127        let mut buf = Vec::new();
128        msg.encode(&mut buf)?;
129        self.inner.write_all(&buf).await?;
130        Ok(buf)
131    }
132
133    /// Write a subgroup stream header.
134    pub async fn write_subgroup_header(
135        &mut self,
136        header: &AnySubgroupHeader,
137    ) -> Result<(), ConnectionError> {
138        let mut buf = Vec::new();
139        header.encode(&mut buf);
140        self.inner.write_all(&buf).await?;
141        Ok(())
142    }
143
144    /// Write a fetch response header.
145    pub async fn write_fetch_header(
146        &mut self,
147        header: &AnyFetchHeader,
148    ) -> Result<(), ConnectionError> {
149        let mut buf = Vec::new();
150        header.encode(&mut buf);
151        self.inner.write_all(&buf).await?;
152        Ok(())
153    }
154
155    /// Append a draft-11 subgroup object (header + payload) to the stream.
156    /// Draft-11 subgroup objects may carry extension headers but the writer
157    /// path defers to the codec `ObjectHeader::encode` helper.
158    pub async fn write_subgroup_object(
159        &mut self,
160        object: &SubgroupObject,
161    ) -> Result<(), ConnectionError> {
162        let mut buf = Vec::new();
163        object.header.encode(&mut buf);
164        buf.extend_from_slice(&object.payload);
165        self.inner.write_all(&buf).await?;
166        Ok(())
167    }
168
169    /// Append a draft-11 fetch object (header + payload) to the stream.
170    pub async fn write_fetch_object(
171        &mut self,
172        object: &FetchObject,
173    ) -> Result<(), ConnectionError> {
174        let mut buf = Vec::new();
175        object.header.encode(&mut buf);
176        buf.extend_from_slice(&object.payload);
177        self.inner.write_all(&buf).await?;
178        Ok(())
179    }
180
181    /// Finish the stream (send FIN).
182    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
183        self.inner.finish()?;
184        Ok(())
185    }
186}
187
188/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
189pub struct FramedRecvStream {
190    inner: RecvStream,
191    buf: BytesMut,
192}
193
194impl FramedRecvStream {
195    /// Create a new framed receive stream.
196    pub fn new(inner: RecvStream) -> Self {
197        Self { inner, buf: BytesMut::with_capacity(4096) }
198    }
199
200    /// Get the transport-level stream ID.
201    pub fn stream_id(&self) -> u64 {
202        self.inner.stream_id()
203    }
204
205    /// Read more data from the stream into the internal buffer.
206    async fn fill(&mut self) -> Result<bool, ConnectionError> {
207        let mut tmp = [0u8; 4096];
208        match self.inner.read(&mut tmp).await {
209            Ok(Some(n)) => {
210                self.buf.extend_from_slice(&tmp[..n]);
211                Ok(true)
212            }
213            Ok(None) => Ok(false),
214            Err(e) => Err(ConnectionError::Transport(e)),
215        }
216    }
217
218    /// Ensure at least `n` bytes are available in the buffer.
219    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
220        while self.buf.len() < n {
221            if !self.fill().await? {
222                return Err(ConnectionError::UnexpectedEnd);
223            }
224        }
225        Ok(())
226    }
227
228    /// Read a control message from the stream.
229    ///
230    /// When `capture_raw` is true, the returned tuple includes a clone of the
231    /// framed wire bytes (for observer emission). When false, the second
232    /// element is `None` and the payload clone is skipped.
233    pub async fn read_control(
234        &mut self,
235        capture_raw: bool,
236    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
237        // Read type ID varint
238        self.ensure(1).await?;
239        let type_len = varint_len(self.buf[0]);
240        self.ensure(type_len).await?;
241
242        let mut cursor = &self.buf[..type_len];
243        let _type_id = VarInt::decode(&mut cursor)?;
244
245        // Draft-11 uses a fixed 16-bit length after the type id.
246        let len_field_size = 2;
247        self.ensure(type_len + len_field_size).await?;
248        let payload_len = u16::from_be_bytes([self.buf[type_len], self.buf[type_len + 1]]) as usize;
249
250        // Read full payload
251        let total = type_len + len_field_size + payload_len;
252        self.ensure(total).await?;
253
254        // Capture raw bytes only if requested (observer attached).
255        let raw = capture_raw.then(|| self.buf[..total].to_vec());
256
257        // Now decode the whole message using the draft-11 dispatcher
258        let mut frame = &self.buf[..total];
259        let msg = AnyControlMessage::decode(DraftVersion::Draft11, &mut frame)?;
260        self.buf.advance(total);
261        Ok((msg, raw))
262    }
263
264    /// Read a subgroup stream header.
265    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
266        self.ensure(1).await?;
267        loop {
268            let mut cursor = &self.buf[..];
269            match AnySubgroupHeader::decode(DraftVersion::Draft11, &mut cursor) {
270                Ok(header) => {
271                    let consumed = self.buf.len() - cursor.remaining();
272                    self.buf.advance(consumed);
273                    return Ok(header);
274                }
275                Err(CodecError::UnexpectedEnd) => {
276                    if !self.fill().await? {
277                        return Err(ConnectionError::UnexpectedEnd);
278                    }
279                }
280                Err(e) => return Err(ConnectionError::Codec(e)),
281            }
282        }
283    }
284
285    /// Read a fetch response header.
286    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
287        self.ensure(1).await?;
288        loop {
289            let mut cursor = &self.buf[..];
290            match AnyFetchHeader::decode(DraftVersion::Draft11, &mut cursor) {
291                Ok(header) => {
292                    let consumed = self.buf.len() - cursor.remaining();
293                    self.buf.advance(consumed);
294                    return Ok(header);
295                }
296                Err(CodecError::UnexpectedEnd) => {
297                    if !self.fill().await? {
298                        return Err(ConnectionError::UnexpectedEnd);
299                    }
300                }
301                Err(e) => return Err(ConnectionError::Codec(e)),
302            }
303        }
304    }
305
306    /// Read the next draft-11 subgroup object (header + payload). The codec
307    /// helper `ObjectHeader::decode` is called without extensions by default;
308    /// callers that care about extensions should drive decoding explicitly.
309    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
310        loop {
311            let mut cursor = &self.buf[..];
312            match ObjectHeader::decode(&mut cursor) {
313                Ok(header) => {
314                    let header_consumed = self.buf.len() - cursor.remaining();
315                    let payload_len = header.payload_length.into_inner() as usize;
316                    let total = header_consumed + payload_len;
317                    if self.buf.len() < total {
318                        if !self.fill().await? {
319                            return Err(ConnectionError::UnexpectedEnd);
320                        }
321                        continue;
322                    }
323                    let payload = self.buf[header_consumed..total].to_vec();
324                    self.buf.advance(total);
325                    return Ok(SubgroupObject { header, payload });
326                }
327                Err(CodecError::UnexpectedEnd) => {
328                    if !self.fill().await? {
329                        return Err(ConnectionError::UnexpectedEnd);
330                    }
331                }
332                Err(e) => return Err(ConnectionError::Codec(e)),
333            }
334        }
335    }
336
337    /// Read the next draft-11 fetch object (header + payload).
338    pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
339        loop {
340            let mut cursor = &self.buf[..];
341            match FetchObjectHeader::decode(&mut cursor) {
342                Ok(header) => {
343                    let header_consumed = self.buf.len() - cursor.remaining();
344                    let payload_len = header.payload_length.into_inner() as usize;
345                    let total = header_consumed + payload_len;
346                    if self.buf.len() < total {
347                        if !self.fill().await? {
348                            return Err(ConnectionError::UnexpectedEnd);
349                        }
350                        continue;
351                    }
352                    let payload = self.buf[header_consumed..total].to_vec();
353                    self.buf.advance(total);
354                    return Ok(FetchObject { header, payload });
355                }
356                Err(CodecError::UnexpectedEnd) => {
357                    if !self.fill().await? {
358                        return Err(ConnectionError::UnexpectedEnd);
359                    }
360                }
361                Err(e) => return Err(ConnectionError::Codec(e)),
362            }
363        }
364    }
365}
366
367/// A live draft-11 MoQT connection over QUIC or WebTransport.
368pub struct Connection {
369    transport: Transport,
370    endpoint: Endpoint,
371    control_send: Option<FramedSendStream>,
372    control_recv: Option<FramedRecvStream>,
373    observer: Option<Box<dyn ConnectionObserver>>,
374    /// Setup events buffered during `connect()` and replayed when an
375    /// observer attaches via `set_observer` — without this, an observer
376    /// attached after `connect` returns would never see the handshake.
377    pending_events: Vec<ClientEvent>,
378}
379
380impl Connection {
381    /// Connect to a draft-11 MoQT server as a client.
382    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
383        let transport = match &config.transport {
384            TransportType::Quic => Self::connect_quic(addr, &config).await?,
385            TransportType::WebTransport { url } => {
386                let url = url.clone();
387                Self::connect_webtransport(&url, &config).await?
388            }
389        };
390
391        // Open bidirectional control stream
392        let (send, recv) = transport.open_bi().await?;
393        let mut control_send = FramedSendStream::new(send);
394        let mut control_recv = FramedRecvStream::new(recv);
395
396        // Perform setup handshake
397        let mut endpoint = Endpoint::new();
398        endpoint.connect()?;
399        let setup_msg = endpoint
400            .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
401        let any_setup = AnyControlMessage::Draft11(setup_msg);
402        let raw_setup = control_send.write_control(&any_setup).await?;
403
404        let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
405        match &server_setup {
406            AnyControlMessage::Draft11(ControlMessage::ServerSetup(ref ss)) => {
407                endpoint.receive_server_setup(ss)?;
408            }
409            _ => {
410                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
411            }
412        }
413
414        let mut pending_events = Vec::with_capacity(3);
415        pending_events.push(ClientEvent::ControlMessage {
416            direction: Direction::Send,
417            message: any_setup,
418            raw: Some(raw_setup),
419        });
420        pending_events.push(ClientEvent::ControlMessage {
421            direction: Direction::Receive,
422            message: server_setup,
423            raw: raw_server_setup,
424        });
425        if let Some(v) = endpoint.negotiated_version() {
426            pending_events.push(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
427        }
428
429        Ok(Self {
430            transport,
431            endpoint,
432            control_send: Some(control_send),
433            control_recv: Some(control_recv),
434            observer: None,
435            pending_events,
436        })
437    }
438
439    /// Establish a raw QUIC connection.
440    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
441        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
442            ConnectionError::InvalidAddress(e.to_string())
443        })?;
444
445        let mut tls_config = if config.skip_cert_verification {
446            rustls::ClientConfig::builder()
447                .dangerous()
448                .with_custom_certificate_verifier(Arc::new(SkipVerification))
449                .with_no_client_auth()
450        } else {
451            let mut roots = rustls::RootCertStore::empty();
452            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
453            for der in &config.ca_certs {
454                roots
455                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
456                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
457            }
458            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
459        };
460
461        tls_config.alpn_protocols = config.alpn();
462
463        let quic_config: quinn::crypto::rustls::QuicClientConfig =
464            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
465        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
466
467        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
468            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
469        quinn_endpoint.set_default_client_config(client_config);
470
471        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
472
473        let quic = quinn_endpoint
474            .connect(server_addr, &server_name)
475            .map_err(TransportError::from)?
476            .await
477            .map_err(TransportError::from)?;
478
479        Ok(Transport::Quic(QuicTransport::new(quic)))
480    }
481
482    /// Establish a WebTransport connection.
483    #[cfg(feature = "webtransport")]
484    async fn connect_webtransport(
485        url: &str,
486        config: &ClientConfig,
487    ) -> Result<Transport, ConnectionError> {
488        use crate::transport::webtransport::WebTransportTransport;
489
490        let wt_config = if config.skip_cert_verification {
491            wtransport::ClientConfig::builder()
492                .with_bind_default()
493                .with_no_cert_validation()
494                .build()
495        } else {
496            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
497        };
498
499        let endpoint = wtransport::Endpoint::client(wt_config)
500            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
501
502        let connection = endpoint
503            .connect(url)
504            .await
505            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
506
507        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
508    }
509
510    /// Stub for when the webtransport feature is not enabled.
511    #[cfg(not(feature = "webtransport"))]
512    async fn connect_webtransport(
513        _url: &str,
514        _config: &ClientConfig,
515    ) -> Result<Transport, ConnectionError> {
516        Err(ConnectionError::Transport(TransportError::Connect(
517            "webtransport feature not enabled".into(),
518        )))
519    }
520
521    // ── Observer ───────────────────────────────────────────────
522
523    /// Attach an observer. Buffered handshake events from `connect()` are
524    /// flushed in arrival order before this returns.
525    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
526        self.observer = Some(observer);
527        for event in self.pending_events.drain(..) {
528            if let Some(ref obs) = self.observer {
529                obs.on_event_owned(event);
530            }
531        }
532    }
533
534    /// Remove the observer.
535    pub fn clear_observer(&mut self) {
536        self.observer = None;
537    }
538
539    /// Emit an event to the observer, if one is attached.
540    fn emit(&self, event: ClientEvent) {
541        if let Some(ref obs) = self.observer {
542            obs.on_event_owned(event);
543        }
544    }
545
546    // ── Control message I/O ─────────────────────────────────
547
548    /// Send a control message on the control stream.
549    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
550        let any = AnyControlMessage::Draft11(msg.clone());
551        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
552        let raw = send.write_control(&any).await?;
553        self.emit(ClientEvent::ControlMessage {
554            direction: Direction::Send,
555            message: any,
556            raw: Some(raw),
557        });
558        Ok(())
559    }
560
561    /// Read the next control message from the control stream.
562    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
563        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
564        let capture_raw = self.observer.is_some();
565        let (any, raw) = recv.read_control(capture_raw).await?;
566        if capture_raw {
567            self.emit(ClientEvent::ControlMessage {
568                direction: Direction::Receive,
569                message: any.clone(),
570                raw,
571            });
572        }
573        match any {
574            AnyControlMessage::Draft11(msg) => Ok(msg),
575            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
576        }
577    }
578
579    /// Read and dispatch the next incoming control message through the endpoint
580    /// state machine. Returns the decoded message for inspection.
581    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
582        let msg = self.recv_control().await?;
583        self.endpoint.receive_message(msg.clone())?;
584
585        if let ControlMessage::GoAway(ref ga) = msg {
586            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
587        }
588
589        Ok(msg)
590    }
591
592    // ── Subscribe flow ──────────────────────────────────────
593
594    /// Send a SUBSCRIBE and return the allocated request ID.
595    #[allow(clippy::too_many_arguments)]
596    pub async fn subscribe(
597        &mut self,
598        track_alias: VarInt,
599        track_namespace: TrackNamespace,
600        track_name: Vec<u8>,
601        subscriber_priority: u8,
602        group_order: VarInt,
603        filter_type: VarInt,
604    ) -> Result<VarInt, ConnectionError> {
605        let (req_id, msg) = self.endpoint.subscribe(
606            track_alias,
607            track_namespace,
608            track_name,
609            subscriber_priority,
610            group_order,
611            filter_type,
612        )?;
613        self.send_control(&msg).await?;
614        Ok(req_id)
615    }
616
617    /// Send an UNSUBSCRIBE for the given request ID.
618    pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
619        let msg = self.endpoint.unsubscribe(request_id)?;
620        self.send_control(&msg).await
621    }
622
623    // ── Fetch flow ──────────────────────────────────────────
624
625    /// Send a FETCH and return the allocated request ID.
626    #[allow(clippy::too_many_arguments)]
627    pub async fn fetch(
628        &mut self,
629        track_namespace: TrackNamespace,
630        track_name: Vec<u8>,
631        subscriber_priority: u8,
632        group_order: VarInt,
633        start_group: VarInt,
634        start_object: VarInt,
635        end_group: VarInt,
636        end_object: VarInt,
637    ) -> Result<VarInt, ConnectionError> {
638        let (req_id, msg) = self.endpoint.fetch(
639            track_namespace,
640            track_name,
641            subscriber_priority,
642            group_order,
643            start_group,
644            start_object,
645            end_group,
646            end_object,
647        )?;
648        self.send_control(&msg).await?;
649        Ok(req_id)
650    }
651
652    /// Send a FETCH_CANCEL for the given request ID.
653    pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
654        let msg = self.endpoint.fetch_cancel(request_id)?;
655        self.send_control(&msg).await
656    }
657
658    // ── Namespace flows ─────────────────────────────────────
659
660    /// Send a SUBSCRIBE_ANNOUNCES. Returns the allocated request ID.
661    pub async fn subscribe_announces(
662        &mut self,
663        track_namespace_prefix: TrackNamespace,
664    ) -> Result<VarInt, ConnectionError> {
665        let (req_id, msg) = self.endpoint.subscribe_announces(track_namespace_prefix)?;
666        self.send_control(&msg).await?;
667        Ok(req_id)
668    }
669
670    /// Send an ANNOUNCE. Returns the allocated request ID.
671    pub async fn announce(
672        &mut self,
673        track_namespace: TrackNamespace,
674    ) -> Result<VarInt, ConnectionError> {
675        let (req_id, msg) = self.endpoint.announce(track_namespace)?;
676        self.send_control(&msg).await?;
677        Ok(req_id)
678    }
679
680    /// Send an UNANNOUNCE.
681    pub async fn unannounce(
682        &mut self,
683        track_namespace: TrackNamespace,
684    ) -> Result<(), ConnectionError> {
685        let msg = self.endpoint.unannounce(track_namespace)?;
686        self.send_control(&msg).await
687    }
688
689    // ── Track Status flow ────────────────────────────────────
690
691    /// Send a TRACK_STATUS_REQUEST. Returns the allocated request ID.
692    pub async fn track_status_request(
693        &mut self,
694        track_namespace: TrackNamespace,
695        track_name: Vec<u8>,
696    ) -> Result<VarInt, ConnectionError> {
697        let (req_id, msg) = self.endpoint.track_status_request(track_namespace, track_name)?;
698        self.send_control(&msg).await?;
699        Ok(req_id)
700    }
701
702    // ── Data streams ────────────────────────────────────────
703
704    /// Open a new unidirectional stream for sending subgroup data.
705    pub async fn open_subgroup_stream(
706        &self,
707        header: &AnySubgroupHeader,
708    ) -> Result<FramedSendStream, ConnectionError> {
709        let send = self.transport.open_uni().await?;
710        let mut framed = FramedSendStream::new(send);
711        let sid = framed.stream_id();
712        framed.write_subgroup_header(header).await?;
713        self.emit(ClientEvent::StreamOpened {
714            direction: Direction::Send,
715            stream_kind: StreamKind::Subgroup,
716            stream_id: sid,
717        });
718        self.emit(ClientEvent::DataStreamHeader {
719            stream_id: sid,
720            direction: Direction::Send,
721            header: header.clone(),
722        });
723        Ok(framed)
724    }
725
726    /// Accept an incoming unidirectional data stream and read its subgroup
727    /// header.
728    pub async fn accept_subgroup_stream(
729        &self,
730    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
731        let recv = self.transport.accept_uni().await?;
732        let mut framed = FramedRecvStream::new(recv);
733        let sid = framed.stream_id();
734        let header = framed.read_subgroup_header().await?;
735        self.emit(ClientEvent::StreamOpened {
736            direction: Direction::Receive,
737            stream_kind: StreamKind::Subgroup,
738            stream_id: sid,
739        });
740        self.emit(ClientEvent::DataStreamHeader {
741            stream_id: sid,
742            direction: Direction::Receive,
743            header: header.clone(),
744        });
745        Ok((header, framed))
746    }
747
748    /// Send an object via datagram.
749    pub fn send_datagram(
750        &self,
751        header: &AnyDatagramHeader,
752        payload: &[u8],
753    ) -> Result<(), ConnectionError> {
754        let mut buf = Vec::new();
755        header.encode(&mut buf);
756        buf.extend_from_slice(payload);
757        self.emit(ClientEvent::DatagramReceived {
758            direction: Direction::Send,
759            header: header.clone(),
760            payload_len: payload.len(),
761        });
762        self.transport.send_datagram(bytes::Bytes::from(buf))?;
763        Ok(())
764    }
765
766    /// Receive a datagram and decode its header.
767    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
768        let data = self.transport.recv_datagram().await?;
769        let mut cursor = &data[..];
770        let header = AnyDatagramHeader::decode(DraftVersion::Draft11, &mut cursor)?;
771        let consumed = data.len() - cursor.len();
772        let payload = data.slice(consumed..);
773        self.emit(ClientEvent::DatagramReceived {
774            direction: Direction::Receive,
775            header: header.clone(),
776            payload_len: payload.len(),
777        });
778        Ok((header, payload))
779    }
780
781    // ── Accessors ───────────────────────────────────────────
782
783    /// Access the underlying endpoint state machine.
784    pub fn endpoint(&self) -> &Endpoint {
785        &self.endpoint
786    }
787
788    /// Mutable access to the endpoint state machine.
789    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
790        &mut self.endpoint
791    }
792
793    /// Get the negotiated MoQT version.
794    pub fn negotiated_version(&self) -> Option<VarInt> {
795        self.endpoint.negotiated_version()
796    }
797
798    /// Close the connection.
799    pub fn close(&self, code: u32, reason: &[u8]) {
800        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
801        self.transport.close(code, reason);
802    }
803}
804
805/// Determine the encoded length of a varint from its first byte.
806fn varint_len(first_byte: u8) -> usize {
807    1 << (first_byte >> 6)
808}
809
810/// TLS certificate verifier that skips all verification (for testing only).
811#[derive(Debug)]
812struct SkipVerification;
813
814impl rustls::client::danger::ServerCertVerifier for SkipVerification {
815    fn verify_server_cert(
816        &self,
817        _end_entity: &rustls::pki_types::CertificateDer<'_>,
818        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
819        _server_name: &rustls::pki_types::ServerName<'_>,
820        _ocsp_response: &[u8],
821        _now: rustls::pki_types::UnixTime,
822    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
823        Ok(rustls::client::danger::ServerCertVerified::assertion())
824    }
825
826    fn verify_tls12_signature(
827        &self,
828        _message: &[u8],
829        _cert: &rustls::pki_types::CertificateDer<'_>,
830        _dcs: &rustls::DigitallySignedStruct,
831    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
832        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
833    }
834
835    fn verify_tls13_signature(
836        &self,
837        _message: &[u8],
838        _cert: &rustls::pki_types::CertificateDer<'_>,
839        _dcs: &rustls::DigitallySignedStruct,
840    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
841        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
842    }
843
844    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
845        vec![
846            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
847            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
848            rustls::SignatureScheme::ED25519,
849            rustls::SignatureScheme::RSA_PSS_SHA256,
850            rustls::SignatureScheme::RSA_PSS_SHA384,
851            rustls::SignatureScheme::RSA_PSS_SHA512,
852        ]
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859
860    #[test]
861    fn client_config_supported_versions_default() {
862        let config = ClientConfig {
863            additional_versions: Vec::new(),
864            transport: TransportType::Quic,
865            skip_cert_verification: false,
866            ca_certs: Vec::new(),
867            setup_parameters: Vec::new(),
868        };
869        let versions = config.supported_versions();
870        assert_eq!(versions.len(), 1);
871        assert_eq!(versions[0].into_inner(), 0xff000000 + 11);
872    }
873
874    #[test]
875    fn client_config_alpn_quic() {
876        let config = ClientConfig {
877            additional_versions: Vec::new(),
878            transport: TransportType::Quic,
879            skip_cert_verification: false,
880            ca_certs: Vec::new(),
881            setup_parameters: Vec::new(),
882        };
883        assert_eq!(config.alpn(), vec![DraftVersion::Draft11.quic_alpn().to_vec()]);
884    }
885
886    #[test]
887    fn moqt_alpn_value() {
888        assert_eq!(MOQT_ALPN, b"moq-00");
889    }
890}