this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

maybe-working mst diffing

+197 -34
+197 -34
mst.py
··· 2 2 import dag_cbor 3 3 import operator 4 4 from multiformats import multihash, CID 5 - from functools import cached_property 5 + from functools import cached_property, reduce 6 6 from more_itertools import ilen 7 7 from itertools import takewhile 8 8 from dataclasses import dataclass 9 - from typing import Tuple, Self, Optional, Any, Dict, List, Type, Iterable 9 + from typing import Tuple, Self, Optional, Any, Dict, List, Set, Type, Iterable 10 10 from collections import namedtuple 11 11 12 12 from util import indent, hash_to_cid 13 - from blockstore import BlockStore, MemoryBlockStore 13 + from blockstore import BlockStore, MemoryBlockStore, OverlayBlockStore 14 14 15 15 # tuple helpers 16 16 def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple: ··· 369 369 A NodeWalker starts off at the root of a tree, and can walk along or recurse 370 370 down into subtrees. 371 371 372 - Walking "off the end" of a subtree brings you back up to its parent. 373 - 374 - At any point in time, the current node is given by node_stack[-1], and its current position 375 - within that node is given by idx_stack[-1], which corresponds to a subtree index. 372 + Walking "off the end" of a subtree brings you back up to its next non-empty parent. 376 373 377 374 Recall MSTNode layout: 378 375 ··· 380 377 vals: (0, 1, 2, 3) 381 378 subtrees: (0, 1, 2, 3, 4) 382 379 """ 383 - ns: NodeStore 380 + KEY_MIN = "" # string that compares less than all legal key strings 381 + KEY_MAX = "\xff" # string that compares greater than all legal key strings 384 382 385 383 @dataclass 386 384 class StackFrame: ··· 389 387 rkey: str 390 388 idx: int 391 389 392 - KEY_MIN = "" # string that compares less than all legal key strings 393 - KEY_MAX = "\xff" # string that compares greater than all legal key strings 394 - 395 - @dataclass 396 - class State: 397 - lkey: str 398 - lval: Optional[CID] 399 - subtree: Optional[CID] 400 - rkey: str 401 - rval: Optional[CID] 402 - 390 + ns: NodeStore 403 391 stack: List[StackFrame] 404 392 405 - def __init__(self, ns: NodeStore, root_cid: CID) -> None: 393 + def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None: 406 394 self.ns = ns 407 395 self.stack = [self.StackFrame( 408 396 node=self.ns.get(root_cid), 409 - lkey=self.KEY_MIN, 410 - rkey=self.KEY_MAX, 397 + lkey=lkey, 398 + rkey=rkey, 411 399 idx=0 412 400 )] 401 + 402 + def subtree_walker(self) -> Self: 403 + return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey) 413 404 414 405 @property 415 406 def frame(self) -> StackFrame: ··· 427 418 def subtree(self) -> Optional[CID]: 428 419 return self.frame.node.subtrees[self.frame.idx] 429 420 421 + # hmmmm rkey is overloaded here... "right key" not "record key"... 430 422 @property 431 423 def rkey(self) -> str: 432 424 return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx] ··· 436 428 return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx] 437 429 438 430 @property 439 - def final(self) -> bool: 440 - return self.subtree is None and self.rkey == NodeWalker.KEY_MAX 431 + def is_final(self) -> bool: 432 + return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey) 441 433 442 434 def right(self) -> None: 443 435 if (self.frame.idx + 1) >= len(self.frame.node.subtrees): ··· 459 451 rkey=self.rkey, 460 452 idx=0 461 453 )) 454 + 455 + # everything above here is core tree walking logic 456 + # everything below here is helper functions 457 + 458 + def next_kv(self) -> Tuple[str, CID]: 459 + while self.subtree: # recurse down every subtree 460 + self.down() 461 + self.right() 462 + return self.lkey, self.lval # the kv pair we just jumped over 463 + 464 + # iterate over every k/v pair in key-sorted order 465 + def iter_kv(self): 466 + while not self.is_final: 467 + yield self.next_kv() 468 + 469 + # get all mst nodes down and to the right of the current position 470 + def iter_node_cids(self): 471 + yield self.frame.node.cid 472 + while not self.is_final: 473 + while self.subtree: # recurse down every subtree 474 + self.down() 475 + yield self.frame.node.cid 476 + self.right() 477 + 462 478 463 479 def enumerate_mst(ns: NodeStore, root_cid: CID): 480 + for k, v in NodeWalker(ns, root_cid).iter_kv(): 481 + print(k, "->", v.encode("base32")) 482 + 483 + # start inclusive, end exclusive 484 + def enumerate_mst_range(ns: NodeStore, root_cid: CID, start: str, end: str): 464 485 cur = NodeWalker(ns, root_cid) 465 - while not cur.final: 466 - while cur.subtree: # recurse down every subtree 467 - cur.down() 468 - cur.right() 469 - print(cur.lkey, "->", cur.lval.encode("base32")) # print the kv pair we just jumped over 486 + while True: 487 + while cur.rkey < start: 488 + cur.right() 489 + if not cur.subtree: 490 + break 491 + cur.down() 492 + 493 + for k, v, in cur.iter_kv(): 494 + if k >= end: 495 + break 496 + print(k, "->", v.encode("base32")) 497 + 498 + def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]): 499 + created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get, created)), {}) 500 + deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get, deleted)), {}) 501 + for created_key in created_kv.keys() - deleted_kv.keys(): 502 + yield ("created", created_key, created_kv[created_key].encode("base32")) 503 + for updated_key in created_kv.keys() & deleted_kv.keys(): 504 + v1 = created_kv[updated_key] 505 + v2 = deleted_kv[updated_key] 506 + if v1 != v2: 507 + yield ("updated", updated_key, v1.encode("base32"), v2.encode("base32")) 508 + for deleted_key in deleted_kv.keys() - created_kv.keys(): 509 + yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging 510 + 511 + def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created_deleted 512 + created, deleted = mst_diff_recursive(NodeWalker(ns, root_a), NodeWalker(ns, root_b)) 513 + middle = created & deleted 514 + #assert(not middle) # should be no intersection!!! 515 + return created - middle, deleted - middle 516 + 517 + def very_slow_mst_diff(ns, root_a: CID, root_b: CID): 518 + """ 519 + This should return the same result as mst_diff, but it gets there in a very slow 520 + yet less error-prone way, so it's useful for testing. 521 + """ 522 + a_nodes = set(NodeWalker(ns, root_a).iter_node_cids()) 523 + b_nodes = set(NodeWalker(ns, root_b).iter_node_cids()) 524 + return b_nodes - a_nodes, a_nodes - b_nodes 525 + 526 + def mst_diff_recursive(a: NodeWalker, b: NodeWalker) -> Tuple[Set[CID], Set[CID]]: # created, deleted 527 + mst_created = set() # MST nodes in b but not in a 528 + mst_deleted = set() # MST nodes in a but not in b 529 + 530 + # the easiest of all cases 531 + if a.frame.node.cid == b.frame.node.cid: # includes the case where they're both None 532 + return mst_created, mst_deleted # no difference 533 + 534 + # trivial 535 + if a.frame.node.is_empty(): 536 + mst_created |= set(b.iter_node_cids()) 537 + return mst_created, mst_deleted 538 + 539 + # likewise 540 + if b.frame.node.is_empty(): 541 + mst_deleted |= set(a.iter_node_cids()) 542 + return mst_created, mst_deleted 543 + 544 + # now we're onto the hard part 470 545 546 + """ 547 + theory: most trees that get compared will have lots of shared blocks (which we can skip over, due to identical CIDs) 548 + completely different trees will inevitably have to visit every node. 549 + 550 + general idea: 551 + 1. if one cursor is "behind" the other, catch it up 552 + 2. when we're matched up, skip over identical subtrees (and recursively diff non-identical subtrees) 553 + 554 + XXX: this seems to work nicely but I'm not sure if it's necessarily efficient for all tree layouts? 555 + """ 556 + 557 + # NB: these will end up as false-positives if one tree is a subtree of the other 558 + mst_created.add(b.frame.node.cid) 559 + mst_deleted.add(a.frame.node.cid) 560 + 561 + while True: 562 + # "catch up" cursor a, if it's behind 563 + while a.rkey < b.rkey and not a.is_final: 564 + if a.subtree: # recurse down every subtree 565 + a.down() 566 + mst_deleted.add(a.frame.node.cid) 567 + else: 568 + a.right() 569 + 570 + # catch up cursor b, likewise 571 + while b.rkey < a.rkey and not b.is_final: 572 + if b.subtree: # recurse down every subtree 573 + b.down() 574 + mst_created.add(b.frame.node.cid) 575 + else: 576 + b.right() 577 + 578 + assert(b.rkey == a.rkey) 579 + # the rkeys match, but the subrees below us might not 580 + 581 + c, d = mst_diff_recursive(a.subtree_walker(), b.subtree_walker()) 582 + mst_created |= c 583 + mst_deleted |= d 584 + 585 + # check if we can still go right XXX: do we need to care about the case where one can, but the other can't? 586 + # To consider: maybe if I just step a, b will catch up automagically 587 + if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey: 588 + break 589 + 590 + a.right() 591 + b.right() 592 + 593 + return mst_created, mst_deleted 471 594 472 595 if __name__ == "__main__": 473 - if 1: 596 + if 0: 597 + import sys 598 + sys.setrecursionlimit(999999999) 474 599 from carfile import ReadOnlyCARBlockStore 475 600 f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb") 476 - bs = ReadOnlyCARBlockStore(f) 477 - commit_obj = dag_cbor.decode(bs.get(bytes(bs.car_roots[0]))) 601 + bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f)) 602 + commit_obj = dag_cbor.decode(bs.get(bytes(bs.lower.car_roots[0]))) 478 603 mst_root: CID = commit_obj["data"] 479 604 ns = NodeStore(bs) 480 - #wrangler = NodeWrangler(ns) 605 + wrangler = NodeWrangler(ns) 481 606 #print(wrangler) 482 - enumerate_mst(ns, mst_root) 607 + #enumerate_mst(ns, mst_root) 608 + enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff") 609 + 610 + root2 = wrangler.delete(mst_root, "app.bsky.feed.generator/alttext") 611 + root2 = wrangler.delete(root2, "app.bsky.feed.like/3kas3fyvkti22") 612 + root2 = wrangler.put(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah")) 613 + 614 + c, d = mst_diff(ns, mst_root, root2) 615 + print("CREATED:") 616 + for x in c: 617 + print("created", x.encode("base32")) 618 + print("DELETED:") 619 + for x in d: 620 + print("deleted", x.encode("base32")) 621 + 622 + for op in record_diff(ns, c, d): 623 + print(op) 624 + 625 + e, f = very_slow_mst_diff(ns, mst_root, root2) 626 + assert(e == c) 627 + assert(f == d) 483 628 else: 484 629 bs = MemoryBlockStore() 485 630 ns = NodeStore(bs) ··· 490 635 print(ns.pretty(root)) 491 636 root = wrangler.put(root, "foo", hash_to_cid(b"bar")) 492 637 print(ns.pretty(root)) 638 + root_a = root 493 639 root = wrangler.put(root, "bar", hash_to_cid(b"bat")) 494 640 root = wrangler.put(root, "xyzz", hash_to_cid(b"bat")) 641 + root = wrangler.delete(root, "foo") 642 + print("=============") 643 + print(ns.pretty(root_a)) 644 + print("=============") 495 645 print(ns.pretty(root)) 496 646 #exit() 647 + print("=============") 497 648 enumerate_mst(ns, root) 649 + c, d = mst_diff(ns, root_a, root) 650 + print("CREATED:") 651 + for x in c: 652 + print("created", x.encode("base32")) 653 + print("DELETED:") 654 + for x in d: 655 + print("deleted", x.encode("base32")) 656 + 657 + e, f = very_slow_mst_diff(ns, root_a, root) 658 + assert(e == c) 659 + assert(f == d) 660 + 498 661 exit() 499 662 root = wrangler.delete(root, "foo") 500 663 root = wrangler.delete(root, "hello")