1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft14::endpoint::{Endpoint, EndpointError};
6use crate::draft14::event::{ClientEvent, Direction, StreamKind};
7use crate::draft14::observer::ConnectionObserver;
8use crate::draft14::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12 AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft14::data_stream::{FetchObject, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft14::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::types::*;
18use moqtap_codec::varint::VarInt;
19use moqtap_codec::version::DraftVersion;
20
21pub const MOQT_ALPN: &[u8] = b"moq-00";
23
24#[derive(Debug, thiserror::Error)]
26pub enum ConnectionError {
27 #[error("endpoint error: {0}")]
29 Endpoint(#[from] EndpointError),
30 #[error("codec error: {0}")]
32 Codec(#[from] CodecError),
33 #[error("transport error: {0}")]
35 Transport(#[from] TransportError),
36 #[error("varint error: {0}")]
38 VarInt(#[from] moqtap_codec::varint::VarIntError),
39 #[error("control stream not open")]
41 NoControlStream,
42 #[error("unexpected end of stream")]
44 UnexpectedEnd,
45 #[error("stream finished")]
47 StreamFinished,
48 #[error("invalid server address: {0}")]
50 InvalidAddress(String),
51 #[error("TLS config error: {0}")]
53 TlsConfig(String),
54 #[error("data stream state error: {0}")]
56 DataStreamState(&'static str),
57}
58
59#[derive(Debug, Clone)]
61pub enum TransportType {
62 Quic,
64 WebTransport {
66 url: String,
68 },
69}
70
71pub struct ClientConfig {
75 pub draft: DraftVersion,
77 pub additional_versions: Vec<DraftVersion>,
80 pub transport: TransportType,
82 pub skip_cert_verification: bool,
84 pub ca_certs: Vec<Vec<u8>>,
86 pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
88}
89
90impl ClientConfig {
91 pub fn supported_versions(&self) -> Vec<VarInt> {
94 let mut versions = vec![self.draft.version_varint()];
95 for v in &self.additional_versions {
96 let varint = v.version_varint();
97 if !versions.contains(&varint) {
98 versions.push(varint);
99 }
100 }
101 versions
102 }
103
104 pub fn alpn(&self) -> Vec<Vec<u8>> {
106 match &self.transport {
107 TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
108 TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
109 }
110 }
111}
112
113pub struct FramedSendStream {
115 inner: SendStream,
116 draft: DraftVersion,
117 subgroup_io: Option<SubgroupObjectReader>,
122}
123
124impl FramedSendStream {
125 pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
127 Self { inner, draft, subgroup_io: None }
128 }
129
130 pub fn stream_id(&self) -> u64 {
132 self.inner.stream_id()
133 }
134
135 pub async fn write_control(
138 &mut self,
139 msg: &AnyControlMessage,
140 ) -> Result<Vec<u8>, ConnectionError> {
141 let mut buf = Vec::new();
142 msg.encode(&mut buf)?;
143 self.inner.write_all(&buf).await?;
144 Ok(buf)
145 }
146
147 pub async fn write_subgroup_header(
150 &mut self,
151 header: &AnySubgroupHeader,
152 ) -> Result<(), ConnectionError> {
153 let mut buf = Vec::new();
154 header.encode(&mut buf);
155 self.inner.write_all(&buf).await?;
156 if let AnySubgroupHeader::Draft14(ref d14) = header {
157 self.subgroup_io = Some(SubgroupObjectReader::new(d14));
158 }
159 Ok(())
160 }
161
162 pub async fn write_fetch_header(
164 &mut self,
165 header: &AnyFetchHeader,
166 ) -> Result<(), ConnectionError> {
167 let mut buf = Vec::new();
168 header.encode(&mut buf);
169 self.inner.write_all(&buf).await?;
170 Ok(())
171 }
172
173 pub async fn write_subgroup_object(
178 &mut self,
179 object: &SubgroupObject,
180 ) -> Result<(), ConnectionError> {
181 let writer = self
182 .subgroup_io
183 .as_mut()
184 .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
185 let mut buf = Vec::new();
186 writer.write_object(object, &mut buf)?;
187 self.inner.write_all(&buf).await?;
188 Ok(())
189 }
190
191 pub async fn write_fetch_object(
195 &mut self,
196 object: &FetchObject,
197 ) -> Result<(), ConnectionError> {
198 let mut buf = Vec::new();
199 object.encode(&mut buf);
200 self.inner.write_all(&buf).await?;
201 Ok(())
202 }
203
204 pub async fn finish(&mut self) -> Result<(), ConnectionError> {
206 self.inner.finish()?;
207 Ok(())
208 }
209
210 pub fn draft(&self) -> DraftVersion {
212 self.draft
213 }
214}
215
216pub struct FramedRecvStream {
218 inner: RecvStream,
219 buf: BytesMut,
220 draft: DraftVersion,
221 subgroup_io: Option<SubgroupObjectReader>,
226}
227
228impl FramedRecvStream {
229 pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
231 Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
232 }
233
234 pub fn stream_id(&self) -> u64 {
236 self.inner.stream_id()
237 }
238
239 async fn fill(&mut self) -> Result<bool, ConnectionError> {
241 let mut tmp = [0u8; 4096];
242 match self.inner.read(&mut tmp).await {
243 Ok(Some(n)) => {
244 self.buf.extend_from_slice(&tmp[..n]);
245 Ok(true)
246 }
247 Ok(None) => Ok(false),
248 Err(e) => Err(ConnectionError::Transport(e)),
249 }
250 }
251
252 async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
254 while self.buf.len() < n {
255 if !self.fill().await? {
256 return Err(ConnectionError::UnexpectedEnd);
257 }
258 }
259 Ok(())
260 }
261
262 pub async fn read_control(
268 &mut self,
269 capture_raw: bool,
270 ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
271 self.ensure(1).await?;
273 let type_len = varint_len(self.buf[0]);
274 self.ensure(type_len).await?;
275
276 let mut cursor = &self.buf[..type_len];
277 let _type_id = VarInt::decode(&mut cursor)?;
278
279 let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
281 self.ensure(type_len + 2).await?;
282 let hi = self.buf[type_len] as usize;
283 let lo = self.buf[type_len + 1] as usize;
284 ((hi << 8) | lo, 2)
285 } else {
286 self.ensure(type_len + 1).await?;
287 let payload_len_start = type_len;
288 let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
289 self.ensure(type_len + payload_len_varint_len).await?;
290 let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
291 let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
292 (payload_len, payload_len_varint_len)
293 };
294
295 let total = type_len + len_field_size + payload_len;
297 self.ensure(total).await?;
298
299 let raw = capture_raw.then(|| self.buf[..total].to_vec());
301
302 let mut frame = &self.buf[..total];
304 let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
305 self.buf.advance(total);
306 Ok((msg, raw))
307 }
308
309 pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
313 self.ensure(1).await?;
314 loop {
315 let mut cursor = &self.buf[..];
316 match AnySubgroupHeader::decode(self.draft, &mut cursor) {
317 Ok(header) => {
318 let consumed = self.buf.len() - cursor.remaining();
319 self.buf.advance(consumed);
320 if let AnySubgroupHeader::Draft14(ref d14) = header {
321 self.subgroup_io = Some(SubgroupObjectReader::new(d14));
322 }
323 return Ok(header);
324 }
325 Err(CodecError::UnexpectedEnd) => {
326 if !self.fill().await? {
327 return Err(ConnectionError::UnexpectedEnd);
328 }
329 }
330 Err(e) => return Err(ConnectionError::Codec(e)),
331 }
332 }
333 }
334
335 pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
337 self.ensure(1).await?;
338 loop {
339 let mut cursor = &self.buf[..];
340 match AnyFetchHeader::decode(self.draft, &mut cursor) {
341 Ok(header) => {
342 let consumed = self.buf.len() - cursor.remaining();
343 self.buf.advance(consumed);
344 return Ok(header);
345 }
346 Err(CodecError::UnexpectedEnd) => {
347 if !self.fill().await? {
348 return Err(ConnectionError::UnexpectedEnd);
349 }
350 }
351 Err(e) => return Err(ConnectionError::Codec(e)),
352 }
353 }
354 }
355
356 pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
362 if self.subgroup_io.is_none() {
363 return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
364 }
365 loop {
366 let reader = self.subgroup_io.as_mut().unwrap();
367 let mut probe = reader.clone();
368 let mut cursor = &self.buf[..];
369 match probe.read_object(&mut cursor) {
370 Ok(obj) => {
371 let consumed = self.buf.len() - cursor.remaining();
372 self.buf.advance(consumed);
373 *reader = probe;
375 return Ok(obj);
376 }
377 Err(CodecError::UnexpectedEnd) => {
378 if !self.fill().await? {
379 return Err(ConnectionError::UnexpectedEnd);
380 }
381 }
382 Err(e) => return Err(ConnectionError::Codec(e)),
383 }
384 }
385 }
386
387 pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
391 loop {
392 let mut cursor = &self.buf[..];
393 match FetchObject::decode(&mut cursor) {
394 Ok(obj) => {
395 let consumed = self.buf.len() - cursor.remaining();
396 self.buf.advance(consumed);
397 return Ok(obj);
398 }
399 Err(CodecError::UnexpectedEnd) => {
400 if !self.fill().await? {
401 return Err(ConnectionError::UnexpectedEnd);
402 }
403 }
404 Err(e) => return Err(ConnectionError::Codec(e)),
405 }
406 }
407 }
408
409 pub fn draft(&self) -> DraftVersion {
411 self.draft
412 }
413}
414
415pub struct Connection {
418 transport: Transport,
419 endpoint: Endpoint,
420 draft: DraftVersion,
421 control_send: Option<FramedSendStream>,
422 control_recv: Option<FramedRecvStream>,
423 observer: Option<Box<dyn ConnectionObserver>>,
424 pending_events: Vec<ClientEvent>,
428}
429
430impl Connection {
431 pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
437 let draft = config.draft;
438 let transport = match &config.transport {
439 TransportType::Quic => Self::connect_quic(addr, &config).await?,
440 TransportType::WebTransport { url } => {
441 let url = url.clone();
442 Self::connect_webtransport(&url, &config).await?
443 }
444 };
445
446 let (send, recv) = transport.open_bi().await?;
448 let mut control_send = FramedSendStream::new(send, draft);
449 let mut control_recv = FramedRecvStream::new(recv, draft);
450
451 let mut endpoint = Endpoint::new(Role::Client);
453 endpoint.connect()?;
454 let setup_msg = endpoint
455 .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
456 let any_setup = AnyControlMessage::Draft14(setup_msg);
457 let raw_setup = control_send.write_control(&any_setup).await?;
458
459 let (server_setup, raw_server_setup) = control_recv.read_control(true).await?;
460 match &server_setup {
462 AnyControlMessage::Draft14(ControlMessage::ServerSetup(ref ss)) => {
463 endpoint.receive_server_setup(ss)?;
464 }
465 _ => {
466 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
467 }
468 }
469
470 let mut pending_events = Vec::with_capacity(3);
471 pending_events.push(ClientEvent::ControlMessage {
472 direction: Direction::Send,
473 message: any_setup,
474 raw: Some(raw_setup),
475 });
476 pending_events.push(ClientEvent::ControlMessage {
477 direction: Direction::Receive,
478 message: server_setup,
479 raw: raw_server_setup,
480 });
481 if let Some(v) = endpoint.negotiated_version() {
482 pending_events.push(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
483 }
484
485 Ok(Self {
486 transport,
487 endpoint,
488 draft,
489 control_send: Some(control_send),
490 control_recv: Some(control_recv),
491 observer: None,
492 pending_events,
493 })
494 }
495
496 async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
498 let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
499 ConnectionError::InvalidAddress(e.to_string())
500 })?;
501
502 let mut tls_config = if config.skip_cert_verification {
504 rustls::ClientConfig::builder()
505 .dangerous()
506 .with_custom_certificate_verifier(Arc::new(SkipVerification))
507 .with_no_client_auth()
508 } else {
509 let mut roots = rustls::RootCertStore::empty();
510 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
511 for der in &config.ca_certs {
512 roots
513 .add(rustls::pki_types::CertificateDer::from(der.clone()))
514 .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
515 }
516 rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
517 };
518
519 tls_config.alpn_protocols = config.alpn();
520
521 let quic_config: quinn::crypto::rustls::QuicClientConfig =
522 tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
523 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
524
525 let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
526 .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
527 quinn_endpoint.set_default_client_config(client_config);
528
529 let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
530
531 let quic = quinn_endpoint
532 .connect(server_addr, &server_name)
533 .map_err(TransportError::from)?
534 .await
535 .map_err(TransportError::from)?;
536
537 Ok(Transport::Quic(QuicTransport::new(quic)))
538 }
539
540 #[cfg(feature = "webtransport")]
542 async fn connect_webtransport(
543 url: &str,
544 config: &ClientConfig,
545 ) -> Result<Transport, ConnectionError> {
546 use crate::transport::webtransport::WebTransportTransport;
547
548 let wt_config = if config.skip_cert_verification {
549 wtransport::ClientConfig::builder()
550 .with_bind_default()
551 .with_no_cert_validation()
552 .build()
553 } else {
554 wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
555 };
556
557 let endpoint = wtransport::Endpoint::client(wt_config)
558 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
559
560 let connection = endpoint
561 .connect(url)
562 .await
563 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
564
565 Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
566 }
567
568 #[cfg(not(feature = "webtransport"))]
570 async fn connect_webtransport(
571 _url: &str,
572 _config: &ClientConfig,
573 ) -> Result<Transport, ConnectionError> {
574 Err(ConnectionError::Transport(TransportError::Connect(
575 "webtransport feature not enabled".into(),
576 )))
577 }
578
579 pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
584 self.observer = Some(observer);
585 for event in self.pending_events.drain(..) {
586 if let Some(ref obs) = self.observer {
587 obs.on_event_owned(event);
588 }
589 }
590 }
591
592 pub fn clear_observer(&mut self) {
594 self.observer = None;
595 }
596
597 fn emit(&self, event: ClientEvent) {
599 if let Some(ref obs) = self.observer {
600 obs.on_event_owned(event);
601 }
602 }
603
604 pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
610 let any = AnyControlMessage::Draft14(msg.clone());
611 let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
612 let raw = send.write_control(&any).await?;
613 self.emit(ClientEvent::ControlMessage {
614 direction: Direction::Send,
615 message: any,
616 raw: Some(raw),
617 });
618 Ok(())
619 }
620
621 pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
626 let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
627 let capture_raw = self.observer.is_some();
628 let (any, raw) = recv.read_control(capture_raw).await?;
629 if capture_raw {
630 self.emit(ClientEvent::ControlMessage {
631 direction: Direction::Receive,
632 message: any.clone(),
633 raw,
634 });
635 }
636 match any {
638 AnyControlMessage::Draft14(msg) => Ok(msg),
639 _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
640 }
641 }
642
643 pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
646 let msg = self.recv_control().await?;
647 self.endpoint.receive_message(msg.clone())?;
648
649 if let ControlMessage::GoAway(ref ga) = msg {
651 self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
652 }
653
654 Ok(msg)
655 }
656
657 pub async fn subscribe(
661 &mut self,
662 track_namespace: TrackNamespace,
663 track_name: Vec<u8>,
664 subscriber_priority: u8,
665 group_order: GroupOrder,
666 filter_type: FilterType,
667 ) -> Result<VarInt, ConnectionError> {
668 let (req_id, msg) = self.endpoint.subscribe(
669 track_namespace,
670 track_name,
671 subscriber_priority,
672 group_order,
673 filter_type,
674 )?;
675 self.send_control(&msg).await?;
676 Ok(req_id)
677 }
678
679 pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
681 let msg = self.endpoint.unsubscribe(request_id)?;
682 self.send_control(&msg).await
683 }
684
685 pub async fn subscribe_update(
688 &mut self,
689 subscription_request_id: VarInt,
690 start_location: Location,
691 end_group: VarInt,
692 subscriber_priority: u8,
693 forward: Forward,
694 parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
695 ) -> Result<VarInt, ConnectionError> {
696 let (req_id, msg) = self.endpoint.subscribe_update(
697 subscription_request_id,
698 start_location,
699 end_group,
700 subscriber_priority,
701 forward,
702 parameters,
703 )?;
704 self.send_control(&msg).await?;
705 Ok(req_id)
706 }
707
708 pub async fn fetch(
712 &mut self,
713 track_namespace: TrackNamespace,
714 track_name: Vec<u8>,
715 start_group: VarInt,
716 start_object: VarInt,
717 ) -> Result<VarInt, ConnectionError> {
718 let (req_id, msg) =
719 self.endpoint.fetch(track_namespace, track_name, start_group, start_object)?;
720 self.send_control(&msg).await?;
721 Ok(req_id)
722 }
723
724 pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
726 let msg = self.endpoint.fetch_cancel(request_id)?;
727 self.send_control(&msg).await
728 }
729
730 pub async fn subscribe_namespace(
734 &mut self,
735 track_namespace: TrackNamespace,
736 ) -> Result<VarInt, ConnectionError> {
737 let (req_id, msg) = self.endpoint.subscribe_namespace(track_namespace)?;
738 self.send_control(&msg).await?;
739 Ok(req_id)
740 }
741
742 pub async fn publish_namespace(
744 &mut self,
745 track_namespace: TrackNamespace,
746 ) -> Result<VarInt, ConnectionError> {
747 let (req_id, msg) = self.endpoint.publish_namespace(track_namespace)?;
748 self.send_control(&msg).await?;
749 Ok(req_id)
750 }
751
752 pub async fn track_status(
756 &mut self,
757 track_namespace: TrackNamespace,
758 track_name: Vec<u8>,
759 ) -> Result<VarInt, ConnectionError> {
760 let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name)?;
761 self.send_control(&msg).await?;
762 Ok(req_id)
763 }
764
765 pub async fn publish(
769 &mut self,
770 track_namespace: TrackNamespace,
771 track_name: Vec<u8>,
772 forward: Forward,
773 ) -> Result<VarInt, ConnectionError> {
774 let (req_id, msg) = self.endpoint.publish(track_namespace, track_name, forward)?;
775 self.send_control(&msg).await?;
776 Ok(req_id)
777 }
778
779 pub async fn publish_done(
781 &mut self,
782 request_id: VarInt,
783 status_code: VarInt,
784 reason_phrase: Vec<u8>,
785 ) -> Result<(), ConnectionError> {
786 let msg = self.endpoint.send_publish_done(request_id, status_code, reason_phrase)?;
787 self.send_control(&msg).await
788 }
789
790 pub async fn open_subgroup_stream(
794 &self,
795 header: &AnySubgroupHeader,
796 ) -> Result<FramedSendStream, ConnectionError> {
797 let send = self.transport.open_uni().await?;
798 let mut framed = FramedSendStream::new(send, self.draft);
799 let sid = framed.stream_id();
800 framed.write_subgroup_header(header).await?;
801 self.emit(ClientEvent::StreamOpened {
802 direction: Direction::Send,
803 stream_kind: StreamKind::Subgroup,
804 stream_id: sid,
805 });
806 self.emit(ClientEvent::DataStreamHeader {
807 stream_id: sid,
808 direction: Direction::Send,
809 header: header.clone(),
810 });
811 Ok(framed)
812 }
813
814 pub async fn accept_subgroup_stream(
817 &self,
818 ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
819 let recv = self.transport.accept_uni().await?;
820 let mut framed = FramedRecvStream::new(recv, self.draft);
821 let sid = framed.stream_id();
822 let header = framed.read_subgroup_header().await?;
823 self.emit(ClientEvent::StreamOpened {
824 direction: Direction::Receive,
825 stream_kind: StreamKind::Subgroup,
826 stream_id: sid,
827 });
828 self.emit(ClientEvent::DataStreamHeader {
829 stream_id: sid,
830 direction: Direction::Receive,
831 header: header.clone(),
832 });
833 Ok((header, framed))
834 }
835
836 pub fn send_datagram(
838 &self,
839 header: &AnyDatagramHeader,
840 payload: &[u8],
841 ) -> Result<(), ConnectionError> {
842 let mut buf = Vec::new();
843 header.encode(&mut buf);
844 buf.extend_from_slice(payload);
845 self.emit(ClientEvent::DatagramReceived {
846 direction: Direction::Send,
847 header: header.clone(),
848 payload_len: payload.len(),
849 });
850 self.transport.send_datagram(bytes::Bytes::from(buf))?;
851 Ok(())
852 }
853
854 pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
856 let data = self.transport.recv_datagram().await?;
857 let mut cursor = &data[..];
858 let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
859 let consumed = data.len() - cursor.len();
860 let payload = data.slice(consumed..);
861 self.emit(ClientEvent::DatagramReceived {
862 direction: Direction::Receive,
863 header: header.clone(),
864 payload_len: payload.len(),
865 });
866 Ok((header, payload))
867 }
868
869 pub fn endpoint(&self) -> &Endpoint {
873 &self.endpoint
874 }
875
876 pub fn endpoint_mut(&mut self) -> &mut Endpoint {
878 &mut self.endpoint
879 }
880
881 pub fn negotiated_version(&self) -> Option<VarInt> {
883 self.endpoint.negotiated_version()
884 }
885
886 pub fn draft(&self) -> DraftVersion {
888 self.draft
889 }
890
891 pub fn close(&self, code: u32, reason: &[u8]) {
893 self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
894 self.transport.close(code, reason);
895 }
896}
897
898fn varint_len(first_byte: u8) -> usize {
900 1 << (first_byte >> 6)
901}
902
903#[derive(Debug)]
905struct SkipVerification;
906
907impl rustls::client::danger::ServerCertVerifier for SkipVerification {
908 fn verify_server_cert(
909 &self,
910 _end_entity: &rustls::pki_types::CertificateDer<'_>,
911 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
912 _server_name: &rustls::pki_types::ServerName<'_>,
913 _ocsp_response: &[u8],
914 _now: rustls::pki_types::UnixTime,
915 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
916 Ok(rustls::client::danger::ServerCertVerified::assertion())
917 }
918
919 fn verify_tls12_signature(
920 &self,
921 _message: &[u8],
922 _cert: &rustls::pki_types::CertificateDer<'_>,
923 _dcs: &rustls::DigitallySignedStruct,
924 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
925 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
926 }
927
928 fn verify_tls13_signature(
929 &self,
930 _message: &[u8],
931 _cert: &rustls::pki_types::CertificateDer<'_>,
932 _dcs: &rustls::DigitallySignedStruct,
933 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
934 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
935 }
936
937 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
938 vec![
939 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
940 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
941 rustls::SignatureScheme::ED25519,
942 rustls::SignatureScheme::RSA_PSS_SHA256,
943 rustls::SignatureScheme::RSA_PSS_SHA384,
944 rustls::SignatureScheme::RSA_PSS_SHA512,
945 ]
946 }
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952
953 #[test]
954 fn varint_len_single_byte() {
955 assert_eq!(varint_len(0x00), 1);
957 assert_eq!(varint_len(0x3F), 1);
958 }
959
960 #[test]
961 fn varint_len_two_bytes() {
962 assert_eq!(varint_len(0x40), 2);
964 assert_eq!(varint_len(0x7F), 2);
965 }
966
967 #[test]
968 fn varint_len_four_bytes() {
969 assert_eq!(varint_len(0x80), 4);
971 assert_eq!(varint_len(0xBF), 4);
972 }
973
974 #[test]
975 fn varint_len_eight_bytes() {
976 assert_eq!(varint_len(0xC0), 8);
978 assert_eq!(varint_len(0xFF), 8);
979 }
980
981 #[test]
982 fn client_config_supported_versions_draft14() {
983 let config = ClientConfig {
984 draft: DraftVersion::Draft14,
985 additional_versions: Vec::new(),
986 transport: TransportType::Quic,
987 skip_cert_verification: false,
988 ca_certs: Vec::new(),
989 setup_parameters: Vec::new(),
990 };
991 let versions = config.supported_versions();
992 assert_eq!(versions.len(), 1);
993 assert_eq!(versions[0].into_inner(), 0xff000000 + 14);
994 }
995
996 #[test]
997 fn client_config_supported_versions_draft07() {
998 let config = ClientConfig {
999 draft: DraftVersion::Draft07,
1000 additional_versions: Vec::new(),
1001 transport: TransportType::Quic,
1002 skip_cert_verification: false,
1003 ca_certs: Vec::new(),
1004 setup_parameters: Vec::new(),
1005 };
1006 let versions = config.supported_versions();
1007 assert_eq!(versions.len(), 1);
1008 assert_eq!(versions[0].into_inner(), 0xff000000 + 7);
1009 }
1010
1011 #[test]
1012 fn client_config_alpn_quic() {
1013 let config = ClientConfig {
1014 draft: DraftVersion::Draft14,
1015 additional_versions: Vec::new(),
1016 transport: TransportType::Quic,
1017 skip_cert_verification: false,
1018 ca_certs: Vec::new(),
1019 setup_parameters: Vec::new(),
1020 };
1021 assert_eq!(config.alpn(), vec![b"moq-00".to_vec()]);
1022 }
1023
1024 #[test]
1025 fn client_config_alpn_webtransport() {
1026 let config = ClientConfig {
1027 draft: DraftVersion::Draft14,
1028 additional_versions: Vec::new(),
1029 transport: TransportType::WebTransport { url: "https://example.com".to_string() },
1030 skip_cert_verification: false,
1031 ca_certs: Vec::new(),
1032 setup_parameters: Vec::new(),
1033 };
1034 assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
1035 }
1036
1037 #[test]
1038 fn moqt_alpn_value() {
1039 assert_eq!(MOQT_ALPN, b"moq-00");
1040 }
1041
1042 #[test]
1043 fn transport_type_debug() {
1044 let quic = TransportType::Quic;
1045 assert!(format!("{quic:?}").contains("Quic"));
1046
1047 let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
1048 assert!(format!("{wt:?}").contains("WebTransport"));
1049 }
1050}