···11-import hashlib
22-import dag_cbor
33-import operator
44-from multiformats import multihash, CID
55-from functools import cached_property
66-from more_itertools import ilen
77-from itertools import takewhile
88-from dataclasses import dataclass
99-from typing import Tuple, Self, Optional
1010-1111-1212-@dataclass(frozen=True) # frozen == immutable == win
1313-class MSTNode:
1414- """
1515- k/v pairs are interleaved between subtrees like so: ::
1616-1717- keys: (0, 1, 2, 3)
1818- vals: (0, 1, 2, 3)
1919- subtrees: (0, 1, 2, 3, 4)
2020-2121- If a method is implemented in this class, it's because it's a function/property
2222- of a single node, as opposed to a whole tree
2323- """
2424- keys: Tuple[str] # collection/rkey
2525- vals: Tuple[CID] # record CIDs
2626- subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree
2727-2828-2929- # NB: __init__ is auto-generated by dataclass decorator
3030-3131- # these checks should never fail, and could be skipped for performance
3232- def __post_init__(self) -> None:
3333- # TODO: maybe check that they're tuples here?
3434- # implicitly, the length of self.subtrees must be at least 1
3535- if len(self.subtrees) != len(self.keys) + 1:
3636- raise ValueError("Invalid subtree count")
3737- if len(self.keys) != len(self.vals):
3838- raise ValueError("Mismatched keys/vals lengths")
3939-4040- @classmethod
4141- def empty_root(cls) -> Self:
4242- return cls(
4343- subtrees=(None,),
4444- keys=(),
4545- vals=()
4646- )
4747-4848- # this should maybe not be implemented here?
4949- @staticmethod
5050- def key_height(key: str) -> int:
5151- digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big")
5252- leading_zeroes = 256 - digest.bit_length()
5353- return leading_zeroes // 2
5454-5555- # since we're immutable, this can be cached
5656- @cached_property
5757- def cid(self) -> CID:
5858- digest = multihash.digest(self.serialised, "sha2-256")
5959- cid = CID("base32", 1, "dag-cbor", digest)
6060- return cid
6161-6262- # likewise
6363- @cached_property
6464- def serialised(self) -> bytes:
6565- e = []
6666- prev_key = b""
6767- for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals):
6868- key_bytes = key_str.encode()
6969- shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming
7070- e.append({
7171- "k": key_bytes[shared_prefix_len:],
7272- "p": shared_prefix_len,
7373- "t": subtree,
7474- "v": value,
7575- })
7676- prev_key = key_bytes
7777- return dag_cbor.encode({
7878- "e": e,
7979- "l": self.subtrees[0]
8080- })
8181-8282- @classmethod
8383- def deserialise(cls, data: bytes) -> Self:
8484- cbor = dag_cbor.decode(data)
8585- if len(cbor) != 2: # e, l
8686- raise ValueError("malformed MST node")
8787- subtrees = [cbor["l"]]
8888- keys = []
8989- vals = []
9090- prev_key = b""
9191- for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
9292- if len(e) != 4: # k, p, t, v
9393- raise ValueError("malformed MST node")
9494- prefix_len: int = e["p"]
9595- suffix: bytes = e["k"]
9696- if prefix_len > len(prev_key):
9797- raise ValueError("invalid MST key prefix len")
9898- if prev_key[prefix_len:prefix_len+1] == suffix[:1]:
9999- raise ValueError("non-optimal MST key prefix len")
100100- this_key = prev_key[:prefix_len] + suffix
101101- if this_key <= prev_key:
102102- raise ValueError("invalid MST key sort order")
103103- keys.append(this_key.decode())
104104- vals.append(e["v"])
105105- subtrees.append(e["t"])
106106- prev_key = this_key
107107-108108- return cls(
109109- subtrees=tuple(subtrees),
110110- keys=tuple(keys),
111111- vals=tuple(vals)
112112- )
113113-114114- def is_empty(self) -> bool:
115115- return self.subtrees == (None,)
116116-117117- def _to_optional(self) -> Optional[CID]:
118118- """
119119- returns None if the node is empty
120120- """
121121- if self.is_empty():
122122- return None
123123- return self.cid
124124-125125-126126- @cached_property
127127- def height(self) -> int:
128128- # if there are keys at this level, query one directly
129129- if self.keys:
130130- return self.key_height(self.keys[0])
131131-132132- # we're an empty tree
133133- if self.subtrees[0] is None:
134134- return 0
135135-136136- # this should only happen for non-root nodes with no keys
137137- raise Exception("cannot determine node height")
138138-139139- def gte_index(self, key: str) -> int:
140140- """
141141- find the index of the first key greater than or equal to the specified key
142142- if all keys are smaller, it returns len(keys)
143143- """
144144- i = 0 # this loop could be a binary search but not worth it for small fanouts
145145- while i < len(self.keys) and key > self.keys[i]:
146146- i += 1
147147- return i
148148-149149-150150-"""
151151-if __name__ == "__main__":
152152- from .blockstore import MemoryBlockStore, OverlayBlockStore
153153- from .blockstore.car_reader import ReadOnlyCARBlockStore
154154-155155- if 0:
156156- import sys
157157- sys.setrecursionlimit(999999999)
158158- f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
159159- bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f))
160160- commit_obj = dag_cbor.decode(bs.get_block(bytes(bs.lower.car_roots[0])))
161161- mst_root: CID = commit_obj["data"]
162162- ns = NodeStore(bs)
163163- wrangler = NodeWrangler(ns)
164164- #print(wrangler)
165165- #enumerate_mst(ns, mst_root)
166166- enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff")
167167-168168- root2 = wrangler.del_record(mst_root, "app.bsky.feed.generator/alttext")
169169- root2 = wrangler.del_record(root2, "app.bsky.feed.like/3kas3fyvkti22")
170170- root2 = wrangler.put_record(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah"))
171171-172172- c, d = mst_diff(ns, mst_root, root2)
173173- print("CREATED:")
174174- for x in c:
175175- print("created", x.encode("base32"))
176176- print("DELETED:")
177177- for x in d:
178178- print("deleted", x.encode("base32"))
179179-180180- for op in record_diff(ns, c, d):
181181- print(op)
182182-183183- e, f = very_slow_mst_diff(ns, mst_root, root2)
184184- assert(e == c)
185185- assert(f == d)
186186- else:
187187- bs = MemoryBlockStore()
188188- ns = NodeStore(bs)
189189- wrangler = NodeWrangler(ns)
190190- root = ns.get_node(None).cid
191191- print(ns.pretty(root))
192192- root = wrangler.put_record(root, "hello", hash_to_cid(b"blah"))
193193- print(ns.pretty(root))
194194- root = wrangler.put_record(root, "foo", hash_to_cid(b"bar"))
195195- print(ns.pretty(root))
196196- root_a = root
197197- root = wrangler.put_record(root, "bar", hash_to_cid(b"bat"))
198198- root = wrangler.put_record(root, "xyzz", hash_to_cid(b"bat"))
199199- root = wrangler.del_record(root, "foo")
200200- print("=============")
201201- print(ns.pretty(root_a))
202202- print("=============")
203203- print(ns.pretty(root))
204204- #exit()
205205- print("=============")
206206- enumerate_mst(ns, root)
207207- c, d = mst_diff(ns, root_a, root)
208208- print("CREATED:")
209209- for x in c:
210210- print("created", x.encode("base32"))
211211- print("DELETED:")
212212- for x in d:
213213- print("deleted", x.encode("base32"))
214214-215215- e, f = very_slow_mst_diff(ns, root_a, root)
216216- assert(e == c)
217217- assert(f == d)
218218-219219- exit()
220220- root = wrangler.delete(root, "foo")
221221- root = wrangler.delete(root, "hello")
222222- print(ns.pretty(root))
223223- root = wrangler.delete(root, "bar")
224224- print(ns.pretty(root))
225225- root = wrangler.delete(root, "bar")
226226- print(ns.pretty(root))
227227-"""
+18-13
src/atmst/mst/diff.py
···4455from multiformats import CID
6677-from . import MSTNode
77+from .node import MSTNode
88from .node_store import NodeStore
99from .node_walker import NodeWalker
101011111212def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]) -> Iterable[tuple]:
1313 """
1414- Given two sets of MST nodes (for example, the result of `mst_diff`), this
1515- returns an iterator of record changes, in one of 3 formats:
1414+ Given two sets of MST nodes (for example, the result of :meth:`mst_diff`), this
1515+ returns an iterator of record changes, in one of 3 formats: ::
16161717- ("created", key, value)
1818- ("updated", key, old_value, new_value)
1919- ("deleted", key, value)
1717+ ("created", key, value)
1818+ ("updated", key, old_value, new_value)
1919+ ("deleted", key, value)
2020+2021 """
2122 created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, created)), {})
2223 deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, deleted)), {})
···32333334def very_slow_mst_diff(ns: NodeStore, root_a: CID, root_b: CID):
3435 """
3535- This should return the same result as mst_diff, but it gets there in a very slow
3636- yet less error-prone way, so it's useful for testing.
3636+ This should return the same result as :meth:`mst_diff`, but it gets there in a slow
3737+ but much more obvious way (enumerating all nodes), so it's useful for testing.
37383839 It's actually faster for smaller trees, but it chokes on trees with thousands of nodes (especially if the NodeStore is slow).
3940 """
···4445EMPTY_NODE_CID = MSTNode.empty_root().cid
45464647def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created, deleted
4848+ """
4949+ Given two MST root node CIDs, efficiently compute the difference between them, represented as
5050+ two sets holding the created and deleted MST nodes respectively (referenced by CIDs).
5151+ """
4752 created = set() # MST nodes in b but not in a
4853 deleted = set() # MST nodes in a but not in b
4949- mst_diff_recursive(created, deleted, NodeWalker(ns, root_a), NodeWalker(ns, root_b))
5454+ _mst_diff_recursive(created, deleted, NodeWalker(ns, root_a), NodeWalker(ns, root_b))
5055 middle = created & deleted # my algorithm has occasional false-positives
5156 #assert(not middle) # this fails
5257 #print("middle", len(middle))
···5964 created.add(EMPTY_NODE_CID)
6065 return created, deleted
61666262-def mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: NodeWalker): # created, deleted
6767+def _mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: NodeWalker): # created, deleted
6368 # the easiest of all cases
6464- if a.frame.node.cid == b.frame.node.cid:
6969+ if a.frame.node == b.frame.node:
6570 return # no difference
66716772 # trivial
···113118114119 # the rkeys now match, but the subrees below us might not
115120116116- mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker())
121121+ _mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker())
117122118123 # check if we can still go right XXX: do we need to care about the case where one can, but the other can't?
119124 # To consider: maybe if I just step a, b will catch up automagically
120120- if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey:
125125+ if a.rkey == a.stack[0].rkey and b.rkey == b.stack[0].rkey:
121126 break
122127123128 a.right()
+227
src/atmst/mst/node.py
···11+import hashlib
22+import dag_cbor
33+import operator
44+from multiformats import multihash, CID
55+from functools import cached_property
66+from more_itertools import ilen
77+from itertools import takewhile
88+from dataclasses import dataclass
99+from typing import Tuple, Self, Optional
1010+1111+1212+@dataclass(frozen=True) # frozen == immutable == win
1313+class MSTNode:
1414+ """
1515+ k/v pairs are interleaved between subtrees like so: ::
1616+1717+ keys: (0, 1, 2, 3)
1818+ vals: (0, 1, 2, 3)
1919+ subtrees: (0, 1, 2, 3, 4)
2020+2121+ If a method is implemented in this class, it's because it's a function/property
2222+ of a single node, as opposed to a whole tree
2323+ """
2424+ keys: Tuple[str] # collection/rkey
2525+ vals: Tuple[CID] # record CIDs
2626+ subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree
2727+2828+2929+ # NB: __init__ is auto-generated by dataclass decorator
3030+3131+ # these checks should never fail, and could be skipped for performance
3232+ def __post_init__(self) -> None:
3333+ # TODO: maybe check that they're tuples here?
3434+ # implicitly, the length of self.subtrees must be at least 1
3535+ if len(self.subtrees) != len(self.keys) + 1:
3636+ raise ValueError("Invalid subtree count")
3737+ if len(self.keys) != len(self.vals):
3838+ raise ValueError("Mismatched keys/vals lengths")
3939+4040+ @classmethod
4141+ def empty_root(cls) -> Self:
4242+ return cls(
4343+ subtrees=(None,),
4444+ keys=(),
4545+ vals=()
4646+ )
4747+4848+ # this should maybe not be implemented here?
4949+ @staticmethod
5050+ def key_height(key: str) -> int:
5151+ digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big")
5252+ leading_zeroes = 256 - digest.bit_length()
5353+ return leading_zeroes // 2
5454+5555+ # since we're immutable, this can be cached
5656+ @cached_property
5757+ def cid(self) -> CID:
5858+ digest = multihash.digest(self.serialised, "sha2-256")
5959+ cid = CID("base32", 1, "dag-cbor", digest)
6060+ return cid
6161+6262+ # likewise
6363+ @cached_property
6464+ def serialised(self) -> bytes:
6565+ e = []
6666+ prev_key = b""
6767+ for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals):
6868+ key_bytes = key_str.encode()
6969+ shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming
7070+ e.append({
7171+ "k": key_bytes[shared_prefix_len:],
7272+ "p": shared_prefix_len,
7373+ "t": subtree,
7474+ "v": value,
7575+ })
7676+ prev_key = key_bytes
7777+ return dag_cbor.encode({
7878+ "e": e,
7979+ "l": self.subtrees[0]
8080+ })
8181+8282+ @classmethod
8383+ def deserialise(cls, data: bytes) -> Self:
8484+ cbor = dag_cbor.decode(data)
8585+ if len(cbor) != 2: # e, l
8686+ raise ValueError("malformed MST node")
8787+ subtrees = [cbor["l"]]
8888+ keys = []
8989+ vals = []
9090+ prev_key = b""
9191+ for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
9292+ if len(e) != 4: # k, p, t, v
9393+ raise ValueError("malformed MST node")
9494+ prefix_len: int = e["p"]
9595+ suffix: bytes = e["k"]
9696+ if prefix_len > len(prev_key):
9797+ raise ValueError("invalid MST key prefix len")
9898+ if prev_key[prefix_len:prefix_len+1] == suffix[:1]:
9999+ raise ValueError("non-optimal MST key prefix len")
100100+ this_key = prev_key[:prefix_len] + suffix
101101+ if this_key <= prev_key:
102102+ raise ValueError("invalid MST key sort order")
103103+ keys.append(this_key.decode())
104104+ vals.append(e["v"])
105105+ subtrees.append(e["t"])
106106+ prev_key = this_key
107107+108108+ return cls(
109109+ subtrees=tuple(subtrees),
110110+ keys=tuple(keys),
111111+ vals=tuple(vals)
112112+ )
113113+114114+ def is_empty(self) -> bool:
115115+ return self.subtrees == (None,)
116116+117117+ def _to_optional(self) -> Optional[CID]:
118118+ """
119119+ returns None if the node is empty
120120+ """
121121+ if self.is_empty():
122122+ return None
123123+ return self.cid
124124+125125+126126+ @cached_property
127127+ def height(self) -> int:
128128+ # if there are keys at this level, query one directly
129129+ if self.keys:
130130+ return self.key_height(self.keys[0])
131131+132132+ # we're an empty tree
133133+ if self.subtrees[0] is None:
134134+ return 0
135135+136136+ # this should only happen for non-root nodes with no keys
137137+ raise Exception("cannot determine node height")
138138+139139+ def gte_index(self, key: str) -> int:
140140+ """
141141+ find the index of the first key greater than or equal to the specified key
142142+ if all keys are smaller, it returns len(keys)
143143+ """
144144+ i = 0 # this loop could be a binary search but not worth it for small fanouts
145145+ while i < len(self.keys) and key > self.keys[i]:
146146+ i += 1
147147+ return i
148148+149149+150150+"""
151151+if __name__ == "__main__":
152152+ from .blockstore import MemoryBlockStore, OverlayBlockStore
153153+ from .blockstore.car_reader import ReadOnlyCARBlockStore
154154+155155+ if 0:
156156+ import sys
157157+ sys.setrecursionlimit(999999999)
158158+ f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
159159+ bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f))
160160+ commit_obj = dag_cbor.decode(bs.get_block(bytes(bs.lower.car_roots[0])))
161161+ mst_root: CID = commit_obj["data"]
162162+ ns = NodeStore(bs)
163163+ wrangler = NodeWrangler(ns)
164164+ #print(wrangler)
165165+ #enumerate_mst(ns, mst_root)
166166+ enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff")
167167+168168+ root2 = wrangler.del_record(mst_root, "app.bsky.feed.generator/alttext")
169169+ root2 = wrangler.del_record(root2, "app.bsky.feed.like/3kas3fyvkti22")
170170+ root2 = wrangler.put_record(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah"))
171171+172172+ c, d = mst_diff(ns, mst_root, root2)
173173+ print("CREATED:")
174174+ for x in c:
175175+ print("created", x.encode("base32"))
176176+ print("DELETED:")
177177+ for x in d:
178178+ print("deleted", x.encode("base32"))
179179+180180+ for op in record_diff(ns, c, d):
181181+ print(op)
182182+183183+ e, f = very_slow_mst_diff(ns, mst_root, root2)
184184+ assert(e == c)
185185+ assert(f == d)
186186+ else:
187187+ bs = MemoryBlockStore()
188188+ ns = NodeStore(bs)
189189+ wrangler = NodeWrangler(ns)
190190+ root = ns.get_node(None).cid
191191+ print(ns.pretty(root))
192192+ root = wrangler.put_record(root, "hello", hash_to_cid(b"blah"))
193193+ print(ns.pretty(root))
194194+ root = wrangler.put_record(root, "foo", hash_to_cid(b"bar"))
195195+ print(ns.pretty(root))
196196+ root_a = root
197197+ root = wrangler.put_record(root, "bar", hash_to_cid(b"bat"))
198198+ root = wrangler.put_record(root, "xyzz", hash_to_cid(b"bat"))
199199+ root = wrangler.del_record(root, "foo")
200200+ print("=============")
201201+ print(ns.pretty(root_a))
202202+ print("=============")
203203+ print(ns.pretty(root))
204204+ #exit()
205205+ print("=============")
206206+ enumerate_mst(ns, root)
207207+ c, d = mst_diff(ns, root_a, root)
208208+ print("CREATED:")
209209+ for x in c:
210210+ print("created", x.encode("base32"))
211211+ print("DELETED:")
212212+ for x in d:
213213+ print("deleted", x.encode("base32"))
214214+215215+ e, f = very_slow_mst_diff(ns, root_a, root)
216216+ assert(e == c)
217217+ assert(f == d)
218218+219219+ exit()
220220+ root = wrangler.delete(root, "foo")
221221+ root = wrangler.delete(root, "hello")
222222+ print(ns.pretty(root))
223223+ root = wrangler.delete(root, "bar")
224224+ print(ns.pretty(root))
225225+ root = wrangler.delete(root, "bar")
226226+ print(ns.pretty(root))
227227+"""