defmodule MST.Node do @moduledoc """ Wire-format representation of a single MST node, plus encode/decode. An MST node holds an optional left subtree CID (`left`) and an ordered list of `MST.Node.Entry` values, each carrying a key suffix, a value CID, and an optional right subtree CID. This maps exactly to the AT Protocol node schema: { l: CID | null, e: [ { p, k, v, t } ] } Keys inside a node are prefix-compressed: each entry's `key_suffix` is the portion of the full key that follows the bytes it shares with the previous entry's full key. The first entry always has `prefix_len: 0` and carries its full key in `key_suffix`. Prefix compression is mandatory — the serialised form must be deterministic across implementations. Spec: https://atproto.com/specs/repository#mst-structure """ use TypedStruct alias DASL.{CID, DRISL} alias MST.Node.Entry @type encode_error() :: {:error, :encode, atom()} @type decode_error() :: {:error, :decode, atom()} typedstruct enforce: true do field :left, CID.t() | nil field :entries, [Entry.t()], default: [] end # --------------------------------------------------------------------------- # Construction helpers # --------------------------------------------------------------------------- @doc """ Returns an empty MST node — the only valid representation of an empty tree. ## Examples iex> MST.Node.empty() %MST.Node{left: nil, entries: []} """ @spec empty() :: t() def empty, do: %__MODULE__{left: nil, entries: []} # --------------------------------------------------------------------------- # Key expansion # --------------------------------------------------------------------------- @doc """ Reconstructs the full keys for all entries in the node. Each entry stores only the suffix of its key relative to the previous entry. This function walks the entry list and accumulates the full key for each. ## Examples iex> cid = DASL.CID.compute("a") iex> entries = [ ...> %MST.Node.Entry{prefix_len: 0, key_suffix: "foo/bar", value: cid, right: nil}, ...> %MST.Node.Entry{prefix_len: 4, key_suffix: "baz", value: cid, right: nil}, ...> ] iex> MST.Node.keys(%MST.Node{left: nil, entries: entries}) ["foo/bar", "foo/baz"] """ @spec keys(t()) :: [binary()] def keys(%__MODULE__{entries: entries}), do: expand_keys(entries, "", []) # --------------------------------------------------------------------------- # CID computation # --------------------------------------------------------------------------- @doc """ Computes the `:drisl`-codec CID for this node. Encodes the node to DRISL CBOR bytes and hashes them. Returns an error tuple if encoding fails. ## Examples iex> {:ok, cid} = MST.Node.cid(MST.Node.empty()) iex> cid.codec :drisl """ @spec cid(t()) :: {:ok, CID.t()} | encode_error() def cid(node) do with {:ok, bytes} <- encode(node) do {:ok, CID.compute(bytes, :drisl)} end end # --------------------------------------------------------------------------- # Encoding # --------------------------------------------------------------------------- @doc """ Encodes an `MST.Node` to DRISL CBOR bytes. `nil` subtree links are serialised as explicit CBOR `null` — this is mandatory for cross-implementation CID compatibility: skipping a key vs. writing `null` produces different bytes and therefore a different CID. ## Examples iex> {:ok, bytes} = MST.Node.encode(MST.Node.empty()) iex> is_binary(bytes) true """ @spec encode(t()) :: {:ok, binary()} | encode_error() def encode(%__MODULE__{left: left, entries: entries}) do with {:ok, entry_maps} <- encode_entries(entries), {:ok, bytes} <- DRISL.encode(%{"e" => entry_maps, "l" => left}) do {:ok, bytes} else {:error, reason} when is_atom(reason) -> {:error, :encode, reason} {:error, :encode, _} = err -> err end end # --------------------------------------------------------------------------- # Decoding # --------------------------------------------------------------------------- @doc """ Decodes DRISL CBOR bytes into an `MST.Node`. ## Examples iex> {:ok, bytes} = MST.Node.encode(MST.Node.empty()) iex> {:ok, node} = MST.Node.decode(bytes) iex> node.entries [] iex> node.left nil """ @spec decode(binary()) :: {:ok, t()} | decode_error() def decode(bytes) when is_binary(bytes) do with {:ok, term, <<>>} <- DRISL.decode(bytes), {:ok, node} <- decode_term(term) do {:ok, node} else {:ok, _, _leftover} -> {:error, :decode, :trailing_bytes} {:error, reason} when is_atom(reason) -> {:error, :decode, reason} {:error, :decode, _} = err -> err end end # --------------------------------------------------------------------------- # Compression helpers (used by MST.Tree) # --------------------------------------------------------------------------- @doc """ Compresses a list of `{full_key, value_cid, right_cid | nil}` tuples into a list of `MST.Node.Entry` structs using the key prefix-compression scheme. The first entry always has `prefix_len: 0`. Each subsequent entry computes how many leading bytes it shares with the previous full key. ## Examples iex> cid = DASL.CID.compute("x") iex> entries = MST.Node.compress_entries([{"abc/def", cid, nil}, {"abc/ghi", cid, nil}]) iex> hd(tl(entries)).prefix_len 4 """ @spec compress_entries([{binary(), CID.t(), CID.t() | nil}]) :: [Entry.t()] def compress_entries(triples), do: do_compress(triples, "", []) # --------------------------------------------------------------------------- # Private helpers # --------------------------------------------------------------------------- @spec expand_keys([Entry.t()], binary(), [binary()]) :: [binary()] defp expand_keys([], _prev, acc), do: Enum.reverse(acc) defp expand_keys([entry | rest], prev, acc) do full_key = binary_part(prev, 0, entry.prefix_len) <> entry.key_suffix expand_keys(rest, full_key, [full_key | acc]) end @spec do_compress([{binary(), CID.t(), CID.t() | nil}], binary(), [Entry.t()]) :: [Entry.t()] defp do_compress([], _prev, acc), do: Enum.reverse(acc) defp do_compress([{key, value, right} | rest], prev, acc) do plen = common_prefix_length(prev, key) suffix = binary_part(key, plen, byte_size(key) - plen) entry = %Entry{ prefix_len: plen, key_suffix: suffix, value: value, right: right } do_compress(rest, key, [entry | acc]) end @spec common_prefix_length(binary(), binary()) :: non_neg_integer() defp common_prefix_length(a, b), do: cpl(a, b, 0) defp cpl(<>, <>, n), do: cpl(ra, rb, n + 1) defp cpl(_, _, n), do: n @spec encode_entries([Entry.t()]) :: {:ok, [map()]} | encode_error() defp encode_entries(entries) do result = Enum.reduce_while(entries, {:ok, []}, fn entry, {:ok, acc} -> {:ok, map} = encode_entry(entry) {:cont, {:ok, [map | acc]}} end) case result do {:ok, reversed} -> {:ok, Enum.reverse(reversed)} err -> err end end @spec encode_entry(Entry.t()) :: {:ok, map()} | encode_error() defp encode_entry(%Entry{prefix_len: p, key_suffix: k, value: v, right: t}) do {:ok, %{ "k" => %CBOR.Tag{tag: :bytes, value: k}, "p" => p, "t" => t, "v" => v }} end @spec decode_term(any()) :: {:ok, t()} | decode_error() defp decode_term(%{"e" => entries_raw, "l" => left_raw}) when is_list(entries_raw) do with {:ok, left} <- decode_cid_or_null(left_raw), {:ok, entries} <- decode_entries(entries_raw) do {:ok, %__MODULE__{left: left, entries: entries}} end end defp decode_term(_), do: {:error, :decode, :invalid_structure} @spec decode_entries(list()) :: {:ok, [Entry.t()]} | decode_error() defp decode_entries(entries_raw) do result = Enum.reduce_while(entries_raw, {:ok, []}, fn raw, {:ok, acc} -> case decode_entry(raw) do {:ok, entry} -> {:cont, {:ok, [entry | acc]}} {:error, :decode, _} = err -> {:halt, err} end end) case result do {:ok, reversed} -> {:ok, Enum.reverse(reversed)} err -> err end end @spec decode_entry(any()) :: {:ok, Entry.t()} | decode_error() defp decode_entry(%{ "k" => %CBOR.Tag{tag: :bytes, value: k}, "p" => p, "t" => t_raw, "v" => %CID{} = v }) when is_integer(p) and p >= 0 and is_binary(k) do with {:ok, right} <- decode_cid_or_null(t_raw) do {:ok, %Entry{prefix_len: p, key_suffix: k, value: v, right: right}} end end defp decode_entry(_), do: {:error, :decode, :invalid_entry} @spec decode_cid_or_null(any()) :: {:ok, CID.t() | nil} | decode_error() defp decode_cid_or_null(nil), do: {:ok, nil} defp decode_cid_or_null(%CID{} = cid), do: {:ok, cid} defp decode_cid_or_null(_), do: {:error, :decode, :invalid_cid_link} end