···22import dag_cbor
33import operator
44from multiformats import multihash, CID
55-from functools import cached_property
55+from functools import cached_property, reduce
66from more_itertools import ilen
77from itertools import takewhile
88from dataclasses import dataclass
99-from typing import Tuple, Self, Optional, Any, Dict, List, Type, Iterable
99+from typing import Tuple, Self, Optional, Any, Dict, List, Set, Type, Iterable
1010from collections import namedtuple
11111212from util import indent, hash_to_cid
1313-from blockstore import BlockStore, MemoryBlockStore
1313+from blockstore import BlockStore, MemoryBlockStore, OverlayBlockStore
14141515# tuple helpers
1616def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple:
···369369 A NodeWalker starts off at the root of a tree, and can walk along or recurse
370370 down into subtrees.
371371372372- Walking "off the end" of a subtree brings you back up to its parent.
373373-374374- At any point in time, the current node is given by node_stack[-1], and its current position
375375- within that node is given by idx_stack[-1], which corresponds to a subtree index.
372372+ Walking "off the end" of a subtree brings you back up to its next non-empty parent.
376373377374 Recall MSTNode layout:
378375···380377 vals: (0, 1, 2, 3)
381378 subtrees: (0, 1, 2, 3, 4)
382379 """
383383- ns: NodeStore
380380+ KEY_MIN = "" # string that compares less than all legal key strings
381381+ KEY_MAX = "\xff" # string that compares greater than all legal key strings
384382385383 @dataclass
386384 class StackFrame:
···389387 rkey: str
390388 idx: int
391389392392- KEY_MIN = "" # string that compares less than all legal key strings
393393- KEY_MAX = "\xff" # string that compares greater than all legal key strings
394394-395395- @dataclass
396396- class State:
397397- lkey: str
398398- lval: Optional[CID]
399399- subtree: Optional[CID]
400400- rkey: str
401401- rval: Optional[CID]
402402-390390+ ns: NodeStore
403391 stack: List[StackFrame]
404392405405- def __init__(self, ns: NodeStore, root_cid: CID) -> None:
393393+ def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None:
406394 self.ns = ns
407395 self.stack = [self.StackFrame(
408396 node=self.ns.get(root_cid),
409409- lkey=self.KEY_MIN,
410410- rkey=self.KEY_MAX,
397397+ lkey=lkey,
398398+ rkey=rkey,
411399 idx=0
412400 )]
401401+402402+ def subtree_walker(self) -> Self:
403403+ return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey)
413404414405 @property
415406 def frame(self) -> StackFrame:
···427418 def subtree(self) -> Optional[CID]:
428419 return self.frame.node.subtrees[self.frame.idx]
429420421421+ # hmmmm rkey is overloaded here... "right key" not "record key"...
430422 @property
431423 def rkey(self) -> str:
432424 return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx]
···436428 return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx]
437429438430 @property
439439- def final(self) -> bool:
440440- return self.subtree is None and self.rkey == NodeWalker.KEY_MAX
431431+ def is_final(self) -> bool:
432432+ return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey)
441433442434 def right(self) -> None:
443435 if (self.frame.idx + 1) >= len(self.frame.node.subtrees):
···459451 rkey=self.rkey,
460452 idx=0
461453 ))
454454+455455+ # everything above here is core tree walking logic
456456+ # everything below here is helper functions
457457+458458+ def next_kv(self) -> Tuple[str, CID]:
459459+ while self.subtree: # recurse down every subtree
460460+ self.down()
461461+ self.right()
462462+ return self.lkey, self.lval # the kv pair we just jumped over
463463+464464+ # iterate over every k/v pair in key-sorted order
465465+ def iter_kv(self):
466466+ while not self.is_final:
467467+ yield self.next_kv()
468468+469469+ # get all mst nodes down and to the right of the current position
470470+ def iter_node_cids(self):
471471+ yield self.frame.node.cid
472472+ while not self.is_final:
473473+ while self.subtree: # recurse down every subtree
474474+ self.down()
475475+ yield self.frame.node.cid
476476+ self.right()
477477+462478463479def enumerate_mst(ns: NodeStore, root_cid: CID):
480480+ for k, v in NodeWalker(ns, root_cid).iter_kv():
481481+ print(k, "->", v.encode("base32"))
482482+483483+# start inclusive, end exclusive
484484+def enumerate_mst_range(ns: NodeStore, root_cid: CID, start: str, end: str):
464485 cur = NodeWalker(ns, root_cid)
465465- while not cur.final:
466466- while cur.subtree: # recurse down every subtree
467467- cur.down()
468468- cur.right()
469469- print(cur.lkey, "->", cur.lval.encode("base32")) # print the kv pair we just jumped over
486486+ while True:
487487+ while cur.rkey < start:
488488+ cur.right()
489489+ if not cur.subtree:
490490+ break
491491+ cur.down()
492492+493493+ for k, v, in cur.iter_kv():
494494+ if k >= end:
495495+ break
496496+ print(k, "->", v.encode("base32"))
497497+498498+def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]):
499499+ created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get, created)), {})
500500+ deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get, deleted)), {})
501501+ for created_key in created_kv.keys() - deleted_kv.keys():
502502+ yield ("created", created_key, created_kv[created_key].encode("base32"))
503503+ for updated_key in created_kv.keys() & deleted_kv.keys():
504504+ v1 = created_kv[updated_key]
505505+ v2 = deleted_kv[updated_key]
506506+ if v1 != v2:
507507+ yield ("updated", updated_key, v1.encode("base32"), v2.encode("base32"))
508508+ for deleted_key in deleted_kv.keys() - created_kv.keys():
509509+ yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging
510510+511511+def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created_deleted
512512+ created, deleted = mst_diff_recursive(NodeWalker(ns, root_a), NodeWalker(ns, root_b))
513513+ middle = created & deleted
514514+ #assert(not middle) # should be no intersection!!!
515515+ return created - middle, deleted - middle
516516+517517+def very_slow_mst_diff(ns, root_a: CID, root_b: CID):
518518+ """
519519+ This should return the same result as mst_diff, but it gets there in a very slow
520520+ yet less error-prone way, so it's useful for testing.
521521+ """
522522+ a_nodes = set(NodeWalker(ns, root_a).iter_node_cids())
523523+ b_nodes = set(NodeWalker(ns, root_b).iter_node_cids())
524524+ return b_nodes - a_nodes, a_nodes - b_nodes
525525+526526+def mst_diff_recursive(a: NodeWalker, b: NodeWalker) -> Tuple[Set[CID], Set[CID]]: # created, deleted
527527+ mst_created = set() # MST nodes in b but not in a
528528+ mst_deleted = set() # MST nodes in a but not in b
529529+530530+ # the easiest of all cases
531531+ if a.frame.node.cid == b.frame.node.cid: # includes the case where they're both None
532532+ return mst_created, mst_deleted # no difference
533533+534534+ # trivial
535535+ if a.frame.node.is_empty():
536536+ mst_created |= set(b.iter_node_cids())
537537+ return mst_created, mst_deleted
538538+539539+ # likewise
540540+ if b.frame.node.is_empty():
541541+ mst_deleted |= set(a.iter_node_cids())
542542+ return mst_created, mst_deleted
543543+544544+ # now we're onto the hard part
470545546546+ """
547547+ theory: most trees that get compared will have lots of shared blocks (which we can skip over, due to identical CIDs)
548548+ completely different trees will inevitably have to visit every node.
549549+550550+ general idea:
551551+ 1. if one cursor is "behind" the other, catch it up
552552+ 2. when we're matched up, skip over identical subtrees (and recursively diff non-identical subtrees)
553553+554554+ XXX: this seems to work nicely but I'm not sure if it's necessarily efficient for all tree layouts?
555555+ """
556556+557557+ # NB: these will end up as false-positives if one tree is a subtree of the other
558558+ mst_created.add(b.frame.node.cid)
559559+ mst_deleted.add(a.frame.node.cid)
560560+561561+ while True:
562562+ # "catch up" cursor a, if it's behind
563563+ while a.rkey < b.rkey and not a.is_final:
564564+ if a.subtree: # recurse down every subtree
565565+ a.down()
566566+ mst_deleted.add(a.frame.node.cid)
567567+ else:
568568+ a.right()
569569+570570+ # catch up cursor b, likewise
571571+ while b.rkey < a.rkey and not b.is_final:
572572+ if b.subtree: # recurse down every subtree
573573+ b.down()
574574+ mst_created.add(b.frame.node.cid)
575575+ else:
576576+ b.right()
577577+578578+ assert(b.rkey == a.rkey)
579579+ # the rkeys match, but the subrees below us might not
580580+581581+ c, d = mst_diff_recursive(a.subtree_walker(), b.subtree_walker())
582582+ mst_created |= c
583583+ mst_deleted |= d
584584+585585+ # check if we can still go right XXX: do we need to care about the case where one can, but the other can't?
586586+ # To consider: maybe if I just step a, b will catch up automagically
587587+ if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey:
588588+ break
589589+590590+ a.right()
591591+ b.right()
592592+593593+ return mst_created, mst_deleted
471594472595if __name__ == "__main__":
473473- if 1:
596596+ if 0:
597597+ import sys
598598+ sys.setrecursionlimit(999999999)
474599 from carfile import ReadOnlyCARBlockStore
475600 f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb")
476476- bs = ReadOnlyCARBlockStore(f)
477477- commit_obj = dag_cbor.decode(bs.get(bytes(bs.car_roots[0])))
601601+ bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f))
602602+ commit_obj = dag_cbor.decode(bs.get(bytes(bs.lower.car_roots[0])))
478603 mst_root: CID = commit_obj["data"]
479604 ns = NodeStore(bs)
480480- #wrangler = NodeWrangler(ns)
605605+ wrangler = NodeWrangler(ns)
481606 #print(wrangler)
482482- enumerate_mst(ns, mst_root)
607607+ #enumerate_mst(ns, mst_root)
608608+ enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff")
609609+610610+ root2 = wrangler.delete(mst_root, "app.bsky.feed.generator/alttext")
611611+ root2 = wrangler.delete(root2, "app.bsky.feed.like/3kas3fyvkti22")
612612+ root2 = wrangler.put(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah"))
613613+614614+ c, d = mst_diff(ns, mst_root, root2)
615615+ print("CREATED:")
616616+ for x in c:
617617+ print("created", x.encode("base32"))
618618+ print("DELETED:")
619619+ for x in d:
620620+ print("deleted", x.encode("base32"))
621621+622622+ for op in record_diff(ns, c, d):
623623+ print(op)
624624+625625+ e, f = very_slow_mst_diff(ns, mst_root, root2)
626626+ assert(e == c)
627627+ assert(f == d)
483628 else:
484629 bs = MemoryBlockStore()
485630 ns = NodeStore(bs)
···490635 print(ns.pretty(root))
491636 root = wrangler.put(root, "foo", hash_to_cid(b"bar"))
492637 print(ns.pretty(root))
638638+ root_a = root
493639 root = wrangler.put(root, "bar", hash_to_cid(b"bat"))
494640 root = wrangler.put(root, "xyzz", hash_to_cid(b"bat"))
641641+ root = wrangler.delete(root, "foo")
642642+ print("=============")
643643+ print(ns.pretty(root_a))
644644+ print("=============")
495645 print(ns.pretty(root))
496646 #exit()
647647+ print("=============")
497648 enumerate_mst(ns, root)
649649+ c, d = mst_diff(ns, root_a, root)
650650+ print("CREATED:")
651651+ for x in c:
652652+ print("created", x.encode("base32"))
653653+ print("DELETED:")
654654+ for x in d:
655655+ print("deleted", x.encode("base32"))
656656+657657+ e, f = very_slow_mst_diff(ns, root_a, root)
658658+ assert(e == c)
659659+ assert(f == d)
660660+498661 exit()
499662 root = wrangler.delete(root, "foo")
500663 root = wrangler.delete(root, "hello")