Skip to main content

moqtap_client/draft14/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft14::endpoint::{Endpoint, EndpointError};
6use crate::draft14::event::{ClientEvent, Direction, StreamKind};
7use crate::draft14::observer::ConnectionObserver;
8use crate::draft14::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft14::data_stream::{FetchObject, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft14::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::types::*;
18use moqtap_codec::varint::VarInt;
19use moqtap_codec::version::DraftVersion;
20
21/// MoQT ALPN identifier (used by raw QUIC transport).
22pub const MOQT_ALPN: &[u8] = b"moq-00";
23
24/// Errors from the connection layer.
25#[derive(Debug, thiserror::Error)]
26pub enum ConnectionError {
27    /// Endpoint state machine error.
28    #[error("endpoint error: {0}")]
29    Endpoint(#[from] EndpointError),
30    /// Wire codec error.
31    #[error("codec error: {0}")]
32    Codec(#[from] CodecError),
33    /// Transport-level error.
34    #[error("transport error: {0}")]
35    Transport(#[from] TransportError),
36    /// Variable-length integer decoding error.
37    #[error("varint error: {0}")]
38    VarInt(#[from] moqtap_codec::varint::VarIntError),
39    /// Control stream was not opened.
40    #[error("control stream not open")]
41    NoControlStream,
42    /// Stream ended before a complete message was read.
43    #[error("unexpected end of stream")]
44    UnexpectedEnd,
45    /// Stream was finished by the peer.
46    #[error("stream finished")]
47    StreamFinished,
48    /// Invalid server address string.
49    #[error("invalid server address: {0}")]
50    InvalidAddress(String),
51    /// TLS configuration error.
52    #[error("TLS config error: {0}")]
53    TlsConfig(String),
54    /// Data stream used out of order (e.g. object before header).
55    #[error("data stream state error: {0}")]
56    DataStreamState(&'static str),
57}
58
59/// Transport type for the connection.
60#[derive(Debug, Clone)]
61pub enum TransportType {
62    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
63    Quic,
64    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
65    WebTransport {
66        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
67        url: String,
68    },
69}
70
71/// Configuration for a MoQT client connection.
72///
73/// Both `draft` and `transport` are required — there is no `Default` impl.
74pub struct ClientConfig {
75    /// The MoQT draft version to use (primary, determines codec/framing).
76    pub draft: DraftVersion,
77    /// Additional draft versions to offer in CLIENT_SETUP.
78    /// The primary `draft` is always included first.
79    pub additional_versions: Vec<DraftVersion>,
80    /// The transport type (QUIC or WebTransport).
81    pub transport: TransportType,
82    /// Whether to skip TLS certificate verification (for testing).
83    pub skip_cert_verification: bool,
84    /// Custom CA certificates to trust (DER-encoded).
85    pub ca_certs: Vec<Vec<u8>>,
86    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
87    pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
88}
89
90impl ClientConfig {
91    /// Returns the MoQT version varints for the CLIENT_SETUP message.
92    /// Primary draft first, then any additional versions.
93    pub fn supported_versions(&self) -> Vec<VarInt> {
94        let mut versions = vec![self.draft.version_varint()];
95        for v in &self.additional_versions {
96            let varint = v.version_varint();
97            if !versions.contains(&varint) {
98                versions.push(varint);
99            }
100        }
101        versions
102    }
103
104    /// Returns the ALPN protocol identifiers for the transport.
105    pub fn alpn(&self) -> Vec<Vec<u8>> {
106        match &self.transport {
107            TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
108            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
109        }
110    }
111}
112
113/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
114pub struct FramedSendStream {
115    inner: SendStream,
116    draft: DraftVersion,
117    /// Stateful subgroup object encoder. Initialized by
118    /// [`FramedSendStream::write_subgroup_header`] and used by
119    /// [`FramedSendStream::write_subgroup_object`] to track the delta-encoded
120    /// object ID state.
121    subgroup_io: Option<SubgroupObjectReader>,
122}
123
124impl FramedSendStream {
125    /// Create a new framed send stream for the given draft version.
126    pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
127        Self { inner, draft, subgroup_io: None }
128    }
129
130    /// Get the transport-level stream ID.
131    pub fn stream_id(&self) -> u64 {
132        self.inner.stream_id()
133    }
134
135    /// Write a control message to the stream with type+length framing.
136    /// Returns the raw bytes that were written (for event capture).
137    pub async fn write_control(
138        &mut self,
139        msg: &AnyControlMessage,
140    ) -> Result<Vec<u8>, ConnectionError> {
141        let mut buf = Vec::new();
142        msg.encode(&mut buf)?;
143        self.inner.write_all(&buf).await?;
144        Ok(buf)
145    }
146
147    /// Write a subgroup stream header. Also initializes the internal
148    /// delta-encoding state used by [`FramedSendStream::write_subgroup_object`].
149    pub async fn write_subgroup_header(
150        &mut self,
151        header: &AnySubgroupHeader,
152    ) -> Result<(), ConnectionError> {
153        let mut buf = Vec::new();
154        header.encode(&mut buf);
155        self.inner.write_all(&buf).await?;
156        if let AnySubgroupHeader::Draft14(ref d14) = header {
157            self.subgroup_io = Some(SubgroupObjectReader::new(d14));
158        }
159        Ok(())
160    }
161
162    /// Write a fetch response header.
163    pub async fn write_fetch_header(
164        &mut self,
165        header: &AnyFetchHeader,
166    ) -> Result<(), ConnectionError> {
167        let mut buf = Vec::new();
168        header.encode(&mut buf);
169        self.inner.write_all(&buf).await?;
170        Ok(())
171    }
172
173    /// Append a draft-14 subgroup object to the stream. Uses the stateful
174    /// reader installed by [`FramedSendStream::write_subgroup_header`] to
175    /// produce the correct delta-encoded object ID. Returns an error if
176    /// called before a subgroup header was written.
177    pub async fn write_subgroup_object(
178        &mut self,
179        object: &SubgroupObject,
180    ) -> Result<(), ConnectionError> {
181        let writer = self
182            .subgroup_io
183            .as_mut()
184            .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
185        let mut buf = Vec::new();
186        writer.write_object(object, &mut buf)?;
187        self.inner.write_all(&buf).await?;
188        Ok(())
189    }
190
191    /// Append a draft-14 fetch object to the stream. Fetch objects are
192    /// self-contained (each carries its own group/subgroup/object IDs and
193    /// priority), so no prior `write_fetch_header` bookkeeping is required.
194    pub async fn write_fetch_object(
195        &mut self,
196        object: &FetchObject,
197    ) -> Result<(), ConnectionError> {
198        let mut buf = Vec::new();
199        object.encode(&mut buf);
200        self.inner.write_all(&buf).await?;
201        Ok(())
202    }
203
204    /// Finish the stream (send FIN).
205    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
206        self.inner.finish()?;
207        Ok(())
208    }
209
210    /// Returns the draft version this stream is framed for.
211    pub fn draft(&self) -> DraftVersion {
212        self.draft
213    }
214}
215
216/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
217pub struct FramedRecvStream {
218    inner: RecvStream,
219    buf: BytesMut,
220    draft: DraftVersion,
221    /// Stateful subgroup object decoder. Initialized by
222    /// [`FramedRecvStream::read_subgroup_header`] and used by
223    /// [`FramedRecvStream::read_subgroup_object`] to track delta-encoded
224    /// object IDs and whether extension headers are present.
225    subgroup_io: Option<SubgroupObjectReader>,
226}
227
228impl FramedRecvStream {
229    /// Create a new framed receive stream for the given draft version.
230    pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
231        Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
232    }
233
234    /// Get the transport-level stream ID.
235    pub fn stream_id(&self) -> u64 {
236        self.inner.stream_id()
237    }
238
239    /// Read more data from the stream into the internal buffer.
240    async fn fill(&mut self) -> Result<bool, ConnectionError> {
241        let mut tmp = [0u8; 4096];
242        match self.inner.read(&mut tmp).await {
243            Ok(Some(n)) => {
244                self.buf.extend_from_slice(&tmp[..n]);
245                Ok(true)
246            }
247            Ok(None) => Ok(false),
248            Err(e) => Err(ConnectionError::Transport(e)),
249        }
250    }
251
252    /// Ensure at least `n` bytes are available in the buffer.
253    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
254        while self.buf.len() < n {
255            if !self.fill().await? {
256                return Err(ConnectionError::UnexpectedEnd);
257            }
258        }
259        Ok(())
260    }
261
262    /// Read a control message from the stream.
263    ///
264    /// When `capture_raw` is true, the returned tuple includes a clone of the
265    /// framed wire bytes (for observer emission). When false, the second
266    /// element is `None` and the payload clone is skipped.
267    pub async fn read_control(
268        &mut self,
269        capture_raw: bool,
270    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
271        // Read type ID varint
272        self.ensure(1).await?;
273        let type_len = varint_len(self.buf[0]);
274        self.ensure(type_len).await?;
275
276        let mut cursor = &self.buf[..type_len];
277        let _type_id = VarInt::decode(&mut cursor)?;
278
279        // Read payload length (16-bit BE for draft-11+, varint for earlier drafts)
280        let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
281            self.ensure(type_len + 2).await?;
282            let hi = self.buf[type_len] as usize;
283            let lo = self.buf[type_len + 1] as usize;
284            ((hi << 8) | lo, 2)
285        } else {
286            self.ensure(type_len + 1).await?;
287            let payload_len_start = type_len;
288            let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
289            self.ensure(type_len + payload_len_varint_len).await?;
290            let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
291            let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
292            (payload_len, payload_len_varint_len)
293        };
294
295        // Read full payload
296        let total = type_len + len_field_size + payload_len;
297        self.ensure(total).await?;
298
299        // Capture raw bytes only if requested (observer attached).
300        let raw = capture_raw.then(|| self.buf[..total].to_vec());
301
302        // Now decode the whole message
303        let mut frame = &self.buf[..total];
304        let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
305        self.buf.advance(total);
306        Ok((msg, raw))
307    }
308
309    /// Read a subgroup stream header. Also initializes the internal
310    /// delta-decoding state used by
311    /// [`FramedRecvStream::read_subgroup_object`].
312    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
313        self.ensure(1).await?;
314        loop {
315            let mut cursor = &self.buf[..];
316            match AnySubgroupHeader::decode(self.draft, &mut cursor) {
317                Ok(header) => {
318                    let consumed = self.buf.len() - cursor.remaining();
319                    self.buf.advance(consumed);
320                    if let AnySubgroupHeader::Draft14(ref d14) = header {
321                        self.subgroup_io = Some(SubgroupObjectReader::new(d14));
322                    }
323                    return Ok(header);
324                }
325                Err(CodecError::UnexpectedEnd) => {
326                    if !self.fill().await? {
327                        return Err(ConnectionError::UnexpectedEnd);
328                    }
329                }
330                Err(e) => return Err(ConnectionError::Codec(e)),
331            }
332        }
333    }
334
335    /// Read a fetch response header.
336    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
337        self.ensure(1).await?;
338        loop {
339            let mut cursor = &self.buf[..];
340            match AnyFetchHeader::decode(self.draft, &mut cursor) {
341                Ok(header) => {
342                    let consumed = self.buf.len() - cursor.remaining();
343                    self.buf.advance(consumed);
344                    return Ok(header);
345                }
346                Err(CodecError::UnexpectedEnd) => {
347                    if !self.fill().await? {
348                        return Err(ConnectionError::UnexpectedEnd);
349                    }
350                }
351                Err(e) => return Err(ConnectionError::Codec(e)),
352            }
353        }
354    }
355
356    /// Read the next draft-14 subgroup object from this stream. Uses the
357    /// stateful reader installed by
358    /// [`FramedRecvStream::read_subgroup_header`] to decode the delta-
359    /// encoded object ID and handle extension headers per the stream type.
360    /// Returns an error if called before a subgroup header was read.
361    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
362        if self.subgroup_io.is_none() {
363            return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
364        }
365        loop {
366            let reader = self.subgroup_io.as_mut().unwrap();
367            let mut probe = reader.clone();
368            let mut cursor = &self.buf[..];
369            match probe.read_object(&mut cursor) {
370                Ok(obj) => {
371                    let consumed = self.buf.len() - cursor.remaining();
372                    self.buf.advance(consumed);
373                    // Commit state from the successful probe decode.
374                    *reader = probe;
375                    return Ok(obj);
376                }
377                Err(CodecError::UnexpectedEnd) => {
378                    if !self.fill().await? {
379                        return Err(ConnectionError::UnexpectedEnd);
380                    }
381                }
382                Err(e) => return Err(ConnectionError::Codec(e)),
383            }
384        }
385    }
386
387    /// Read the next draft-14 fetch object from this stream. Fetch objects
388    /// are self-describing, so no prior `read_fetch_header` state is needed
389    /// to decode each object.
390    pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
391        loop {
392            let mut cursor = &self.buf[..];
393            match FetchObject::decode(&mut cursor) {
394                Ok(obj) => {
395                    let consumed = self.buf.len() - cursor.remaining();
396                    self.buf.advance(consumed);
397                    return Ok(obj);
398                }
399                Err(CodecError::UnexpectedEnd) => {
400                    if !self.fill().await? {
401                        return Err(ConnectionError::UnexpectedEnd);
402                    }
403                }
404                Err(e) => return Err(ConnectionError::Codec(e)),
405            }
406        }
407    }
408
409    /// Returns the draft version this stream is framed for.
410    pub fn draft(&self) -> DraftVersion {
411        self.draft
412    }
413}
414
415/// A live MoQT connection over QUIC or WebTransport, combining the endpoint
416/// state machine with actual network I/O.
417pub struct Connection {
418    transport: Transport,
419    endpoint: Endpoint,
420    draft: DraftVersion,
421    control_send: Option<FramedSendStream>,
422    control_recv: Option<FramedRecvStream>,
423    observer: Option<Box<dyn ConnectionObserver>>,
424    /// Setup events buffered during `connect()` and replayed when an
425    /// observer attaches via `set_observer` — without this, an observer
426    /// attached after `connect` returns would never see the handshake.
427    pending_events: Vec<ClientEvent>,
428}
429
430impl Connection {
431    /// Connect to a MoQT server as a client.
432    ///
433    /// Establishes a QUIC or WebTransport connection (based on `config.transport`),
434    /// opens a bidirectional control stream, performs the CLIENT_SETUP /
435    /// SERVER_SETUP handshake, and returns a ready-to-use connection.
436    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
437        let draft = config.draft;
438        let transport = match &config.transport {
439            TransportType::Quic => Self::connect_quic(addr, &config).await?,
440            TransportType::WebTransport { url } => {
441                let url = url.clone();
442                Self::connect_webtransport(&url, &config).await?
443            }
444        };
445
446        // Open bidirectional control stream
447        let (send, recv) = transport.open_bi().await?;
448        let mut control_send = FramedSendStream::new(send, draft);
449        let mut control_recv = FramedRecvStream::new(recv, draft);
450
451        // Perform setup handshake
452        let mut endpoint = Endpoint::new(Role::Client);
453        endpoint.connect()?;
454        let setup_msg = endpoint
455            .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
456        let any_setup = AnyControlMessage::Draft14(setup_msg);
457        let raw_setup = control_send.write_control(&any_setup).await?;
458
459        let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
460        // Unwrap to draft-14 for the endpoint (which is draft-14 only)
461        match &server_setup {
462            AnyControlMessage::Draft14(ControlMessage::ServerSetup(ref ss)) => {
463                endpoint.receive_server_setup(ss)?;
464            }
465            _ => {
466                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
467            }
468        }
469
470        let mut pending_events = Vec::with_capacity(3);
471        pending_events.push(ClientEvent::ControlMessage {
472            direction: Direction::Send,
473            message: any_setup,
474            raw: Some(raw_setup),
475        });
476        pending_events.push(ClientEvent::ControlMessage {
477            direction: Direction::Receive,
478            message: server_setup,
479            raw: raw_server_setup,
480        });
481        if let Some(v) = endpoint.negotiated_version() {
482            pending_events.push(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
483        }
484
485        Ok(Self {
486            transport,
487            endpoint,
488            draft,
489            control_send: Some(control_send),
490            control_recv: Some(control_recv),
491            observer: None,
492            pending_events,
493        })
494    }
495
496    /// Establish a raw QUIC connection.
497    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
498        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
499            ConnectionError::InvalidAddress(e.to_string())
500        })?;
501
502        // Build TLS config
503        let mut tls_config = if config.skip_cert_verification {
504            rustls::ClientConfig::builder()
505                .dangerous()
506                .with_custom_certificate_verifier(Arc::new(SkipVerification))
507                .with_no_client_auth()
508        } else {
509            let mut roots = rustls::RootCertStore::empty();
510            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
511            for der in &config.ca_certs {
512                roots
513                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
514                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
515            }
516            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
517        };
518
519        tls_config.alpn_protocols = config.alpn();
520
521        let quic_config: quinn::crypto::rustls::QuicClientConfig =
522            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
523        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
524
525        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
526            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
527        quinn_endpoint.set_default_client_config(client_config);
528
529        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
530
531        let quic = quinn_endpoint
532            .connect(server_addr, &server_name)
533            .map_err(TransportError::from)?
534            .await
535            .map_err(TransportError::from)?;
536
537        Ok(Transport::Quic(QuicTransport::new(quic)))
538    }
539
540    /// Establish a WebTransport connection.
541    #[cfg(feature = "webtransport")]
542    async fn connect_webtransport(
543        url: &str,
544        config: &ClientConfig,
545    ) -> Result<Transport, ConnectionError> {
546        use crate::transport::webtransport::WebTransportTransport;
547
548        let wt_config = if config.skip_cert_verification {
549            wtransport::ClientConfig::builder()
550                .with_bind_default()
551                .with_no_cert_validation()
552                .build()
553        } else {
554            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
555        };
556
557        let endpoint = wtransport::Endpoint::client(wt_config)
558            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
559
560        let connection = endpoint
561            .connect(url)
562            .await
563            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
564
565        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
566    }
567
568    /// Stub for when the webtransport feature is not enabled.
569    #[cfg(not(feature = "webtransport"))]
570    async fn connect_webtransport(
571        _url: &str,
572        _config: &ClientConfig,
573    ) -> Result<Transport, ConnectionError> {
574        Err(ConnectionError::Transport(TransportError::Connect(
575            "webtransport feature not enabled".into(),
576        )))
577    }
578
579    // ── Observer ───────────────────────────────────────────────
580
581    /// Attach an observer. Buffered handshake events from `connect()` are
582    /// flushed in arrival order before this returns.
583    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
584        self.observer = Some(observer);
585        for event in self.pending_events.drain(..) {
586            if let Some(ref obs) = self.observer {
587                obs.on_event_owned(event);
588            }
589        }
590    }
591
592    /// Remove the observer.
593    pub fn clear_observer(&mut self) {
594        self.observer = None;
595    }
596
597    /// Emit an event to the observer, if one is attached.
598    fn emit(&self, event: ClientEvent) {
599        if let Some(ref obs) = self.observer {
600            obs.on_event_owned(event);
601        }
602    }
603
604    // ── Control message I/O ─────────────────────────────────
605
606    /// Send a control message on the control stream.
607    ///
608    /// Wraps the draft-14 message in `AnyControlMessage::Draft14` for framing.
609    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
610        let any = AnyControlMessage::Draft14(msg.clone());
611        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
612        let raw = send.write_control(&any).await?;
613        self.emit(ClientEvent::ControlMessage {
614            direction: Direction::Send,
615            message: any,
616            raw: Some(raw),
617        });
618        Ok(())
619    }
620
621    /// Read the next control message from the control stream.
622    ///
623    /// Returns the `AnyControlMessage` and also extracts the draft-14
624    /// `ControlMessage` for internal endpoint dispatch.
625    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
626        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
627        let capture_raw = self.observer.is_some();
628        let (any, raw) = recv.read_control(capture_raw).await?;
629        if capture_raw {
630            self.emit(ClientEvent::ControlMessage {
631                direction: Direction::Receive,
632                message: any.clone(),
633                raw,
634            });
635        }
636        // Unwrap to draft-14 for the endpoint
637        match any {
638            AnyControlMessage::Draft14(msg) => Ok(msg),
639            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
640        }
641    }
642
643    /// Read and dispatch the next incoming control message through the endpoint
644    /// state machine. Returns the decoded message for inspection.
645    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
646        let msg = self.recv_control().await?;
647        self.endpoint.receive_message(msg.clone())?;
648
649        // Emit draining event if this was a GoAway
650        if let ControlMessage::GoAway(ref ga) = msg {
651            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
652        }
653
654        Ok(msg)
655    }
656
657    // ── Subscribe flow ──────────────────────────────────────
658
659    /// Send a SUBSCRIBE and return the allocated request ID.
660    pub async fn subscribe(
661        &mut self,
662        track_namespace: TrackNamespace,
663        track_name: Vec<u8>,
664        subscriber_priority: u8,
665        group_order: GroupOrder,
666        filter_type: FilterType,
667    ) -> Result<VarInt, ConnectionError> {
668        let (req_id, msg) = self.endpoint.subscribe(
669            track_namespace,
670            track_name,
671            subscriber_priority,
672            group_order,
673            filter_type,
674        )?;
675        self.send_control(&msg).await?;
676        Ok(req_id)
677    }
678
679    /// Send an UNSUBSCRIBE for the given request ID.
680    pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
681        let msg = self.endpoint.unsubscribe(request_id)?;
682        self.send_control(&msg).await
683    }
684
685    /// Send a SUBSCRIBE_UPDATE for an active subscription. Returns the
686    /// allocated request ID of the update message.
687    pub async fn subscribe_update(
688        &mut self,
689        subscription_request_id: VarInt,
690        start_location: Location,
691        end_group: VarInt,
692        subscriber_priority: u8,
693        forward: Forward,
694        parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
695    ) -> Result<VarInt, ConnectionError> {
696        let (req_id, msg) = self.endpoint.subscribe_update(
697            subscription_request_id,
698            start_location,
699            end_group,
700            subscriber_priority,
701            forward,
702            parameters,
703        )?;
704        self.send_control(&msg).await?;
705        Ok(req_id)
706    }
707
708    // ── Fetch flow ──────────────────────────────────────────
709
710    /// Send a FETCH and return the allocated request ID.
711    pub async fn fetch(
712        &mut self,
713        track_namespace: TrackNamespace,
714        track_name: Vec<u8>,
715        start_group: VarInt,
716        start_object: VarInt,
717    ) -> Result<VarInt, ConnectionError> {
718        let (req_id, msg) =
719            self.endpoint.fetch(track_namespace, track_name, start_group, start_object)?;
720        self.send_control(&msg).await?;
721        Ok(req_id)
722    }
723
724    /// Send a FETCH_CANCEL for the given request ID.
725    pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
726        let msg = self.endpoint.fetch_cancel(request_id)?;
727        self.send_control(&msg).await
728    }
729
730    // ── Namespace flows ─────────────────────────────────────
731
732    /// Send a SUBSCRIBE_NAMESPACE and return the request ID.
733    pub async fn subscribe_namespace(
734        &mut self,
735        track_namespace: TrackNamespace,
736    ) -> Result<VarInt, ConnectionError> {
737        let (req_id, msg) = self.endpoint.subscribe_namespace(track_namespace)?;
738        self.send_control(&msg).await?;
739        Ok(req_id)
740    }
741
742    /// Send a PUBLISH_NAMESPACE and return the request ID.
743    pub async fn publish_namespace(
744        &mut self,
745        track_namespace: TrackNamespace,
746    ) -> Result<VarInt, ConnectionError> {
747        let (req_id, msg) = self.endpoint.publish_namespace(track_namespace)?;
748        self.send_control(&msg).await?;
749        Ok(req_id)
750    }
751
752    // ── Track Status flow ────────────────────────────────────
753
754    /// Send a TRACK_STATUS and return the allocated request ID.
755    pub async fn track_status(
756        &mut self,
757        track_namespace: TrackNamespace,
758        track_name: Vec<u8>,
759    ) -> Result<VarInt, ConnectionError> {
760        let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name)?;
761        self.send_control(&msg).await?;
762        Ok(req_id)
763    }
764
765    // ── Publish flow (publisher side) ───────────────────────
766
767    /// Send a PUBLISH and return the allocated request ID.
768    pub async fn publish(
769        &mut self,
770        track_namespace: TrackNamespace,
771        track_name: Vec<u8>,
772        forward: Forward,
773    ) -> Result<VarInt, ConnectionError> {
774        let (req_id, msg) = self.endpoint.publish(track_namespace, track_name, forward)?;
775        self.send_control(&msg).await?;
776        Ok(req_id)
777    }
778
779    /// Send a PUBLISH_DONE for the given request ID.
780    pub async fn publish_done(
781        &mut self,
782        request_id: VarInt,
783        status_code: VarInt,
784        reason_phrase: Vec<u8>,
785    ) -> Result<(), ConnectionError> {
786        let msg = self.endpoint.send_publish_done(request_id, status_code, reason_phrase)?;
787        self.send_control(&msg).await
788    }
789
790    // ── Data streams ────────────────────────────────────────
791
792    /// Open a new unidirectional stream for sending subgroup data.
793    pub async fn open_subgroup_stream(
794        &self,
795        header: &AnySubgroupHeader,
796    ) -> Result<FramedSendStream, ConnectionError> {
797        let send = self.transport.open_uni().await?;
798        let mut framed = FramedSendStream::new(send, self.draft);
799        let sid = framed.stream_id();
800        framed.write_subgroup_header(header).await?;
801        self.emit(ClientEvent::StreamOpened {
802            direction: Direction::Send,
803            stream_kind: StreamKind::Subgroup,
804            stream_id: sid,
805        });
806        self.emit(ClientEvent::DataStreamHeader {
807            stream_id: sid,
808            direction: Direction::Send,
809            header: header.clone(),
810        });
811        Ok(framed)
812    }
813
814    /// Accept an incoming unidirectional data stream and read its subgroup
815    /// header.
816    pub async fn accept_subgroup_stream(
817        &self,
818    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
819        let recv = self.transport.accept_uni().await?;
820        let mut framed = FramedRecvStream::new(recv, self.draft);
821        let sid = framed.stream_id();
822        let header = framed.read_subgroup_header().await?;
823        self.emit(ClientEvent::StreamOpened {
824            direction: Direction::Receive,
825            stream_kind: StreamKind::Subgroup,
826            stream_id: sid,
827        });
828        self.emit(ClientEvent::DataStreamHeader {
829            stream_id: sid,
830            direction: Direction::Receive,
831            header: header.clone(),
832        });
833        Ok((header, framed))
834    }
835
836    /// Send an object via datagram.
837    pub fn send_datagram(
838        &self,
839        header: &AnyDatagramHeader,
840        payload: &[u8],
841    ) -> Result<(), ConnectionError> {
842        let mut buf = Vec::new();
843        header.encode(&mut buf);
844        buf.extend_from_slice(payload);
845        self.emit(ClientEvent::DatagramReceived {
846            direction: Direction::Send,
847            header: header.clone(),
848            payload_len: payload.len(),
849        });
850        self.transport.send_datagram(bytes::Bytes::from(buf))?;
851        Ok(())
852    }
853
854    /// Receive a datagram and decode its header.
855    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
856        let data = self.transport.recv_datagram().await?;
857        let mut cursor = &data[..];
858        let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
859        let consumed = data.len() - cursor.len();
860        let payload = data.slice(consumed..);
861        self.emit(ClientEvent::DatagramReceived {
862            direction: Direction::Receive,
863            header: header.clone(),
864            payload_len: payload.len(),
865        });
866        Ok((header, payload))
867    }
868
869    // ── Accessors ───────────────────────────────────────────
870
871    /// Access the underlying endpoint state machine.
872    pub fn endpoint(&self) -> &Endpoint {
873        &self.endpoint
874    }
875
876    /// Mutable access to the endpoint state machine.
877    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
878        &mut self.endpoint
879    }
880
881    /// Get the negotiated MoQT version.
882    pub fn negotiated_version(&self) -> Option<VarInt> {
883        self.endpoint.negotiated_version()
884    }
885
886    /// Returns the draft version this connection is using.
887    pub fn draft(&self) -> DraftVersion {
888        self.draft
889    }
890
891    /// Close the connection.
892    pub fn close(&self, code: u32, reason: &[u8]) {
893        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
894        self.transport.close(code, reason);
895    }
896}
897
898/// Determine the encoded length of a varint from its first byte.
899fn varint_len(first_byte: u8) -> usize {
900    1 << (first_byte >> 6)
901}
902
903/// TLS certificate verifier that skips all verification (for testing only).
904#[derive(Debug)]
905struct SkipVerification;
906
907impl rustls::client::danger::ServerCertVerifier for SkipVerification {
908    fn verify_server_cert(
909        &self,
910        _end_entity: &rustls::pki_types::CertificateDer<'_>,
911        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
912        _server_name: &rustls::pki_types::ServerName<'_>,
913        _ocsp_response: &[u8],
914        _now: rustls::pki_types::UnixTime,
915    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
916        Ok(rustls::client::danger::ServerCertVerified::assertion())
917    }
918
919    fn verify_tls12_signature(
920        &self,
921        _message: &[u8],
922        _cert: &rustls::pki_types::CertificateDer<'_>,
923        _dcs: &rustls::DigitallySignedStruct,
924    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
925        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
926    }
927
928    fn verify_tls13_signature(
929        &self,
930        _message: &[u8],
931        _cert: &rustls::pki_types::CertificateDer<'_>,
932        _dcs: &rustls::DigitallySignedStruct,
933    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
934        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
935    }
936
937    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
938        vec![
939            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
940            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
941            rustls::SignatureScheme::ED25519,
942            rustls::SignatureScheme::RSA_PSS_SHA256,
943            rustls::SignatureScheme::RSA_PSS_SHA384,
944            rustls::SignatureScheme::RSA_PSS_SHA512,
945        ]
946    }
947}
948
949#[cfg(test)]
950mod tests {
951    use super::*;
952
953    #[test]
954    fn varint_len_single_byte() {
955        // 0b00xxxxxx -> 1 byte
956        assert_eq!(varint_len(0x00), 1);
957        assert_eq!(varint_len(0x3F), 1);
958    }
959
960    #[test]
961    fn varint_len_two_bytes() {
962        // 0b01xxxxxx -> 2 bytes
963        assert_eq!(varint_len(0x40), 2);
964        assert_eq!(varint_len(0x7F), 2);
965    }
966
967    #[test]
968    fn varint_len_four_bytes() {
969        // 0b10xxxxxx -> 4 bytes
970        assert_eq!(varint_len(0x80), 4);
971        assert_eq!(varint_len(0xBF), 4);
972    }
973
974    #[test]
975    fn varint_len_eight_bytes() {
976        // 0b11xxxxxx -> 8 bytes
977        assert_eq!(varint_len(0xC0), 8);
978        assert_eq!(varint_len(0xFF), 8);
979    }
980
981    #[test]
982    fn client_config_supported_versions_draft14() {
983        let config = ClientConfig {
984            draft: DraftVersion::Draft14,
985            additional_versions: Vec::new(),
986            transport: TransportType::Quic,
987            skip_cert_verification: false,
988            ca_certs: Vec::new(),
989            setup_parameters: Vec::new(),
990        };
991        let versions = config.supported_versions();
992        assert_eq!(versions.len(), 1);
993        assert_eq!(versions[0].into_inner(), 0xff000000 + 14);
994    }
995
996    #[test]
997    fn client_config_supported_versions_draft07() {
998        let config = ClientConfig {
999            draft: DraftVersion::Draft07,
1000            additional_versions: Vec::new(),
1001            transport: TransportType::Quic,
1002            skip_cert_verification: false,
1003            ca_certs: Vec::new(),
1004            setup_parameters: Vec::new(),
1005        };
1006        let versions = config.supported_versions();
1007        assert_eq!(versions.len(), 1);
1008        assert_eq!(versions[0].into_inner(), 0xff000000 + 7);
1009    }
1010
1011    #[test]
1012    fn client_config_alpn_quic() {
1013        let config = ClientConfig {
1014            draft: DraftVersion::Draft14,
1015            additional_versions: Vec::new(),
1016            transport: TransportType::Quic,
1017            skip_cert_verification: false,
1018            ca_certs: Vec::new(),
1019            setup_parameters: Vec::new(),
1020        };
1021        assert_eq!(config.alpn(), vec![b"moq-00".to_vec()]);
1022    }
1023
1024    #[test]
1025    fn client_config_alpn_webtransport() {
1026        let config = ClientConfig {
1027            draft: DraftVersion::Draft14,
1028            additional_versions: Vec::new(),
1029            transport: TransportType::WebTransport { url: "https://example.com".to_string() },
1030            skip_cert_verification: false,
1031            ca_certs: Vec::new(),
1032            setup_parameters: Vec::new(),
1033        };
1034        assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
1035    }
1036
1037    #[test]
1038    fn moqt_alpn_value() {
1039        assert_eq!(MOQT_ALPN, b"moq-00");
1040    }
1041
1042    #[test]
1043    fn transport_type_debug() {
1044        let quic = TransportType::Quic;
1045        assert!(format!("{quic:?}").contains("Quic"));
1046
1047        let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
1048        assert!(format!("{wt:?}").contains("WebTransport"));
1049    }
1050}