A better Rust ATProto crate
1//! Custom serde helpers for bytes::Bytes using serde_bytes
2
3use alloc::string::String;
4use alloc::vec::Vec;
5use base64::{
6 Engine,
7 prelude::{BASE64_STANDARD, BASE64_STANDARD_NO_PAD, BASE64_URL_SAFE, BASE64_URL_SAFE_NO_PAD},
8};
9use bytes::Bytes;
10use core::fmt;
11use serde::{
12 Deserializer, Serializer,
13 de::{self, MapAccess, Visitor},
14};
15
16/// Serialize Bytes as a CBOR byte string
17pub fn serialize<S>(bytes: &Option<Bytes>, serializer: S) -> Result<S::Ok, S::Error>
18where
19 S: Serializer,
20{
21 if let Some(bytes) = bytes {
22 if serializer.is_human_readable() {
23 // JSON: {"$bytes": "base64 string"}
24 use serde::ser::SerializeMap;
25 let mut map = serializer.serialize_map(Some(1))?;
26 map.serialize_entry("$bytes", &BASE64_STANDARD.encode(bytes))?;
27 map.end()
28 } else {
29 // CBOR: raw bytes
30 serde_bytes::serialize(bytes.as_ref(), serializer)
31 }
32 } else {
33 serializer.serialize_none()
34 }
35}
36
37/// Deserialize Bytes from a CBOR byte string
38pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Bytes>, D::Error>
39where
40 D: Deserializer<'de>,
41{
42 if deserializer.is_human_readable() {
43 Ok(deserializer.deserialize_any(OptBytesVisitor)?)
44 } else {
45 let vec: Option<Vec<u8>> = serde_bytes::deserialize(deserializer)?;
46 Ok(vec.map(Bytes::from))
47 }
48}
49
50struct OptBytesVisitor;
51
52impl<'de> Visitor<'de> for OptBytesVisitor {
53 type Value = Option<Bytes>;
54
55 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
56 formatter.write_str("a base64-encoded string")
57 }
58
59 fn visit_none<E>(self) -> Result<Self::Value, E>
60 where
61 E: de::Error,
62 {
63 Ok(None)
64 }
65
66 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
67 where
68 D: Deserializer<'de>,
69 {
70 let vec: Vec<u8> = serde_bytes::deserialize(deserializer)?;
71 Ok(Some(Bytes::from(vec)))
72 }
73
74 fn visit_unit<E>(self) -> Result<Self::Value, E> {
75 Ok(None)
76 }
77
78 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
79 where
80 A: MapAccess<'de>,
81 {
82 let mut bytes = None;
83
84 while let Some(key) = map.next_key()? {
85 match key {
86 "$bytes" => {
87 if bytes.is_some() {
88 return Err(de::Error::duplicate_field("$bytes"));
89 }
90 let bytes_str: String = map.next_value()?;
91 // First one should just work. rest are insurance.
92 bytes = if let Ok(bytes) = BASE64_STANDARD.decode(&bytes_str) {
93 Some(Bytes::from_owner(bytes))
94 } else if let Ok(bytes) = BASE64_STANDARD_NO_PAD.decode(&bytes_str) {
95 Some(Bytes::from_owner(bytes))
96 } else if let Ok(bytes) = BASE64_URL_SAFE.decode(&bytes_str) {
97 Some(Bytes::from_owner(bytes))
98 } else if let Ok(bytes) = BASE64_URL_SAFE_NO_PAD.decode(&bytes_str) {
99 Some(Bytes::from_owner(bytes))
100 } else {
101 return Err(de::Error::custom("invalid base64 string"));
102 }
103 }
104 _ => {
105 return Err(de::Error::unknown_field(key, &["$bytes"]));
106 }
107 }
108 }
109
110 Ok(bytes)
111 }
112}