Skip to main content

moqtap_client/draft10/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft10::endpoint::{Endpoint, EndpointError};
6use crate::draft10::event::{ClientEvent, Direction, FetchObject, StreamKind, SubgroupObject};
7use crate::draft10::observer::ConnectionObserver;
8use crate::transport::quic::QuicTransport;
9use crate::transport::{RecvStream, SendStream, Transport, TransportError};
10use moqtap_codec::dispatch::{
11    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
12};
13use moqtap_codec::draft10::data_stream::{FetchObjectHeader, ObjectHeader};
14use moqtap_codec::draft10::message::ControlMessage;
15use moqtap_codec::error::CodecError;
16use moqtap_codec::types::*;
17use moqtap_codec::varint::VarInt;
18use moqtap_codec::version::DraftVersion;
19
20/// MoQT ALPN identifier (used by raw QUIC transport).
21pub const MOQT_ALPN: &[u8] = b"moq-00";
22
23/// Errors from the draft-10 connection layer.
24#[derive(Debug, thiserror::Error)]
25pub enum ConnectionError {
26    /// Endpoint state machine error.
27    #[error("endpoint error: {0}")]
28    Endpoint(#[from] EndpointError),
29    /// Wire codec error.
30    #[error("codec error: {0}")]
31    Codec(#[from] CodecError),
32    /// Transport-level error.
33    #[error("transport error: {0}")]
34    Transport(#[from] TransportError),
35    /// Variable-length integer decoding error.
36    #[error("varint error: {0}")]
37    VarInt(#[from] moqtap_codec::varint::VarIntError),
38    /// Control stream was not opened.
39    #[error("control stream not open")]
40    NoControlStream,
41    /// Stream ended before a complete message was read.
42    #[error("unexpected end of stream")]
43    UnexpectedEnd,
44    /// Stream was finished by the peer.
45    #[error("stream finished")]
46    StreamFinished,
47    /// Invalid server address string.
48    #[error("invalid server address: {0}")]
49    InvalidAddress(String),
50    /// TLS configuration error.
51    #[error("TLS config error: {0}")]
52    TlsConfig(String),
53}
54
55/// Transport type for the connection.
56#[derive(Debug, Clone)]
57pub enum TransportType {
58    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
59    Quic,
60    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
61    WebTransport {
62        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
63        url: String,
64    },
65}
66
67/// Configuration for a draft-10 MoQT client connection.
68pub struct ClientConfig {
69    /// Additional draft versions to offer in CLIENT_SETUP (draft-10 is always
70    /// offered first).
71    pub additional_versions: Vec<DraftVersion>,
72    /// The transport type (QUIC or WebTransport).
73    pub transport: TransportType,
74    /// Whether to skip TLS certificate verification (for testing).
75    pub skip_cert_verification: bool,
76    /// Custom CA certificates to trust (DER-encoded).
77    pub ca_certs: Vec<Vec<u8>>,
78    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
79    pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
80}
81
82impl ClientConfig {
83    /// Returns the MoQT version varints for the CLIENT_SETUP message.
84    /// Draft-10 first, then any additional versions.
85    pub fn supported_versions(&self) -> Vec<VarInt> {
86        let mut versions = vec![DraftVersion::Draft10.version_varint()];
87        for v in &self.additional_versions {
88            let varint = v.version_varint();
89            if !versions.contains(&varint) {
90                versions.push(varint);
91            }
92        }
93        versions
94    }
95
96    /// Returns the ALPN protocol identifiers for the transport.
97    pub fn alpn(&self) -> Vec<Vec<u8>> {
98        match &self.transport {
99            TransportType::Quic => vec![DraftVersion::Draft10.quic_alpn().to_vec()],
100            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
101        }
102    }
103}
104
105/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
106pub struct FramedSendStream {
107    inner: SendStream,
108}
109
110impl FramedSendStream {
111    /// Create a new framed send stream.
112    pub fn new(inner: SendStream) -> Self {
113        Self { inner }
114    }
115
116    /// Get the transport-level stream ID.
117    pub fn stream_id(&self) -> u64 {
118        self.inner.stream_id()
119    }
120
121    /// Write a control message to the stream with type+length framing.
122    /// Returns the raw bytes that were written (for event capture).
123    pub async fn write_control(
124        &mut self,
125        msg: &AnyControlMessage,
126    ) -> Result<Vec<u8>, ConnectionError> {
127        let mut buf = Vec::new();
128        msg.encode(&mut buf)?;
129        self.inner.write_all(&buf).await?;
130        Ok(buf)
131    }
132
133    /// Write a subgroup stream header.
134    pub async fn write_subgroup_header(
135        &mut self,
136        header: &AnySubgroupHeader,
137    ) -> Result<(), ConnectionError> {
138        let mut buf = Vec::new();
139        header.encode(&mut buf);
140        self.inner.write_all(&buf).await?;
141        Ok(())
142    }
143
144    /// Write a fetch response header.
145    pub async fn write_fetch_header(
146        &mut self,
147        header: &AnyFetchHeader,
148    ) -> Result<(), ConnectionError> {
149        let mut buf = Vec::new();
150        header.encode(&mut buf);
151        self.inner.write_all(&buf).await?;
152        Ok(())
153    }
154
155    /// Append a draft-10 subgroup object (header + payload) to the stream.
156    /// Draft-10 subgroup objects are stateless — no header-dependent bookkeeping.
157    pub async fn write_subgroup_object(
158        &mut self,
159        object: &SubgroupObject,
160    ) -> Result<(), ConnectionError> {
161        let mut buf = Vec::new();
162        object.header.encode(&mut buf);
163        buf.extend_from_slice(&object.payload);
164        self.inner.write_all(&buf).await?;
165        Ok(())
166    }
167
168    /// Append a draft-10 fetch object (header + payload) to the stream.
169    pub async fn write_fetch_object(
170        &mut self,
171        object: &FetchObject,
172    ) -> Result<(), ConnectionError> {
173        let mut buf = Vec::new();
174        object.header.encode(&mut buf);
175        buf.extend_from_slice(&object.payload);
176        self.inner.write_all(&buf).await?;
177        Ok(())
178    }
179
180    /// Finish the stream (send FIN).
181    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
182        self.inner.finish()?;
183        Ok(())
184    }
185}
186
187/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
188pub struct FramedRecvStream {
189    inner: RecvStream,
190    buf: BytesMut,
191}
192
193impl FramedRecvStream {
194    /// Create a new framed receive stream.
195    pub fn new(inner: RecvStream) -> Self {
196        Self { inner, buf: BytesMut::with_capacity(4096) }
197    }
198
199    /// Get the transport-level stream ID.
200    pub fn stream_id(&self) -> u64 {
201        self.inner.stream_id()
202    }
203
204    /// Read more data from the stream into the internal buffer.
205    async fn fill(&mut self) -> Result<bool, ConnectionError> {
206        let mut tmp = [0u8; 4096];
207        match self.inner.read(&mut tmp).await {
208            Ok(Some(n)) => {
209                self.buf.extend_from_slice(&tmp[..n]);
210                Ok(true)
211            }
212            Ok(None) => Ok(false),
213            Err(e) => Err(ConnectionError::Transport(e)),
214        }
215    }
216
217    /// Ensure at least `n` bytes are available in the buffer.
218    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
219        while self.buf.len() < n {
220            if !self.fill().await? {
221                return Err(ConnectionError::UnexpectedEnd);
222            }
223        }
224        Ok(())
225    }
226
227    /// Read a control message from the stream.
228    ///
229    /// When `capture_raw` is true, the returned tuple includes a clone of the
230    /// framed wire bytes (for observer emission). When false, the second
231    /// element is `None` and the payload clone is skipped.
232    pub async fn read_control(
233        &mut self,
234        capture_raw: bool,
235    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
236        // Read type ID varint
237        self.ensure(1).await?;
238        let type_len = varint_len(self.buf[0]);
239        self.ensure(type_len).await?;
240
241        let mut cursor = &self.buf[..type_len];
242        let _type_id = VarInt::decode(&mut cursor)?;
243
244        // Draft-10 uses varint length framing.
245        self.ensure(type_len + 1).await?;
246        let payload_len_start = type_len;
247        let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
248        self.ensure(type_len + payload_len_varint_len).await?;
249        let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
250        let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
251        let len_field_size = payload_len_varint_len;
252
253        // Read full payload
254        let total = type_len + len_field_size + payload_len;
255        self.ensure(total).await?;
256
257        // Capture raw bytes only if requested (observer attached).
258        let raw = capture_raw.then(|| self.buf[..total].to_vec());
259
260        // Now decode the whole message using the draft-10 dispatcher
261        let mut frame = &self.buf[..total];
262        let msg = AnyControlMessage::decode(DraftVersion::Draft10, &mut frame)?;
263        self.buf.advance(total);
264        Ok((msg, raw))
265    }
266
267    /// Read a subgroup stream header.
268    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
269        self.ensure(1).await?;
270        loop {
271            let mut cursor = &self.buf[..];
272            match AnySubgroupHeader::decode(DraftVersion::Draft10, &mut cursor) {
273                Ok(header) => {
274                    let consumed = self.buf.len() - cursor.remaining();
275                    self.buf.advance(consumed);
276                    return Ok(header);
277                }
278                Err(CodecError::UnexpectedEnd) => {
279                    if !self.fill().await? {
280                        return Err(ConnectionError::UnexpectedEnd);
281                    }
282                }
283                Err(e) => return Err(ConnectionError::Codec(e)),
284            }
285        }
286    }
287
288    /// Read a fetch response header.
289    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
290        self.ensure(1).await?;
291        loop {
292            let mut cursor = &self.buf[..];
293            match AnyFetchHeader::decode(DraftVersion::Draft10, &mut cursor) {
294                Ok(header) => {
295                    let consumed = self.buf.len() - cursor.remaining();
296                    self.buf.advance(consumed);
297                    return Ok(header);
298                }
299                Err(CodecError::UnexpectedEnd) => {
300                    if !self.fill().await? {
301                        return Err(ConnectionError::UnexpectedEnd);
302                    }
303                }
304                Err(e) => return Err(ConnectionError::Codec(e)),
305            }
306        }
307    }
308
309    /// Read the next draft-10 subgroup object (header + payload). Since
310    /// draft-10 subgroup objects are stateless, this does not require any
311    /// prior header-decoding state.
312    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
313        loop {
314            let mut cursor = &self.buf[..];
315            match ObjectHeader::decode(&mut cursor) {
316                Ok(header) => {
317                    let header_consumed = self.buf.len() - cursor.remaining();
318                    let payload_len = header.payload_length.into_inner() as usize;
319                    let total = header_consumed + payload_len;
320                    if self.buf.len() < total {
321                        if !self.fill().await? {
322                            return Err(ConnectionError::UnexpectedEnd);
323                        }
324                        continue;
325                    }
326                    let payload = self.buf[header_consumed..total].to_vec();
327                    self.buf.advance(total);
328                    return Ok(SubgroupObject { header, payload });
329                }
330                Err(CodecError::UnexpectedEnd) => {
331                    if !self.fill().await? {
332                        return Err(ConnectionError::UnexpectedEnd);
333                    }
334                }
335                Err(e) => return Err(ConnectionError::Codec(e)),
336            }
337        }
338    }
339
340    /// Read the next draft-10 fetch object (header + payload).
341    pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
342        loop {
343            let mut cursor = &self.buf[..];
344            match FetchObjectHeader::decode(&mut cursor) {
345                Ok(header) => {
346                    let header_consumed = self.buf.len() - cursor.remaining();
347                    let payload_len = header.payload_length.into_inner() as usize;
348                    let total = header_consumed + payload_len;
349                    if self.buf.len() < total {
350                        if !self.fill().await? {
351                            return Err(ConnectionError::UnexpectedEnd);
352                        }
353                        continue;
354                    }
355                    let payload = self.buf[header_consumed..total].to_vec();
356                    self.buf.advance(total);
357                    return Ok(FetchObject { header, payload });
358                }
359                Err(CodecError::UnexpectedEnd) => {
360                    if !self.fill().await? {
361                        return Err(ConnectionError::UnexpectedEnd);
362                    }
363                }
364                Err(e) => return Err(ConnectionError::Codec(e)),
365            }
366        }
367    }
368}
369
370/// A live draft-10 MoQT connection over QUIC or WebTransport.
371pub struct Connection {
372    transport: Transport,
373    endpoint: Endpoint,
374    control_send: Option<FramedSendStream>,
375    control_recv: Option<FramedRecvStream>,
376    observer: Option<Box<dyn ConnectionObserver>>,
377    /// Setup events buffered during `connect()` and replayed when an
378    /// observer attaches via `set_observer` — without this, an observer
379    /// attached after `connect` returns would never see the handshake.
380    pending_events: Vec<ClientEvent>,
381}
382
383impl Connection {
384    /// Connect to a draft-10 MoQT server as a client.
385    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
386        let transport = match &config.transport {
387            TransportType::Quic => Self::connect_quic(addr, &config).await?,
388            TransportType::WebTransport { url } => {
389                let url = url.clone();
390                Self::connect_webtransport(&url, &config).await?
391            }
392        };
393
394        // Open bidirectional control stream
395        let (send, recv) = transport.open_bi().await?;
396        let mut control_send = FramedSendStream::new(send);
397        let mut control_recv = FramedRecvStream::new(recv);
398
399        // Perform setup handshake
400        let mut endpoint = Endpoint::new();
401        endpoint.connect()?;
402        let setup_msg = endpoint
403            .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
404        let any_setup = AnyControlMessage::Draft10(setup_msg);
405        let raw_setup = control_send.write_control(&any_setup).await?;
406
407        let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
408        match &server_setup {
409            AnyControlMessage::Draft10(ControlMessage::ServerSetup(ref ss)) => {
410                endpoint.receive_server_setup(ss)?;
411            }
412            _ => {
413                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
414            }
415        }
416
417        let mut pending_events = Vec::with_capacity(3);
418        pending_events.push(ClientEvent::ControlMessage {
419            direction: Direction::Send,
420            message: any_setup,
421            raw: Some(raw_setup),
422        });
423        pending_events.push(ClientEvent::ControlMessage {
424            direction: Direction::Receive,
425            message: server_setup,
426            raw: raw_server_setup,
427        });
428        if let Some(v) = endpoint.negotiated_version() {
429            pending_events.push(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
430        }
431
432        Ok(Self {
433            transport,
434            endpoint,
435            control_send: Some(control_send),
436            control_recv: Some(control_recv),
437            observer: None,
438            pending_events,
439        })
440    }
441
442    /// Establish a raw QUIC connection.
443    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
444        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
445            ConnectionError::InvalidAddress(e.to_string())
446        })?;
447
448        let mut tls_config = if config.skip_cert_verification {
449            rustls::ClientConfig::builder()
450                .dangerous()
451                .with_custom_certificate_verifier(Arc::new(SkipVerification))
452                .with_no_client_auth()
453        } else {
454            let mut roots = rustls::RootCertStore::empty();
455            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
456            for der in &config.ca_certs {
457                roots
458                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
459                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
460            }
461            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
462        };
463
464        tls_config.alpn_protocols = config.alpn();
465
466        let quic_config: quinn::crypto::rustls::QuicClientConfig =
467            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
468        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
469
470        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
471            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
472        quinn_endpoint.set_default_client_config(client_config);
473
474        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
475
476        let quic = quinn_endpoint
477            .connect(server_addr, &server_name)
478            .map_err(TransportError::from)?
479            .await
480            .map_err(TransportError::from)?;
481
482        Ok(Transport::Quic(QuicTransport::new(quic)))
483    }
484
485    /// Establish a WebTransport connection.
486    #[cfg(feature = "webtransport")]
487    async fn connect_webtransport(
488        url: &str,
489        config: &ClientConfig,
490    ) -> Result<Transport, ConnectionError> {
491        use crate::transport::webtransport::WebTransportTransport;
492
493        let wt_config = if config.skip_cert_verification {
494            wtransport::ClientConfig::builder()
495                .with_bind_default()
496                .with_no_cert_validation()
497                .build()
498        } else {
499            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
500        };
501
502        let endpoint = wtransport::Endpoint::client(wt_config)
503            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
504
505        let connection = endpoint
506            .connect(url)
507            .await
508            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
509
510        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
511    }
512
513    /// Stub for when the webtransport feature is not enabled.
514    #[cfg(not(feature = "webtransport"))]
515    async fn connect_webtransport(
516        _url: &str,
517        _config: &ClientConfig,
518    ) -> Result<Transport, ConnectionError> {
519        Err(ConnectionError::Transport(TransportError::Connect(
520            "webtransport feature not enabled".into(),
521        )))
522    }
523
524    // ── Observer ───────────────────────────────────────────────
525
526    /// Attach an observer. Buffered handshake events from `connect()` are
527    /// flushed in arrival order before this returns.
528    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
529        self.observer = Some(observer);
530        for event in self.pending_events.drain(..) {
531            if let Some(ref obs) = self.observer {
532                obs.on_event_owned(event);
533            }
534        }
535    }
536
537    /// Remove the observer.
538    pub fn clear_observer(&mut self) {
539        self.observer = None;
540    }
541
542    /// Emit an event to the observer, if one is attached.
543    fn emit(&self, event: ClientEvent) {
544        if let Some(ref obs) = self.observer {
545            obs.on_event_owned(event);
546        }
547    }
548
549    // ── Control message I/O ─────────────────────────────────
550
551    /// Send a control message on the control stream.
552    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
553        let any = AnyControlMessage::Draft10(msg.clone());
554        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
555        let raw = send.write_control(&any).await?;
556        self.emit(ClientEvent::ControlMessage {
557            direction: Direction::Send,
558            message: any,
559            raw: Some(raw),
560        });
561        Ok(())
562    }
563
564    /// Read the next control message from the control stream.
565    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
566        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
567        let capture_raw = self.observer.is_some();
568        let (any, raw) = recv.read_control(capture_raw).await?;
569        if capture_raw {
570            self.emit(ClientEvent::ControlMessage {
571                direction: Direction::Receive,
572                message: any.clone(),
573                raw,
574            });
575        }
576        match any {
577            AnyControlMessage::Draft10(msg) => Ok(msg),
578            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
579        }
580    }
581
582    /// Read and dispatch the next incoming control message through the endpoint
583    /// state machine. Returns the decoded message for inspection.
584    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
585        let msg = self.recv_control().await?;
586        self.endpoint.receive_message(msg.clone())?;
587
588        if let ControlMessage::GoAway(ref ga) = msg {
589            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
590        }
591
592        Ok(msg)
593    }
594
595    // ── Subscribe flow ──────────────────────────────────────
596
597    /// Send a SUBSCRIBE and return the allocated subscribe ID.
598    #[allow(clippy::too_many_arguments)]
599    pub async fn subscribe(
600        &mut self,
601        track_alias: VarInt,
602        track_namespace: TrackNamespace,
603        track_name: Vec<u8>,
604        subscriber_priority: u8,
605        group_order: GroupOrder,
606        filter_type: FilterType,
607    ) -> Result<VarInt, ConnectionError> {
608        let (sub_id, msg) = self.endpoint.subscribe(
609            track_alias,
610            track_namespace,
611            track_name,
612            subscriber_priority,
613            group_order,
614            filter_type,
615        )?;
616        self.send_control(&msg).await?;
617        Ok(sub_id)
618    }
619
620    /// Send an UNSUBSCRIBE for the given subscribe ID.
621    pub async fn unsubscribe(&mut self, subscribe_id: VarInt) -> Result<(), ConnectionError> {
622        let msg = self.endpoint.unsubscribe(subscribe_id)?;
623        self.send_control(&msg).await
624    }
625
626    // ── Fetch flow ──────────────────────────────────────────
627
628    /// Send a FETCH and return the allocated subscribe ID.
629    #[allow(clippy::too_many_arguments)]
630    pub async fn fetch(
631        &mut self,
632        track_namespace: TrackNamespace,
633        track_name: Vec<u8>,
634        subscriber_priority: u8,
635        group_order: GroupOrder,
636        start_group: VarInt,
637        start_object: VarInt,
638        end_group: VarInt,
639        end_object: VarInt,
640    ) -> Result<VarInt, ConnectionError> {
641        let (sub_id, msg) = self.endpoint.fetch(
642            track_namespace,
643            track_name,
644            subscriber_priority,
645            group_order,
646            start_group,
647            start_object,
648            end_group,
649            end_object,
650        )?;
651        self.send_control(&msg).await?;
652        Ok(sub_id)
653    }
654
655    /// Send a FETCH_CANCEL for the given subscribe ID.
656    pub async fn fetch_cancel(&mut self, subscribe_id: VarInt) -> Result<(), ConnectionError> {
657        let msg = self.endpoint.fetch_cancel(subscribe_id)?;
658        self.send_control(&msg).await
659    }
660
661    // ── Namespace flows ─────────────────────────────────────
662
663    /// Send a SUBSCRIBE_ANNOUNCES.
664    pub async fn subscribe_announces(
665        &mut self,
666        track_namespace_prefix: TrackNamespace,
667    ) -> Result<(), ConnectionError> {
668        let msg = self.endpoint.subscribe_announces(track_namespace_prefix)?;
669        self.send_control(&msg).await
670    }
671
672    /// Send an ANNOUNCE.
673    pub async fn announce(
674        &mut self,
675        track_namespace: TrackNamespace,
676    ) -> Result<(), ConnectionError> {
677        let msg = self.endpoint.announce(track_namespace)?;
678        self.send_control(&msg).await
679    }
680
681    /// Send an UNANNOUNCE.
682    pub async fn unannounce(
683        &mut self,
684        track_namespace: TrackNamespace,
685    ) -> Result<(), ConnectionError> {
686        let msg = self.endpoint.unannounce(track_namespace)?;
687        self.send_control(&msg).await
688    }
689
690    // ── Track Status flow ────────────────────────────────────
691
692    /// Send a TRACK_STATUS_REQUEST.
693    pub async fn track_status_request(
694        &mut self,
695        track_namespace: TrackNamespace,
696        track_name: Vec<u8>,
697    ) -> Result<(), ConnectionError> {
698        let msg = self.endpoint.track_status_request(track_namespace, track_name)?;
699        self.send_control(&msg).await
700    }
701
702    // ── Data streams ────────────────────────────────────────
703
704    /// Open a new unidirectional stream for sending subgroup data.
705    pub async fn open_subgroup_stream(
706        &self,
707        header: &AnySubgroupHeader,
708    ) -> Result<FramedSendStream, ConnectionError> {
709        let send = self.transport.open_uni().await?;
710        let mut framed = FramedSendStream::new(send);
711        let sid = framed.stream_id();
712        framed.write_subgroup_header(header).await?;
713        self.emit(ClientEvent::StreamOpened {
714            direction: Direction::Send,
715            stream_kind: StreamKind::Subgroup,
716            stream_id: sid,
717        });
718        self.emit(ClientEvent::DataStreamHeader {
719            stream_id: sid,
720            direction: Direction::Send,
721            header: header.clone(),
722        });
723        Ok(framed)
724    }
725
726    /// Accept an incoming unidirectional data stream and read its subgroup
727    /// header.
728    pub async fn accept_subgroup_stream(
729        &self,
730    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
731        let recv = self.transport.accept_uni().await?;
732        let mut framed = FramedRecvStream::new(recv);
733        let sid = framed.stream_id();
734        let header = framed.read_subgroup_header().await?;
735        self.emit(ClientEvent::StreamOpened {
736            direction: Direction::Receive,
737            stream_kind: StreamKind::Subgroup,
738            stream_id: sid,
739        });
740        self.emit(ClientEvent::DataStreamHeader {
741            stream_id: sid,
742            direction: Direction::Receive,
743            header: header.clone(),
744        });
745        Ok((header, framed))
746    }
747
748    /// Send an object via datagram.
749    pub fn send_datagram(
750        &self,
751        header: &AnyDatagramHeader,
752        payload: &[u8],
753    ) -> Result<(), ConnectionError> {
754        let mut buf = Vec::new();
755        header.encode(&mut buf);
756        buf.extend_from_slice(payload);
757        self.emit(ClientEvent::DatagramReceived {
758            direction: Direction::Send,
759            header: header.clone(),
760            payload_len: payload.len(),
761        });
762        self.transport.send_datagram(bytes::Bytes::from(buf))?;
763        Ok(())
764    }
765
766    /// Receive a datagram and decode its header.
767    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
768        let data = self.transport.recv_datagram().await?;
769        let mut cursor = &data[..];
770        let header = AnyDatagramHeader::decode(DraftVersion::Draft10, &mut cursor)?;
771        let consumed = data.len() - cursor.len();
772        let payload = data.slice(consumed..);
773        self.emit(ClientEvent::DatagramReceived {
774            direction: Direction::Receive,
775            header: header.clone(),
776            payload_len: payload.len(),
777        });
778        Ok((header, payload))
779    }
780
781    // ── Accessors ───────────────────────────────────────────
782
783    /// Access the underlying endpoint state machine.
784    pub fn endpoint(&self) -> &Endpoint {
785        &self.endpoint
786    }
787
788    /// Mutable access to the endpoint state machine.
789    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
790        &mut self.endpoint
791    }
792
793    /// Get the negotiated MoQT version.
794    pub fn negotiated_version(&self) -> Option<VarInt> {
795        self.endpoint.negotiated_version()
796    }
797
798    /// Close the connection.
799    pub fn close(&self, code: u32, reason: &[u8]) {
800        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
801        self.transport.close(code, reason);
802    }
803}
804
805/// Determine the encoded length of a varint from its first byte.
806fn varint_len(first_byte: u8) -> usize {
807    1 << (first_byte >> 6)
808}
809
810/// TLS certificate verifier that skips all verification (for testing only).
811#[derive(Debug)]
812struct SkipVerification;
813
814impl rustls::client::danger::ServerCertVerifier for SkipVerification {
815    fn verify_server_cert(
816        &self,
817        _end_entity: &rustls::pki_types::CertificateDer<'_>,
818        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
819        _server_name: &rustls::pki_types::ServerName<'_>,
820        _ocsp_response: &[u8],
821        _now: rustls::pki_types::UnixTime,
822    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
823        Ok(rustls::client::danger::ServerCertVerified::assertion())
824    }
825
826    fn verify_tls12_signature(
827        &self,
828        _message: &[u8],
829        _cert: &rustls::pki_types::CertificateDer<'_>,
830        _dcs: &rustls::DigitallySignedStruct,
831    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
832        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
833    }
834
835    fn verify_tls13_signature(
836        &self,
837        _message: &[u8],
838        _cert: &rustls::pki_types::CertificateDer<'_>,
839        _dcs: &rustls::DigitallySignedStruct,
840    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
841        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
842    }
843
844    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
845        vec![
846            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
847            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
848            rustls::SignatureScheme::ED25519,
849            rustls::SignatureScheme::RSA_PSS_SHA256,
850            rustls::SignatureScheme::RSA_PSS_SHA384,
851            rustls::SignatureScheme::RSA_PSS_SHA512,
852        ]
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859
860    #[test]
861    fn client_config_supported_versions_default() {
862        let config = ClientConfig {
863            additional_versions: Vec::new(),
864            transport: TransportType::Quic,
865            skip_cert_verification: false,
866            ca_certs: Vec::new(),
867            setup_parameters: Vec::new(),
868        };
869        let versions = config.supported_versions();
870        assert_eq!(versions.len(), 1);
871        assert_eq!(versions[0].into_inner(), 0xff000000 + 10);
872    }
873
874    #[test]
875    fn client_config_alpn_quic() {
876        let config = ClientConfig {
877            additional_versions: Vec::new(),
878            transport: TransportType::Quic,
879            skip_cert_verification: false,
880            ca_certs: Vec::new(),
881            setup_parameters: Vec::new(),
882        };
883        assert_eq!(config.alpn(), vec![DraftVersion::Draft10.quic_alpn().to_vec()]);
884    }
885
886    #[test]
887    fn moqt_alpn_value() {
888        assert_eq!(MOQT_ALPN, b"moq-00");
889    }
890}