An Elixir implementation of AT Protocol-flavoured Merkle Search Trees (MST)
1defmodule MST.Node do
2 @moduledoc """
3 Wire-format representation of a single MST node, plus encode/decode.
4
5 An MST node holds an optional left subtree CID (`left`) and an ordered list
6 of `MST.Node.Entry` values, each carrying a key suffix, a value CID, and an
7 optional right subtree CID. This maps exactly to the AT Protocol node schema:
8
9 { l: CID | null, e: [ { p, k, v, t } ] }
10
11 Keys inside a node are prefix-compressed: each entry's `key_suffix` is the
12 portion of the full key that follows the bytes it shares with the previous
13 entry's full key. The first entry always has `prefix_len: 0` and carries its
14 full key in `key_suffix`. Prefix compression is mandatory — the serialised
15 form must be deterministic across implementations.
16
17 Spec: https://atproto.com/specs/repository#mst-structure
18 """
19
20 use TypedStruct
21
22 alias DASL.{CID, DRISL}
23 alias MST.Node.Entry
24
25 @type encode_error() :: {:error, :encode, atom()}
26 @type decode_error() :: {:error, :decode, atom()}
27
28 typedstruct enforce: true do
29 field :left, CID.t() | nil
30 field :entries, [Entry.t()], default: []
31 end
32
33 # ---------------------------------------------------------------------------
34 # Construction helpers
35 # ---------------------------------------------------------------------------
36
37 @doc """
38 Returns an empty MST node — the only valid representation of an empty tree.
39
40 ## Examples
41
42 iex> MST.Node.empty()
43 %MST.Node{left: nil, entries: []}
44
45 """
46 @spec empty() :: t()
47 def empty, do: %__MODULE__{left: nil, entries: []}
48
49 # ---------------------------------------------------------------------------
50 # Key expansion
51 # ---------------------------------------------------------------------------
52
53 @doc """
54 Reconstructs the full keys for all entries in the node.
55
56 Each entry stores only the suffix of its key relative to the previous entry.
57 This function walks the entry list and accumulates the full key for each.
58
59 ## Examples
60
61 iex> cid = DASL.CID.compute("a")
62 iex> entries = [
63 ...> %MST.Node.Entry{prefix_len: 0, key_suffix: "foo/bar", value: cid, right: nil},
64 ...> %MST.Node.Entry{prefix_len: 4, key_suffix: "baz", value: cid, right: nil},
65 ...> ]
66 iex> MST.Node.keys(%MST.Node{left: nil, entries: entries})
67 ["foo/bar", "foo/baz"]
68
69 """
70 @spec keys(t()) :: [binary()]
71 def keys(%__MODULE__{entries: entries}), do: expand_keys(entries, "", [])
72
73 # ---------------------------------------------------------------------------
74 # CID computation
75 # ---------------------------------------------------------------------------
76
77 @doc """
78 Computes the `:drisl`-codec CID for this node.
79
80 Encodes the node to DRISL CBOR bytes and hashes them. Returns an error tuple
81 if encoding fails.
82
83 ## Examples
84
85 iex> {:ok, cid} = MST.Node.cid(MST.Node.empty())
86 iex> cid.codec
87 :drisl
88
89 """
90 @spec cid(t()) :: {:ok, CID.t()} | encode_error()
91 def cid(node) do
92 with {:ok, bytes} <- encode(node) do
93 {:ok, CID.compute(bytes, :drisl)}
94 end
95 end
96
97 # ---------------------------------------------------------------------------
98 # Encoding
99 # ---------------------------------------------------------------------------
100
101 @doc """
102 Encodes an `MST.Node` to DRISL CBOR bytes.
103
104 `nil` subtree links are serialised as explicit CBOR `null` — this is
105 mandatory for cross-implementation CID compatibility: skipping a key vs.
106 writing `null` produces different bytes and therefore a different CID.
107
108 ## Examples
109
110 iex> {:ok, bytes} = MST.Node.encode(MST.Node.empty())
111 iex> is_binary(bytes)
112 true
113
114 """
115 @spec encode(t()) :: {:ok, binary()} | encode_error()
116 def encode(%__MODULE__{left: left, entries: entries}) do
117 with {:ok, entry_maps} <- encode_entries(entries),
118 {:ok, bytes} <- DRISL.encode(%{"e" => entry_maps, "l" => left}) do
119 {:ok, bytes}
120 else
121 {:error, reason} when is_atom(reason) -> {:error, :encode, reason}
122 {:error, :encode, _} = err -> err
123 end
124 end
125
126 # ---------------------------------------------------------------------------
127 # Decoding
128 # ---------------------------------------------------------------------------
129
130 @doc """
131 Decodes DRISL CBOR bytes into an `MST.Node`.
132
133 ## Examples
134
135 iex> {:ok, bytes} = MST.Node.encode(MST.Node.empty())
136 iex> {:ok, node} = MST.Node.decode(bytes)
137 iex> node.entries
138 []
139 iex> node.left
140 nil
141
142 """
143 @spec decode(binary()) :: {:ok, t()} | decode_error()
144 def decode(bytes) when is_binary(bytes) do
145 with {:ok, term, <<>>} <- DRISL.decode(bytes),
146 {:ok, node} <- decode_term(term) do
147 {:ok, node}
148 else
149 {:ok, _, _leftover} -> {:error, :decode, :trailing_bytes}
150 {:error, reason} when is_atom(reason) -> {:error, :decode, reason}
151 {:error, :decode, _} = err -> err
152 end
153 end
154
155 # ---------------------------------------------------------------------------
156 # Compression helpers (used by MST.Tree)
157 # ---------------------------------------------------------------------------
158
159 @doc """
160 Compresses a list of `{full_key, value_cid, right_cid | nil}` tuples into a
161 list of `MST.Node.Entry` structs using the key prefix-compression scheme.
162
163 The first entry always has `prefix_len: 0`. Each subsequent entry computes
164 how many leading bytes it shares with the previous full key.
165
166 ## Examples
167
168 iex> cid = DASL.CID.compute("x")
169 iex> entries = MST.Node.compress_entries([{"abc/def", cid, nil}, {"abc/ghi", cid, nil}])
170 iex> hd(tl(entries)).prefix_len
171 4
172
173 """
174 @spec compress_entries([{binary(), CID.t(), CID.t() | nil}]) :: [Entry.t()]
175 def compress_entries(triples), do: do_compress(triples, "", [])
176
177 # ---------------------------------------------------------------------------
178 # Private helpers
179 # ---------------------------------------------------------------------------
180
181 @spec expand_keys([Entry.t()], binary(), [binary()]) :: [binary()]
182 defp expand_keys([], _prev, acc), do: Enum.reverse(acc)
183
184 defp expand_keys([entry | rest], prev, acc) do
185 full_key = binary_part(prev, 0, entry.prefix_len) <> entry.key_suffix
186 expand_keys(rest, full_key, [full_key | acc])
187 end
188
189 @spec do_compress([{binary(), CID.t(), CID.t() | nil}], binary(), [Entry.t()]) :: [Entry.t()]
190 defp do_compress([], _prev, acc), do: Enum.reverse(acc)
191
192 defp do_compress([{key, value, right} | rest], prev, acc) do
193 plen = common_prefix_length(prev, key)
194 suffix = binary_part(key, plen, byte_size(key) - plen)
195
196 entry = %Entry{
197 prefix_len: plen,
198 key_suffix: suffix,
199 value: value,
200 right: right
201 }
202
203 do_compress(rest, key, [entry | acc])
204 end
205
206 @spec common_prefix_length(binary(), binary()) :: non_neg_integer()
207 defp common_prefix_length(a, b), do: cpl(a, b, 0)
208
209 defp cpl(<<c, ra::binary>>, <<c, rb::binary>>, n), do: cpl(ra, rb, n + 1)
210 defp cpl(_, _, n), do: n
211
212 @spec encode_entries([Entry.t()]) :: {:ok, [map()]} | encode_error()
213 defp encode_entries(entries) do
214 result =
215 Enum.reduce_while(entries, {:ok, []}, fn entry, {:ok, acc} ->
216 {:ok, map} = encode_entry(entry)
217 {:cont, {:ok, [map | acc]}}
218 end)
219
220 case result do
221 {:ok, reversed} -> {:ok, Enum.reverse(reversed)}
222 err -> err
223 end
224 end
225
226 @spec encode_entry(Entry.t()) :: {:ok, map()} | encode_error()
227 defp encode_entry(%Entry{prefix_len: p, key_suffix: k, value: v, right: t}) do
228 {:ok,
229 %{
230 "k" => %CBOR.Tag{tag: :bytes, value: k},
231 "p" => p,
232 "t" => t,
233 "v" => v
234 }}
235 end
236
237 @spec decode_term(any()) :: {:ok, t()} | decode_error()
238 defp decode_term(%{"e" => entries_raw, "l" => left_raw}) when is_list(entries_raw) do
239 with {:ok, left} <- decode_cid_or_null(left_raw),
240 {:ok, entries} <- decode_entries(entries_raw) do
241 {:ok, %__MODULE__{left: left, entries: entries}}
242 end
243 end
244
245 defp decode_term(_), do: {:error, :decode, :invalid_structure}
246
247 @spec decode_entries(list()) :: {:ok, [Entry.t()]} | decode_error()
248 defp decode_entries(entries_raw) do
249 result =
250 Enum.reduce_while(entries_raw, {:ok, []}, fn raw, {:ok, acc} ->
251 case decode_entry(raw) do
252 {:ok, entry} -> {:cont, {:ok, [entry | acc]}}
253 {:error, :decode, _} = err -> {:halt, err}
254 end
255 end)
256
257 case result do
258 {:ok, reversed} -> {:ok, Enum.reverse(reversed)}
259 err -> err
260 end
261 end
262
263 @spec decode_entry(any()) :: {:ok, Entry.t()} | decode_error()
264 defp decode_entry(%{
265 "k" => %CBOR.Tag{tag: :bytes, value: k},
266 "p" => p,
267 "t" => t_raw,
268 "v" => %CID{} = v
269 })
270 when is_integer(p) and p >= 0 and is_binary(k) do
271 with {:ok, right} <- decode_cid_or_null(t_raw) do
272 {:ok, %Entry{prefix_len: p, key_suffix: k, value: v, right: right}}
273 end
274 end
275
276 defp decode_entry(_), do: {:error, :decode, :invalid_entry}
277
278 @spec decode_cid_or_null(any()) :: {:ok, CID.t() | nil} | decode_error()
279 defp decode_cid_or_null(nil), do: {:ok, nil}
280 defp decode_cid_or_null(%CID{} = cid), do: {:ok, cid}
281 defp decode_cid_or_null(_), do: {:error, :decode, :invalid_cid_link}
282end