1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft18::endpoint::{Endpoint, EndpointError};
6use crate::draft18::event::{ClientEvent, Direction, StreamKind};
7use crate::draft18::observer::ConnectionObserver;
8use crate::draft18::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12 AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft18::data_stream::{FetchHeader, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft18::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::kvp::KeyValuePair;
18use moqtap_codec::types::*;
19use moqtap_codec::varint::VarInt;
20use moqtap_codec::version::DraftVersion;
21
22pub const MOQT_ALPN: &[u8] = b"moq-00";
24
25#[derive(Debug, thiserror::Error)]
27pub enum ConnectionError {
28 #[error("endpoint error: {0}")]
30 Endpoint(#[from] EndpointError),
31 #[error("codec error: {0}")]
33 Codec(#[from] CodecError),
34 #[error("transport error: {0}")]
36 Transport(#[from] TransportError),
37 #[error("varint error: {0}")]
39 VarInt(#[from] moqtap_codec::varint::VarIntError),
40 #[error("control stream not open")]
42 NoControlStream,
43 #[error("unexpected end of stream")]
45 UnexpectedEnd,
46 #[error("stream finished")]
48 StreamFinished,
49 #[error("invalid server address: {0}")]
51 InvalidAddress(String),
52 #[error("TLS config error: {0}")]
54 TlsConfig(String),
55 #[error("data stream state error: {0}")]
57 DataStreamState(&'static str),
58}
59
60#[derive(Debug, Clone)]
62pub enum TransportType {
63 Quic,
65 WebTransport {
67 url: String,
69 },
70}
71
72pub struct ClientConfig {
76 pub draft: DraftVersion,
78 pub transport: TransportType,
80 pub skip_cert_verification: bool,
82 pub ca_certs: Vec<Vec<u8>>,
84 pub setup_parameters: Vec<KeyValuePair>,
86}
87
88impl ClientConfig {
89 pub fn alpn(&self) -> Vec<Vec<u8>> {
91 match &self.transport {
92 TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
93 TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
94 }
95 }
96}
97
98pub struct FramedSendStream {
100 inner: SendStream,
101 draft: DraftVersion,
102 subgroup_io: Option<SubgroupObjectReader>,
104}
105
106impl FramedSendStream {
107 pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
109 Self { inner, draft, subgroup_io: None }
110 }
111
112 pub fn stream_id(&self) -> u64 {
114 self.inner.stream_id()
115 }
116
117 pub async fn write_control(
120 &mut self,
121 msg: &AnyControlMessage,
122 ) -> Result<Vec<u8>, ConnectionError> {
123 let mut buf = Vec::new();
124 msg.encode(&mut buf)?;
125 self.inner.write_all(&buf).await?;
126 Ok(buf)
127 }
128
129 pub async fn write_subgroup_header(
133 &mut self,
134 header: &AnySubgroupHeader,
135 ) -> Result<(), ConnectionError> {
136 let mut buf = Vec::new();
137 header.encode(&mut buf);
138 self.inner.write_all(&buf).await?;
139 if let AnySubgroupHeader::Draft18(ref d17) = header {
140 self.subgroup_io = Some(SubgroupObjectReader::new(d17));
141 }
142 Ok(())
143 }
144
145 pub async fn write_fetch_header(
147 &mut self,
148 header: &AnyFetchHeader,
149 ) -> Result<(), ConnectionError> {
150 let mut buf = Vec::new();
151 header.encode(&mut buf);
152 self.inner.write_all(&buf).await?;
153 Ok(())
154 }
155
156 pub async fn write_subgroup_object(
160 &mut self,
161 object: &SubgroupObject,
162 ) -> Result<(), ConnectionError> {
163 let writer = self
164 .subgroup_io
165 .as_mut()
166 .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
167 let mut buf = Vec::new();
168 writer.write_object(object, &mut buf)?;
169 self.inner.write_all(&buf).await?;
170 Ok(())
171 }
172
173 pub async fn finish(&mut self) -> Result<(), ConnectionError> {
175 self.inner.finish()?;
176 Ok(())
177 }
178
179 pub fn draft(&self) -> DraftVersion {
181 self.draft
182 }
183}
184
185pub struct FramedRecvStream {
187 inner: RecvStream,
188 buf: BytesMut,
189 draft: DraftVersion,
190 subgroup_io: Option<SubgroupObjectReader>,
192}
193
194impl FramedRecvStream {
195 pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
197 Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
198 }
199
200 pub fn stream_id(&self) -> u64 {
202 self.inner.stream_id()
203 }
204
205 async fn fill(&mut self) -> Result<bool, ConnectionError> {
207 let mut tmp = [0u8; 4096];
208 match self.inner.read(&mut tmp).await {
209 Ok(Some(n)) => {
210 self.buf.extend_from_slice(&tmp[..n]);
211 Ok(true)
212 }
213 Ok(None) => Ok(false),
214 Err(e) => Err(ConnectionError::Transport(e)),
215 }
216 }
217
218 async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
220 while self.buf.len() < n {
221 if !self.fill().await? {
222 return Err(ConnectionError::UnexpectedEnd);
223 }
224 }
225 Ok(())
226 }
227
228 pub async fn read_control(
234 &mut self,
235 capture_raw: bool,
236 ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
237 self.ensure(1).await?;
239 let type_len = varint_len(self.buf[0]);
240 self.ensure(type_len).await?;
241
242 let mut cursor = &self.buf[..type_len];
243 let _type_id = VarInt::decode(&mut cursor)?;
244
245 let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
247 self.ensure(type_len + 2).await?;
248 let hi = self.buf[type_len] as usize;
249 let lo = self.buf[type_len + 1] as usize;
250 ((hi << 8) | lo, 2)
251 } else {
252 self.ensure(type_len + 1).await?;
253 let payload_len_start = type_len;
254 let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
255 self.ensure(type_len + payload_len_varint_len).await?;
256 let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
257 let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
258 (payload_len, payload_len_varint_len)
259 };
260
261 let total = type_len + len_field_size + payload_len;
263 self.ensure(total).await?;
264
265 let raw = capture_raw.then(|| self.buf[..total].to_vec());
267
268 let mut frame = &self.buf[..total];
270 let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
271 self.buf.advance(total);
272 Ok((msg, raw))
273 }
274
275 pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
278 self.ensure(1).await?;
279 loop {
280 let mut cursor = &self.buf[..];
281 match AnySubgroupHeader::decode(self.draft, &mut cursor) {
282 Ok(header) => {
283 let consumed = self.buf.len() - cursor.remaining();
284 self.buf.advance(consumed);
285 if let AnySubgroupHeader::Draft18(ref d17) = header {
286 self.subgroup_io = Some(SubgroupObjectReader::new(d17));
287 }
288 return Ok(header);
289 }
290 Err(CodecError::UnexpectedEnd) => {
291 if !self.fill().await? {
292 return Err(ConnectionError::UnexpectedEnd);
293 }
294 }
295 Err(e) => return Err(ConnectionError::Codec(e)),
296 }
297 }
298 }
299
300 pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
302 self.ensure(1).await?;
303 loop {
304 let mut cursor = &self.buf[..];
305 match AnyFetchHeader::decode(self.draft, &mut cursor) {
306 Ok(header) => {
307 let consumed = self.buf.len() - cursor.remaining();
308 self.buf.advance(consumed);
309 return Ok(header);
310 }
311 Err(CodecError::UnexpectedEnd) => {
312 if !self.fill().await? {
313 return Err(ConnectionError::UnexpectedEnd);
314 }
315 }
316 Err(e) => return Err(ConnectionError::Codec(e)),
317 }
318 }
319 }
320
321 pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
325 if self.subgroup_io.is_none() {
326 return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
327 }
328 loop {
329 let reader = self.subgroup_io.as_mut().unwrap();
330 let mut probe = reader.clone();
331 let mut cursor = &self.buf[..];
332 match probe.read_object(&mut cursor) {
333 Ok(obj) => {
334 let consumed = self.buf.len() - cursor.remaining();
335 self.buf.advance(consumed);
336 *reader = probe;
337 return Ok(obj);
338 }
339 Err(CodecError::UnexpectedEnd) => {
340 if !self.fill().await? {
341 return Err(ConnectionError::UnexpectedEnd);
342 }
343 }
344 Err(e) => return Err(ConnectionError::Codec(e)),
345 }
346 }
347 }
348
349 pub async fn read_fetch_stream_header(&mut self) -> Result<FetchHeader, ConnectionError> {
351 loop {
352 let mut cursor = &self.buf[..];
353 match FetchHeader::decode(&mut cursor) {
354 Ok(hdr) => {
355 let consumed = self.buf.len() - cursor.remaining();
356 self.buf.advance(consumed);
357 return Ok(hdr);
358 }
359 Err(CodecError::UnexpectedEnd) => {
360 if !self.fill().await? {
361 return Err(ConnectionError::UnexpectedEnd);
362 }
363 }
364 Err(e) => return Err(ConnectionError::Codec(e)),
365 }
366 }
367 }
368
369 pub fn draft(&self) -> DraftVersion {
371 self.draft
372 }
373}
374
375pub struct Connection {
378 transport: Transport,
379 endpoint: Endpoint,
380 draft: DraftVersion,
381 control_send: Option<FramedSendStream>,
382 control_recv: Option<FramedRecvStream>,
383 observer: Option<Box<dyn ConnectionObserver>>,
384 pending_events: Vec<ClientEvent>,
388}
389
390impl Connection {
391 pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
398 let draft = config.draft;
399 let transport = match &config.transport {
400 TransportType::Quic => Self::connect_quic(addr, &config).await?,
401 TransportType::WebTransport { url } => {
402 let url = url.clone();
403 Self::connect_webtransport(&url, &config).await?
404 }
405 };
406
407 let (send, recv) = transport.open_bi().await?;
409 let mut control_send = FramedSendStream::new(send, draft);
410 let mut control_recv = FramedRecvStream::new(recv, draft);
411
412 let mut endpoint = Endpoint::new(Role::Client);
414 endpoint.connect()?;
415 let setup_msg = endpoint.send_setup(config.setup_parameters.clone())?;
416 let any_setup = AnyControlMessage::Draft18(setup_msg);
417 let raw_setup = control_send.write_control(&any_setup).await?;
418
419 let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
420 match &server_setup {
422 AnyControlMessage::Draft18(ControlMessage::Setup(ref s)) => {
423 endpoint.receive_setup(s)?;
424 }
425 _ => {
426 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
427 }
428 }
429
430 let pending_events = vec![
431 ClientEvent::ControlMessage {
432 direction: Direction::Send,
433 message: any_setup,
434 raw: Some(raw_setup),
435 },
436 ClientEvent::ControlMessage {
437 direction: Direction::Receive,
438 message: server_setup,
439 raw: raw_server_setup,
440 },
441 ClientEvent::SetupComplete { negotiated_version: 0xff000000 + 18 },
442 ];
443
444 Ok(Self {
445 transport,
446 endpoint,
447 draft,
448 control_send: Some(control_send),
449 control_recv: Some(control_recv),
450 observer: None,
451 pending_events,
452 })
453 }
454
455 async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
457 let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
458 ConnectionError::InvalidAddress(e.to_string())
459 })?;
460
461 let mut tls_config = if config.skip_cert_verification {
463 rustls::ClientConfig::builder()
464 .dangerous()
465 .with_custom_certificate_verifier(Arc::new(SkipVerification))
466 .with_no_client_auth()
467 } else {
468 let mut roots = rustls::RootCertStore::empty();
469 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
470 for der in &config.ca_certs {
471 roots
472 .add(rustls::pki_types::CertificateDer::from(der.clone()))
473 .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
474 }
475 rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
476 };
477
478 tls_config.alpn_protocols = config.alpn();
479
480 let quic_config: quinn::crypto::rustls::QuicClientConfig =
481 tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
482 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
483
484 let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
485 .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
486 quinn_endpoint.set_default_client_config(client_config);
487
488 let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
489
490 let quic = quinn_endpoint
491 .connect(server_addr, &server_name)
492 .map_err(TransportError::from)?
493 .await
494 .map_err(TransportError::from)?;
495
496 Ok(Transport::Quic(QuicTransport::new(quic)))
497 }
498
499 #[cfg(feature = "webtransport")]
501 async fn connect_webtransport(
502 url: &str,
503 config: &ClientConfig,
504 ) -> Result<Transport, ConnectionError> {
505 use crate::transport::webtransport::WebTransportTransport;
506
507 let wt_config = if config.skip_cert_verification {
508 wtransport::ClientConfig::builder()
509 .with_bind_default()
510 .with_no_cert_validation()
511 .build()
512 } else {
513 wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
514 };
515
516 let endpoint = wtransport::Endpoint::client(wt_config)
517 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
518
519 let connection = endpoint
520 .connect(url)
521 .await
522 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
523
524 Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
525 }
526
527 #[cfg(not(feature = "webtransport"))]
529 async fn connect_webtransport(
530 _url: &str,
531 _config: &ClientConfig,
532 ) -> Result<Transport, ConnectionError> {
533 Err(ConnectionError::Transport(TransportError::Connect(
534 "webtransport feature not enabled".into(),
535 )))
536 }
537
538 pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
543 self.observer = Some(observer);
544 for event in self.pending_events.drain(..) {
545 if let Some(ref obs) = self.observer {
546 obs.on_event_owned(event);
547 }
548 }
549 }
550
551 pub fn clear_observer(&mut self) {
553 self.observer = None;
554 }
555
556 fn emit(&self, event: ClientEvent) {
558 if let Some(ref obs) = self.observer {
559 obs.on_event_owned(event);
560 }
561 }
562
563 pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
570 let any = AnyControlMessage::Draft18(msg.clone());
571 let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
572 let raw = send.write_control(&any).await?;
573 self.emit(ClientEvent::ControlMessage {
574 direction: Direction::Send,
575 message: any,
576 raw: Some(raw),
577 });
578 Ok(())
579 }
580
581 pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
586 let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
587 let capture_raw = self.observer.is_some();
588 let (any, raw) = recv.read_control(capture_raw).await?;
589 if capture_raw {
590 self.emit(ClientEvent::ControlMessage {
591 direction: Direction::Receive,
592 message: any.clone(),
593 raw,
594 });
595 }
596 match any {
598 AnyControlMessage::Draft18(msg) => Ok(msg),
599 _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
600 }
601 }
602
603 pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
606 let msg = self.recv_control().await?;
607 self.endpoint.receive_message(msg.clone())?;
608
609 if let ControlMessage::GoAway(ref ga) = msg {
611 self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
612 }
613
614 Ok(msg)
615 }
616
617 pub async fn subscribe(
621 &mut self,
622 track_namespace: TrackNamespace,
623 track_name: Vec<u8>,
624 parameters: Vec<KeyValuePair>,
625 ) -> Result<VarInt, ConnectionError> {
626 let (req_id, msg) = self.endpoint.subscribe(track_namespace, track_name, parameters)?;
627 self.send_control(&msg).await?;
628 Ok(req_id)
629 }
630
631 pub async fn fetch(
638 &mut self,
639 track_namespace: TrackNamespace,
640 track_name: Vec<u8>,
641 start_group: VarInt,
642 start_object: VarInt,
643 end_group: VarInt,
644 end_object: VarInt,
645 ) -> Result<VarInt, ConnectionError> {
646 let (req_id, msg) = self.endpoint.fetch(
647 track_namespace,
648 track_name,
649 start_group,
650 start_object,
651 end_group,
652 end_object,
653 )?;
654 self.send_control(&msg).await?;
655 Ok(req_id)
656 }
657
658 pub async fn joining_fetch(
660 &mut self,
661 joining_request_id: VarInt,
662 joining_start: VarInt,
663 ) -> Result<VarInt, ConnectionError> {
664 let (req_id, msg) = self.endpoint.joining_fetch(joining_request_id, joining_start)?;
665 self.send_control(&msg).await?;
666 Ok(req_id)
667 }
668
669 pub async fn subscribe_namespace(
682 &mut self,
683 namespace_prefix: TrackNamespace,
684 parameters: Vec<KeyValuePair>,
685 ) -> Result<VarInt, ConnectionError> {
686 let (req_id, msg) = self.endpoint.subscribe_namespace(namespace_prefix, parameters)?;
687 self.send_control(&msg).await?;
688 Ok(req_id)
689 }
690
691 pub async fn subscribe_tracks(
694 &mut self,
695 namespace_prefix: TrackNamespace,
696 parameters: Vec<KeyValuePair>,
697 ) -> Result<VarInt, ConnectionError> {
698 let (req_id, msg) = self.endpoint.subscribe_tracks(namespace_prefix, parameters)?;
699 self.send_control(&msg).await?;
700 Ok(req_id)
701 }
702
703 pub async fn publish_namespace(
705 &mut self,
706 track_namespace: TrackNamespace,
707 parameters: Vec<KeyValuePair>,
708 ) -> Result<VarInt, ConnectionError> {
709 let (req_id, msg) = self.endpoint.publish_namespace(track_namespace, parameters)?;
710 self.send_control(&msg).await?;
711 Ok(req_id)
712 }
713
714 pub async fn track_status(
718 &mut self,
719 track_namespace: TrackNamespace,
720 track_name: Vec<u8>,
721 parameters: Vec<KeyValuePair>,
722 ) -> Result<VarInt, ConnectionError> {
723 let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name, parameters)?;
724 self.send_control(&msg).await?;
725 Ok(req_id)
726 }
727
728 pub async fn publish(
732 &mut self,
733 track_namespace: TrackNamespace,
734 track_name: Vec<u8>,
735 track_alias: VarInt,
736 parameters: Vec<KeyValuePair>,
737 track_properties: Vec<KeyValuePair>,
738 ) -> Result<VarInt, ConnectionError> {
739 let (req_id, msg) = self.endpoint.publish(
740 track_namespace,
741 track_name,
742 track_alias,
743 parameters,
744 track_properties,
745 )?;
746 self.send_control(&msg).await?;
747 Ok(req_id)
748 }
749
750 pub async fn publish_done(
752 &mut self,
753 request_id: VarInt,
754 status_code: VarInt,
755 stream_count: VarInt,
756 reason_phrase: Vec<u8>,
757 ) -> Result<(), ConnectionError> {
758 let msg = self.endpoint.send_publish_done(
759 request_id,
760 status_code,
761 stream_count,
762 reason_phrase,
763 )?;
764 self.send_control(&msg).await
765 }
766
767 pub async fn open_subgroup_stream(
771 &self,
772 header: &AnySubgroupHeader,
773 ) -> Result<FramedSendStream, ConnectionError> {
774 let send = self.transport.open_uni().await?;
775 let mut framed = FramedSendStream::new(send, self.draft);
776 let sid = framed.stream_id();
777 framed.write_subgroup_header(header).await?;
778 self.emit(ClientEvent::StreamOpened {
779 direction: Direction::Send,
780 stream_kind: StreamKind::Subgroup,
781 stream_id: sid,
782 });
783 self.emit(ClientEvent::DataStreamHeader {
784 stream_id: sid,
785 direction: Direction::Send,
786 header: header.clone(),
787 });
788 Ok(framed)
789 }
790
791 pub async fn accept_subgroup_stream(
794 &self,
795 ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
796 let recv = self.transport.accept_uni().await?;
797 let mut framed = FramedRecvStream::new(recv, self.draft);
798 let sid = framed.stream_id();
799 let header = framed.read_subgroup_header().await?;
800 self.emit(ClientEvent::StreamOpened {
801 direction: Direction::Receive,
802 stream_kind: StreamKind::Subgroup,
803 stream_id: sid,
804 });
805 self.emit(ClientEvent::DataStreamHeader {
806 stream_id: sid,
807 direction: Direction::Receive,
808 header: header.clone(),
809 });
810 Ok((header, framed))
811 }
812
813 pub fn send_datagram(
815 &self,
816 header: &AnyDatagramHeader,
817 payload: &[u8],
818 ) -> Result<(), ConnectionError> {
819 let mut buf = Vec::new();
820 header.encode(&mut buf);
821 buf.extend_from_slice(payload);
822 self.emit(ClientEvent::DatagramReceived {
823 direction: Direction::Send,
824 header: header.clone(),
825 payload_len: payload.len(),
826 });
827 self.transport.send_datagram(bytes::Bytes::from(buf))?;
828 Ok(())
829 }
830
831 pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
833 let data = self.transport.recv_datagram().await?;
834 let mut cursor = &data[..];
835 let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
836 let consumed = data.len() - cursor.len();
837 let payload = data.slice(consumed..);
838 self.emit(ClientEvent::DatagramReceived {
839 direction: Direction::Receive,
840 header: header.clone(),
841 payload_len: payload.len(),
842 });
843 Ok((header, payload))
844 }
845
846 pub fn endpoint(&self) -> &Endpoint {
850 &self.endpoint
851 }
852
853 pub fn endpoint_mut(&mut self) -> &mut Endpoint {
855 &mut self.endpoint
856 }
857
858 pub fn draft(&self) -> DraftVersion {
860 self.draft
861 }
862
863 pub fn close(&self, code: u32, reason: &[u8]) {
865 self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
866 self.transport.close(code, reason);
867 }
868}
869
870fn varint_len(first_byte: u8) -> usize {
872 1 << (first_byte >> 6)
873}
874
875#[derive(Debug)]
877struct SkipVerification;
878
879impl rustls::client::danger::ServerCertVerifier for SkipVerification {
880 fn verify_server_cert(
881 &self,
882 _end_entity: &rustls::pki_types::CertificateDer<'_>,
883 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
884 _server_name: &rustls::pki_types::ServerName<'_>,
885 _ocsp_response: &[u8],
886 _now: rustls::pki_types::UnixTime,
887 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
888 Ok(rustls::client::danger::ServerCertVerified::assertion())
889 }
890
891 fn verify_tls12_signature(
892 &self,
893 _message: &[u8],
894 _cert: &rustls::pki_types::CertificateDer<'_>,
895 _dcs: &rustls::DigitallySignedStruct,
896 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
897 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
898 }
899
900 fn verify_tls13_signature(
901 &self,
902 _message: &[u8],
903 _cert: &rustls::pki_types::CertificateDer<'_>,
904 _dcs: &rustls::DigitallySignedStruct,
905 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
906 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
907 }
908
909 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
910 vec![
911 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
912 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
913 rustls::SignatureScheme::ED25519,
914 rustls::SignatureScheme::RSA_PSS_SHA256,
915 rustls::SignatureScheme::RSA_PSS_SHA384,
916 rustls::SignatureScheme::RSA_PSS_SHA512,
917 ]
918 }
919}
920
921#[cfg(test)]
922mod tests {
923 use super::*;
924
925 #[test]
926 fn varint_len_single_byte() {
927 assert_eq!(varint_len(0x00), 1);
928 assert_eq!(varint_len(0x3F), 1);
929 }
930
931 #[test]
932 fn varint_len_two_bytes() {
933 assert_eq!(varint_len(0x40), 2);
934 assert_eq!(varint_len(0x7F), 2);
935 }
936
937 #[test]
938 fn varint_len_four_bytes() {
939 assert_eq!(varint_len(0x80), 4);
940 assert_eq!(varint_len(0xBF), 4);
941 }
942
943 #[test]
944 fn varint_len_eight_bytes() {
945 assert_eq!(varint_len(0xC0), 8);
946 assert_eq!(varint_len(0xFF), 8);
947 }
948
949 #[test]
950 fn client_config_alpn_quic_draft18() {
951 let config = ClientConfig {
952 draft: DraftVersion::Draft18,
953 transport: TransportType::Quic,
954 skip_cert_verification: false,
955 ca_certs: Vec::new(),
956 setup_parameters: Vec::new(),
957 };
958 assert_eq!(config.alpn(), vec![b"moqt-18".to_vec()]);
959 }
960
961 #[test]
962 fn client_config_alpn_webtransport() {
963 let config = ClientConfig {
964 draft: DraftVersion::Draft18,
965 transport: TransportType::WebTransport { url: "https://example.com".to_string() },
966 skip_cert_verification: false,
967 ca_certs: Vec::new(),
968 setup_parameters: Vec::new(),
969 };
970 assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
971 }
972
973 #[test]
974 fn moqt_alpn_value() {
975 assert_eq!(MOQT_ALPN, b"moq-00");
976 }
977
978 #[test]
979 fn transport_type_debug() {
980 let quic = TransportType::Quic;
981 assert!(format!("{quic:?}").contains("Quic"));
982
983 let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
984 assert!(format!("{wt:?}").contains("WebTransport"));
985 }
986}