Skip to main content

moqtap_client/draft18/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft18::endpoint::{Endpoint, EndpointError};
6use crate::draft18::event::{ClientEvent, Direction, StreamKind};
7use crate::draft18::observer::ConnectionObserver;
8use crate::draft18::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft18::data_stream::{FetchHeader, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft18::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::kvp::KeyValuePair;
18use moqtap_codec::types::*;
19use moqtap_codec::varint::VarInt;
20use moqtap_codec::version::DraftVersion;
21
22/// MoQT ALPN identifier (used by raw QUIC transport).
23pub const MOQT_ALPN: &[u8] = b"moq-00";
24
25/// Errors from the connection layer.
26#[derive(Debug, thiserror::Error)]
27pub enum ConnectionError {
28    /// Endpoint state machine error.
29    #[error("endpoint error: {0}")]
30    Endpoint(#[from] EndpointError),
31    /// Wire codec error.
32    #[error("codec error: {0}")]
33    Codec(#[from] CodecError),
34    /// Transport-level error.
35    #[error("transport error: {0}")]
36    Transport(#[from] TransportError),
37    /// Variable-length integer decoding error.
38    #[error("varint error: {0}")]
39    VarInt(#[from] moqtap_codec::varint::VarIntError),
40    /// Control stream was not opened.
41    #[error("control stream not open")]
42    NoControlStream,
43    /// Stream ended before a complete message was read.
44    #[error("unexpected end of stream")]
45    UnexpectedEnd,
46    /// Stream was finished by the peer.
47    #[error("stream finished")]
48    StreamFinished,
49    /// Invalid server address string.
50    #[error("invalid server address: {0}")]
51    InvalidAddress(String),
52    /// TLS configuration error.
53    #[error("TLS config error: {0}")]
54    TlsConfig(String),
55    /// Data stream used out of order (e.g. object before header).
56    #[error("data stream state error: {0}")]
57    DataStreamState(&'static str),
58}
59
60/// Transport type for the connection.
61#[derive(Debug, Clone)]
62pub enum TransportType {
63    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
64    Quic,
65    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
66    WebTransport {
67        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
68        url: String,
69    },
70}
71
72/// Configuration for a MoQT client connection.
73///
74/// Both `draft` and `transport` are required -- there is no `Default` impl.
75pub struct ClientConfig {
76    /// The MoQT draft version to use (primary, determines codec/framing).
77    pub draft: DraftVersion,
78    /// The transport type (QUIC or WebTransport).
79    pub transport: TransportType,
80    /// Whether to skip TLS certificate verification (for testing).
81    pub skip_cert_verification: bool,
82    /// Custom CA certificates to trust (DER-encoded).
83    pub ca_certs: Vec<Vec<u8>>,
84    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
85    pub setup_parameters: Vec<KeyValuePair>,
86}
87
88impl ClientConfig {
89    /// Returns the ALPN protocol identifiers for the transport.
90    pub fn alpn(&self) -> Vec<Vec<u8>> {
91        match &self.transport {
92            TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
93            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
94        }
95    }
96}
97
98/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
99pub struct FramedSendStream {
100    inner: SendStream,
101    draft: DraftVersion,
102    /// Stateful subgroup object writer.
103    subgroup_io: Option<SubgroupObjectReader>,
104}
105
106impl FramedSendStream {
107    /// Create a new framed send stream for the given draft version.
108    pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
109        Self { inner, draft, subgroup_io: None }
110    }
111
112    /// Get the transport-level stream ID.
113    pub fn stream_id(&self) -> u64 {
114        self.inner.stream_id()
115    }
116
117    /// Write a control message to the stream with type+length framing.
118    /// Returns the raw bytes that were written (for event capture).
119    pub async fn write_control(
120        &mut self,
121        msg: &AnyControlMessage,
122    ) -> Result<Vec<u8>, ConnectionError> {
123        let mut buf = Vec::new();
124        msg.encode(&mut buf)?;
125        self.inner.write_all(&buf).await?;
126        Ok(buf)
127    }
128
129    /// Write a subgroup stream header. Also initializes the internal
130    /// delta-encoding state used by
131    /// [`FramedSendStream::write_subgroup_object`].
132    pub async fn write_subgroup_header(
133        &mut self,
134        header: &AnySubgroupHeader,
135    ) -> Result<(), ConnectionError> {
136        let mut buf = Vec::new();
137        header.encode(&mut buf);
138        self.inner.write_all(&buf).await?;
139        if let AnySubgroupHeader::Draft18(ref d17) = header {
140            self.subgroup_io = Some(SubgroupObjectReader::new(d17));
141        }
142        Ok(())
143    }
144
145    /// Write a fetch response header.
146    pub async fn write_fetch_header(
147        &mut self,
148        header: &AnyFetchHeader,
149    ) -> Result<(), ConnectionError> {
150        let mut buf = Vec::new();
151        header.encode(&mut buf);
152        self.inner.write_all(&buf).await?;
153        Ok(())
154    }
155
156    /// Append a draft-18 subgroup object to the stream using the
157    /// stateful writer seeded from
158    /// [`FramedSendStream::write_subgroup_header`].
159    pub async fn write_subgroup_object(
160        &mut self,
161        object: &SubgroupObject,
162    ) -> Result<(), ConnectionError> {
163        let writer = self
164            .subgroup_io
165            .as_mut()
166            .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
167        let mut buf = Vec::new();
168        writer.write_object(object, &mut buf)?;
169        self.inner.write_all(&buf).await?;
170        Ok(())
171    }
172
173    /// Finish the stream (send FIN).
174    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
175        self.inner.finish()?;
176        Ok(())
177    }
178
179    /// Returns the draft version this stream is framed for.
180    pub fn draft(&self) -> DraftVersion {
181        self.draft
182    }
183}
184
185/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
186pub struct FramedRecvStream {
187    inner: RecvStream,
188    buf: BytesMut,
189    draft: DraftVersion,
190    /// Stateful subgroup object reader.
191    subgroup_io: Option<SubgroupObjectReader>,
192}
193
194impl FramedRecvStream {
195    /// Create a new framed receive stream for the given draft version.
196    pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
197        Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
198    }
199
200    /// Get the transport-level stream ID.
201    pub fn stream_id(&self) -> u64 {
202        self.inner.stream_id()
203    }
204
205    /// Read more data from the stream into the internal buffer.
206    async fn fill(&mut self) -> Result<bool, ConnectionError> {
207        let mut tmp = [0u8; 4096];
208        match self.inner.read(&mut tmp).await {
209            Ok(Some(n)) => {
210                self.buf.extend_from_slice(&tmp[..n]);
211                Ok(true)
212            }
213            Ok(None) => Ok(false),
214            Err(e) => Err(ConnectionError::Transport(e)),
215        }
216    }
217
218    /// Ensure at least `n` bytes are available in the buffer.
219    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
220        while self.buf.len() < n {
221            if !self.fill().await? {
222                return Err(ConnectionError::UnexpectedEnd);
223            }
224        }
225        Ok(())
226    }
227
228    /// Read a control message from the stream.
229    ///
230    /// When `capture_raw` is true, the returned tuple includes a clone of the
231    /// framed wire bytes (for observer emission). When false, the second
232    /// element is `None` and the payload clone is skipped.
233    pub async fn read_control(
234        &mut self,
235        capture_raw: bool,
236    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
237        // Read type ID varint
238        self.ensure(1).await?;
239        let type_len = varint_len(self.buf[0]);
240        self.ensure(type_len).await?;
241
242        let mut cursor = &self.buf[..type_len];
243        let _type_id = VarInt::decode(&mut cursor)?;
244
245        // Draft-18: 16-bit BE payload length
246        let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
247            self.ensure(type_len + 2).await?;
248            let hi = self.buf[type_len] as usize;
249            let lo = self.buf[type_len + 1] as usize;
250            ((hi << 8) | lo, 2)
251        } else {
252            self.ensure(type_len + 1).await?;
253            let payload_len_start = type_len;
254            let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
255            self.ensure(type_len + payload_len_varint_len).await?;
256            let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
257            let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
258            (payload_len, payload_len_varint_len)
259        };
260
261        // Read full payload
262        let total = type_len + len_field_size + payload_len;
263        self.ensure(total).await?;
264
265        // Capture raw bytes only if requested (observer attached).
266        let raw = capture_raw.then(|| self.buf[..total].to_vec());
267
268        // Now decode the whole message
269        let mut frame = &self.buf[..total];
270        let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
271        self.buf.advance(total);
272        Ok((msg, raw))
273    }
274
275    /// Read a subgroup stream header. Also initializes the internal
276    /// delta-decoding state.
277    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
278        self.ensure(1).await?;
279        loop {
280            let mut cursor = &self.buf[..];
281            match AnySubgroupHeader::decode(self.draft, &mut cursor) {
282                Ok(header) => {
283                    let consumed = self.buf.len() - cursor.remaining();
284                    self.buf.advance(consumed);
285                    if let AnySubgroupHeader::Draft18(ref d17) = header {
286                        self.subgroup_io = Some(SubgroupObjectReader::new(d17));
287                    }
288                    return Ok(header);
289                }
290                Err(CodecError::UnexpectedEnd) => {
291                    if !self.fill().await? {
292                        return Err(ConnectionError::UnexpectedEnd);
293                    }
294                }
295                Err(e) => return Err(ConnectionError::Codec(e)),
296            }
297        }
298    }
299
300    /// Read a fetch response header.
301    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
302        self.ensure(1).await?;
303        loop {
304            let mut cursor = &self.buf[..];
305            match AnyFetchHeader::decode(self.draft, &mut cursor) {
306                Ok(header) => {
307                    let consumed = self.buf.len() - cursor.remaining();
308                    self.buf.advance(consumed);
309                    return Ok(header);
310                }
311                Err(CodecError::UnexpectedEnd) => {
312                    if !self.fill().await? {
313                        return Err(ConnectionError::UnexpectedEnd);
314                    }
315                }
316                Err(e) => return Err(ConnectionError::Codec(e)),
317            }
318        }
319    }
320
321    /// Read the next draft-18 subgroup object from this stream using
322    /// the stateful reader seeded by
323    /// [`FramedRecvStream::read_subgroup_header`].
324    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
325        if self.subgroup_io.is_none() {
326            return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
327        }
328        loop {
329            let reader = self.subgroup_io.as_mut().unwrap();
330            let mut probe = reader.clone();
331            let mut cursor = &self.buf[..];
332            match probe.read_object(&mut cursor) {
333                Ok(obj) => {
334                    let consumed = self.buf.len() - cursor.remaining();
335                    self.buf.advance(consumed);
336                    *reader = probe;
337                    return Ok(obj);
338                }
339                Err(CodecError::UnexpectedEnd) => {
340                    if !self.fill().await? {
341                        return Err(ConnectionError::UnexpectedEnd);
342                    }
343                }
344                Err(e) => return Err(ConnectionError::Codec(e)),
345            }
346        }
347    }
348
349    /// Read the next draft-18 fetch header from this stream.
350    pub async fn read_fetch_stream_header(&mut self) -> Result<FetchHeader, ConnectionError> {
351        loop {
352            let mut cursor = &self.buf[..];
353            match FetchHeader::decode(&mut cursor) {
354                Ok(hdr) => {
355                    let consumed = self.buf.len() - cursor.remaining();
356                    self.buf.advance(consumed);
357                    return Ok(hdr);
358                }
359                Err(CodecError::UnexpectedEnd) => {
360                    if !self.fill().await? {
361                        return Err(ConnectionError::UnexpectedEnd);
362                    }
363                }
364                Err(e) => return Err(ConnectionError::Codec(e)),
365            }
366        }
367    }
368
369    /// Returns the draft version this stream is framed for.
370    pub fn draft(&self) -> DraftVersion {
371        self.draft
372    }
373}
374
375/// A live MoQT connection over QUIC or WebTransport, combining the endpoint
376/// state machine with actual network I/O.
377pub struct Connection {
378    transport: Transport,
379    endpoint: Endpoint,
380    draft: DraftVersion,
381    control_send: Option<FramedSendStream>,
382    control_recv: Option<FramedRecvStream>,
383    observer: Option<Box<dyn ConnectionObserver>>,
384    /// Setup events buffered during `connect()` and replayed when an
385    /// observer attaches via `set_observer` — without this, an observer
386    /// attached after `connect` returns would never see the handshake.
387    pending_events: Vec<ClientEvent>,
388}
389
390impl Connection {
391    /// Connect to a MoQT server as a client.
392    ///
393    /// Establishes a QUIC or WebTransport connection (based on
394    /// `config.transport`), opens a bidirectional control stream,
395    /// performs the CLIENT_SETUP / SERVER_SETUP handshake, and returns
396    /// a ready-to-use connection.
397    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
398        let draft = config.draft;
399        let transport = match &config.transport {
400            TransportType::Quic => Self::connect_quic(addr, &config).await?,
401            TransportType::WebTransport { url } => {
402                let url = url.clone();
403                Self::connect_webtransport(&url, &config).await?
404            }
405        };
406
407        // Open bidirectional control stream
408        let (send, recv) = transport.open_bi().await?;
409        let mut control_send = FramedSendStream::new(send, draft);
410        let mut control_recv = FramedRecvStream::new(recv, draft);
411
412        // Perform setup handshake (draft-18: no versions)
413        let mut endpoint = Endpoint::new(Role::Client);
414        endpoint.connect()?;
415        let setup_msg = endpoint.send_setup(config.setup_parameters.clone())?;
416        let any_setup = AnyControlMessage::Draft18(setup_msg);
417        let raw_setup = control_send.write_control(&any_setup).await?;
418
419        let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
420        // Unified SETUP in draft-18: server responds with the same message type.
421        match &server_setup {
422            AnyControlMessage::Draft18(ControlMessage::Setup(ref s)) => {
423                endpoint.receive_setup(s)?;
424            }
425            _ => {
426                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
427            }
428        }
429
430        let pending_events = vec![
431            ClientEvent::ControlMessage {
432                direction: Direction::Send,
433                message: any_setup,
434                raw: Some(raw_setup),
435            },
436            ClientEvent::ControlMessage {
437                direction: Direction::Receive,
438                message: server_setup,
439                raw: raw_server_setup,
440            },
441            ClientEvent::SetupComplete { negotiated_version: 0xff000000 + 18 },
442        ];
443
444        Ok(Self {
445            transport,
446            endpoint,
447            draft,
448            control_send: Some(control_send),
449            control_recv: Some(control_recv),
450            observer: None,
451            pending_events,
452        })
453    }
454
455    /// Establish a raw QUIC connection.
456    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
457        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
458            ConnectionError::InvalidAddress(e.to_string())
459        })?;
460
461        // Build TLS config
462        let mut tls_config = if config.skip_cert_verification {
463            rustls::ClientConfig::builder()
464                .dangerous()
465                .with_custom_certificate_verifier(Arc::new(SkipVerification))
466                .with_no_client_auth()
467        } else {
468            let mut roots = rustls::RootCertStore::empty();
469            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
470            for der in &config.ca_certs {
471                roots
472                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
473                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
474            }
475            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
476        };
477
478        tls_config.alpn_protocols = config.alpn();
479
480        let quic_config: quinn::crypto::rustls::QuicClientConfig =
481            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
482        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
483
484        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
485            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
486        quinn_endpoint.set_default_client_config(client_config);
487
488        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
489
490        let quic = quinn_endpoint
491            .connect(server_addr, &server_name)
492            .map_err(TransportError::from)?
493            .await
494            .map_err(TransportError::from)?;
495
496        Ok(Transport::Quic(QuicTransport::new(quic)))
497    }
498
499    /// Establish a WebTransport connection.
500    #[cfg(feature = "webtransport")]
501    async fn connect_webtransport(
502        url: &str,
503        config: &ClientConfig,
504    ) -> Result<Transport, ConnectionError> {
505        use crate::transport::webtransport::WebTransportTransport;
506
507        let wt_config = if config.skip_cert_verification {
508            wtransport::ClientConfig::builder()
509                .with_bind_default()
510                .with_no_cert_validation()
511                .build()
512        } else {
513            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
514        };
515
516        let endpoint = wtransport::Endpoint::client(wt_config)
517            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
518
519        let connection = endpoint
520            .connect(url)
521            .await
522            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
523
524        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
525    }
526
527    /// Stub for when the webtransport feature is not enabled.
528    #[cfg(not(feature = "webtransport"))]
529    async fn connect_webtransport(
530        _url: &str,
531        _config: &ClientConfig,
532    ) -> Result<Transport, ConnectionError> {
533        Err(ConnectionError::Transport(TransportError::Connect(
534            "webtransport feature not enabled".into(),
535        )))
536    }
537
538    // -- Observer ---------------------------------------------------
539
540    /// Attach an observer. Buffered handshake events from `connect()` are
541    /// flushed in arrival order before this returns.
542    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
543        self.observer = Some(observer);
544        for event in self.pending_events.drain(..) {
545            if let Some(ref obs) = self.observer {
546                obs.on_event_owned(event);
547            }
548        }
549    }
550
551    /// Remove the observer.
552    pub fn clear_observer(&mut self) {
553        self.observer = None;
554    }
555
556    /// Emit an event to the observer, if one is attached.
557    fn emit(&self, event: ClientEvent) {
558        if let Some(ref obs) = self.observer {
559            obs.on_event_owned(event);
560        }
561    }
562
563    // -- Control message I/O ----------------------------------------
564
565    /// Send a control message on the control stream.
566    ///
567    /// Wraps the draft-18 message in `AnyControlMessage::Draft18` for
568    /// framing.
569    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
570        let any = AnyControlMessage::Draft18(msg.clone());
571        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
572        let raw = send.write_control(&any).await?;
573        self.emit(ClientEvent::ControlMessage {
574            direction: Direction::Send,
575            message: any,
576            raw: Some(raw),
577        });
578        Ok(())
579    }
580
581    /// Read the next control message from the control stream.
582    ///
583    /// Returns the `AnyControlMessage` and also extracts the draft-18
584    /// `ControlMessage` for internal endpoint dispatch.
585    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
586        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
587        let capture_raw = self.observer.is_some();
588        let (any, raw) = recv.read_control(capture_raw).await?;
589        if capture_raw {
590            self.emit(ClientEvent::ControlMessage {
591                direction: Direction::Receive,
592                message: any.clone(),
593                raw,
594            });
595        }
596        // Unwrap to draft-18 for the endpoint
597        match any {
598            AnyControlMessage::Draft18(msg) => Ok(msg),
599            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
600        }
601    }
602
603    /// Read and dispatch the next incoming control message through the
604    /// endpoint state machine. Returns the decoded message for inspection.
605    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
606        let msg = self.recv_control().await?;
607        self.endpoint.receive_message(msg.clone())?;
608
609        // Emit draining event if this was a GoAway
610        if let ControlMessage::GoAway(ref ga) = msg {
611            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
612        }
613
614        Ok(msg)
615    }
616
617    // -- Subscribe flow ---------------------------------------------
618
619    /// Send a SUBSCRIBE and return the allocated request ID.
620    pub async fn subscribe(
621        &mut self,
622        track_namespace: TrackNamespace,
623        track_name: Vec<u8>,
624        parameters: Vec<KeyValuePair>,
625    ) -> Result<VarInt, ConnectionError> {
626        let (req_id, msg) = self.endpoint.subscribe(track_namespace, track_name, parameters)?;
627        self.send_control(&msg).await?;
628        Ok(req_id)
629    }
630
631    // Draft-18 inherits draft-17's removal of UNSUBSCRIBE: subscribers end
632    // via RequestUpdate or by waiting for PublishDone.
633
634    // -- Fetch flow -------------------------------------------------
635
636    /// Send a standalone FETCH and return the allocated request ID.
637    pub async fn fetch(
638        &mut self,
639        track_namespace: TrackNamespace,
640        track_name: Vec<u8>,
641        start_group: VarInt,
642        start_object: VarInt,
643        end_group: VarInt,
644        end_object: VarInt,
645    ) -> Result<VarInt, ConnectionError> {
646        let (req_id, msg) = self.endpoint.fetch(
647            track_namespace,
648            track_name,
649            start_group,
650            start_object,
651            end_group,
652            end_object,
653        )?;
654        self.send_control(&msg).await?;
655        Ok(req_id)
656    }
657
658    /// Send a joining FETCH and return the allocated request ID.
659    pub async fn joining_fetch(
660        &mut self,
661        joining_request_id: VarInt,
662        joining_start: VarInt,
663    ) -> Result<VarInt, ConnectionError> {
664        let (req_id, msg) = self.endpoint.joining_fetch(joining_request_id, joining_start)?;
665        self.send_control(&msg).await?;
666        Ok(req_id)
667    }
668
669    // Draft-18 inherits draft-17's removal of FETCH_CANCEL: fetchers abort
670    // via stream reset.
671
672    // -- Namespace flows --------------------------------------------
673
674    /// Send a SUBSCRIBE_NAMESPACE and return the request ID.
675    ///
676    /// Draft-18 split the draft-17 SUBSCRIBE_NAMESPACE into two messages.
677    /// This call sends the renumbered SUBSCRIBE_NAMESPACE (type 0x50), which
678    /// subscribes to NAMESPACE / NAMESPACE_DONE announcements only. To
679    /// receive PUBLISH messages for matching tracks, use
680    /// [`Self::subscribe_tracks`] instead.
681    pub async fn subscribe_namespace(
682        &mut self,
683        namespace_prefix: TrackNamespace,
684        parameters: Vec<KeyValuePair>,
685    ) -> Result<VarInt, ConnectionError> {
686        let (req_id, msg) = self.endpoint.subscribe_namespace(namespace_prefix, parameters)?;
687        self.send_control(&msg).await?;
688        Ok(req_id)
689    }
690
691    /// Send a SUBSCRIBE_TRACKS (type 0x51, new in draft-18) and return the
692    /// request ID. Causes the relay to PUBLISH matching tracks back to us.
693    pub async fn subscribe_tracks(
694        &mut self,
695        namespace_prefix: TrackNamespace,
696        parameters: Vec<KeyValuePair>,
697    ) -> Result<VarInt, ConnectionError> {
698        let (req_id, msg) = self.endpoint.subscribe_tracks(namespace_prefix, parameters)?;
699        self.send_control(&msg).await?;
700        Ok(req_id)
701    }
702
703    /// Send a PUBLISH_NAMESPACE and return the request ID.
704    pub async fn publish_namespace(
705        &mut self,
706        track_namespace: TrackNamespace,
707        parameters: Vec<KeyValuePair>,
708    ) -> Result<VarInt, ConnectionError> {
709        let (req_id, msg) = self.endpoint.publish_namespace(track_namespace, parameters)?;
710        self.send_control(&msg).await?;
711        Ok(req_id)
712    }
713
714    // -- Track Status flow ------------------------------------------
715
716    /// Send a TRACK_STATUS and return the allocated request ID.
717    pub async fn track_status(
718        &mut self,
719        track_namespace: TrackNamespace,
720        track_name: Vec<u8>,
721        parameters: Vec<KeyValuePair>,
722    ) -> Result<VarInt, ConnectionError> {
723        let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name, parameters)?;
724        self.send_control(&msg).await?;
725        Ok(req_id)
726    }
727
728    // -- Publish flow (publisher side) ------------------------------
729
730    /// Send a PUBLISH and return the allocated request ID.
731    pub async fn publish(
732        &mut self,
733        track_namespace: TrackNamespace,
734        track_name: Vec<u8>,
735        track_alias: VarInt,
736        parameters: Vec<KeyValuePair>,
737        track_properties: Vec<KeyValuePair>,
738    ) -> Result<VarInt, ConnectionError> {
739        let (req_id, msg) = self.endpoint.publish(
740            track_namespace,
741            track_name,
742            track_alias,
743            parameters,
744            track_properties,
745        )?;
746        self.send_control(&msg).await?;
747        Ok(req_id)
748    }
749
750    /// Send a PUBLISH_DONE for the given request ID.
751    pub async fn publish_done(
752        &mut self,
753        request_id: VarInt,
754        status_code: VarInt,
755        stream_count: VarInt,
756        reason_phrase: Vec<u8>,
757    ) -> Result<(), ConnectionError> {
758        let msg = self.endpoint.send_publish_done(
759            request_id,
760            status_code,
761            stream_count,
762            reason_phrase,
763        )?;
764        self.send_control(&msg).await
765    }
766
767    // -- Data streams -----------------------------------------------
768
769    /// Open a new unidirectional stream for sending subgroup data.
770    pub async fn open_subgroup_stream(
771        &self,
772        header: &AnySubgroupHeader,
773    ) -> Result<FramedSendStream, ConnectionError> {
774        let send = self.transport.open_uni().await?;
775        let mut framed = FramedSendStream::new(send, self.draft);
776        let sid = framed.stream_id();
777        framed.write_subgroup_header(header).await?;
778        self.emit(ClientEvent::StreamOpened {
779            direction: Direction::Send,
780            stream_kind: StreamKind::Subgroup,
781            stream_id: sid,
782        });
783        self.emit(ClientEvent::DataStreamHeader {
784            stream_id: sid,
785            direction: Direction::Send,
786            header: header.clone(),
787        });
788        Ok(framed)
789    }
790
791    /// Accept an incoming unidirectional data stream and read its subgroup
792    /// header.
793    pub async fn accept_subgroup_stream(
794        &self,
795    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
796        let recv = self.transport.accept_uni().await?;
797        let mut framed = FramedRecvStream::new(recv, self.draft);
798        let sid = framed.stream_id();
799        let header = framed.read_subgroup_header().await?;
800        self.emit(ClientEvent::StreamOpened {
801            direction: Direction::Receive,
802            stream_kind: StreamKind::Subgroup,
803            stream_id: sid,
804        });
805        self.emit(ClientEvent::DataStreamHeader {
806            stream_id: sid,
807            direction: Direction::Receive,
808            header: header.clone(),
809        });
810        Ok((header, framed))
811    }
812
813    /// Send an object via datagram.
814    pub fn send_datagram(
815        &self,
816        header: &AnyDatagramHeader,
817        payload: &[u8],
818    ) -> Result<(), ConnectionError> {
819        let mut buf = Vec::new();
820        header.encode(&mut buf);
821        buf.extend_from_slice(payload);
822        self.emit(ClientEvent::DatagramReceived {
823            direction: Direction::Send,
824            header: header.clone(),
825            payload_len: payload.len(),
826        });
827        self.transport.send_datagram(bytes::Bytes::from(buf))?;
828        Ok(())
829    }
830
831    /// Receive a datagram and decode its header.
832    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
833        let data = self.transport.recv_datagram().await?;
834        let mut cursor = &data[..];
835        let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
836        let consumed = data.len() - cursor.len();
837        let payload = data.slice(consumed..);
838        self.emit(ClientEvent::DatagramReceived {
839            direction: Direction::Receive,
840            header: header.clone(),
841            payload_len: payload.len(),
842        });
843        Ok((header, payload))
844    }
845
846    // -- Accessors --------------------------------------------------
847
848    /// Access the underlying endpoint state machine.
849    pub fn endpoint(&self) -> &Endpoint {
850        &self.endpoint
851    }
852
853    /// Mutable access to the endpoint state machine.
854    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
855        &mut self.endpoint
856    }
857
858    /// Returns the draft version this connection is using.
859    pub fn draft(&self) -> DraftVersion {
860        self.draft
861    }
862
863    /// Close the connection.
864    pub fn close(&self, code: u32, reason: &[u8]) {
865        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
866        self.transport.close(code, reason);
867    }
868}
869
870/// Determine the encoded length of a varint from its first byte.
871fn varint_len(first_byte: u8) -> usize {
872    1 << (first_byte >> 6)
873}
874
875/// TLS certificate verifier that skips all verification (for testing only).
876#[derive(Debug)]
877struct SkipVerification;
878
879impl rustls::client::danger::ServerCertVerifier for SkipVerification {
880    fn verify_server_cert(
881        &self,
882        _end_entity: &rustls::pki_types::CertificateDer<'_>,
883        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
884        _server_name: &rustls::pki_types::ServerName<'_>,
885        _ocsp_response: &[u8],
886        _now: rustls::pki_types::UnixTime,
887    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
888        Ok(rustls::client::danger::ServerCertVerified::assertion())
889    }
890
891    fn verify_tls12_signature(
892        &self,
893        _message: &[u8],
894        _cert: &rustls::pki_types::CertificateDer<'_>,
895        _dcs: &rustls::DigitallySignedStruct,
896    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
897        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
898    }
899
900    fn verify_tls13_signature(
901        &self,
902        _message: &[u8],
903        _cert: &rustls::pki_types::CertificateDer<'_>,
904        _dcs: &rustls::DigitallySignedStruct,
905    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
906        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
907    }
908
909    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
910        vec![
911            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
912            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
913            rustls::SignatureScheme::ED25519,
914            rustls::SignatureScheme::RSA_PSS_SHA256,
915            rustls::SignatureScheme::RSA_PSS_SHA384,
916            rustls::SignatureScheme::RSA_PSS_SHA512,
917        ]
918    }
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924
925    #[test]
926    fn varint_len_single_byte() {
927        assert_eq!(varint_len(0x00), 1);
928        assert_eq!(varint_len(0x3F), 1);
929    }
930
931    #[test]
932    fn varint_len_two_bytes() {
933        assert_eq!(varint_len(0x40), 2);
934        assert_eq!(varint_len(0x7F), 2);
935    }
936
937    #[test]
938    fn varint_len_four_bytes() {
939        assert_eq!(varint_len(0x80), 4);
940        assert_eq!(varint_len(0xBF), 4);
941    }
942
943    #[test]
944    fn varint_len_eight_bytes() {
945        assert_eq!(varint_len(0xC0), 8);
946        assert_eq!(varint_len(0xFF), 8);
947    }
948
949    #[test]
950    fn client_config_alpn_quic_draft18() {
951        let config = ClientConfig {
952            draft: DraftVersion::Draft18,
953            transport: TransportType::Quic,
954            skip_cert_verification: false,
955            ca_certs: Vec::new(),
956            setup_parameters: Vec::new(),
957        };
958        assert_eq!(config.alpn(), vec![b"moqt-18".to_vec()]);
959    }
960
961    #[test]
962    fn client_config_alpn_webtransport() {
963        let config = ClientConfig {
964            draft: DraftVersion::Draft18,
965            transport: TransportType::WebTransport { url: "https://example.com".to_string() },
966            skip_cert_verification: false,
967            ca_certs: Vec::new(),
968            setup_parameters: Vec::new(),
969        };
970        assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
971    }
972
973    #[test]
974    fn moqt_alpn_value() {
975        assert_eq!(MOQT_ALPN, b"moq-00");
976    }
977
978    #[test]
979    fn transport_type_debug() {
980        let quic = TransportType::Quic;
981        assert!(format!("{quic:?}").contains("Quic"));
982
983        let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
984        assert!(format!("{wt:?}").contains("WebTransport"));
985    }
986}