···11+from abc import ABC, abstractmethod
22+from typing import Self, Optional, Dict, BinaryIO
33+import sqlite3
44+55+66+class BlockStore(ABC):
77+ """
88+ A block store is a k/v store where values are immutable once set. They can be deleted, though.
99+ In practice, k==hash(v), but this API doesn't care about that.
1010+1111+ I'm not using the "native" __getitem__, __setitem__, __del__ methods because
1212+ the semantics of these methods differ subtly.
1313+1414+ if you call put() twice with the same args, the second call is a nop.
1515+ if you call put() twice with the same key but different value, you get a ValueError
1616+1717+ get() offers no default return value, you get a KeyError if it doesn't exist.
1818+1919+ if you try to delete a key that doesn't exist, that's a nop.
2020+ """
2121+2222+ @abstractmethod
2323+ def put(self, key: bytes, value: bytes) -> None:
2424+ pass
2525+2626+ @abstractmethod
2727+ def get(self, key: bytes) -> bytes:
2828+ pass
2929+3030+ @abstractmethod
3131+ def delete(self, key: bytes) -> None:
3232+ pass
3333+3434+3535+class MemoryBlockStore(BlockStore):
3636+ _state = Dict[bytes, bytes]
3737+3838+ def __init__(self, state: Optional[Dict[bytes, bytes]]=None) -> None:
3939+ """
4040+ NB: if a state dict is passed, it'll get mutated in-place
4141+ """
4242+ self._state = dict() if state is None else state
4343+4444+ def put(self, key: bytes, value: bytes) -> None:
4545+ existing_value = self._state.get(key)
4646+ if existing_value:
4747+ if existing_value == value:
4848+ return # the value matches, there's nothing to do
4949+ raise ValueError("block values are immutable")
5050+ self._state[key] = value
5151+5252+ def get(self, key: bytes) -> bytes:
5353+ value = self._state.get(key)
5454+ if value is None:
5555+ raise KeyError("no block matches this key")
5656+ return value
5757+5858+ def delete(self, key: bytes) -> None:
5959+ if key in self._state:
6060+ del self._state[key]
6161+6262+6363+class SqliteBlockStore(BlockStore):
6464+ """
6565+ NB: Caller is responsible for calling commit(), etc.
6666+ TODO: consider allowing a custom table name?
6767+ """
6868+ def __init__(self, con: sqlite3.Connection, table: str="mst_blocks") -> None:
6969+ self.table = table
7070+ self._cur = con.cursor()
7171+ self._cur.execute(f"""
7272+ CREATE TABLE IF NOT EXISTS {self.table} (
7373+ block_key BLOB PRIMARY KEY,
7474+ block_val BLOB NOT NULL
7575+ ) WITHOUT ROWID;
7676+ """)
7777+7878+ def put(self, key: bytes, value: bytes) -> None:
7979+ # XXX: this will fail silently if the key already exists but with a different value
8080+ # (that should never happen but it'd be nice to have guard rails)
8181+ self._cur.execute(f"INSERT OR IGNORE INTO {self.table} (block_key, block_val) VALUES (?, ?)", (key, value))
8282+8383+ def get(self, key: bytes) -> bytes:
8484+ row = self._cur.execute(f"SELECT block_val FROM {self.table} WHERE block_key=?", (key,)).fetchone()
8585+ if row is None:
8686+ raise KeyError("no block matches this key")
8787+ return row[0]
8888+8989+ def delete(self, key: bytes) -> None:
9090+ self._cur.execute(f"DELETE FROM {self.table} WHERE block_key=?", (key,))
9191+9292+9393+class OverlayBlockStore(BlockStore):
9494+ """
9595+ reads come from "upper", then "lower" if they don't exist in upper.
9696+ writes/deletes go only to "upper".
9797+ """
9898+9999+ def __init__(self, upper: BlockStore, lower: BlockStore) -> None:
100100+ self.upper = upper
101101+ self.lower = lower
102102+103103+ def put(self, key: bytes, value: bytes) -> None:
104104+ self.upper.put(key, value)
105105+106106+ def get(self, key: bytes) -> bytes:
107107+ try:
108108+ return self.upper.get(key)
109109+ except KeyError:
110110+ return self.lower.get(key)
111111+112112+ def delete(self, key: bytes) -> None:
113113+ self.upper.delete(key)
114114+115115+116116+117117+if __name__ == "__main__":
118118+ import os
119119+120120+ bs = MemoryBlockStore()
121121+ bs.put(b"hello", b"world")
122122+123123+ bs.put(b"hello", b"world") # putting twice is a nop
124124+125125+ try:
126126+ bs.put(b"hello", b"foobar")
127127+ assert(False) # should be unreachable
128128+ except ValueError:
129129+ pass
130130+131131+ print("hello ->", bs.get(b"hello"))
132132+133133+ bs.delete(b"nothing") # nop
134134+135135+ bs.delete(b"hello")
136136+137137+ try:
138138+ bs.get(b"hello")
139139+ assert(False) # should be unreachable
140140+ except KeyError:
141141+ pass
142142+143143+ TEST_DB = "test.db"
144144+145145+ with sqlite3.connect(TEST_DB) as db:
146146+ bs = SqliteBlockStore(db)
147147+ bs.put(b"hello", b"sqlite world")
148148+149149+ with sqlite3.connect(TEST_DB) as db:
150150+ bs = SqliteBlockStore(db)
151151+ print("hello ->", bs.get(b"hello"))
152152+ bs.delete(b"hello")
153153+154154+ try:
155155+ with sqlite3.connect(TEST_DB) as db:
156156+ bs = SqliteBlockStore(db)
157157+ print("hello ->", bs.get(b"hello"))
158158+ assert(False) # should be unreachable
159159+ except KeyError:
160160+ pass
161161+162162+ os.remove(TEST_DB) # clean up
+74
carfile.py
···11+from typing import Self, Optional, Dict, List, Tuple, BinaryIO
22+from multiformats import varint, CID
33+import dag_cbor
44+55+from blockstore import BlockStore
66+77+class ReadOnlyCARBlockStore(BlockStore):
88+ """
99+ This is a sliiiightly unclean abstraction because BlockStores are indexed
1010+ by `bytes` rather than CID, but same idea. This is convenient for verifying
1111+ proofs provided in CAR format, and for testing.
1212+ """
1313+1414+ car_roots: List[CID]
1515+ block_offsets: Dict[bytes, Tuple[int, int]] # CID -> (offset, length)
1616+1717+ def __init__(self, file: BinaryIO) -> None:
1818+ """
1919+ pre-scan over the whole file, recording the offsets of each block
2020+ """
2121+2222+ self.file = file
2323+ file.seek(0)
2424+2525+ # parse out CAR header
2626+ header_len = varint.decode(file)
2727+ header = file.read(header_len)
2828+ if len(header) != header_len:
2929+ raise EOFError("not enough CAR header bytes")
3030+ header_obj = dag_cbor.decode(header)
3131+ if header_obj.get("version") != 1:
3232+ raise ValueError(f"unsupported CAR version ({header_obj.get('version')})")
3333+ self.car_roots = header_obj["roots"]
3434+3535+ # scan through the CAR to find block offsets
3636+ self.block_offsets = {}
3737+ while True:
3838+ try:
3939+ length = varint.decode(file)
4040+ except ValueError:
4141+ break # EOF
4242+ start = file.tell()
4343+ CID_LENGTH = 36 # XXX: this is a questionable assumption!!!
4444+ cid = file.read(CID_LENGTH)
4545+ if cid[:4] != b"\x01\x71\x12\x20": # I think this is enough to verify the assumption
4646+ raise ValueError("unsupported CID type")
4747+ self.block_offsets[cid] = (start + CID_LENGTH, length - CID_LENGTH)
4848+ file.seek(start + length)
4949+5050+ def put(self, key: bytes, value: bytes) -> None:
5151+ raise NotImplementedError("ReadOnlyCARBlockStore does not support put()")
5252+5353+ def get(self, key: bytes) -> bytes:
5454+ offset, length = self.block_offsets[key]
5555+ self.file.seek(offset)
5656+ value = self.file.read(length)
5757+ if len(value) != length:
5858+ raise EOFError()
5959+ return value
6060+6161+ def delete(self, key: bytes) -> None:
6262+ raise NotImplementedError("ReadOnlyCARBlockStore does not support delete()")
6363+6464+6565+if __name__ == "__main__":
6666+ f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
6767+ bs = ReadOnlyCARBlockStore(f)
6868+ commit_obj = dag_cbor.decode(bs.get(bytes(bs.car_roots[0])))
6969+ print(commit_obj)
7070+ mst_root: CID = commit_obj["data"]
7171+7272+ from mst import NodeStore
7373+ ns = NodeStore(bs)
7474+ print(ns.get(mst_root))
+375
mst.py
···11+import hashlib
22+import dag_cbor
33+import operator
44+from multiformats import multihash, CID
55+from functools import cached_property
66+from more_itertools import ilen
77+from itertools import takewhile
88+from dataclasses import dataclass
99+from typing import Tuple, Self, Optional, Any, Type, Iterable
1010+1111+from util import indent, hash_to_cid
1212+from blockstore import BlockStore, MemoryBlockStore
1313+1414+# tuple helpers
1515+def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple:
1616+ return original[:i] + (value,) + original[i + 1:]
1717+1818+def tuple_insert_at(original: tuple, i: int, value: Any) -> tuple:
1919+ return original[:i] + (value,) + original[i:]
2020+2121+def tuple_remove_at(original: tuple, i: int) -> tuple:
2222+ return original[:i] + original[i + 1:]
2323+2424+2525+@dataclass(frozen=True) # frozen == immutable == win
2626+class MSTNode:
2727+ """
2828+ k/v pairs are interleaved between subtrees like so:
2929+3030+ keys: (0, 1, 2, 3)
3131+ vals: (0, 1, 2, 3)
3232+ subtrees: (0, 1, 2, 3, 4)
3333+ """
3434+ keys: Tuple[str] # collection/rkey
3535+ vals: Tuple[CID] # record CIDs
3636+ subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree
3737+3838+3939+ # NB: __init__ is auto-generated by dataclass decorator
4040+4141+ # these checks should never fail, and could be skipped for performance
4242+ def __post_init__(self) -> None:
4343+ # TODO: maybe check that they're tuples here?
4444+ # implicitly, the length of self.subtrees must be at least 1
4545+ if len(self.subtrees) != len(self.keys) + 1:
4646+ raise ValueError("Invalid subtree count")
4747+ if len(self.keys) != len(self.vals):
4848+ raise ValueError("Mismatched keys/vals lengths")
4949+5050+ @classmethod
5151+ def empty_root(cls) -> Self:
5252+ return cls(
5353+ subtrees=(None,),
5454+ keys=(),
5555+ vals=()
5656+ )
5757+5858+ @staticmethod
5959+ def key_height(key: str) -> int:
6060+ digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big")
6161+ leading_zeroes = 256 - digest.bit_length()
6262+ return leading_zeroes // 2
6363+6464+ # since we're immutable, this can be cached
6565+ @cached_property
6666+ def cid(self) -> CID:
6767+ digest = multihash.digest(self.serialised, "sha2-256")
6868+ cid = CID("base32", 1, "dag-cbor", digest)
6969+ return cid
7070+7171+ # likewise
7272+ @cached_property
7373+ def serialised(self) -> bytes:
7474+ e = []
7575+ prev_key = b""
7676+ for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals):
7777+ key_bytes = key_str.encode()
7878+ shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming
7979+ e.append({
8080+ "k": key_bytes[shared_prefix_len:],
8181+ "p": shared_prefix_len,
8282+ "t": subtree,
8383+ "v": value,
8484+ })
8585+ prev_key = key_bytes
8686+ return dag_cbor.encode({
8787+ "e": e,
8888+ "l": self.subtrees[0]
8989+ })
9090+9191+ @classmethod
9292+ def deserialise(cls, data: bytes) -> Self:
9393+ cbor = dag_cbor.decode(data)
9494+ if len(cbor) != 2: # e, l
9595+ raise ValueError("malformed MST node")
9696+ subtrees = [cbor["l"]]
9797+ keys = []
9898+ vals = []
9999+ prev_key = b""
100100+ for e in cbor["e"]: # TODO: make extra sure that these checks are watertight
101101+ if len(e) != 4: # k, p, t, v
102102+ raise ValueError("malformed MST node")
103103+ prefix_len: int = e["p"]
104104+ suffix: bytes = e["k"]
105105+ if prefix_len > len(prev_key):
106106+ raise ValueError("invalid MST key prefix len")
107107+ if prev_key[prefix_len:prefix_len+1] == suffix[:1]:
108108+ raise ValueError("non-optimal MST key prefix len")
109109+ this_key = prev_key[:prefix_len] + suffix
110110+ if this_key <= prev_key:
111111+ raise ValueError("invalid MST key sort order")
112112+ keys.append(this_key.decode())
113113+ vals.append(e["v"])
114114+ subtrees.append(e["t"])
115115+ prev_key = this_key
116116+117117+ return cls(
118118+ subtrees=tuple(subtrees),
119119+ keys=tuple(keys),
120120+ vals=tuple(vals)
121121+ )
122122+123123+ def is_empty(self) -> bool:
124124+ return self.subtrees == (None,)
125125+126126+ def _to_optional(self) -> Optional[CID]:
127127+ """
128128+ returns None if the node is empty
129129+ """
130130+ if self.is_empty():
131131+ return None
132132+ return self.cid
133133+134134+135135+ @cached_property
136136+ def height(self) -> int:
137137+ # if there are keys at this level, query one directly
138138+ if self.keys:
139139+ return self.key_height(self.keys[0])
140140+141141+ # we're an empty tree
142142+ if self.subtrees[0] is None:
143143+ return 0
144144+145145+ # this should only happen for non-root nodes with no keys
146146+ raise Exception("cannot determine node height")
147147+148148+ def gte_index(self, key: str) -> int:
149149+ """
150150+ find the index of the first key greater than or equal to the specified key
151151+ if all keys are smaller, it returns len(keys)
152152+ """
153153+ i = 0 # this loop could be a binary search but not worth it for small fanouts
154154+ while i < len(self.keys) and key > self.keys[i]:
155155+ i += 1
156156+ return i
157157+158158+159159+class NodeStore:
160160+ """
161161+ NodeStore wraps a BlockStore to provide a more ergonomic interface
162162+ for loading and storing MSTNodes
163163+ """
164164+ bs: BlockStore
165165+166166+ def __init__(self, bs: BlockStore) -> None:
167167+ self.bs = bs
168168+169169+ # TODO: LRU cache this
170170+ def get(self, cid: Optional[CID]) -> MSTNode:
171171+ """
172172+ if cid is None, returns an empty MST node
173173+ """
174174+ if cid is None:
175175+ return MSTNode.empty_root()
176176+177177+ return MSTNode.deserialise(self.bs.get(bytes(cid)))
178178+179179+ # TODO: also put in cache
180180+ def put(self, node: MSTNode) -> MSTNode:
181181+ self.bs.put(bytes(node.cid), node.serialised)
182182+ return node # this is convenient
183183+184184+185185+186186+class MST:
187187+ ns: NodeStore
188188+ root: CID
189189+190190+ def __init__(self, ns: NodeStore, root: Optional[CID]=None) -> None:
191191+ self.ns = ns
192192+ if root is None:
193193+ root = ns.put(MSTNode.empty_root()).cid
194194+ self.root = root
195195+196196+ def put(self, key: str, val: CID):
197197+ self.root = self._put(key, val)
198198+199199+ def delete(self, key: str):
200200+ self.root = self._delete(key)
201201+202202+ def _put(self, key: str, val: CID) -> CID:
203203+ root = ns.get(self.root)
204204+ if root.is_empty(): # special case for empty tree
205205+ return self._put_here(root, key, val)
206206+ return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height)
207207+208208+ def _put_here(self, node: MSTNode, key: str, val: CID) -> CID:
209209+ i = node.gte_index(key)
210210+211211+ # the key is already present!
212212+ if i < len(node.keys) and node.keys[i] == key:
213213+ if node.vals[i] == val:
214214+ return node.cid # we can return our old self if there is no change
215215+ return self.ns.put(MSTNode(
216216+ keys=node.keys,
217217+ vals=tuple_replace_at(node.vals, i, val),
218218+ subtrees=node.subtrees
219219+ )).cid
220220+221221+ return self.ns.put(MSTNode(
222222+ keys=tuple_insert_at(node.keys, i, key),
223223+ vals=tuple_insert_at(node.vals, i, val),
224224+ subtrees = node.subtrees[:i] + \
225225+ self._split_on_key(node.subtrees[i], key) + \
226226+ node.subtrees[i + 1:],
227227+ )).cid
228228+229229+ def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> CID:
230230+ if key_height > tree_height: # we need to grow the tree
231231+ return self.ns.put(self._put_recursive(
232232+ MSTNode.empty_root(),
233233+ key, val, key_height, tree_height + 1
234234+ )).cid
235235+236236+ if key_height < tree_height: # we need to look below
237237+ i = node.gte_index(key)
238238+ return self.ns.put(MSTNode(
239239+ keys=node.keys,
240240+ vals=node.vals,
241241+ subtrees=tuple_replace_at(
242242+ node.subtrees, i,
243243+ self._put_recursive(
244244+ self.ns.get(node.subtrees[i]),
245245+ key, val, key_height, tree_height - 1
246246+ )
247247+ )
248248+ )).cid
249249+250250+ # we can insert here
251251+ assert(key_height == tree_height)
252252+ return self._put_here(node, key, val)
253253+254254+ def _split_on_key(self, node_cid: Optional[CID], key: str) -> Tuple[Optional[CID], Optional[CID]]:
255255+ if node_cid is None:
256256+ return None, None
257257+ node = ns.get(node_cid)
258258+ i = node.gte_index(key)
259259+ lsub, rsub = self._split_on_key(node.subtrees[i], key)
260260+ return self.ns.put(MSTNode(
261261+ keys=node.keys[:i],
262262+ vals=node.vals[:i],
263263+ subtrees=node.subtrees[:i] + (lsub,)
264264+ ))._to_optional(), self.ns.put(MSTNode(
265265+ keys=node.keys[i:],
266266+ vals=node.vals[i:],
267267+ subtrees=(rsub,) + node.subtrees[i + 1:],
268268+ ))._to_optional()
269269+270270+ def _squash_top(self, node_cid: Optional[CID]) -> Optional[CID]:
271271+ """
272272+ strip empty nodes from the top of the tree
273273+ """
274274+ node = self.ns.get(node_cid)
275275+ if node.keys:
276276+ return node_cid
277277+ if node.subtrees[0] is None:
278278+ return node_cid
279279+ return self._squash_top(node.subtrees[0])
280280+281281+ def _delete(self, key: str) -> CID:
282282+ root = ns.get(self.root)
283283+ # XXX: handle empty tree result case
284284+ return self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))
285285+286286+287287+ def _delete_recursive(self, node: MSTNode, key: str, key_height: int, tree_height: int) -> Optional[CID]:
288288+ if key_height > tree_height: # the key cannot possibly be in this tree, no change needed
289289+ return node.cid
290290+291291+ i = node.gte_index(key)
292292+ if key_height < tree_height: # the key must be deleted from a subtree
293293+ if node.subtrees[i] is None:
294294+ return node.cid # the key cannot be in this subtree, no change needed
295295+ return self.ns.put(MSTNode(
296296+ keys=node.keys,
297297+ vals=node.vals,
298298+ subtrees=tuple_replace_at(
299299+ node.subtrees,
300300+ i,
301301+ self._delete_recursive(self.ns.get(node.subtrees[i]), key, key_height, tree_height - 1)
302302+ )
303303+ ))._to_optional()
304304+305305+ i = node.gte_index(key)
306306+ if i == len(node.keys) or node.keys[i] != key:
307307+ return node.cid # key already not present
308308+309309+ assert(node.keys[i] == key) # sanity check (should always be true)
310310+311311+ return self.ns.put(MSTNode(
312312+ keys=tuple_remove_at(node.keys, i),
313313+ vals=tuple_remove_at(node.vals, i),
314314+ subtrees=node.subtrees[:i] + (
315315+ self._merge(node.subtrees[i], node.subtrees[i + 1]),
316316+ ) + node.subtrees[i + 2:]
317317+ ))._to_optional()
318318+319319+ def _merge(self, left_cid: Optional[CID], right_cid: Optional[CID]) -> Optional[CID]:
320320+ if left_cid is None:
321321+ return right_cid # includes the case where left == right == None
322322+ if right_cid is None:
323323+ return left_cid
324324+ left = self.ns.get(left_cid)
325325+ right = self.ns.get(right_cid)
326326+ return self.ns.put(MSTNode(
327327+ keys=left.keys + right.keys,
328328+ vals=left.vals + right.vals,
329329+ subtrees=left.subtrees[:-1] + (
330330+ self._merge(
331331+ left.subtrees[-1],
332332+ right.subtrees[0]
333333+ ),
334334+ ) + right.subtrees[1:]
335335+ ))._to_optional()
336336+337337+ def __repr__(self):
338338+ return self.pretty(self.root)
339339+340340+ def pretty(self, node_cid: Optional[CID]) -> str:
341341+ if node_cid is None:
342342+ return "<empty>"
343343+ node = self.ns.get(node_cid)
344344+ res = f"MSTNode<cid={node.cid.encode("base32")}>(\n{indent(self.pretty(node.subtrees[0]))},\n"
345345+ for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]):
346346+ res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode("base32")},\n"
347347+ res += indent(self.pretty(t)) + ",\n"
348348+ res += ")"
349349+ return res
350350+351351+352352+if __name__ == "__main__":
353353+ if 0:
354354+ from carfile import ReadOnlyCARBlockStore
355355+ f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
356356+ bs = ReadOnlyCARBlockStore(f)
357357+ commit_obj = dag_cbor.decode(bs.get(bytes(bs.car_roots[0])))
358358+ mst_root: CID = commit_obj["data"]
359359+ ns = NodeStore(bs)
360360+ mst = MST(ns, mst_root)
361361+ print(mst)
362362+ else:
363363+ bs = MemoryBlockStore()
364364+ ns = NodeStore(bs)
365365+ mst = MST(ns)
366366+ print(mst)
367367+ mst.root = mst._put("hello", hash_to_cid(b"blah"))
368368+ print(mst)
369369+ mst.root = mst._put("foo", hash_to_cid(b"bar"))
370370+ print(mst)
371371+ mst.root = mst._put("bar", hash_to_cid(b"bat"))
372372+ print(mst)
373373+ mst.root = mst._delete("foo")
374374+ mst.root = mst._delete("hello")
375375+ print(mst)