1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft15::endpoint::{Endpoint, EndpointError};
6use crate::draft15::event::{ClientEvent, Direction, StreamKind};
7use crate::draft15::observer::ConnectionObserver;
8use crate::draft15::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12 AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft15::data_stream::{FetchHeader, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft15::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::kvp::KeyValuePair;
18use moqtap_codec::types::*;
19use moqtap_codec::varint::VarInt;
20use moqtap_codec::version::DraftVersion;
21
22pub const MOQT_ALPN: &[u8] = b"moq-00";
24
25#[derive(Debug, thiserror::Error)]
27pub enum ConnectionError {
28 #[error("endpoint error: {0}")]
30 Endpoint(#[from] EndpointError),
31 #[error("codec error: {0}")]
33 Codec(#[from] CodecError),
34 #[error("transport error: {0}")]
36 Transport(#[from] TransportError),
37 #[error("varint error: {0}")]
39 VarInt(#[from] moqtap_codec::varint::VarIntError),
40 #[error("control stream not open")]
42 NoControlStream,
43 #[error("unexpected end of stream")]
45 UnexpectedEnd,
46 #[error("stream finished")]
48 StreamFinished,
49 #[error("invalid server address: {0}")]
51 InvalidAddress(String),
52 #[error("TLS config error: {0}")]
54 TlsConfig(String),
55 #[error("data stream state error: {0}")]
57 DataStreamState(&'static str),
58}
59
60#[derive(Debug, Clone)]
62pub enum TransportType {
63 Quic,
65 WebTransport {
67 url: String,
69 },
70}
71
72pub struct ClientConfig {
76 pub draft: DraftVersion,
78 pub transport: TransportType,
80 pub skip_cert_verification: bool,
82 pub ca_certs: Vec<Vec<u8>>,
84 pub setup_parameters: Vec<KeyValuePair>,
86}
87
88impl ClientConfig {
89 pub fn alpn(&self) -> Vec<Vec<u8>> {
91 match &self.transport {
92 TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
93 TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
94 }
95 }
96}
97
98pub struct FramedSendStream {
100 inner: SendStream,
101 draft: DraftVersion,
102 subgroup_io: Option<SubgroupObjectReader>,
106}
107
108impl FramedSendStream {
109 pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
111 Self { inner, draft, subgroup_io: None }
112 }
113
114 pub fn stream_id(&self) -> u64 {
116 self.inner.stream_id()
117 }
118
119 pub async fn write_control(
122 &mut self,
123 msg: &AnyControlMessage,
124 ) -> Result<Vec<u8>, ConnectionError> {
125 let mut buf = Vec::new();
126 msg.encode(&mut buf)?;
127 self.inner.write_all(&buf).await?;
128 Ok(buf)
129 }
130
131 pub async fn write_subgroup_header(
135 &mut self,
136 header: &AnySubgroupHeader,
137 ) -> Result<(), ConnectionError> {
138 let mut buf = Vec::new();
139 header.encode(&mut buf);
140 self.inner.write_all(&buf).await?;
141 if let AnySubgroupHeader::Draft15(ref d15) = header {
142 self.subgroup_io = Some(SubgroupObjectReader::new(d15));
143 }
144 Ok(())
145 }
146
147 pub async fn write_fetch_header(
149 &mut self,
150 header: &AnyFetchHeader,
151 ) -> Result<(), ConnectionError> {
152 let mut buf = Vec::new();
153 header.encode(&mut buf);
154 self.inner.write_all(&buf).await?;
155 Ok(())
156 }
157
158 pub async fn write_subgroup_object(
163 &mut self,
164 object: &SubgroupObject,
165 ) -> Result<(), ConnectionError> {
166 let writer = self
167 .subgroup_io
168 .as_mut()
169 .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
170 let mut buf = Vec::new();
171 writer.write_object(object, &mut buf)?;
172 self.inner.write_all(&buf).await?;
173 Ok(())
174 }
175
176 pub async fn finish(&mut self) -> Result<(), ConnectionError> {
178 self.inner.finish()?;
179 Ok(())
180 }
181
182 pub fn draft(&self) -> DraftVersion {
184 self.draft
185 }
186}
187
188pub struct FramedRecvStream {
190 inner: RecvStream,
191 buf: BytesMut,
192 draft: DraftVersion,
193 subgroup_io: Option<SubgroupObjectReader>,
196}
197
198impl FramedRecvStream {
199 pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
201 Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
202 }
203
204 pub fn stream_id(&self) -> u64 {
206 self.inner.stream_id()
207 }
208
209 async fn fill(&mut self) -> Result<bool, ConnectionError> {
211 let mut tmp = [0u8; 4096];
212 match self.inner.read(&mut tmp).await {
213 Ok(Some(n)) => {
214 self.buf.extend_from_slice(&tmp[..n]);
215 Ok(true)
216 }
217 Ok(None) => Ok(false),
218 Err(e) => Err(ConnectionError::Transport(e)),
219 }
220 }
221
222 async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
224 while self.buf.len() < n {
225 if !self.fill().await? {
226 return Err(ConnectionError::UnexpectedEnd);
227 }
228 }
229 Ok(())
230 }
231
232 pub async fn read_control(
238 &mut self,
239 capture_raw: bool,
240 ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
241 self.ensure(1).await?;
243 let type_len = varint_len(self.buf[0]);
244 self.ensure(type_len).await?;
245
246 let mut cursor = &self.buf[..type_len];
247 let _type_id = VarInt::decode(&mut cursor)?;
248
249 let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
251 self.ensure(type_len + 2).await?;
252 let hi = self.buf[type_len] as usize;
253 let lo = self.buf[type_len + 1] as usize;
254 ((hi << 8) | lo, 2)
255 } else {
256 self.ensure(type_len + 1).await?;
257 let payload_len_start = type_len;
258 let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
259 self.ensure(type_len + payload_len_varint_len).await?;
260 let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
261 let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
262 (payload_len, payload_len_varint_len)
263 };
264
265 let total = type_len + len_field_size + payload_len;
267 self.ensure(total).await?;
268
269 let raw = capture_raw.then(|| self.buf[..total].to_vec());
271
272 let mut frame = &self.buf[..total];
274 let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
275 self.buf.advance(total);
276 Ok((msg, raw))
277 }
278
279 pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
282 self.ensure(1).await?;
283 loop {
284 let mut cursor = &self.buf[..];
285 match AnySubgroupHeader::decode(self.draft, &mut cursor) {
286 Ok(header) => {
287 let consumed = self.buf.len() - cursor.remaining();
288 self.buf.advance(consumed);
289 if let AnySubgroupHeader::Draft15(ref d15) = header {
290 self.subgroup_io = Some(SubgroupObjectReader::new(d15));
291 }
292 return Ok(header);
293 }
294 Err(CodecError::UnexpectedEnd) => {
295 if !self.fill().await? {
296 return Err(ConnectionError::UnexpectedEnd);
297 }
298 }
299 Err(e) => return Err(ConnectionError::Codec(e)),
300 }
301 }
302 }
303
304 pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
306 self.ensure(1).await?;
307 loop {
308 let mut cursor = &self.buf[..];
309 match AnyFetchHeader::decode(self.draft, &mut cursor) {
310 Ok(header) => {
311 let consumed = self.buf.len() - cursor.remaining();
312 self.buf.advance(consumed);
313 return Ok(header);
314 }
315 Err(CodecError::UnexpectedEnd) => {
316 if !self.fill().await? {
317 return Err(ConnectionError::UnexpectedEnd);
318 }
319 }
320 Err(e) => return Err(ConnectionError::Codec(e)),
321 }
322 }
323 }
324
325 pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
332 if self.subgroup_io.is_none() {
333 return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
334 }
335 loop {
336 let reader = self.subgroup_io.as_mut().unwrap();
337 let mut probe = reader.clone();
338 let mut cursor = &self.buf[..];
339 match probe.read_object(&mut cursor) {
340 Ok(obj) => {
341 let consumed = self.buf.len() - cursor.remaining();
342 self.buf.advance(consumed);
343 *reader = probe;
344 return Ok(obj);
345 }
346 Err(CodecError::UnexpectedEnd) => {
347 if !self.fill().await? {
348 return Err(ConnectionError::UnexpectedEnd);
349 }
350 }
351 Err(e) => return Err(ConnectionError::Codec(e)),
352 }
353 }
354 }
355
356 pub async fn read_fetch_stream_header(&mut self) -> Result<FetchHeader, ConnectionError> {
358 loop {
359 let mut cursor = &self.buf[..];
360 match FetchHeader::decode(&mut cursor) {
361 Ok(hdr) => {
362 let consumed = self.buf.len() - cursor.remaining();
363 self.buf.advance(consumed);
364 return Ok(hdr);
365 }
366 Err(CodecError::UnexpectedEnd) => {
367 if !self.fill().await? {
368 return Err(ConnectionError::UnexpectedEnd);
369 }
370 }
371 Err(e) => return Err(ConnectionError::Codec(e)),
372 }
373 }
374 }
375
376 pub fn draft(&self) -> DraftVersion {
378 self.draft
379 }
380}
381
382pub struct Connection {
385 transport: Transport,
386 endpoint: Endpoint,
387 draft: DraftVersion,
388 control_send: Option<FramedSendStream>,
389 control_recv: Option<FramedRecvStream>,
390 observer: Option<Box<dyn ConnectionObserver>>,
391 pending_events: Vec<ClientEvent>,
395}
396
397impl Connection {
398 pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
405 let draft = config.draft;
406 let transport = match &config.transport {
407 TransportType::Quic => Self::connect_quic(addr, &config).await?,
408 TransportType::WebTransport { url } => {
409 let url = url.clone();
410 Self::connect_webtransport(&url, &config).await?
411 }
412 };
413
414 let (send, recv) = transport.open_bi().await?;
416 let mut control_send = FramedSendStream::new(send, draft);
417 let mut control_recv = FramedRecvStream::new(recv, draft);
418
419 let mut endpoint = Endpoint::new(Role::Client);
421 endpoint.connect()?;
422 let setup_msg = endpoint.send_client_setup(config.setup_parameters.clone())?;
423 let any_setup = AnyControlMessage::Draft15(setup_msg);
424 let raw_setup = control_send.write_control(&any_setup).await?;
425
426 let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
427 match &server_setup {
429 AnyControlMessage::Draft15(ControlMessage::ServerSetup(ref ss)) => {
430 endpoint.receive_server_setup(ss)?;
431 }
432 _ => {
433 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
434 }
435 }
436
437 let pending_events = vec![
438 ClientEvent::ControlMessage {
439 direction: Direction::Send,
440 message: any_setup,
441 raw: Some(raw_setup),
442 },
443 ClientEvent::ControlMessage {
444 direction: Direction::Receive,
445 message: server_setup,
446 raw: raw_server_setup,
447 },
448 ClientEvent::SetupComplete { negotiated_version: 0xff000000 + 15 },
449 ];
450
451 Ok(Self {
452 transport,
453 endpoint,
454 draft,
455 control_send: Some(control_send),
456 control_recv: Some(control_recv),
457 observer: None,
458 pending_events,
459 })
460 }
461
462 async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
464 let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
465 ConnectionError::InvalidAddress(e.to_string())
466 })?;
467
468 let mut tls_config = if config.skip_cert_verification {
470 rustls::ClientConfig::builder()
471 .dangerous()
472 .with_custom_certificate_verifier(Arc::new(SkipVerification))
473 .with_no_client_auth()
474 } else {
475 let mut roots = rustls::RootCertStore::empty();
476 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
477 for der in &config.ca_certs {
478 roots
479 .add(rustls::pki_types::CertificateDer::from(der.clone()))
480 .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
481 }
482 rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
483 };
484
485 tls_config.alpn_protocols = config.alpn();
486
487 let quic_config: quinn::crypto::rustls::QuicClientConfig =
488 tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
489 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
490
491 let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
492 .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
493 quinn_endpoint.set_default_client_config(client_config);
494
495 let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
496
497 let quic = quinn_endpoint
498 .connect(server_addr, &server_name)
499 .map_err(TransportError::from)?
500 .await
501 .map_err(TransportError::from)?;
502
503 Ok(Transport::Quic(QuicTransport::new(quic)))
504 }
505
506 #[cfg(feature = "webtransport")]
508 async fn connect_webtransport(
509 url: &str,
510 config: &ClientConfig,
511 ) -> Result<Transport, ConnectionError> {
512 use crate::transport::webtransport::WebTransportTransport;
513
514 let wt_config = if config.skip_cert_verification {
515 wtransport::ClientConfig::builder()
516 .with_bind_default()
517 .with_no_cert_validation()
518 .build()
519 } else {
520 wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
521 };
522
523 let endpoint = wtransport::Endpoint::client(wt_config)
524 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
525
526 let connection = endpoint
527 .connect(url)
528 .await
529 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
530
531 Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
532 }
533
534 #[cfg(not(feature = "webtransport"))]
536 async fn connect_webtransport(
537 _url: &str,
538 _config: &ClientConfig,
539 ) -> Result<Transport, ConnectionError> {
540 Err(ConnectionError::Transport(TransportError::Connect(
541 "webtransport feature not enabled".into(),
542 )))
543 }
544
545 pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
550 self.observer = Some(observer);
551 for event in self.pending_events.drain(..) {
552 if let Some(ref obs) = self.observer {
553 obs.on_event_owned(event);
554 }
555 }
556 }
557
558 pub fn clear_observer(&mut self) {
560 self.observer = None;
561 }
562
563 fn emit(&self, event: ClientEvent) {
565 if let Some(ref obs) = self.observer {
566 obs.on_event_owned(event);
567 }
568 }
569
570 pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
577 let any = AnyControlMessage::Draft15(msg.clone());
578 let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
579 let raw = send.write_control(&any).await?;
580 self.emit(ClientEvent::ControlMessage {
581 direction: Direction::Send,
582 message: any,
583 raw: Some(raw),
584 });
585 Ok(())
586 }
587
588 pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
593 let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
594 let capture_raw = self.observer.is_some();
595 let (any, raw) = recv.read_control(capture_raw).await?;
596 if capture_raw {
597 self.emit(ClientEvent::ControlMessage {
598 direction: Direction::Receive,
599 message: any.clone(),
600 raw,
601 });
602 }
603 match any {
605 AnyControlMessage::Draft15(msg) => Ok(msg),
606 _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
607 }
608 }
609
610 pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
613 let msg = self.recv_control().await?;
614 self.endpoint.receive_message(msg.clone())?;
615
616 if let ControlMessage::GoAway(ref ga) = msg {
618 self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
619 }
620
621 Ok(msg)
622 }
623
624 pub async fn subscribe(
628 &mut self,
629 track_namespace: TrackNamespace,
630 track_name: Vec<u8>,
631 parameters: Vec<KeyValuePair>,
632 ) -> Result<VarInt, ConnectionError> {
633 let (req_id, msg) = self.endpoint.subscribe(track_namespace, track_name, parameters)?;
634 self.send_control(&msg).await?;
635 Ok(req_id)
636 }
637
638 pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
640 let msg = self.endpoint.unsubscribe(request_id)?;
641 self.send_control(&msg).await
642 }
643
644 pub async fn fetch(
648 &mut self,
649 track_namespace: TrackNamespace,
650 track_name: Vec<u8>,
651 start_group: VarInt,
652 start_object: VarInt,
653 end_group: VarInt,
654 end_object: VarInt,
655 ) -> Result<VarInt, ConnectionError> {
656 let (req_id, msg) = self.endpoint.fetch(
657 track_namespace,
658 track_name,
659 start_group,
660 start_object,
661 end_group,
662 end_object,
663 )?;
664 self.send_control(&msg).await?;
665 Ok(req_id)
666 }
667
668 pub async fn joining_fetch(
670 &mut self,
671 joining_request_id: VarInt,
672 joining_start: VarInt,
673 ) -> Result<VarInt, ConnectionError> {
674 let (req_id, msg) = self.endpoint.joining_fetch(joining_request_id, joining_start)?;
675 self.send_control(&msg).await?;
676 Ok(req_id)
677 }
678
679 pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
681 let msg = self.endpoint.fetch_cancel(request_id)?;
682 self.send_control(&msg).await
683 }
684
685 pub async fn subscribe_namespace(
689 &mut self,
690 namespace_prefix: TrackNamespace,
691 parameters: Vec<KeyValuePair>,
692 ) -> Result<VarInt, ConnectionError> {
693 let (req_id, msg) = self.endpoint.subscribe_namespace(namespace_prefix, parameters)?;
694 self.send_control(&msg).await?;
695 Ok(req_id)
696 }
697
698 pub async fn publish_namespace(
700 &mut self,
701 track_namespace: TrackNamespace,
702 parameters: Vec<KeyValuePair>,
703 ) -> Result<VarInt, ConnectionError> {
704 let (req_id, msg) = self.endpoint.publish_namespace(track_namespace, parameters)?;
705 self.send_control(&msg).await?;
706 Ok(req_id)
707 }
708
709 pub async fn track_status(
713 &mut self,
714 track_namespace: TrackNamespace,
715 track_name: Vec<u8>,
716 parameters: Vec<KeyValuePair>,
717 ) -> Result<VarInt, ConnectionError> {
718 let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name, parameters)?;
719 self.send_control(&msg).await?;
720 Ok(req_id)
721 }
722
723 pub async fn publish(
727 &mut self,
728 track_namespace: TrackNamespace,
729 track_name: Vec<u8>,
730 track_alias: VarInt,
731 parameters: Vec<KeyValuePair>,
732 ) -> Result<VarInt, ConnectionError> {
733 let (req_id, msg) =
734 self.endpoint.publish(track_namespace, track_name, track_alias, parameters)?;
735 self.send_control(&msg).await?;
736 Ok(req_id)
737 }
738
739 pub async fn publish_done(
741 &mut self,
742 request_id: VarInt,
743 status_code: VarInt,
744 stream_count: VarInt,
745 reason_phrase: Vec<u8>,
746 ) -> Result<(), ConnectionError> {
747 let msg = self.endpoint.send_publish_done(
748 request_id,
749 status_code,
750 stream_count,
751 reason_phrase,
752 )?;
753 self.send_control(&msg).await
754 }
755
756 pub async fn open_subgroup_stream(
760 &self,
761 header: &AnySubgroupHeader,
762 ) -> Result<FramedSendStream, ConnectionError> {
763 let send = self.transport.open_uni().await?;
764 let mut framed = FramedSendStream::new(send, self.draft);
765 let sid = framed.stream_id();
766 framed.write_subgroup_header(header).await?;
767 self.emit(ClientEvent::StreamOpened {
768 direction: Direction::Send,
769 stream_kind: StreamKind::Subgroup,
770 stream_id: sid,
771 });
772 self.emit(ClientEvent::DataStreamHeader {
773 stream_id: sid,
774 direction: Direction::Send,
775 header: header.clone(),
776 });
777 Ok(framed)
778 }
779
780 pub async fn accept_subgroup_stream(
783 &self,
784 ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
785 let recv = self.transport.accept_uni().await?;
786 let mut framed = FramedRecvStream::new(recv, self.draft);
787 let sid = framed.stream_id();
788 let header = framed.read_subgroup_header().await?;
789 self.emit(ClientEvent::StreamOpened {
790 direction: Direction::Receive,
791 stream_kind: StreamKind::Subgroup,
792 stream_id: sid,
793 });
794 self.emit(ClientEvent::DataStreamHeader {
795 stream_id: sid,
796 direction: Direction::Receive,
797 header: header.clone(),
798 });
799 Ok((header, framed))
800 }
801
802 pub fn send_datagram(
804 &self,
805 header: &AnyDatagramHeader,
806 payload: &[u8],
807 ) -> Result<(), ConnectionError> {
808 let mut buf = Vec::new();
809 header.encode(&mut buf);
810 buf.extend_from_slice(payload);
811 self.emit(ClientEvent::DatagramReceived {
812 direction: Direction::Send,
813 header: header.clone(),
814 payload_len: payload.len(),
815 });
816 self.transport.send_datagram(bytes::Bytes::from(buf))?;
817 Ok(())
818 }
819
820 pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
822 let data = self.transport.recv_datagram().await?;
823 let mut cursor = &data[..];
824 let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
825 let consumed = data.len() - cursor.len();
826 let payload = data.slice(consumed..);
827 self.emit(ClientEvent::DatagramReceived {
828 direction: Direction::Receive,
829 header: header.clone(),
830 payload_len: payload.len(),
831 });
832 Ok((header, payload))
833 }
834
835 pub fn endpoint(&self) -> &Endpoint {
839 &self.endpoint
840 }
841
842 pub fn endpoint_mut(&mut self) -> &mut Endpoint {
844 &mut self.endpoint
845 }
846
847 pub fn draft(&self) -> DraftVersion {
849 self.draft
850 }
851
852 pub fn close(&self, code: u32, reason: &[u8]) {
854 self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
855 self.transport.close(code, reason);
856 }
857}
858
859fn varint_len(first_byte: u8) -> usize {
861 1 << (first_byte >> 6)
862}
863
864#[derive(Debug)]
866struct SkipVerification;
867
868impl rustls::client::danger::ServerCertVerifier for SkipVerification {
869 fn verify_server_cert(
870 &self,
871 _end_entity: &rustls::pki_types::CertificateDer<'_>,
872 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
873 _server_name: &rustls::pki_types::ServerName<'_>,
874 _ocsp_response: &[u8],
875 _now: rustls::pki_types::UnixTime,
876 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
877 Ok(rustls::client::danger::ServerCertVerified::assertion())
878 }
879
880 fn verify_tls12_signature(
881 &self,
882 _message: &[u8],
883 _cert: &rustls::pki_types::CertificateDer<'_>,
884 _dcs: &rustls::DigitallySignedStruct,
885 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
886 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
887 }
888
889 fn verify_tls13_signature(
890 &self,
891 _message: &[u8],
892 _cert: &rustls::pki_types::CertificateDer<'_>,
893 _dcs: &rustls::DigitallySignedStruct,
894 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
895 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
896 }
897
898 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
899 vec![
900 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
901 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
902 rustls::SignatureScheme::ED25519,
903 rustls::SignatureScheme::RSA_PSS_SHA256,
904 rustls::SignatureScheme::RSA_PSS_SHA384,
905 rustls::SignatureScheme::RSA_PSS_SHA512,
906 ]
907 }
908}
909
910#[cfg(test)]
911mod tests {
912 use super::*;
913
914 #[test]
915 fn varint_len_single_byte() {
916 assert_eq!(varint_len(0x00), 1);
917 assert_eq!(varint_len(0x3F), 1);
918 }
919
920 #[test]
921 fn varint_len_two_bytes() {
922 assert_eq!(varint_len(0x40), 2);
923 assert_eq!(varint_len(0x7F), 2);
924 }
925
926 #[test]
927 fn varint_len_four_bytes() {
928 assert_eq!(varint_len(0x80), 4);
929 assert_eq!(varint_len(0xBF), 4);
930 }
931
932 #[test]
933 fn varint_len_eight_bytes() {
934 assert_eq!(varint_len(0xC0), 8);
935 assert_eq!(varint_len(0xFF), 8);
936 }
937
938 #[test]
939 fn client_config_alpn_quic_draft15() {
940 let config = ClientConfig {
941 draft: DraftVersion::Draft15,
942 transport: TransportType::Quic,
943 skip_cert_verification: false,
944 ca_certs: Vec::new(),
945 setup_parameters: Vec::new(),
946 };
947 assert_eq!(config.alpn(), vec![b"moqt-15".to_vec()]);
948 }
949
950 #[test]
951 fn client_config_alpn_webtransport() {
952 let config = ClientConfig {
953 draft: DraftVersion::Draft15,
954 transport: TransportType::WebTransport { url: "https://example.com".to_string() },
955 skip_cert_verification: false,
956 ca_certs: Vec::new(),
957 setup_parameters: Vec::new(),
958 };
959 assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
960 }
961
962 #[test]
963 fn moqt_alpn_value() {
964 assert_eq!(MOQT_ALPN, b"moq-00");
965 }
966
967 #[test]
968 fn transport_type_debug() {
969 let quic = TransportType::Quic;
970 assert!(format!("{quic:?}").contains("Quic"));
971
972 let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
973 assert!(format!("{wt:?}").contains("WebTransport"));
974 }
975}