···11-import hashlib
22-import dag_cbor
33-import operator
44-from multiformats import multihash, CID
55-from functools import cached_property, reduce
66-from more_itertools import ilen
77-from itertools import takewhile
88-from dataclasses import dataclass
99-from typing import Tuple, Self, Optional, Any, Dict, List, Set, Type, Iterable
1010-from collections import namedtuple
1111-1212-from .util import indent, hash_to_cid
1313-from .blockstore import BlockStore
1414-1515-# tuple helpers
1616-def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple:
1717- return original[:i] + (value,) + original[i + 1:]
1818-1919-def tuple_insert_at(original: tuple, i: int, value: Any) -> tuple:
2020- return original[:i] + (value,) + original[i:]
2121-2222-def tuple_remove_at(original: tuple, i: int) -> tuple:
2323- return original[:i] + original[i + 1:]
2424-2525-2626-@dataclass(frozen=True) # frozen == immutable == win
2727-class MSTNode:
2828- """
2929- k/v pairs are interleaved between subtrees like so: ::
3030-3131- keys: (0, 1, 2, 3)
3232- vals: (0, 1, 2, 3)
3333- subtrees: (0, 1, 2, 3, 4)
3434-3535- If a method is implemented in this class, it's because it's a function/property
3636- of a single node, as opposed to a whole tree
3737- """
3838- keys: Tuple[str] # collection/rkey
3939- vals: Tuple[CID] # record CIDs
4040- subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree
4141-4242-4343- # NB: __init__ is auto-generated by dataclass decorator
4444-4545- # these checks should never fail, and could be skipped for performance
4646- def __post_init__(self) -> None:
4747- # TODO: maybe check that they're tuples here?
4848- # implicitly, the length of self.subtrees must be at least 1
4949- if len(self.subtrees) != len(self.keys) + 1:
5050- raise ValueError("Invalid subtree count")
5151- if len(self.keys) != len(self.vals):
5252- raise ValueError("Mismatched keys/vals lengths")
5353-5454- @classmethod
5555- def empty_root(cls) -> Self:
5656- return cls(
5757- subtrees=(None,),
5858- keys=(),
5959- vals=()
6060- )
6161-6262- # this should maybe not be implemented here?
6363- @staticmethod
6464- def key_height(key: str) -> int:
6565- digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big")
6666- leading_zeroes = 256 - digest.bit_length()
6767- return leading_zeroes // 2
6868-6969- # since we're immutable, this can be cached
7070- @cached_property
7171- def cid(self) -> CID:
7272- digest = multihash.digest(self.serialised, "sha2-256")
7373- cid = CID("base32", 1, "dag-cbor", digest)
7474- return cid
7575-7676- # likewise
7777- @cached_property
7878- def serialised(self) -> bytes:
7979- e = []
8080- prev_key = b""
8181- for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals):
8282- key_bytes = key_str.encode()
8383- shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming
8484- e.append({
8585- "k": key_bytes[shared_prefix_len:],
8686- "p": shared_prefix_len,
8787- "t": subtree,
8888- "v": value,
8989- })
9090- prev_key = key_bytes
9191- return dag_cbor.encode({
9292- "e": e,
9393- "l": self.subtrees[0]
9494- })
9595-9696- @classmethod
9797- def deserialise(cls, data: bytes) -> Self:
9898- cbor = dag_cbor.decode(data)
9999- if len(cbor) != 2: # e, l
100100- raise ValueError("malformed MST node")
101101- subtrees = [cbor["l"]]
102102- keys = []
103103- vals = []
104104- prev_key = b""
105105- for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
106106- if len(e) != 4: # k, p, t, v
107107- raise ValueError("malformed MST node")
108108- prefix_len: int = e["p"]
109109- suffix: bytes = e["k"]
110110- if prefix_len > len(prev_key):
111111- raise ValueError("invalid MST key prefix len")
112112- if prev_key[prefix_len:prefix_len+1] == suffix[:1]:
113113- raise ValueError("non-optimal MST key prefix len")
114114- this_key = prev_key[:prefix_len] + suffix
115115- if this_key <= prev_key:
116116- raise ValueError("invalid MST key sort order")
117117- keys.append(this_key.decode())
118118- vals.append(e["v"])
119119- subtrees.append(e["t"])
120120- prev_key = this_key
121121-122122- return cls(
123123- subtrees=tuple(subtrees),
124124- keys=tuple(keys),
125125- vals=tuple(vals)
126126- )
127127-128128- def is_empty(self) -> bool:
129129- return self.subtrees == (None,)
130130-131131- def _to_optional(self) -> Optional[CID]:
132132- """
133133- returns None if the node is empty
134134- """
135135- if self.is_empty():
136136- return None
137137- return self.cid
138138-139139-140140- @cached_property
141141- def height(self) -> int:
142142- # if there are keys at this level, query one directly
143143- if self.keys:
144144- return self.key_height(self.keys[0])
145145-146146- # we're an empty tree
147147- if self.subtrees[0] is None:
148148- return 0
149149-150150- # this should only happen for non-root nodes with no keys
151151- raise Exception("cannot determine node height")
152152-153153- def gte_index(self, key: str) -> int:
154154- """
155155- find the index of the first key greater than or equal to the specified key
156156- if all keys are smaller, it returns len(keys)
157157- """
158158- i = 0 # this loop could be a binary search but not worth it for small fanouts
159159- while i < len(self.keys) and key > self.keys[i]:
160160- i += 1
161161- return i
162162-163163-164164-class NodeStore:
165165- """
166166- NodeStore wraps a BlockStore to provide a more ergonomic interface
167167- for loading and storing MSTNodes
168168- """
169169- bs: BlockStore
170170- cache: Dict[Optional[CID], MSTNode] # XXX: this cache will grow forever!
171171- #cache_counts: Dict[Optional[CID], int]
172172-173173- def __init__(self, bs: BlockStore) -> None:
174174- self.bs = bs
175175- self.cache = {}
176176- #self.cache_counts = {}
177177-178178- # TODO: LRU cache this - this package looks ideal: https://github.com/amitdev/lru-dict
179179- def get_node(self, cid: Optional[CID]) -> MSTNode:
180180- cached = self.cache.get(cid)
181181- if cached:
182182- return cached
183183- """
184184- if cid is None, returns an empty MST node
185185- """
186186- if cid is None:
187187- return self.put_node(MSTNode.empty_root())
188188-189189- res = MSTNode.deserialise(self.bs.get_block(bytes(cid)))
190190- self.cache[cid] = res
191191- return res
192192-193193- # TODO: also put in cache
194194- def put_node(self, node: MSTNode) -> MSTNode:
195195- self.cache[node.cid] = node
196196- self.bs.put_block(bytes(node.cid), node.serialised)
197197- return node # this is convenient
198198-199199- # MST pretty-printing
200200- # this should maybe not be implemented here
201201- def pretty(self, node_cid: Optional[CID]) -> str:
202202- if node_cid is None:
203203- return "<empty>"
204204- node = self.get_node(node_cid)
205205- res = f"MSTNode<cid={node.cid.encode("base32")}>(\n{indent(self.pretty(node.subtrees[0]))},\n"
206206- for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]):
207207- res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode("base32")},\n"
208208- res += indent(self.pretty(t)) + ",\n"
209209- res += ")"
210210- return res
211211-212212-213213-class NodeWrangler:
214214- """
215215- NodeWrangler is where core MST transformation ops are implemented, backed
216216- by a NodeStore
217217-218218- The external APIs take a CID (the MST root) and return a CID (the new root),
219219- while storing any newly created nodes in the NodeStore.
220220-221221- Neither method should ever fail - deleting a node that doesn't exist is a nop,
222222- and adding the same node twice with the same value is also a nop. Callers
223223- can detect these cases by seeing if the initial and final CIDs changed.
224224- """
225225- ns: NodeStore
226226-227227- def __init__(self, ns: NodeStore) -> None:
228228- self.ns = ns
229229-230230- def put_record(self, root_cid: CID, key: str, val: CID) -> CID:
231231- root = self.ns.get_node(root_cid)
232232- if root.is_empty(): # special case for empty tree
233233- return self._put_here(root, key, val).cid
234234- return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid
235235-236236- def del_record(self, root_cid: CID, key: str) -> CID:
237237- root = self.ns.get_node(root_cid)
238238-239239- # Note: the seemingly redundant outer .get().cid is required to transform
240240- # a None cid into the cid representing an empty node (we could maybe find a more elegant
241241- # way of doing this...)
242242- return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid
243243-244244-245245-246246- def _put_here(self, node: MSTNode, key: str, val: CID) -> MSTNode:
247247- i = node.gte_index(key)
248248-249249- # the key is already present!
250250- if i < len(node.keys) and node.keys[i] == key:
251251- if node.vals[i] == val:
252252- return node # we can return our old self if there is no change
253253- return self.ns.put_node(MSTNode(
254254- keys=node.keys,
255255- vals=tuple_replace_at(node.vals, i, val),
256256- subtrees=node.subtrees
257257- ))
258258-259259- return self.ns.put_node(MSTNode(
260260- keys=tuple_insert_at(node.keys, i, key),
261261- vals=tuple_insert_at(node.vals, i, val),
262262- subtrees = node.subtrees[:i] + \
263263- self._split_on_key(node.subtrees[i], key) + \
264264- node.subtrees[i + 1:],
265265- ))
266266-267267- def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> MSTNode:
268268- if key_height > tree_height: # we need to grow the tree
269269- return self.ns.put_node(self._put_recursive(
270270- MSTNode.empty_root(),
271271- key, val, key_height, tree_height + 1
272272- ))
273273-274274- if key_height < tree_height: # we need to look below
275275- i = node.gte_index(key)
276276- return self.ns.put_node(MSTNode(
277277- keys=node.keys,
278278- vals=node.vals,
279279- subtrees=tuple_replace_at(
280280- node.subtrees, i,
281281- self._put_recursive(
282282- self.ns.get_node(node.subtrees[i]),
283283- key, val, key_height, tree_height - 1
284284- ).cid
285285- )
286286- ))
287287-288288- # we can insert here
289289- assert(key_height == tree_height)
290290- return self._put_here(node, key, val)
291291-292292- def _split_on_key(self, node_cid: Optional[CID], key: str) -> Tuple[Optional[CID], Optional[CID]]:
293293- if node_cid is None:
294294- return None, None
295295- node = self.ns.get_node(node_cid)
296296- i = node.gte_index(key)
297297- lsub, rsub = self._split_on_key(node.subtrees[i], key)
298298- return self.ns.put_node(MSTNode(
299299- keys=node.keys[:i],
300300- vals=node.vals[:i],
301301- subtrees=node.subtrees[:i] + (lsub,)
302302- ))._to_optional(), self.ns.put_node(MSTNode(
303303- keys=node.keys[i:],
304304- vals=node.vals[i:],
305305- subtrees=(rsub,) + node.subtrees[i + 1:],
306306- ))._to_optional()
307307-308308- def _squash_top(self, node_cid: Optional[CID]) -> Optional[CID]:
309309- """
310310- strip empty nodes from the top of the tree
311311- """
312312- node = self.ns.get_node(node_cid)
313313- if node.keys:
314314- return node_cid
315315- if node.subtrees[0] is None:
316316- return node_cid
317317- return self._squash_top(node.subtrees[0])
318318-319319- def _delete_recursive(self, node: MSTNode, key: str, key_height: int, tree_height: int) -> Optional[CID]:
320320- if key_height > tree_height: # the key cannot possibly be in this tree, no change needed
321321- return node._to_optional()
322322-323323- i = node.gte_index(key)
324324- if key_height < tree_height: # the key must be deleted from a subtree
325325- if node.subtrees[i] is None:
326326- return node._to_optional() # the key cannot be in this subtree, no change needed
327327- return self.ns.put_node(MSTNode(
328328- keys=node.keys,
329329- vals=node.vals,
330330- subtrees=tuple_replace_at(
331331- node.subtrees,
332332- i,
333333- self._delete_recursive(self.ns.get_node(node.subtrees[i]), key, key_height, tree_height - 1)
334334- )
335335- ))._to_optional()
336336-337337- i = node.gte_index(key)
338338- if i == len(node.keys) or node.keys[i] != key:
339339- return node._to_optional() # key already not present
340340-341341- assert(node.keys[i] == key) # sanity check (should always be true)
342342-343343- return self.ns.put_node(MSTNode(
344344- keys=tuple_remove_at(node.keys, i),
345345- vals=tuple_remove_at(node.vals, i),
346346- subtrees=node.subtrees[:i] + (
347347- self._merge(node.subtrees[i], node.subtrees[i + 1]),
348348- ) + node.subtrees[i + 2:]
349349- ))._to_optional()
350350-351351- def _merge(self, left_cid: Optional[CID], right_cid: Optional[CID]) -> Optional[CID]:
352352- if left_cid is None:
353353- return right_cid # includes the case where left == right == None
354354- if right_cid is None:
355355- return left_cid
356356- left = self.ns.get_node(left_cid)
357357- right = self.ns.get_node(right_cid)
358358- return self.ns.put_node(MSTNode(
359359- keys=left.keys + right.keys,
360360- vals=left.vals + right.vals,
361361- subtrees=left.subtrees[:-1] + (
362362- self._merge(
363363- left.subtrees[-1],
364364- right.subtrees[0]
365365- ),
366366- ) + right.subtrees[1:]
367367- ))._to_optional()
368368-369369-370370-class NodeWalker:
371371- """
372372- NodeWalker makes implementing tree diffing and other MST query ops more
373373- convenient (but it does not, itself, implement them).
374374-375375- A NodeWalker starts off at the root of a tree, and can walk along or recurse
376376- down into subtrees.
377377-378378- Walking "off the end" of a subtree brings you back up to its next non-empty parent.
379379-380380- Recall MSTNode layout: ::
381381-382382- keys: (lkey) (0, 1, 2, 3) (rkey)
383383- vals: (0, 1, 2, 3)
384384- subtrees: (0, 1, 2, 3, 4)
385385-386386- """
387387- KEY_MIN = "" # string that compares less than all legal key strings
388388- KEY_MAX = "\xff" # string that compares greater than all legal key strings
389389-390390- @dataclass
391391- class StackFrame:
392392- node: MSTNode # could store CIDs only to save memory, in theory, but not much point
393393- lkey: str
394394- rkey: str
395395- idx: int
396396-397397- ns: NodeStore
398398- stack: List[StackFrame]
399399-400400- def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None:
401401- self.ns = ns
402402- self.stack = [self.StackFrame(
403403- node=self.ns.get_node(root_cid),
404404- lkey=lkey,
405405- rkey=rkey,
406406- idx=0
407407- )]
408408-409409- def subtree_walker(self) -> Self:
410410- return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey)
411411-412412- @property
413413- def frame(self) -> StackFrame:
414414- return self.stack[-1]
415415-416416- @property
417417- def lkey(self) -> str:
418418- return self.frame.lkey if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1]
419419-420420- @property
421421- def lval(self) -> Optional[CID]:
422422- return None if self.frame.idx == 0 else self.frame.node.vals[self.frame.idx - 1]
423423-424424- @property
425425- def subtree(self) -> Optional[CID]:
426426- return self.frame.node.subtrees[self.frame.idx]
427427-428428- # hmmmm rkey is overloaded here... "right key" not "record key"...
429429- @property
430430- def rkey(self) -> str:
431431- return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx]
432432-433433- @property
434434- def rval(self) -> Optional[CID]:
435435- return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx]
436436-437437- @property
438438- def is_final(self) -> bool:
439439- return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey)
440440-441441- def right(self) -> None:
442442- if (self.frame.idx + 1) >= len(self.frame.node.subtrees):
443443- # we reached the end of this node, go up a level
444444- self.stack.pop()
445445- if not self.stack:
446446- raise StopIteration # you probably want to check .final instead of hitting this
447447- return self.right() # we need to recurse, to skip over empty intermediates on the way back up
448448- self.frame.idx += 1
449449-450450- def down(self) -> None:
451451- subtree = self.frame.node.subtrees[self.frame.idx]
452452- if subtree is None:
453453- raise Exception("oi, you can't recurse here mate")
454454-455455- self.stack.append(self.StackFrame(
456456- node=self.ns.get_node(subtree),
457457- lkey=self.lkey,
458458- rkey=self.rkey,
459459- idx=0
460460- ))
461461-462462- # everything above here is core tree walking logic
463463- # everything below here is helper functions
464464-465465- def next_kv(self) -> Tuple[str, CID]:
466466- while self.subtree: # recurse down every subtree
467467- self.down()
468468- self.right()
469469- return self.lkey, self.lval # the kv pair we just jumped over
470470-471471- # iterate over every k/v pair in key-sorted order
472472- def iter_kv(self):
473473- while not self.is_final:
474474- yield self.next_kv()
475475-476476- # get all mst nodes down and to the right of the current position
477477- def iter_node_cids(self):
478478- yield self.frame.node.cid
479479- while not self.is_final:
480480- while self.subtree: # recurse down every subtree
481481- self.down()
482482- yield self.frame.node.cid
483483- self.right()
484484-485485-486486-def enumerate_mst(ns: NodeStore, root_cid: CID):
487487- for k, v in NodeWalker(ns, root_cid).iter_kv():
488488- print(k, "->", v.encode("base32"))
489489-490490-# start inclusive, end exclusive
491491-def enumerate_mst_range(ns: NodeStore, root_cid: CID, start: str, end: str):
492492- cur = NodeWalker(ns, root_cid)
493493- while True:
494494- while cur.rkey < start:
495495- cur.right()
496496- if not cur.subtree:
497497- break
498498- cur.down()
499499-500500- for k, v, in cur.iter_kv():
501501- if k >= end:
502502- break
503503- print(k, "->", v.encode("base32"))
504504-505505-def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]):
506506- created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, created)), {})
507507- deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, deleted)), {})
508508- for created_key in created_kv.keys() - deleted_kv.keys():
509509- yield ("created", created_key, created_kv[created_key].encode("base32"))
510510- for updated_key in created_kv.keys() & deleted_kv.keys():
511511- v1 = created_kv[updated_key]
512512- v2 = deleted_kv[updated_key]
513513- if v1 != v2:
514514- yield ("updated", updated_key, v1.encode("base32"), v2.encode("base32"))
515515- for deleted_key in deleted_kv.keys() - created_kv.keys():
516516- yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging
517517-518518-def very_slow_mst_diff(ns, root_a: CID, root_b: CID):
519519- """
520520- This should return the same result as mst_diff, but it gets there in a very slow
521521- yet less error-prone way, so it's useful for testing.
522522- """
523523- a_nodes = set(NodeWalker(ns, root_a).iter_node_cids())
524524- b_nodes = set(NodeWalker(ns, root_b).iter_node_cids())
525525- return b_nodes - a_nodes, a_nodes - b_nodes
526526-527527-EMPTY_NODE_CID = MSTNode.empty_root().cid
528528-529529-def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created_deleted
530530- created = set() # MST nodes in b but not in a
531531- deleted = set() # MST nodes in a but not in b
532532- mst_diff_recursive(created, deleted, NodeWalker(ns, root_a), NodeWalker(ns, root_b))
533533- middle = created & deleted # my algorithm has occasional false-positives
534534- #assert(not middle) # this fails
535535- #print("middle", len(middle))
536536- created -= middle
537537- deleted -= middle
538538- # special case: if one of the root nodes was empty
539539- if root_a == EMPTY_NODE_CID and root_b != EMPTY_NODE_CID:
540540- deleted.add(EMPTY_NODE_CID)
541541- if root_b == EMPTY_NODE_CID and root_a != EMPTY_NODE_CID:
542542- created.add(EMPTY_NODE_CID)
543543- return created, deleted
544544-545545-def mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: NodeWalker): # created, deleted
546546- # the easiest of all cases
547547- if a.frame.node.cid == b.frame.node.cid:
548548- return # no difference
549549-550550- # trivial
551551- if a.frame.node.is_empty():
552552- #mst_deleted.add(a.frame.node.cid) # this doesn't work because it might've been a null subtree node
553553- created |= set(b.iter_node_cids())
554554- return
555555-556556- # likewise
557557- if b.frame.node.is_empty():
558558- #mst_created.add(b.frame.node.cid)
559559- deleted |= set(a.iter_node_cids())
560560- return
561561-562562- # now we're onto the hard part
563563-564564- """
565565- theory: most trees that get compared will have lots of shared blocks (which we can skip over, due to identical CIDs)
566566- completely different trees will inevitably have to visit every node.
567567-568568- general idea:
569569- 1. if one cursor is "behind" the other, catch it up
570570- 2. when we're matched up, skip over identical subtrees (and recursively diff non-identical subtrees)
571571-572572- XXX: this seems to work nicely but I'm not sure if it's necessarily efficient for all tree layouts?
573573- """
574574-575575- # NB: these will end up as false-positives if one tree is a subtree of the other
576576- created.add(b.frame.node.cid)
577577- deleted.add(a.frame.node.cid)
578578-579579- while True:
580580- while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other
581581- # "catch up" cursor a, if it's behind
582582- while a.rkey < b.rkey and not a.is_final:
583583- if a.subtree: # recurse down every subtree
584584- a.down()
585585- deleted.add(a.frame.node.cid)
586586- else:
587587- a.right()
588588-589589- # catch up cursor b, likewise
590590- while b.rkey < a.rkey and not b.is_final:
591591- if b.subtree: # recurse down every subtree
592592- b.down()
593593- created.add(b.frame.node.cid)
594594- else:
595595- b.right()
596596-597597- # the rkeys now match, but the subrees below us might not
598598-599599- mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker())
600600-601601- # check if we can still go right XXX: do we need to care about the case where one can, but the other can't?
602602- # To consider: maybe if I just step a, b will catch up automagically
603603- if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey:
604604- break
605605-606606- a.right()
607607- b.right()
608608-609609-610610-if __name__ == "__main__":
611611- from .blockstore import MemoryBlockStore, OverlayBlockStore
612612- from .blockstore.car_reader import ReadOnlyCARBlockStore
613613-614614- if 0:
615615- import sys
616616- sys.setrecursionlimit(999999999)
617617- f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
618618- bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f))
619619- commit_obj = dag_cbor.decode(bs.get_block(bytes(bs.lower.car_roots[0])))
620620- mst_root: CID = commit_obj["data"]
621621- ns = NodeStore(bs)
622622- wrangler = NodeWrangler(ns)
623623- #print(wrangler)
624624- #enumerate_mst(ns, mst_root)
625625- enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff")
626626-627627- root2 = wrangler.del_record(mst_root, "app.bsky.feed.generator/alttext")
628628- root2 = wrangler.del_record(root2, "app.bsky.feed.like/3kas3fyvkti22")
629629- root2 = wrangler.put_record(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah"))
630630-631631- c, d = mst_diff(ns, mst_root, root2)
632632- print("CREATED:")
633633- for x in c:
634634- print("created", x.encode("base32"))
635635- print("DELETED:")
636636- for x in d:
637637- print("deleted", x.encode("base32"))
638638-639639- for op in record_diff(ns, c, d):
640640- print(op)
641641-642642- e, f = very_slow_mst_diff(ns, mst_root, root2)
643643- assert(e == c)
644644- assert(f == d)
645645- else:
646646- bs = MemoryBlockStore()
647647- ns = NodeStore(bs)
648648- wrangler = NodeWrangler(ns)
649649- root = ns.get_node(None).cid
650650- print(ns.pretty(root))
651651- root = wrangler.put_record(root, "hello", hash_to_cid(b"blah"))
652652- print(ns.pretty(root))
653653- root = wrangler.put_record(root, "foo", hash_to_cid(b"bar"))
654654- print(ns.pretty(root))
655655- root_a = root
656656- root = wrangler.put_record(root, "bar", hash_to_cid(b"bat"))
657657- root = wrangler.put_record(root, "xyzz", hash_to_cid(b"bat"))
658658- root = wrangler.del_record(root, "foo")
659659- print("=============")
660660- print(ns.pretty(root_a))
661661- print("=============")
662662- print(ns.pretty(root))
663663- #exit()
664664- print("=============")
665665- enumerate_mst(ns, root)
666666- c, d = mst_diff(ns, root_a, root)
667667- print("CREATED:")
668668- for x in c:
669669- print("created", x.encode("base32"))
670670- print("DELETED:")
671671- for x in d:
672672- print("deleted", x.encode("base32"))
673673-674674- e, f = very_slow_mst_diff(ns, root_a, root)
675675- assert(e == c)
676676- assert(f == d)
677677-678678- exit()
679679- root = wrangler.delete(root, "foo")
680680- root = wrangler.delete(root, "hello")
681681- print(ns.pretty(root))
682682- root = wrangler.delete(root, "bar")
683683- print(ns.pretty(root))
684684- root = wrangler.delete(root, "bar")
685685- print(ns.pretty(root))
+227
src/atmst/mst/__init__.py
···11+import hashlib
22+import dag_cbor
33+import operator
44+from multiformats import multihash, CID
55+from functools import cached_property
66+from more_itertools import ilen
77+from itertools import takewhile
88+from dataclasses import dataclass
99+from typing import Tuple, Self, Optional
1010+1111+1212+@dataclass(frozen=True) # frozen == immutable == win
1313+class MSTNode:
1414+ """
1515+ k/v pairs are interleaved between subtrees like so: ::
1616+1717+ keys: (0, 1, 2, 3)
1818+ vals: (0, 1, 2, 3)
1919+ subtrees: (0, 1, 2, 3, 4)
2020+2121+ If a method is implemented in this class, it's because it's a function/property
2222+ of a single node, as opposed to a whole tree
2323+ """
2424+ keys: Tuple[str] # collection/rkey
2525+ vals: Tuple[CID] # record CIDs
2626+ subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree
2727+2828+2929+ # NB: __init__ is auto-generated by dataclass decorator
3030+3131+ # these checks should never fail, and could be skipped for performance
3232+ def __post_init__(self) -> None:
3333+ # TODO: maybe check that they're tuples here?
3434+ # implicitly, the length of self.subtrees must be at least 1
3535+ if len(self.subtrees) != len(self.keys) + 1:
3636+ raise ValueError("Invalid subtree count")
3737+ if len(self.keys) != len(self.vals):
3838+ raise ValueError("Mismatched keys/vals lengths")
3939+4040+ @classmethod
4141+ def empty_root(cls) -> Self:
4242+ return cls(
4343+ subtrees=(None,),
4444+ keys=(),
4545+ vals=()
4646+ )
4747+4848+ # this should maybe not be implemented here?
4949+ @staticmethod
5050+ def key_height(key: str) -> int:
5151+ digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big")
5252+ leading_zeroes = 256 - digest.bit_length()
5353+ return leading_zeroes // 2
5454+5555+ # since we're immutable, this can be cached
5656+ @cached_property
5757+ def cid(self) -> CID:
5858+ digest = multihash.digest(self.serialised, "sha2-256")
5959+ cid = CID("base32", 1, "dag-cbor", digest)
6060+ return cid
6161+6262+ # likewise
6363+ @cached_property
6464+ def serialised(self) -> bytes:
6565+ e = []
6666+ prev_key = b""
6767+ for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals):
6868+ key_bytes = key_str.encode()
6969+ shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming
7070+ e.append({
7171+ "k": key_bytes[shared_prefix_len:],
7272+ "p": shared_prefix_len,
7373+ "t": subtree,
7474+ "v": value,
7575+ })
7676+ prev_key = key_bytes
7777+ return dag_cbor.encode({
7878+ "e": e,
7979+ "l": self.subtrees[0]
8080+ })
8181+8282+ @classmethod
8383+ def deserialise(cls, data: bytes) -> Self:
8484+ cbor = dag_cbor.decode(data)
8585+ if len(cbor) != 2: # e, l
8686+ raise ValueError("malformed MST node")
8787+ subtrees = [cbor["l"]]
8888+ keys = []
8989+ vals = []
9090+ prev_key = b""
9191+ for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations
9292+ if len(e) != 4: # k, p, t, v
9393+ raise ValueError("malformed MST node")
9494+ prefix_len: int = e["p"]
9595+ suffix: bytes = e["k"]
9696+ if prefix_len > len(prev_key):
9797+ raise ValueError("invalid MST key prefix len")
9898+ if prev_key[prefix_len:prefix_len+1] == suffix[:1]:
9999+ raise ValueError("non-optimal MST key prefix len")
100100+ this_key = prev_key[:prefix_len] + suffix
101101+ if this_key <= prev_key:
102102+ raise ValueError("invalid MST key sort order")
103103+ keys.append(this_key.decode())
104104+ vals.append(e["v"])
105105+ subtrees.append(e["t"])
106106+ prev_key = this_key
107107+108108+ return cls(
109109+ subtrees=tuple(subtrees),
110110+ keys=tuple(keys),
111111+ vals=tuple(vals)
112112+ )
113113+114114+ def is_empty(self) -> bool:
115115+ return self.subtrees == (None,)
116116+117117+ def _to_optional(self) -> Optional[CID]:
118118+ """
119119+ returns None if the node is empty
120120+ """
121121+ if self.is_empty():
122122+ return None
123123+ return self.cid
124124+125125+126126+ @cached_property
127127+ def height(self) -> int:
128128+ # if there are keys at this level, query one directly
129129+ if self.keys:
130130+ return self.key_height(self.keys[0])
131131+132132+ # we're an empty tree
133133+ if self.subtrees[0] is None:
134134+ return 0
135135+136136+ # this should only happen for non-root nodes with no keys
137137+ raise Exception("cannot determine node height")
138138+139139+ def gte_index(self, key: str) -> int:
140140+ """
141141+ find the index of the first key greater than or equal to the specified key
142142+ if all keys are smaller, it returns len(keys)
143143+ """
144144+ i = 0 # this loop could be a binary search but not worth it for small fanouts
145145+ while i < len(self.keys) and key > self.keys[i]:
146146+ i += 1
147147+ return i
148148+149149+150150+"""
151151+if __name__ == "__main__":
152152+ from .blockstore import MemoryBlockStore, OverlayBlockStore
153153+ from .blockstore.car_reader import ReadOnlyCARBlockStore
154154+155155+ if 0:
156156+ import sys
157157+ sys.setrecursionlimit(999999999)
158158+ f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
159159+ bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f))
160160+ commit_obj = dag_cbor.decode(bs.get_block(bytes(bs.lower.car_roots[0])))
161161+ mst_root: CID = commit_obj["data"]
162162+ ns = NodeStore(bs)
163163+ wrangler = NodeWrangler(ns)
164164+ #print(wrangler)
165165+ #enumerate_mst(ns, mst_root)
166166+ enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff")
167167+168168+ root2 = wrangler.del_record(mst_root, "app.bsky.feed.generator/alttext")
169169+ root2 = wrangler.del_record(root2, "app.bsky.feed.like/3kas3fyvkti22")
170170+ root2 = wrangler.put_record(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah"))
171171+172172+ c, d = mst_diff(ns, mst_root, root2)
173173+ print("CREATED:")
174174+ for x in c:
175175+ print("created", x.encode("base32"))
176176+ print("DELETED:")
177177+ for x in d:
178178+ print("deleted", x.encode("base32"))
179179+180180+ for op in record_diff(ns, c, d):
181181+ print(op)
182182+183183+ e, f = very_slow_mst_diff(ns, mst_root, root2)
184184+ assert(e == c)
185185+ assert(f == d)
186186+ else:
187187+ bs = MemoryBlockStore()
188188+ ns = NodeStore(bs)
189189+ wrangler = NodeWrangler(ns)
190190+ root = ns.get_node(None).cid
191191+ print(ns.pretty(root))
192192+ root = wrangler.put_record(root, "hello", hash_to_cid(b"blah"))
193193+ print(ns.pretty(root))
194194+ root = wrangler.put_record(root, "foo", hash_to_cid(b"bar"))
195195+ print(ns.pretty(root))
196196+ root_a = root
197197+ root = wrangler.put_record(root, "bar", hash_to_cid(b"bat"))
198198+ root = wrangler.put_record(root, "xyzz", hash_to_cid(b"bat"))
199199+ root = wrangler.del_record(root, "foo")
200200+ print("=============")
201201+ print(ns.pretty(root_a))
202202+ print("=============")
203203+ print(ns.pretty(root))
204204+ #exit()
205205+ print("=============")
206206+ enumerate_mst(ns, root)
207207+ c, d = mst_diff(ns, root_a, root)
208208+ print("CREATED:")
209209+ for x in c:
210210+ print("created", x.encode("base32"))
211211+ print("DELETED:")
212212+ for x in d:
213213+ print("deleted", x.encode("base32"))
214214+215215+ e, f = very_slow_mst_diff(ns, root_a, root)
216216+ assert(e == c)
217217+ assert(f == d)
218218+219219+ exit()
220220+ root = wrangler.delete(root, "foo")
221221+ root = wrangler.delete(root, "hello")
222222+ print(ns.pretty(root))
223223+ root = wrangler.delete(root, "bar")
224224+ print(ns.pretty(root))
225225+ root = wrangler.delete(root, "bar")
226226+ print(ns.pretty(root))
227227+"""
+124
src/atmst/mst/diff.py
···11+import operator
22+from typing import Tuple, Set, Iterable
33+from functools import reduce
44+55+from multiformats import CID
66+77+from . import MSTNode
88+from .node_store import NodeStore
99+from .node_walker import NodeWalker
1010+1111+1212+def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]) -> Iterable[tuple]:
1313+ """
1414+ Given two sets of MST nodes (for example, the result of `mst_diff`), this
1515+ returns an iterator of record changes, in one of 3 formats:
1616+1717+ ("created", key, value)
1818+ ("updated", key, old_value, new_value)
1919+ ("deleted", key, value)
2020+ """
2121+ created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, created)), {})
2222+ deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, deleted)), {})
2323+ for created_key in created_kv.keys() - deleted_kv.keys():
2424+ yield ("created", created_key, created_kv[created_key].encode("base32"))
2525+ for updated_key in created_kv.keys() & deleted_kv.keys():
2626+ v1 = created_kv[updated_key]
2727+ v2 = deleted_kv[updated_key]
2828+ if v1 != v2:
2929+ yield ("updated", updated_key, v1.encode("base32"), v2.encode("base32"))
3030+ for deleted_key in deleted_kv.keys() - created_kv.keys():
3131+ yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging
3232+3333+def very_slow_mst_diff(ns: NodeStore, root_a: CID, root_b: CID):
3434+ """
3535+ This should return the same result as mst_diff, but it gets there in a very slow
3636+ yet less error-prone way, so it's useful for testing.
3737+3838+ It's actually faster for smaller trees, but it chokes on trees with thousands of nodes (especially if the NodeStore is slow).
3939+ """
4040+ a_nodes = set(NodeWalker(ns, root_a).iter_node_cids())
4141+ b_nodes = set(NodeWalker(ns, root_b).iter_node_cids())
4242+ return b_nodes - a_nodes, a_nodes - b_nodes
4343+4444+EMPTY_NODE_CID = MSTNode.empty_root().cid
4545+4646+def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created, deleted
4747+ created = set() # MST nodes in b but not in a
4848+ deleted = set() # MST nodes in a but not in b
4949+ mst_diff_recursive(created, deleted, NodeWalker(ns, root_a), NodeWalker(ns, root_b))
5050+ middle = created & deleted # my algorithm has occasional false-positives
5151+ #assert(not middle) # this fails
5252+ #print("middle", len(middle))
5353+ created -= middle
5454+ deleted -= middle
5555+ # special case: if one of the root nodes was empty
5656+ if root_a == EMPTY_NODE_CID and root_b != EMPTY_NODE_CID:
5757+ deleted.add(EMPTY_NODE_CID)
5858+ if root_b == EMPTY_NODE_CID and root_a != EMPTY_NODE_CID:
5959+ created.add(EMPTY_NODE_CID)
6060+ return created, deleted
6161+6262+def mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: NodeWalker): # created, deleted
6363+ # the easiest of all cases
6464+ if a.frame.node.cid == b.frame.node.cid:
6565+ return # no difference
6666+6767+ # trivial
6868+ if a.frame.node.is_empty():
6969+ #mst_deleted.add(a.frame.node.cid) # this doesn't work because it might've been a null subtree node
7070+ created |= set(b.iter_node_cids())
7171+ return
7272+7373+ # likewise
7474+ if b.frame.node.is_empty():
7575+ #mst_created.add(b.frame.node.cid)
7676+ deleted |= set(a.iter_node_cids())
7777+ return
7878+7979+ # now we're onto the hard part
8080+8181+ """
8282+ theory: most trees that get compared will have lots of shared blocks (which we can skip over, due to identical CIDs)
8383+ completely different trees will inevitably have to visit every node.
8484+8585+ general idea:
8686+ 1. if one cursor is "behind" the other, catch it up
8787+ 2. when we're matched up, skip over identical subtrees (and recursively diff non-identical subtrees)
8888+8989+ XXX: this seems to work nicely but I'm not sure if it's necessarily efficient for all tree layouts?
9090+ """
9191+9292+ # NB: these will end up as false-positives if one tree is a subtree of the other
9393+ created.add(b.frame.node.cid)
9494+ deleted.add(a.frame.node.cid)
9595+9696+ while True:
9797+ while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other
9898+ # "catch up" cursor a, if it's behind
9999+ while a.rkey < b.rkey and not a.is_final:
100100+ if a.subtree: # recurse down every subtree
101101+ a.down()
102102+ deleted.add(a.frame.node.cid)
103103+ else:
104104+ a.right()
105105+106106+ # catch up cursor b, likewise
107107+ while b.rkey < a.rkey and not b.is_final:
108108+ if b.subtree: # recurse down every subtree
109109+ b.down()
110110+ created.add(b.frame.node.cid)
111111+ else:
112112+ b.right()
113113+114114+ # the rkeys now match, but the subrees below us might not
115115+116116+ mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker())
117117+118118+ # check if we can still go right XXX: do we need to care about the case where one can, but the other can't?
119119+ # To consider: maybe if I just step a, b will catch up automagically
120120+ if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey:
121121+ break
122122+123123+ a.right()
124124+ b.right()
+55
src/atmst/mst/node_store.py
···11+from typing import Optional, Dict
22+33+from multiformats import CID
44+55+from ..blockstore import BlockStore
66+from ..util import indent
77+from . import MSTNode
88+99+class NodeStore:
1010+ """
1111+ NodeStore wraps a BlockStore to provide a more ergonomic interface
1212+ for loading and storing MSTNodes
1313+ """
1414+ bs: BlockStore
1515+ cache: Dict[Optional[CID], MSTNode] # XXX: this cache will grow forever!
1616+ #cache_counts: Dict[Optional[CID], int]
1717+1818+ def __init__(self, bs: BlockStore) -> None:
1919+ self.bs = bs
2020+ self.cache = {}
2121+ #self.cache_counts = {}
2222+2323+ # TODO: LRU cache this - this package looks ideal: https://github.com/amitdev/lru-dict
2424+ def get_node(self, cid: Optional[CID]) -> MSTNode:
2525+ cached = self.cache.get(cid)
2626+ if cached:
2727+ return cached
2828+ """
2929+ if cid is None, returns an empty MST node
3030+ """
3131+ if cid is None:
3232+ return self.put_node(MSTNode.empty_root())
3333+3434+ res = MSTNode.deserialise(self.bs.get_block(bytes(cid)))
3535+ self.cache[cid] = res
3636+ return res
3737+3838+ # TODO: also put in cache
3939+ def put_node(self, node: MSTNode) -> MSTNode:
4040+ self.cache[node.cid] = node
4141+ self.bs.put_block(bytes(node.cid), node.serialised)
4242+ return node # this is convenient
4343+4444+ # MST pretty-printing
4545+ # this should maybe not be implemented here
4646+ def pretty(self, node_cid: Optional[CID]) -> str:
4747+ if node_cid is None:
4848+ return "<empty>"
4949+ node = self.get_node(node_cid)
5050+ res = f"MSTNode<cid={node.cid.encode("base32")}>(\n{indent(self.pretty(node.subtrees[0]))},\n"
5151+ for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]):
5252+ res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode("base32")},\n"
5353+ res += indent(self.pretty(t)) + ",\n"
5454+ res += ")"
5555+ return res
+142
src/atmst/mst/node_walker.py
···11+from dataclasses import dataclass
22+from typing import Tuple, Self, Optional, List
33+44+from multiformats import CID
55+66+from . import MSTNode
77+from .node_store import NodeStore
88+99+class NodeWalker:
1010+ """
1111+ NodeWalker makes implementing tree diffing and other MST query ops more
1212+ convenient (but it does not, itself, implement them).
1313+1414+ A NodeWalker starts off at the root of a tree, and can walk along or recurse
1515+ down into subtrees.
1616+1717+ Walking "off the end" of a subtree brings you back up to its next non-empty parent.
1818+1919+ Recall MSTNode layout: ::
2020+2121+ keys: (lkey) (0, 1, 2, 3) (rkey)
2222+ vals: (0, 1, 2, 3)
2323+ subtrees: (0, 1, 2, 3, 4)
2424+2525+ """
2626+ KEY_MIN = "" # string that compares less than all legal key strings
2727+ KEY_MAX = "\xff" # string that compares greater than all legal key strings
2828+2929+ @dataclass
3030+ class StackFrame:
3131+ node: MSTNode # could store CIDs only to save memory, in theory, but not much point
3232+ lkey: str
3333+ rkey: str
3434+ idx: int
3535+3636+ ns: NodeStore
3737+ stack: List[StackFrame]
3838+3939+ def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None:
4040+ self.ns = ns
4141+ self.stack = [self.StackFrame(
4242+ node=self.ns.get_node(root_cid),
4343+ lkey=lkey,
4444+ rkey=rkey,
4545+ idx=0
4646+ )]
4747+4848+ def subtree_walker(self) -> Self:
4949+ return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey)
5050+5151+ @property
5252+ def frame(self) -> StackFrame:
5353+ return self.stack[-1]
5454+5555+ @property
5656+ def lkey(self) -> str:
5757+ return self.frame.lkey if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1]
5858+5959+ @property
6060+ def lval(self) -> Optional[CID]:
6161+ return None if self.frame.idx == 0 else self.frame.node.vals[self.frame.idx - 1]
6262+6363+ @property
6464+ def subtree(self) -> Optional[CID]:
6565+ return self.frame.node.subtrees[self.frame.idx]
6666+6767+ # hmmmm rkey is overloaded here... "right key" not "record key"...
6868+ @property
6969+ def rkey(self) -> str:
7070+ return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx]
7171+7272+ @property
7373+ def rval(self) -> Optional[CID]:
7474+ return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx]
7575+7676+ @property
7777+ def is_final(self) -> bool:
7878+ return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey)
7979+8080+ def right(self) -> None:
8181+ if (self.frame.idx + 1) >= len(self.frame.node.subtrees):
8282+ # we reached the end of this node, go up a level
8383+ self.stack.pop()
8484+ if not self.stack:
8585+ raise StopIteration # you probably want to check .final instead of hitting this
8686+ return self.right() # we need to recurse, to skip over empty intermediates on the way back up
8787+ self.frame.idx += 1
8888+8989+ def down(self) -> None:
9090+ subtree = self.frame.node.subtrees[self.frame.idx]
9191+ if subtree is None:
9292+ raise Exception("oi, you can't recurse here mate")
9393+9494+ self.stack.append(self.StackFrame(
9595+ node=self.ns.get_node(subtree),
9696+ lkey=self.lkey,
9797+ rkey=self.rkey,
9898+ idx=0
9999+ ))
100100+101101+ # everything above here is core tree walking logic
102102+ # everything below here is helper functions
103103+104104+ def next_kv(self) -> Tuple[str, CID]:
105105+ while self.subtree: # recurse down every subtree
106106+ self.down()
107107+ self.right()
108108+ return self.lkey, self.lval # the kv pair we just jumped over
109109+110110+ # iterate over every k/v pair in key-sorted order
111111+ def iter_kv(self):
112112+ while not self.is_final:
113113+ yield self.next_kv()
114114+115115+ # get all mst nodes down and to the right of the current position
116116+ def iter_node_cids(self):
117117+ yield self.frame.node.cid
118118+ while not self.is_final:
119119+ while self.subtree: # recurse down every subtree
120120+ self.down()
121121+ yield self.frame.node.cid
122122+ self.right()
123123+124124+125125+def enumerate_mst(ns: NodeStore, root_cid: CID):
126126+ for k, v in NodeWalker(ns, root_cid).iter_kv():
127127+ print(k, "->", v.encode("base32"))
128128+129129+# start inclusive, end exclusive
130130+def enumerate_mst_range(ns: NodeStore, root_cid: CID, start: str, end: str):
131131+ cur = NodeWalker(ns, root_cid)
132132+ while True:
133133+ while cur.rkey < start:
134134+ cur.right()
135135+ if not cur.subtree:
136136+ break
137137+ cur.down()
138138+139139+ for k, v, in cur.iter_kv():
140140+ if k >= end:
141141+ break
142142+ print(k, "->", v.encode("base32"))
+173
src/atmst/mst/wrangler.py
···11+from typing import Tuple, Optional, Any
22+33+from multiformats import CID
44+55+from . import MSTNode
66+from .node_store import NodeStore
77+88+# tuple helpers
99+def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple:
1010+ return original[:i] + (value,) + original[i + 1:]
1111+1212+def tuple_insert_at(original: tuple, i: int, value: Any) -> tuple:
1313+ return original[:i] + (value,) + original[i:]
1414+1515+def tuple_remove_at(original: tuple, i: int) -> tuple:
1616+ return original[:i] + original[i + 1:]
1717+1818+1919+class NodeWrangler:
2020+ """
2121+ NodeWrangler is where core MST transformation ops are implemented, backed
2222+ by a NodeStore
2323+2424+ The external APIs take a CID (the MST root) and return a CID (the new root),
2525+ while storing any newly created nodes in the NodeStore.
2626+2727+ Neither method should ever fail - deleting a node that doesn't exist is a nop,
2828+ and adding the same node twice with the same value is also a nop. Callers
2929+ can detect these cases by seeing if the initial and final CIDs changed.
3030+ """
3131+ ns: NodeStore
3232+3333+ def __init__(self, ns: NodeStore) -> None:
3434+ self.ns = ns
3535+3636+ def put_record(self, root_cid: CID, key: str, val: CID) -> CID:
3737+ root = self.ns.get_node(root_cid)
3838+ if root.is_empty(): # special case for empty tree
3939+ return self._put_here(root, key, val).cid
4040+ return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid
4141+4242+ def del_record(self, root_cid: CID, key: str) -> CID:
4343+ root = self.ns.get_node(root_cid)
4444+4545+ # Note: the seemingly redundant outer .get().cid is required to transform
4646+ # a None cid into the cid representing an empty node (we could maybe find a more elegant
4747+ # way of doing this...)
4848+ return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid
4949+5050+5151+5252+ def _put_here(self, node: MSTNode, key: str, val: CID) -> MSTNode:
5353+ i = node.gte_index(key)
5454+5555+ # the key is already present!
5656+ if i < len(node.keys) and node.keys[i] == key:
5757+ if node.vals[i] == val:
5858+ return node # we can return our old self if there is no change
5959+ return self.ns.put_node(MSTNode(
6060+ keys=node.keys,
6161+ vals=tuple_replace_at(node.vals, i, val),
6262+ subtrees=node.subtrees
6363+ ))
6464+6565+ return self.ns.put_node(MSTNode(
6666+ keys=tuple_insert_at(node.keys, i, key),
6767+ vals=tuple_insert_at(node.vals, i, val),
6868+ subtrees = node.subtrees[:i] + \
6969+ self._split_on_key(node.subtrees[i], key) + \
7070+ node.subtrees[i + 1:],
7171+ ))
7272+7373+ def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> MSTNode:
7474+ if key_height > tree_height: # we need to grow the tree
7575+ return self.ns.put_node(self._put_recursive(
7676+ MSTNode.empty_root(),
7777+ key, val, key_height, tree_height + 1
7878+ ))
7979+8080+ if key_height < tree_height: # we need to look below
8181+ i = node.gte_index(key)
8282+ return self.ns.put_node(MSTNode(
8383+ keys=node.keys,
8484+ vals=node.vals,
8585+ subtrees=tuple_replace_at(
8686+ node.subtrees, i,
8787+ self._put_recursive(
8888+ self.ns.get_node(node.subtrees[i]),
8989+ key, val, key_height, tree_height - 1
9090+ ).cid
9191+ )
9292+ ))
9393+9494+ # we can insert here
9595+ assert(key_height == tree_height)
9696+ return self._put_here(node, key, val)
9797+9898+ def _split_on_key(self, node_cid: Optional[CID], key: str) -> Tuple[Optional[CID], Optional[CID]]:
9999+ if node_cid is None:
100100+ return None, None
101101+ node = self.ns.get_node(node_cid)
102102+ i = node.gte_index(key)
103103+ lsub, rsub = self._split_on_key(node.subtrees[i], key)
104104+ return self.ns.put_node(MSTNode(
105105+ keys=node.keys[:i],
106106+ vals=node.vals[:i],
107107+ subtrees=node.subtrees[:i] + (lsub,)
108108+ ))._to_optional(), self.ns.put_node(MSTNode(
109109+ keys=node.keys[i:],
110110+ vals=node.vals[i:],
111111+ subtrees=(rsub,) + node.subtrees[i + 1:],
112112+ ))._to_optional()
113113+114114+ def _squash_top(self, node_cid: Optional[CID]) -> Optional[CID]:
115115+ """
116116+ strip empty nodes from the top of the tree
117117+ """
118118+ node = self.ns.get_node(node_cid)
119119+ if node.keys:
120120+ return node_cid
121121+ if node.subtrees[0] is None:
122122+ return node_cid
123123+ return self._squash_top(node.subtrees[0])
124124+125125+ def _delete_recursive(self, node: MSTNode, key: str, key_height: int, tree_height: int) -> Optional[CID]:
126126+ if key_height > tree_height: # the key cannot possibly be in this tree, no change needed
127127+ return node._to_optional()
128128+129129+ i = node.gte_index(key)
130130+ if key_height < tree_height: # the key must be deleted from a subtree
131131+ if node.subtrees[i] is None:
132132+ return node._to_optional() # the key cannot be in this subtree, no change needed
133133+ return self.ns.put_node(MSTNode(
134134+ keys=node.keys,
135135+ vals=node.vals,
136136+ subtrees=tuple_replace_at(
137137+ node.subtrees,
138138+ i,
139139+ self._delete_recursive(self.ns.get_node(node.subtrees[i]), key, key_height, tree_height - 1)
140140+ )
141141+ ))._to_optional()
142142+143143+ i = node.gte_index(key)
144144+ if i == len(node.keys) or node.keys[i] != key:
145145+ return node._to_optional() # key already not present
146146+147147+ assert(node.keys[i] == key) # sanity check (should always be true)
148148+149149+ return self.ns.put_node(MSTNode(
150150+ keys=tuple_remove_at(node.keys, i),
151151+ vals=tuple_remove_at(node.vals, i),
152152+ subtrees=node.subtrees[:i] + (
153153+ self._merge(node.subtrees[i], node.subtrees[i + 1]),
154154+ ) + node.subtrees[i + 2:]
155155+ ))._to_optional()
156156+157157+ def _merge(self, left_cid: Optional[CID], right_cid: Optional[CID]) -> Optional[CID]:
158158+ if left_cid is None:
159159+ return right_cid # includes the case where left == right == None
160160+ if right_cid is None:
161161+ return left_cid
162162+ left = self.ns.get_node(left_cid)
163163+ right = self.ns.get_node(right_cid)
164164+ return self.ns.put_node(MSTNode(
165165+ keys=left.keys + right.keys,
166166+ vals=left.vals + right.vals,
167167+ subtrees=left.subtrees[:-1] + (
168168+ self._merge(
169169+ left.subtrees[-1],
170170+ right.subtrees[0]
171171+ ),
172172+ ) + right.subtrees[1:]
173173+ ))._to_optional()