Skip to main content

moqtap_client/draft15/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft15::endpoint::{Endpoint, EndpointError};
6use crate::draft15::event::{ClientEvent, Direction, StreamKind};
7use crate::draft15::observer::ConnectionObserver;
8use crate::draft15::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft15::data_stream::{FetchHeader, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft15::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::kvp::KeyValuePair;
18use moqtap_codec::types::*;
19use moqtap_codec::varint::VarInt;
20use moqtap_codec::version::DraftVersion;
21
22/// MoQT ALPN identifier (used by raw QUIC transport).
23pub const MOQT_ALPN: &[u8] = b"moq-00";
24
25/// Errors from the connection layer.
26#[derive(Debug, thiserror::Error)]
27pub enum ConnectionError {
28    /// Endpoint state machine error.
29    #[error("endpoint error: {0}")]
30    Endpoint(#[from] EndpointError),
31    /// Wire codec error.
32    #[error("codec error: {0}")]
33    Codec(#[from] CodecError),
34    /// Transport-level error.
35    #[error("transport error: {0}")]
36    Transport(#[from] TransportError),
37    /// Variable-length integer decoding error.
38    #[error("varint error: {0}")]
39    VarInt(#[from] moqtap_codec::varint::VarIntError),
40    /// Control stream was not opened.
41    #[error("control stream not open")]
42    NoControlStream,
43    /// Stream ended before a complete message was read.
44    #[error("unexpected end of stream")]
45    UnexpectedEnd,
46    /// Stream was finished by the peer.
47    #[error("stream finished")]
48    StreamFinished,
49    /// Invalid server address string.
50    #[error("invalid server address: {0}")]
51    InvalidAddress(String),
52    /// TLS configuration error.
53    #[error("TLS config error: {0}")]
54    TlsConfig(String),
55    /// Data stream used out of order (e.g. object before header).
56    #[error("data stream state error: {0}")]
57    DataStreamState(&'static str),
58}
59
60/// Transport type for the connection.
61#[derive(Debug, Clone)]
62pub enum TransportType {
63    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
64    Quic,
65    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
66    WebTransport {
67        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
68        url: String,
69    },
70}
71
72/// Configuration for a MoQT client connection.
73///
74/// Both `draft` and `transport` are required -- there is no `Default` impl.
75pub struct ClientConfig {
76    /// The MoQT draft version to use (primary, determines codec/framing).
77    pub draft: DraftVersion,
78    /// The transport type (QUIC or WebTransport).
79    pub transport: TransportType,
80    /// Whether to skip TLS certificate verification (for testing).
81    pub skip_cert_verification: bool,
82    /// Custom CA certificates to trust (DER-encoded).
83    pub ca_certs: Vec<Vec<u8>>,
84    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
85    pub setup_parameters: Vec<KeyValuePair>,
86}
87
88impl ClientConfig {
89    /// Returns the ALPN protocol identifiers for the transport.
90    pub fn alpn(&self) -> Vec<Vec<u8>> {
91        match &self.transport {
92            TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
93            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
94        }
95    }
96}
97
98/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
99pub struct FramedSendStream {
100    inner: SendStream,
101    draft: DraftVersion,
102    /// Stateful subgroup object writer (tracks delta encoding state and
103    /// extension-presence flag, seeded from the stream's
104    /// `SubgroupHeader`).
105    subgroup_io: Option<SubgroupObjectReader>,
106}
107
108impl FramedSendStream {
109    /// Create a new framed send stream for the given draft version.
110    pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
111        Self { inner, draft, subgroup_io: None }
112    }
113
114    /// Get the transport-level stream ID.
115    pub fn stream_id(&self) -> u64 {
116        self.inner.stream_id()
117    }
118
119    /// Write a control message to the stream with type+length framing.
120    /// Returns the raw bytes that were written (for event capture).
121    pub async fn write_control(
122        &mut self,
123        msg: &AnyControlMessage,
124    ) -> Result<Vec<u8>, ConnectionError> {
125        let mut buf = Vec::new();
126        msg.encode(&mut buf)?;
127        self.inner.write_all(&buf).await?;
128        Ok(buf)
129    }
130
131    /// Write a subgroup stream header. Also initializes the internal
132    /// delta-encoding state used by
133    /// [`FramedSendStream::write_subgroup_object`].
134    pub async fn write_subgroup_header(
135        &mut self,
136        header: &AnySubgroupHeader,
137    ) -> Result<(), ConnectionError> {
138        let mut buf = Vec::new();
139        header.encode(&mut buf);
140        self.inner.write_all(&buf).await?;
141        if let AnySubgroupHeader::Draft15(ref d15) = header {
142            self.subgroup_io = Some(SubgroupObjectReader::new(d15));
143        }
144        Ok(())
145    }
146
147    /// Write a fetch response header.
148    pub async fn write_fetch_header(
149        &mut self,
150        header: &AnyFetchHeader,
151    ) -> Result<(), ConnectionError> {
152        let mut buf = Vec::new();
153        header.encode(&mut buf);
154        self.inner.write_all(&buf).await?;
155        Ok(())
156    }
157
158    /// Append a draft-15 subgroup object to the stream. Uses the
159    /// stateful writer seeded from
160    /// [`FramedSendStream::write_subgroup_header`] to produce correct
161    /// delta-encoded object IDs.
162    pub async fn write_subgroup_object(
163        &mut self,
164        object: &SubgroupObject,
165    ) -> Result<(), ConnectionError> {
166        let writer = self
167            .subgroup_io
168            .as_mut()
169            .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
170        let mut buf = Vec::new();
171        writer.write_object(object, &mut buf)?;
172        self.inner.write_all(&buf).await?;
173        Ok(())
174    }
175
176    /// Finish the stream (send FIN).
177    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
178        self.inner.finish()?;
179        Ok(())
180    }
181
182    /// Returns the draft version this stream is framed for.
183    pub fn draft(&self) -> DraftVersion {
184        self.draft
185    }
186}
187
188/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
189pub struct FramedRecvStream {
190    inner: RecvStream,
191    buf: BytesMut,
192    draft: DraftVersion,
193    /// Stateful subgroup object reader (tracks delta-decode state and
194    /// extension-presence flag).
195    subgroup_io: Option<SubgroupObjectReader>,
196}
197
198impl FramedRecvStream {
199    /// Create a new framed receive stream for the given draft version.
200    pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
201        Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
202    }
203
204    /// Get the transport-level stream ID.
205    pub fn stream_id(&self) -> u64 {
206        self.inner.stream_id()
207    }
208
209    /// Read more data from the stream into the internal buffer.
210    async fn fill(&mut self) -> Result<bool, ConnectionError> {
211        let mut tmp = [0u8; 4096];
212        match self.inner.read(&mut tmp).await {
213            Ok(Some(n)) => {
214                self.buf.extend_from_slice(&tmp[..n]);
215                Ok(true)
216            }
217            Ok(None) => Ok(false),
218            Err(e) => Err(ConnectionError::Transport(e)),
219        }
220    }
221
222    /// Ensure at least `n` bytes are available in the buffer.
223    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
224        while self.buf.len() < n {
225            if !self.fill().await? {
226                return Err(ConnectionError::UnexpectedEnd);
227            }
228        }
229        Ok(())
230    }
231
232    /// Read a control message from the stream.
233    ///
234    /// When `capture_raw` is true, the returned tuple includes a clone of the
235    /// framed wire bytes (for observer emission). When false, the second
236    /// element is `None` and the payload clone is skipped.
237    pub async fn read_control(
238        &mut self,
239        capture_raw: bool,
240    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
241        // Read type ID varint
242        self.ensure(1).await?;
243        let type_len = varint_len(self.buf[0]);
244        self.ensure(type_len).await?;
245
246        let mut cursor = &self.buf[..type_len];
247        let _type_id = VarInt::decode(&mut cursor)?;
248
249        // Draft-15: 16-bit BE payload length
250        let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
251            self.ensure(type_len + 2).await?;
252            let hi = self.buf[type_len] as usize;
253            let lo = self.buf[type_len + 1] as usize;
254            ((hi << 8) | lo, 2)
255        } else {
256            self.ensure(type_len + 1).await?;
257            let payload_len_start = type_len;
258            let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
259            self.ensure(type_len + payload_len_varint_len).await?;
260            let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
261            let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
262            (payload_len, payload_len_varint_len)
263        };
264
265        // Read full payload
266        let total = type_len + len_field_size + payload_len;
267        self.ensure(total).await?;
268
269        // Capture raw bytes only if requested (observer attached).
270        let raw = capture_raw.then(|| self.buf[..total].to_vec());
271
272        // Now decode the whole message
273        let mut frame = &self.buf[..total];
274        let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
275        self.buf.advance(total);
276        Ok((msg, raw))
277    }
278
279    /// Read a subgroup stream header. Also initializes the internal
280    /// delta-decoding state.
281    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
282        self.ensure(1).await?;
283        loop {
284            let mut cursor = &self.buf[..];
285            match AnySubgroupHeader::decode(self.draft, &mut cursor) {
286                Ok(header) => {
287                    let consumed = self.buf.len() - cursor.remaining();
288                    self.buf.advance(consumed);
289                    if let AnySubgroupHeader::Draft15(ref d15) = header {
290                        self.subgroup_io = Some(SubgroupObjectReader::new(d15));
291                    }
292                    return Ok(header);
293                }
294                Err(CodecError::UnexpectedEnd) => {
295                    if !self.fill().await? {
296                        return Err(ConnectionError::UnexpectedEnd);
297                    }
298                }
299                Err(e) => return Err(ConnectionError::Codec(e)),
300            }
301        }
302    }
303
304    /// Read a fetch response header.
305    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
306        self.ensure(1).await?;
307        loop {
308            let mut cursor = &self.buf[..];
309            match AnyFetchHeader::decode(self.draft, &mut cursor) {
310                Ok(header) => {
311                    let consumed = self.buf.len() - cursor.remaining();
312                    self.buf.advance(consumed);
313                    return Ok(header);
314                }
315                Err(CodecError::UnexpectedEnd) => {
316                    if !self.fill().await? {
317                        return Err(ConnectionError::UnexpectedEnd);
318                    }
319                }
320                Err(e) => return Err(ConnectionError::Codec(e)),
321            }
322        }
323    }
324
325    /// Read the next draft-15 subgroup object from this stream. Uses
326    /// the stateful reader seeded by
327    /// [`FramedRecvStream::read_subgroup_header`] to decode the
328    /// delta-encoded object ID and (when the stream type says so) the
329    /// extension block. Returns an error if called before a subgroup
330    /// header was read.
331    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
332        if self.subgroup_io.is_none() {
333            return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
334        }
335        loop {
336            let reader = self.subgroup_io.as_mut().unwrap();
337            let mut probe = reader.clone();
338            let mut cursor = &self.buf[..];
339            match probe.read_object(&mut cursor) {
340                Ok(obj) => {
341                    let consumed = self.buf.len() - cursor.remaining();
342                    self.buf.advance(consumed);
343                    *reader = probe;
344                    return Ok(obj);
345                }
346                Err(CodecError::UnexpectedEnd) => {
347                    if !self.fill().await? {
348                        return Err(ConnectionError::UnexpectedEnd);
349                    }
350                }
351                Err(e) => return Err(ConnectionError::Codec(e)),
352            }
353        }
354    }
355
356    /// Read the next draft-15 fetch header from this stream.
357    pub async fn read_fetch_stream_header(&mut self) -> Result<FetchHeader, ConnectionError> {
358        loop {
359            let mut cursor = &self.buf[..];
360            match FetchHeader::decode(&mut cursor) {
361                Ok(hdr) => {
362                    let consumed = self.buf.len() - cursor.remaining();
363                    self.buf.advance(consumed);
364                    return Ok(hdr);
365                }
366                Err(CodecError::UnexpectedEnd) => {
367                    if !self.fill().await? {
368                        return Err(ConnectionError::UnexpectedEnd);
369                    }
370                }
371                Err(e) => return Err(ConnectionError::Codec(e)),
372            }
373        }
374    }
375
376    /// Returns the draft version this stream is framed for.
377    pub fn draft(&self) -> DraftVersion {
378        self.draft
379    }
380}
381
382/// A live MoQT connection over QUIC or WebTransport, combining the endpoint
383/// state machine with actual network I/O.
384pub struct Connection {
385    transport: Transport,
386    endpoint: Endpoint,
387    draft: DraftVersion,
388    control_send: Option<FramedSendStream>,
389    control_recv: Option<FramedRecvStream>,
390    observer: Option<Box<dyn ConnectionObserver>>,
391    /// Setup events buffered during `connect()` and replayed when an
392    /// observer attaches via `set_observer` — without this, an observer
393    /// attached after `connect` returns would never see the handshake.
394    pending_events: Vec<ClientEvent>,
395}
396
397impl Connection {
398    /// Connect to a MoQT server as a client.
399    ///
400    /// Establishes a QUIC or WebTransport connection (based on
401    /// `config.transport`), opens a bidirectional control stream,
402    /// performs the CLIENT_SETUP / SERVER_SETUP handshake, and returns
403    /// a ready-to-use connection.
404    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
405        let draft = config.draft;
406        let transport = match &config.transport {
407            TransportType::Quic => Self::connect_quic(addr, &config).await?,
408            TransportType::WebTransport { url } => {
409                let url = url.clone();
410                Self::connect_webtransport(&url, &config).await?
411            }
412        };
413
414        // Open bidirectional control stream
415        let (send, recv) = transport.open_bi().await?;
416        let mut control_send = FramedSendStream::new(send, draft);
417        let mut control_recv = FramedRecvStream::new(recv, draft);
418
419        // Perform setup handshake (draft-15: no versions)
420        let mut endpoint = Endpoint::new(Role::Client);
421        endpoint.connect()?;
422        let setup_msg = endpoint.send_client_setup(config.setup_parameters.clone())?;
423        let any_setup = AnyControlMessage::Draft15(setup_msg);
424        let raw_setup = control_send.write_control(&any_setup).await?;
425
426        let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
427        // Unwrap to draft-15 for the endpoint
428        match &server_setup {
429            AnyControlMessage::Draft15(ControlMessage::ServerSetup(ref ss)) => {
430                endpoint.receive_server_setup(ss)?;
431            }
432            _ => {
433                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
434            }
435        }
436
437        let pending_events = vec![
438            ClientEvent::ControlMessage {
439                direction: Direction::Send,
440                message: any_setup,
441                raw: Some(raw_setup),
442            },
443            ClientEvent::ControlMessage {
444                direction: Direction::Receive,
445                message: server_setup,
446                raw: raw_server_setup,
447            },
448            ClientEvent::SetupComplete { negotiated_version: 0xff000000 + 15 },
449        ];
450
451        Ok(Self {
452            transport,
453            endpoint,
454            draft,
455            control_send: Some(control_send),
456            control_recv: Some(control_recv),
457            observer: None,
458            pending_events,
459        })
460    }
461
462    /// Establish a raw QUIC connection.
463    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
464        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
465            ConnectionError::InvalidAddress(e.to_string())
466        })?;
467
468        // Build TLS config
469        let mut tls_config = if config.skip_cert_verification {
470            rustls::ClientConfig::builder()
471                .dangerous()
472                .with_custom_certificate_verifier(Arc::new(SkipVerification))
473                .with_no_client_auth()
474        } else {
475            let mut roots = rustls::RootCertStore::empty();
476            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
477            for der in &config.ca_certs {
478                roots
479                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
480                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
481            }
482            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
483        };
484
485        tls_config.alpn_protocols = config.alpn();
486
487        let quic_config: quinn::crypto::rustls::QuicClientConfig =
488            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
489        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
490
491        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
492            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
493        quinn_endpoint.set_default_client_config(client_config);
494
495        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
496
497        let quic = quinn_endpoint
498            .connect(server_addr, &server_name)
499            .map_err(TransportError::from)?
500            .await
501            .map_err(TransportError::from)?;
502
503        Ok(Transport::Quic(QuicTransport::new(quic)))
504    }
505
506    /// Establish a WebTransport connection.
507    #[cfg(feature = "webtransport")]
508    async fn connect_webtransport(
509        url: &str,
510        config: &ClientConfig,
511    ) -> Result<Transport, ConnectionError> {
512        use crate::transport::webtransport::WebTransportTransport;
513
514        let wt_config = if config.skip_cert_verification {
515            wtransport::ClientConfig::builder()
516                .with_bind_default()
517                .with_no_cert_validation()
518                .build()
519        } else {
520            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
521        };
522
523        let endpoint = wtransport::Endpoint::client(wt_config)
524            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
525
526        let connection = endpoint
527            .connect(url)
528            .await
529            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
530
531        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
532    }
533
534    /// Stub for when the webtransport feature is not enabled.
535    #[cfg(not(feature = "webtransport"))]
536    async fn connect_webtransport(
537        _url: &str,
538        _config: &ClientConfig,
539    ) -> Result<Transport, ConnectionError> {
540        Err(ConnectionError::Transport(TransportError::Connect(
541            "webtransport feature not enabled".into(),
542        )))
543    }
544
545    // -- Observer ---------------------------------------------------
546
547    /// Attach an observer. Buffered handshake events from `connect()` are
548    /// flushed in arrival order before this returns.
549    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
550        self.observer = Some(observer);
551        for event in self.pending_events.drain(..) {
552            if let Some(ref obs) = self.observer {
553                obs.on_event_owned(event);
554            }
555        }
556    }
557
558    /// Remove the observer.
559    pub fn clear_observer(&mut self) {
560        self.observer = None;
561    }
562
563    /// Emit an event to the observer, if one is attached.
564    fn emit(&self, event: ClientEvent) {
565        if let Some(ref obs) = self.observer {
566            obs.on_event_owned(event);
567        }
568    }
569
570    // -- Control message I/O ----------------------------------------
571
572    /// Send a control message on the control stream.
573    ///
574    /// Wraps the draft-15 message in `AnyControlMessage::Draft15` for
575    /// framing.
576    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
577        let any = AnyControlMessage::Draft15(msg.clone());
578        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
579        let raw = send.write_control(&any).await?;
580        self.emit(ClientEvent::ControlMessage {
581            direction: Direction::Send,
582            message: any,
583            raw: Some(raw),
584        });
585        Ok(())
586    }
587
588    /// Read the next control message from the control stream.
589    ///
590    /// Returns the `AnyControlMessage` and also extracts the draft-15
591    /// `ControlMessage` for internal endpoint dispatch.
592    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
593        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
594        let capture_raw = self.observer.is_some();
595        let (any, raw) = recv.read_control(capture_raw).await?;
596        if capture_raw {
597            self.emit(ClientEvent::ControlMessage {
598                direction: Direction::Receive,
599                message: any.clone(),
600                raw,
601            });
602        }
603        // Unwrap to draft-15 for the endpoint
604        match any {
605            AnyControlMessage::Draft15(msg) => Ok(msg),
606            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
607        }
608    }
609
610    /// Read and dispatch the next incoming control message through the
611    /// endpoint state machine. Returns the decoded message for inspection.
612    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
613        let msg = self.recv_control().await?;
614        self.endpoint.receive_message(msg.clone())?;
615
616        // Emit draining event if this was a GoAway
617        if let ControlMessage::GoAway(ref ga) = msg {
618            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
619        }
620
621        Ok(msg)
622    }
623
624    // -- Subscribe flow ---------------------------------------------
625
626    /// Send a SUBSCRIBE and return the allocated request ID.
627    pub async fn subscribe(
628        &mut self,
629        track_namespace: TrackNamespace,
630        track_name: Vec<u8>,
631        parameters: Vec<KeyValuePair>,
632    ) -> Result<VarInt, ConnectionError> {
633        let (req_id, msg) = self.endpoint.subscribe(track_namespace, track_name, parameters)?;
634        self.send_control(&msg).await?;
635        Ok(req_id)
636    }
637
638    /// Send an UNSUBSCRIBE for the given request ID.
639    pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
640        let msg = self.endpoint.unsubscribe(request_id)?;
641        self.send_control(&msg).await
642    }
643
644    // -- Fetch flow -------------------------------------------------
645
646    /// Send a standalone FETCH and return the allocated request ID.
647    pub async fn fetch(
648        &mut self,
649        track_namespace: TrackNamespace,
650        track_name: Vec<u8>,
651        start_group: VarInt,
652        start_object: VarInt,
653        end_group: VarInt,
654        end_object: VarInt,
655    ) -> Result<VarInt, ConnectionError> {
656        let (req_id, msg) = self.endpoint.fetch(
657            track_namespace,
658            track_name,
659            start_group,
660            start_object,
661            end_group,
662            end_object,
663        )?;
664        self.send_control(&msg).await?;
665        Ok(req_id)
666    }
667
668    /// Send a joining FETCH and return the allocated request ID.
669    pub async fn joining_fetch(
670        &mut self,
671        joining_request_id: VarInt,
672        joining_start: VarInt,
673    ) -> Result<VarInt, ConnectionError> {
674        let (req_id, msg) = self.endpoint.joining_fetch(joining_request_id, joining_start)?;
675        self.send_control(&msg).await?;
676        Ok(req_id)
677    }
678
679    /// Send a FETCH_CANCEL for the given request ID.
680    pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
681        let msg = self.endpoint.fetch_cancel(request_id)?;
682        self.send_control(&msg).await
683    }
684
685    // -- Namespace flows --------------------------------------------
686
687    /// Send a SUBSCRIBE_NAMESPACE and return the request ID.
688    pub async fn subscribe_namespace(
689        &mut self,
690        namespace_prefix: TrackNamespace,
691        parameters: Vec<KeyValuePair>,
692    ) -> Result<VarInt, ConnectionError> {
693        let (req_id, msg) = self.endpoint.subscribe_namespace(namespace_prefix, parameters)?;
694        self.send_control(&msg).await?;
695        Ok(req_id)
696    }
697
698    /// Send a PUBLISH_NAMESPACE and return the request ID.
699    pub async fn publish_namespace(
700        &mut self,
701        track_namespace: TrackNamespace,
702        parameters: Vec<KeyValuePair>,
703    ) -> Result<VarInt, ConnectionError> {
704        let (req_id, msg) = self.endpoint.publish_namespace(track_namespace, parameters)?;
705        self.send_control(&msg).await?;
706        Ok(req_id)
707    }
708
709    // -- Track Status flow ------------------------------------------
710
711    /// Send a TRACK_STATUS and return the allocated request ID.
712    pub async fn track_status(
713        &mut self,
714        track_namespace: TrackNamespace,
715        track_name: Vec<u8>,
716        parameters: Vec<KeyValuePair>,
717    ) -> Result<VarInt, ConnectionError> {
718        let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name, parameters)?;
719        self.send_control(&msg).await?;
720        Ok(req_id)
721    }
722
723    // -- Publish flow (publisher side) ------------------------------
724
725    /// Send a PUBLISH and return the allocated request ID.
726    pub async fn publish(
727        &mut self,
728        track_namespace: TrackNamespace,
729        track_name: Vec<u8>,
730        track_alias: VarInt,
731        parameters: Vec<KeyValuePair>,
732    ) -> Result<VarInt, ConnectionError> {
733        let (req_id, msg) =
734            self.endpoint.publish(track_namespace, track_name, track_alias, parameters)?;
735        self.send_control(&msg).await?;
736        Ok(req_id)
737    }
738
739    /// Send a PUBLISH_DONE for the given request ID.
740    pub async fn publish_done(
741        &mut self,
742        request_id: VarInt,
743        status_code: VarInt,
744        stream_count: VarInt,
745        reason_phrase: Vec<u8>,
746    ) -> Result<(), ConnectionError> {
747        let msg = self.endpoint.send_publish_done(
748            request_id,
749            status_code,
750            stream_count,
751            reason_phrase,
752        )?;
753        self.send_control(&msg).await
754    }
755
756    // -- Data streams -----------------------------------------------
757
758    /// Open a new unidirectional stream for sending subgroup data.
759    pub async fn open_subgroup_stream(
760        &self,
761        header: &AnySubgroupHeader,
762    ) -> Result<FramedSendStream, ConnectionError> {
763        let send = self.transport.open_uni().await?;
764        let mut framed = FramedSendStream::new(send, self.draft);
765        let sid = framed.stream_id();
766        framed.write_subgroup_header(header).await?;
767        self.emit(ClientEvent::StreamOpened {
768            direction: Direction::Send,
769            stream_kind: StreamKind::Subgroup,
770            stream_id: sid,
771        });
772        self.emit(ClientEvent::DataStreamHeader {
773            stream_id: sid,
774            direction: Direction::Send,
775            header: header.clone(),
776        });
777        Ok(framed)
778    }
779
780    /// Accept an incoming unidirectional data stream and read its subgroup
781    /// header.
782    pub async fn accept_subgroup_stream(
783        &self,
784    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
785        let recv = self.transport.accept_uni().await?;
786        let mut framed = FramedRecvStream::new(recv, self.draft);
787        let sid = framed.stream_id();
788        let header = framed.read_subgroup_header().await?;
789        self.emit(ClientEvent::StreamOpened {
790            direction: Direction::Receive,
791            stream_kind: StreamKind::Subgroup,
792            stream_id: sid,
793        });
794        self.emit(ClientEvent::DataStreamHeader {
795            stream_id: sid,
796            direction: Direction::Receive,
797            header: header.clone(),
798        });
799        Ok((header, framed))
800    }
801
802    /// Send an object via datagram.
803    pub fn send_datagram(
804        &self,
805        header: &AnyDatagramHeader,
806        payload: &[u8],
807    ) -> Result<(), ConnectionError> {
808        let mut buf = Vec::new();
809        header.encode(&mut buf);
810        buf.extend_from_slice(payload);
811        self.emit(ClientEvent::DatagramReceived {
812            direction: Direction::Send,
813            header: header.clone(),
814            payload_len: payload.len(),
815        });
816        self.transport.send_datagram(bytes::Bytes::from(buf))?;
817        Ok(())
818    }
819
820    /// Receive a datagram and decode its header.
821    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
822        let data = self.transport.recv_datagram().await?;
823        let mut cursor = &data[..];
824        let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
825        let consumed = data.len() - cursor.len();
826        let payload = data.slice(consumed..);
827        self.emit(ClientEvent::DatagramReceived {
828            direction: Direction::Receive,
829            header: header.clone(),
830            payload_len: payload.len(),
831        });
832        Ok((header, payload))
833    }
834
835    // -- Accessors --------------------------------------------------
836
837    /// Access the underlying endpoint state machine.
838    pub fn endpoint(&self) -> &Endpoint {
839        &self.endpoint
840    }
841
842    /// Mutable access to the endpoint state machine.
843    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
844        &mut self.endpoint
845    }
846
847    /// Returns the draft version this connection is using.
848    pub fn draft(&self) -> DraftVersion {
849        self.draft
850    }
851
852    /// Close the connection.
853    pub fn close(&self, code: u32, reason: &[u8]) {
854        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
855        self.transport.close(code, reason);
856    }
857}
858
859/// Determine the encoded length of a varint from its first byte.
860fn varint_len(first_byte: u8) -> usize {
861    1 << (first_byte >> 6)
862}
863
864/// TLS certificate verifier that skips all verification (for testing only).
865#[derive(Debug)]
866struct SkipVerification;
867
868impl rustls::client::danger::ServerCertVerifier for SkipVerification {
869    fn verify_server_cert(
870        &self,
871        _end_entity: &rustls::pki_types::CertificateDer<'_>,
872        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
873        _server_name: &rustls::pki_types::ServerName<'_>,
874        _ocsp_response: &[u8],
875        _now: rustls::pki_types::UnixTime,
876    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
877        Ok(rustls::client::danger::ServerCertVerified::assertion())
878    }
879
880    fn verify_tls12_signature(
881        &self,
882        _message: &[u8],
883        _cert: &rustls::pki_types::CertificateDer<'_>,
884        _dcs: &rustls::DigitallySignedStruct,
885    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
886        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
887    }
888
889    fn verify_tls13_signature(
890        &self,
891        _message: &[u8],
892        _cert: &rustls::pki_types::CertificateDer<'_>,
893        _dcs: &rustls::DigitallySignedStruct,
894    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
895        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
896    }
897
898    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
899        vec![
900            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
901            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
902            rustls::SignatureScheme::ED25519,
903            rustls::SignatureScheme::RSA_PSS_SHA256,
904            rustls::SignatureScheme::RSA_PSS_SHA384,
905            rustls::SignatureScheme::RSA_PSS_SHA512,
906        ]
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913
914    #[test]
915    fn varint_len_single_byte() {
916        assert_eq!(varint_len(0x00), 1);
917        assert_eq!(varint_len(0x3F), 1);
918    }
919
920    #[test]
921    fn varint_len_two_bytes() {
922        assert_eq!(varint_len(0x40), 2);
923        assert_eq!(varint_len(0x7F), 2);
924    }
925
926    #[test]
927    fn varint_len_four_bytes() {
928        assert_eq!(varint_len(0x80), 4);
929        assert_eq!(varint_len(0xBF), 4);
930    }
931
932    #[test]
933    fn varint_len_eight_bytes() {
934        assert_eq!(varint_len(0xC0), 8);
935        assert_eq!(varint_len(0xFF), 8);
936    }
937
938    #[test]
939    fn client_config_alpn_quic_draft15() {
940        let config = ClientConfig {
941            draft: DraftVersion::Draft15,
942            transport: TransportType::Quic,
943            skip_cert_verification: false,
944            ca_certs: Vec::new(),
945            setup_parameters: Vec::new(),
946        };
947        assert_eq!(config.alpn(), vec![b"moqt-15".to_vec()]);
948    }
949
950    #[test]
951    fn client_config_alpn_webtransport() {
952        let config = ClientConfig {
953            draft: DraftVersion::Draft15,
954            transport: TransportType::WebTransport { url: "https://example.com".to_string() },
955            skip_cert_verification: false,
956            ca_certs: Vec::new(),
957            setup_parameters: Vec::new(),
958        };
959        assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
960    }
961
962    #[test]
963    fn moqt_alpn_value() {
964        assert_eq!(MOQT_ALPN, b"moq-00");
965    }
966
967    #[test]
968    fn transport_type_debug() {
969        let quic = TransportType::Quic;
970        assert!(format!("{quic:?}").contains("Quic"));
971
972        let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
973        assert!(format!("{wt:?}").contains("WebTransport"));
974    }
975}