this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

even more restructuring

+739 -691
+2 -2
mst_test.py
··· 1 1 import random 2 - from atmst.mst import mst_diff, very_slow_mst_diff, NodeStore, NodeWrangler, hash_to_cid 3 - from atmst.blockstore import MemoryBlockStore 2 + from atmst import MemoryBlockStore, NodeStore, NodeWrangler, mst_diff, very_slow_mst_diff 3 + from atmst.util import hash_to_cid 4 4 import time 5 5 6 6 PERF_BENCH = False
+10 -1
src/atmst/__init__.py
··· 1 1 from .blockstore import BlockStore, MemoryBlockStore 2 2 from .blockstore.car_reader import ReadOnlyCARBlockStore 3 - from .mst import NodeStore, NodeWalker, NodeWrangler 3 + from .mst.node_walker import NodeWalker 4 + from .mst.node_store import NodeStore 5 + from .mst.wrangler import NodeWrangler 6 + from .mst.diff import mst_diff, very_slow_mst_diff, record_diff 7 + 8 + __all__ = [ 9 + "BlockStore", "MemoryBlockStore", "ReadOnlyCARBlockStore", 10 + "NodeWalker", "NodeStore", "NodeWrangler", 11 + "mst_diff", "very_slow_mst_diff", "record_diff", 12 + ]
+3 -2
src/atmst/blockstore/__init__.py
··· 1 1 from abc import ABC, abstractmethod 2 - from typing import Self, Optional, Dict, BinaryIO 2 + from typing import Optional, Dict 3 3 import sqlite3 4 4 5 5 ··· 113 113 self.upper.del_block(key) 114 114 115 115 116 - 116 + """ 117 117 if __name__ == "__main__": 118 118 import os 119 119 ··· 160 160 pass 161 161 162 162 os.remove(TEST_DB) # clean up 163 + """
+3 -1
src/atmst/blockstore/car_reader.py
··· 1 - from typing import Self, Optional, Dict, List, Tuple, BinaryIO 1 + from typing import Dict, List, Tuple, BinaryIO 2 2 from multiformats import varint, CID 3 3 import dag_cbor 4 4 ··· 62 62 raise NotImplementedError("ReadOnlyCARBlockStore does not support delete()") 63 63 64 64 65 + """ 65 66 if __name__ == "__main__": 66 67 f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb") 67 68 bs = ReadOnlyCARBlockStore(f) ··· 72 73 from ..mst import NodeStore 73 74 ns = NodeStore(bs) 74 75 print(ns.get_node(mst_root)) 76 + """
-685
src/atmst/mst.py
··· 1 - import hashlib 2 - import dag_cbor 3 - import operator 4 - from multiformats import multihash, CID 5 - from functools import cached_property, reduce 6 - from more_itertools import ilen 7 - from itertools import takewhile 8 - from dataclasses import dataclass 9 - from typing import Tuple, Self, Optional, Any, Dict, List, Set, Type, Iterable 10 - from collections import namedtuple 11 - 12 - from .util import indent, hash_to_cid 13 - from .blockstore import BlockStore 14 - 15 - # tuple helpers 16 - def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple: 17 - return original[:i] + (value,) + original[i + 1:] 18 - 19 - def tuple_insert_at(original: tuple, i: int, value: Any) -> tuple: 20 - return original[:i] + (value,) + original[i:] 21 - 22 - def tuple_remove_at(original: tuple, i: int) -> tuple: 23 - return original[:i] + original[i + 1:] 24 - 25 - 26 - @dataclass(frozen=True) # frozen == immutable == win 27 - class MSTNode: 28 - """ 29 - k/v pairs are interleaved between subtrees like so: :: 30 - 31 - keys: (0, 1, 2, 3) 32 - vals: (0, 1, 2, 3) 33 - subtrees: (0, 1, 2, 3, 4) 34 - 35 - If a method is implemented in this class, it's because it's a function/property 36 - of a single node, as opposed to a whole tree 37 - """ 38 - keys: Tuple[str] # collection/rkey 39 - vals: Tuple[CID] # record CIDs 40 - subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree 41 - 42 - 43 - # NB: __init__ is auto-generated by dataclass decorator 44 - 45 - # these checks should never fail, and could be skipped for performance 46 - def __post_init__(self) -> None: 47 - # TODO: maybe check that they're tuples here? 48 - # implicitly, the length of self.subtrees must be at least 1 49 - if len(self.subtrees) != len(self.keys) + 1: 50 - raise ValueError("Invalid subtree count") 51 - if len(self.keys) != len(self.vals): 52 - raise ValueError("Mismatched keys/vals lengths") 53 - 54 - @classmethod 55 - def empty_root(cls) -> Self: 56 - return cls( 57 - subtrees=(None,), 58 - keys=(), 59 - vals=() 60 - ) 61 - 62 - # this should maybe not be implemented here? 63 - @staticmethod 64 - def key_height(key: str) -> int: 65 - digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big") 66 - leading_zeroes = 256 - digest.bit_length() 67 - return leading_zeroes // 2 68 - 69 - # since we're immutable, this can be cached 70 - @cached_property 71 - def cid(self) -> CID: 72 - digest = multihash.digest(self.serialised, "sha2-256") 73 - cid = CID("base32", 1, "dag-cbor", digest) 74 - return cid 75 - 76 - # likewise 77 - @cached_property 78 - def serialised(self) -> bytes: 79 - e = [] 80 - prev_key = b"" 81 - for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals): 82 - key_bytes = key_str.encode() 83 - shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming 84 - e.append({ 85 - "k": key_bytes[shared_prefix_len:], 86 - "p": shared_prefix_len, 87 - "t": subtree, 88 - "v": value, 89 - }) 90 - prev_key = key_bytes 91 - return dag_cbor.encode({ 92 - "e": e, 93 - "l": self.subtrees[0] 94 - }) 95 - 96 - @classmethod 97 - def deserialise(cls, data: bytes) -> Self: 98 - cbor = dag_cbor.decode(data) 99 - if len(cbor) != 2: # e, l 100 - raise ValueError("malformed MST node") 101 - subtrees = [cbor["l"]] 102 - keys = [] 103 - vals = [] 104 - prev_key = b"" 105 - for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations 106 - if len(e) != 4: # k, p, t, v 107 - raise ValueError("malformed MST node") 108 - prefix_len: int = e["p"] 109 - suffix: bytes = e["k"] 110 - if prefix_len > len(prev_key): 111 - raise ValueError("invalid MST key prefix len") 112 - if prev_key[prefix_len:prefix_len+1] == suffix[:1]: 113 - raise ValueError("non-optimal MST key prefix len") 114 - this_key = prev_key[:prefix_len] + suffix 115 - if this_key <= prev_key: 116 - raise ValueError("invalid MST key sort order") 117 - keys.append(this_key.decode()) 118 - vals.append(e["v"]) 119 - subtrees.append(e["t"]) 120 - prev_key = this_key 121 - 122 - return cls( 123 - subtrees=tuple(subtrees), 124 - keys=tuple(keys), 125 - vals=tuple(vals) 126 - ) 127 - 128 - def is_empty(self) -> bool: 129 - return self.subtrees == (None,) 130 - 131 - def _to_optional(self) -> Optional[CID]: 132 - """ 133 - returns None if the node is empty 134 - """ 135 - if self.is_empty(): 136 - return None 137 - return self.cid 138 - 139 - 140 - @cached_property 141 - def height(self) -> int: 142 - # if there are keys at this level, query one directly 143 - if self.keys: 144 - return self.key_height(self.keys[0]) 145 - 146 - # we're an empty tree 147 - if self.subtrees[0] is None: 148 - return 0 149 - 150 - # this should only happen for non-root nodes with no keys 151 - raise Exception("cannot determine node height") 152 - 153 - def gte_index(self, key: str) -> int: 154 - """ 155 - find the index of the first key greater than or equal to the specified key 156 - if all keys are smaller, it returns len(keys) 157 - """ 158 - i = 0 # this loop could be a binary search but not worth it for small fanouts 159 - while i < len(self.keys) and key > self.keys[i]: 160 - i += 1 161 - return i 162 - 163 - 164 - class NodeStore: 165 - """ 166 - NodeStore wraps a BlockStore to provide a more ergonomic interface 167 - for loading and storing MSTNodes 168 - """ 169 - bs: BlockStore 170 - cache: Dict[Optional[CID], MSTNode] # XXX: this cache will grow forever! 171 - #cache_counts: Dict[Optional[CID], int] 172 - 173 - def __init__(self, bs: BlockStore) -> None: 174 - self.bs = bs 175 - self.cache = {} 176 - #self.cache_counts = {} 177 - 178 - # TODO: LRU cache this - this package looks ideal: https://github.com/amitdev/lru-dict 179 - def get_node(self, cid: Optional[CID]) -> MSTNode: 180 - cached = self.cache.get(cid) 181 - if cached: 182 - return cached 183 - """ 184 - if cid is None, returns an empty MST node 185 - """ 186 - if cid is None: 187 - return self.put_node(MSTNode.empty_root()) 188 - 189 - res = MSTNode.deserialise(self.bs.get_block(bytes(cid))) 190 - self.cache[cid] = res 191 - return res 192 - 193 - # TODO: also put in cache 194 - def put_node(self, node: MSTNode) -> MSTNode: 195 - self.cache[node.cid] = node 196 - self.bs.put_block(bytes(node.cid), node.serialised) 197 - return node # this is convenient 198 - 199 - # MST pretty-printing 200 - # this should maybe not be implemented here 201 - def pretty(self, node_cid: Optional[CID]) -> str: 202 - if node_cid is None: 203 - return "<empty>" 204 - node = self.get_node(node_cid) 205 - res = f"MSTNode<cid={node.cid.encode("base32")}>(\n{indent(self.pretty(node.subtrees[0]))},\n" 206 - for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]): 207 - res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode("base32")},\n" 208 - res += indent(self.pretty(t)) + ",\n" 209 - res += ")" 210 - return res 211 - 212 - 213 - class NodeWrangler: 214 - """ 215 - NodeWrangler is where core MST transformation ops are implemented, backed 216 - by a NodeStore 217 - 218 - The external APIs take a CID (the MST root) and return a CID (the new root), 219 - while storing any newly created nodes in the NodeStore. 220 - 221 - Neither method should ever fail - deleting a node that doesn't exist is a nop, 222 - and adding the same node twice with the same value is also a nop. Callers 223 - can detect these cases by seeing if the initial and final CIDs changed. 224 - """ 225 - ns: NodeStore 226 - 227 - def __init__(self, ns: NodeStore) -> None: 228 - self.ns = ns 229 - 230 - def put_record(self, root_cid: CID, key: str, val: CID) -> CID: 231 - root = self.ns.get_node(root_cid) 232 - if root.is_empty(): # special case for empty tree 233 - return self._put_here(root, key, val).cid 234 - return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid 235 - 236 - def del_record(self, root_cid: CID, key: str) -> CID: 237 - root = self.ns.get_node(root_cid) 238 - 239 - # Note: the seemingly redundant outer .get().cid is required to transform 240 - # a None cid into the cid representing an empty node (we could maybe find a more elegant 241 - # way of doing this...) 242 - return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid 243 - 244 - 245 - 246 - def _put_here(self, node: MSTNode, key: str, val: CID) -> MSTNode: 247 - i = node.gte_index(key) 248 - 249 - # the key is already present! 250 - if i < len(node.keys) and node.keys[i] == key: 251 - if node.vals[i] == val: 252 - return node # we can return our old self if there is no change 253 - return self.ns.put_node(MSTNode( 254 - keys=node.keys, 255 - vals=tuple_replace_at(node.vals, i, val), 256 - subtrees=node.subtrees 257 - )) 258 - 259 - return self.ns.put_node(MSTNode( 260 - keys=tuple_insert_at(node.keys, i, key), 261 - vals=tuple_insert_at(node.vals, i, val), 262 - subtrees = node.subtrees[:i] + \ 263 - self._split_on_key(node.subtrees[i], key) + \ 264 - node.subtrees[i + 1:], 265 - )) 266 - 267 - def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> MSTNode: 268 - if key_height > tree_height: # we need to grow the tree 269 - return self.ns.put_node(self._put_recursive( 270 - MSTNode.empty_root(), 271 - key, val, key_height, tree_height + 1 272 - )) 273 - 274 - if key_height < tree_height: # we need to look below 275 - i = node.gte_index(key) 276 - return self.ns.put_node(MSTNode( 277 - keys=node.keys, 278 - vals=node.vals, 279 - subtrees=tuple_replace_at( 280 - node.subtrees, i, 281 - self._put_recursive( 282 - self.ns.get_node(node.subtrees[i]), 283 - key, val, key_height, tree_height - 1 284 - ).cid 285 - ) 286 - )) 287 - 288 - # we can insert here 289 - assert(key_height == tree_height) 290 - return self._put_here(node, key, val) 291 - 292 - def _split_on_key(self, node_cid: Optional[CID], key: str) -> Tuple[Optional[CID], Optional[CID]]: 293 - if node_cid is None: 294 - return None, None 295 - node = self.ns.get_node(node_cid) 296 - i = node.gte_index(key) 297 - lsub, rsub = self._split_on_key(node.subtrees[i], key) 298 - return self.ns.put_node(MSTNode( 299 - keys=node.keys[:i], 300 - vals=node.vals[:i], 301 - subtrees=node.subtrees[:i] + (lsub,) 302 - ))._to_optional(), self.ns.put_node(MSTNode( 303 - keys=node.keys[i:], 304 - vals=node.vals[i:], 305 - subtrees=(rsub,) + node.subtrees[i + 1:], 306 - ))._to_optional() 307 - 308 - def _squash_top(self, node_cid: Optional[CID]) -> Optional[CID]: 309 - """ 310 - strip empty nodes from the top of the tree 311 - """ 312 - node = self.ns.get_node(node_cid) 313 - if node.keys: 314 - return node_cid 315 - if node.subtrees[0] is None: 316 - return node_cid 317 - return self._squash_top(node.subtrees[0]) 318 - 319 - def _delete_recursive(self, node: MSTNode, key: str, key_height: int, tree_height: int) -> Optional[CID]: 320 - if key_height > tree_height: # the key cannot possibly be in this tree, no change needed 321 - return node._to_optional() 322 - 323 - i = node.gte_index(key) 324 - if key_height < tree_height: # the key must be deleted from a subtree 325 - if node.subtrees[i] is None: 326 - return node._to_optional() # the key cannot be in this subtree, no change needed 327 - return self.ns.put_node(MSTNode( 328 - keys=node.keys, 329 - vals=node.vals, 330 - subtrees=tuple_replace_at( 331 - node.subtrees, 332 - i, 333 - self._delete_recursive(self.ns.get_node(node.subtrees[i]), key, key_height, tree_height - 1) 334 - ) 335 - ))._to_optional() 336 - 337 - i = node.gte_index(key) 338 - if i == len(node.keys) or node.keys[i] != key: 339 - return node._to_optional() # key already not present 340 - 341 - assert(node.keys[i] == key) # sanity check (should always be true) 342 - 343 - return self.ns.put_node(MSTNode( 344 - keys=tuple_remove_at(node.keys, i), 345 - vals=tuple_remove_at(node.vals, i), 346 - subtrees=node.subtrees[:i] + ( 347 - self._merge(node.subtrees[i], node.subtrees[i + 1]), 348 - ) + node.subtrees[i + 2:] 349 - ))._to_optional() 350 - 351 - def _merge(self, left_cid: Optional[CID], right_cid: Optional[CID]) -> Optional[CID]: 352 - if left_cid is None: 353 - return right_cid # includes the case where left == right == None 354 - if right_cid is None: 355 - return left_cid 356 - left = self.ns.get_node(left_cid) 357 - right = self.ns.get_node(right_cid) 358 - return self.ns.put_node(MSTNode( 359 - keys=left.keys + right.keys, 360 - vals=left.vals + right.vals, 361 - subtrees=left.subtrees[:-1] + ( 362 - self._merge( 363 - left.subtrees[-1], 364 - right.subtrees[0] 365 - ), 366 - ) + right.subtrees[1:] 367 - ))._to_optional() 368 - 369 - 370 - class NodeWalker: 371 - """ 372 - NodeWalker makes implementing tree diffing and other MST query ops more 373 - convenient (but it does not, itself, implement them). 374 - 375 - A NodeWalker starts off at the root of a tree, and can walk along or recurse 376 - down into subtrees. 377 - 378 - Walking "off the end" of a subtree brings you back up to its next non-empty parent. 379 - 380 - Recall MSTNode layout: :: 381 - 382 - keys: (lkey) (0, 1, 2, 3) (rkey) 383 - vals: (0, 1, 2, 3) 384 - subtrees: (0, 1, 2, 3, 4) 385 - 386 - """ 387 - KEY_MIN = "" # string that compares less than all legal key strings 388 - KEY_MAX = "\xff" # string that compares greater than all legal key strings 389 - 390 - @dataclass 391 - class StackFrame: 392 - node: MSTNode # could store CIDs only to save memory, in theory, but not much point 393 - lkey: str 394 - rkey: str 395 - idx: int 396 - 397 - ns: NodeStore 398 - stack: List[StackFrame] 399 - 400 - def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None: 401 - self.ns = ns 402 - self.stack = [self.StackFrame( 403 - node=self.ns.get_node(root_cid), 404 - lkey=lkey, 405 - rkey=rkey, 406 - idx=0 407 - )] 408 - 409 - def subtree_walker(self) -> Self: 410 - return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey) 411 - 412 - @property 413 - def frame(self) -> StackFrame: 414 - return self.stack[-1] 415 - 416 - @property 417 - def lkey(self) -> str: 418 - return self.frame.lkey if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1] 419 - 420 - @property 421 - def lval(self) -> Optional[CID]: 422 - return None if self.frame.idx == 0 else self.frame.node.vals[self.frame.idx - 1] 423 - 424 - @property 425 - def subtree(self) -> Optional[CID]: 426 - return self.frame.node.subtrees[self.frame.idx] 427 - 428 - # hmmmm rkey is overloaded here... "right key" not "record key"... 429 - @property 430 - def rkey(self) -> str: 431 - return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx] 432 - 433 - @property 434 - def rval(self) -> Optional[CID]: 435 - return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx] 436 - 437 - @property 438 - def is_final(self) -> bool: 439 - return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey) 440 - 441 - def right(self) -> None: 442 - if (self.frame.idx + 1) >= len(self.frame.node.subtrees): 443 - # we reached the end of this node, go up a level 444 - self.stack.pop() 445 - if not self.stack: 446 - raise StopIteration # you probably want to check .final instead of hitting this 447 - return self.right() # we need to recurse, to skip over empty intermediates on the way back up 448 - self.frame.idx += 1 449 - 450 - def down(self) -> None: 451 - subtree = self.frame.node.subtrees[self.frame.idx] 452 - if subtree is None: 453 - raise Exception("oi, you can't recurse here mate") 454 - 455 - self.stack.append(self.StackFrame( 456 - node=self.ns.get_node(subtree), 457 - lkey=self.lkey, 458 - rkey=self.rkey, 459 - idx=0 460 - )) 461 - 462 - # everything above here is core tree walking logic 463 - # everything below here is helper functions 464 - 465 - def next_kv(self) -> Tuple[str, CID]: 466 - while self.subtree: # recurse down every subtree 467 - self.down() 468 - self.right() 469 - return self.lkey, self.lval # the kv pair we just jumped over 470 - 471 - # iterate over every k/v pair in key-sorted order 472 - def iter_kv(self): 473 - while not self.is_final: 474 - yield self.next_kv() 475 - 476 - # get all mst nodes down and to the right of the current position 477 - def iter_node_cids(self): 478 - yield self.frame.node.cid 479 - while not self.is_final: 480 - while self.subtree: # recurse down every subtree 481 - self.down() 482 - yield self.frame.node.cid 483 - self.right() 484 - 485 - 486 - def enumerate_mst(ns: NodeStore, root_cid: CID): 487 - for k, v in NodeWalker(ns, root_cid).iter_kv(): 488 - print(k, "->", v.encode("base32")) 489 - 490 - # start inclusive, end exclusive 491 - def enumerate_mst_range(ns: NodeStore, root_cid: CID, start: str, end: str): 492 - cur = NodeWalker(ns, root_cid) 493 - while True: 494 - while cur.rkey < start: 495 - cur.right() 496 - if not cur.subtree: 497 - break 498 - cur.down() 499 - 500 - for k, v, in cur.iter_kv(): 501 - if k >= end: 502 - break 503 - print(k, "->", v.encode("base32")) 504 - 505 - def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]): 506 - created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, created)), {}) 507 - deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, deleted)), {}) 508 - for created_key in created_kv.keys() - deleted_kv.keys(): 509 - yield ("created", created_key, created_kv[created_key].encode("base32")) 510 - for updated_key in created_kv.keys() & deleted_kv.keys(): 511 - v1 = created_kv[updated_key] 512 - v2 = deleted_kv[updated_key] 513 - if v1 != v2: 514 - yield ("updated", updated_key, v1.encode("base32"), v2.encode("base32")) 515 - for deleted_key in deleted_kv.keys() - created_kv.keys(): 516 - yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging 517 - 518 - def very_slow_mst_diff(ns, root_a: CID, root_b: CID): 519 - """ 520 - This should return the same result as mst_diff, but it gets there in a very slow 521 - yet less error-prone way, so it's useful for testing. 522 - """ 523 - a_nodes = set(NodeWalker(ns, root_a).iter_node_cids()) 524 - b_nodes = set(NodeWalker(ns, root_b).iter_node_cids()) 525 - return b_nodes - a_nodes, a_nodes - b_nodes 526 - 527 - EMPTY_NODE_CID = MSTNode.empty_root().cid 528 - 529 - def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created_deleted 530 - created = set() # MST nodes in b but not in a 531 - deleted = set() # MST nodes in a but not in b 532 - mst_diff_recursive(created, deleted, NodeWalker(ns, root_a), NodeWalker(ns, root_b)) 533 - middle = created & deleted # my algorithm has occasional false-positives 534 - #assert(not middle) # this fails 535 - #print("middle", len(middle)) 536 - created -= middle 537 - deleted -= middle 538 - # special case: if one of the root nodes was empty 539 - if root_a == EMPTY_NODE_CID and root_b != EMPTY_NODE_CID: 540 - deleted.add(EMPTY_NODE_CID) 541 - if root_b == EMPTY_NODE_CID and root_a != EMPTY_NODE_CID: 542 - created.add(EMPTY_NODE_CID) 543 - return created, deleted 544 - 545 - def mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: NodeWalker): # created, deleted 546 - # the easiest of all cases 547 - if a.frame.node.cid == b.frame.node.cid: 548 - return # no difference 549 - 550 - # trivial 551 - if a.frame.node.is_empty(): 552 - #mst_deleted.add(a.frame.node.cid) # this doesn't work because it might've been a null subtree node 553 - created |= set(b.iter_node_cids()) 554 - return 555 - 556 - # likewise 557 - if b.frame.node.is_empty(): 558 - #mst_created.add(b.frame.node.cid) 559 - deleted |= set(a.iter_node_cids()) 560 - return 561 - 562 - # now we're onto the hard part 563 - 564 - """ 565 - theory: most trees that get compared will have lots of shared blocks (which we can skip over, due to identical CIDs) 566 - completely different trees will inevitably have to visit every node. 567 - 568 - general idea: 569 - 1. if one cursor is "behind" the other, catch it up 570 - 2. when we're matched up, skip over identical subtrees (and recursively diff non-identical subtrees) 571 - 572 - XXX: this seems to work nicely but I'm not sure if it's necessarily efficient for all tree layouts? 573 - """ 574 - 575 - # NB: these will end up as false-positives if one tree is a subtree of the other 576 - created.add(b.frame.node.cid) 577 - deleted.add(a.frame.node.cid) 578 - 579 - while True: 580 - while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other 581 - # "catch up" cursor a, if it's behind 582 - while a.rkey < b.rkey and not a.is_final: 583 - if a.subtree: # recurse down every subtree 584 - a.down() 585 - deleted.add(a.frame.node.cid) 586 - else: 587 - a.right() 588 - 589 - # catch up cursor b, likewise 590 - while b.rkey < a.rkey and not b.is_final: 591 - if b.subtree: # recurse down every subtree 592 - b.down() 593 - created.add(b.frame.node.cid) 594 - else: 595 - b.right() 596 - 597 - # the rkeys now match, but the subrees below us might not 598 - 599 - mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker()) 600 - 601 - # check if we can still go right XXX: do we need to care about the case where one can, but the other can't? 602 - # To consider: maybe if I just step a, b will catch up automagically 603 - if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey: 604 - break 605 - 606 - a.right() 607 - b.right() 608 - 609 - 610 - if __name__ == "__main__": 611 - from .blockstore import MemoryBlockStore, OverlayBlockStore 612 - from .blockstore.car_reader import ReadOnlyCARBlockStore 613 - 614 - if 0: 615 - import sys 616 - sys.setrecursionlimit(999999999) 617 - f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb") 618 - bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f)) 619 - commit_obj = dag_cbor.decode(bs.get_block(bytes(bs.lower.car_roots[0]))) 620 - mst_root: CID = commit_obj["data"] 621 - ns = NodeStore(bs) 622 - wrangler = NodeWrangler(ns) 623 - #print(wrangler) 624 - #enumerate_mst(ns, mst_root) 625 - enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff") 626 - 627 - root2 = wrangler.del_record(mst_root, "app.bsky.feed.generator/alttext") 628 - root2 = wrangler.del_record(root2, "app.bsky.feed.like/3kas3fyvkti22") 629 - root2 = wrangler.put_record(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah")) 630 - 631 - c, d = mst_diff(ns, mst_root, root2) 632 - print("CREATED:") 633 - for x in c: 634 - print("created", x.encode("base32")) 635 - print("DELETED:") 636 - for x in d: 637 - print("deleted", x.encode("base32")) 638 - 639 - for op in record_diff(ns, c, d): 640 - print(op) 641 - 642 - e, f = very_slow_mst_diff(ns, mst_root, root2) 643 - assert(e == c) 644 - assert(f == d) 645 - else: 646 - bs = MemoryBlockStore() 647 - ns = NodeStore(bs) 648 - wrangler = NodeWrangler(ns) 649 - root = ns.get_node(None).cid 650 - print(ns.pretty(root)) 651 - root = wrangler.put_record(root, "hello", hash_to_cid(b"blah")) 652 - print(ns.pretty(root)) 653 - root = wrangler.put_record(root, "foo", hash_to_cid(b"bar")) 654 - print(ns.pretty(root)) 655 - root_a = root 656 - root = wrangler.put_record(root, "bar", hash_to_cid(b"bat")) 657 - root = wrangler.put_record(root, "xyzz", hash_to_cid(b"bat")) 658 - root = wrangler.del_record(root, "foo") 659 - print("=============") 660 - print(ns.pretty(root_a)) 661 - print("=============") 662 - print(ns.pretty(root)) 663 - #exit() 664 - print("=============") 665 - enumerate_mst(ns, root) 666 - c, d = mst_diff(ns, root_a, root) 667 - print("CREATED:") 668 - for x in c: 669 - print("created", x.encode("base32")) 670 - print("DELETED:") 671 - for x in d: 672 - print("deleted", x.encode("base32")) 673 - 674 - e, f = very_slow_mst_diff(ns, root_a, root) 675 - assert(e == c) 676 - assert(f == d) 677 - 678 - exit() 679 - root = wrangler.delete(root, "foo") 680 - root = wrangler.delete(root, "hello") 681 - print(ns.pretty(root)) 682 - root = wrangler.delete(root, "bar") 683 - print(ns.pretty(root)) 684 - root = wrangler.delete(root, "bar") 685 - print(ns.pretty(root))
+227
src/atmst/mst/__init__.py
··· 1 + import hashlib 2 + import dag_cbor 3 + import operator 4 + from multiformats import multihash, CID 5 + from functools import cached_property 6 + from more_itertools import ilen 7 + from itertools import takewhile 8 + from dataclasses import dataclass 9 + from typing import Tuple, Self, Optional 10 + 11 + 12 + @dataclass(frozen=True) # frozen == immutable == win 13 + class MSTNode: 14 + """ 15 + k/v pairs are interleaved between subtrees like so: :: 16 + 17 + keys: (0, 1, 2, 3) 18 + vals: (0, 1, 2, 3) 19 + subtrees: (0, 1, 2, 3, 4) 20 + 21 + If a method is implemented in this class, it's because it's a function/property 22 + of a single node, as opposed to a whole tree 23 + """ 24 + keys: Tuple[str] # collection/rkey 25 + vals: Tuple[CID] # record CIDs 26 + subtrees: Tuple[Optional[CID]] # a None value represents an empty subtree 27 + 28 + 29 + # NB: __init__ is auto-generated by dataclass decorator 30 + 31 + # these checks should never fail, and could be skipped for performance 32 + def __post_init__(self) -> None: 33 + # TODO: maybe check that they're tuples here? 34 + # implicitly, the length of self.subtrees must be at least 1 35 + if len(self.subtrees) != len(self.keys) + 1: 36 + raise ValueError("Invalid subtree count") 37 + if len(self.keys) != len(self.vals): 38 + raise ValueError("Mismatched keys/vals lengths") 39 + 40 + @classmethod 41 + def empty_root(cls) -> Self: 42 + return cls( 43 + subtrees=(None,), 44 + keys=(), 45 + vals=() 46 + ) 47 + 48 + # this should maybe not be implemented here? 49 + @staticmethod 50 + def key_height(key: str) -> int: 51 + digest = int.from_bytes(hashlib.sha256(key.encode()).digest(), "big") 52 + leading_zeroes = 256 - digest.bit_length() 53 + return leading_zeroes // 2 54 + 55 + # since we're immutable, this can be cached 56 + @cached_property 57 + def cid(self) -> CID: 58 + digest = multihash.digest(self.serialised, "sha2-256") 59 + cid = CID("base32", 1, "dag-cbor", digest) 60 + return cid 61 + 62 + # likewise 63 + @cached_property 64 + def serialised(self) -> bytes: 65 + e = [] 66 + prev_key = b"" 67 + for subtree, key_str, value in zip(self.subtrees[1:], self.keys, self.vals): 68 + key_bytes = key_str.encode() 69 + shared_prefix_len = ilen(takewhile(bool, map(operator.eq, prev_key, key_bytes))) # I love functional programming 70 + e.append({ 71 + "k": key_bytes[shared_prefix_len:], 72 + "p": shared_prefix_len, 73 + "t": subtree, 74 + "v": value, 75 + }) 76 + prev_key = key_bytes 77 + return dag_cbor.encode({ 78 + "e": e, 79 + "l": self.subtrees[0] 80 + }) 81 + 82 + @classmethod 83 + def deserialise(cls, data: bytes) -> Self: 84 + cbor = dag_cbor.decode(data) 85 + if len(cbor) != 2: # e, l 86 + raise ValueError("malformed MST node") 87 + subtrees = [cbor["l"]] 88 + keys = [] 89 + vals = [] 90 + prev_key = b"" 91 + for e in cbor["e"]: # TODO: make extra sure that these checks are watertight wrt non-canonical representations 92 + if len(e) != 4: # k, p, t, v 93 + raise ValueError("malformed MST node") 94 + prefix_len: int = e["p"] 95 + suffix: bytes = e["k"] 96 + if prefix_len > len(prev_key): 97 + raise ValueError("invalid MST key prefix len") 98 + if prev_key[prefix_len:prefix_len+1] == suffix[:1]: 99 + raise ValueError("non-optimal MST key prefix len") 100 + this_key = prev_key[:prefix_len] + suffix 101 + if this_key <= prev_key: 102 + raise ValueError("invalid MST key sort order") 103 + keys.append(this_key.decode()) 104 + vals.append(e["v"]) 105 + subtrees.append(e["t"]) 106 + prev_key = this_key 107 + 108 + return cls( 109 + subtrees=tuple(subtrees), 110 + keys=tuple(keys), 111 + vals=tuple(vals) 112 + ) 113 + 114 + def is_empty(self) -> bool: 115 + return self.subtrees == (None,) 116 + 117 + def _to_optional(self) -> Optional[CID]: 118 + """ 119 + returns None if the node is empty 120 + """ 121 + if self.is_empty(): 122 + return None 123 + return self.cid 124 + 125 + 126 + @cached_property 127 + def height(self) -> int: 128 + # if there are keys at this level, query one directly 129 + if self.keys: 130 + return self.key_height(self.keys[0]) 131 + 132 + # we're an empty tree 133 + if self.subtrees[0] is None: 134 + return 0 135 + 136 + # this should only happen for non-root nodes with no keys 137 + raise Exception("cannot determine node height") 138 + 139 + def gte_index(self, key: str) -> int: 140 + """ 141 + find the index of the first key greater than or equal to the specified key 142 + if all keys are smaller, it returns len(keys) 143 + """ 144 + i = 0 # this loop could be a binary search but not worth it for small fanouts 145 + while i < len(self.keys) and key > self.keys[i]: 146 + i += 1 147 + return i 148 + 149 + 150 + """ 151 + if __name__ == "__main__": 152 + from .blockstore import MemoryBlockStore, OverlayBlockStore 153 + from .blockstore.car_reader import ReadOnlyCARBlockStore 154 + 155 + if 0: 156 + import sys 157 + sys.setrecursionlimit(999999999) 158 + f = open("/home/david/programming/python/bskyclient/retr0id.car", "rb") 159 + bs = OverlayBlockStore(MemoryBlockStore(), ReadOnlyCARBlockStore(f)) 160 + commit_obj = dag_cbor.decode(bs.get_block(bytes(bs.lower.car_roots[0]))) 161 + mst_root: CID = commit_obj["data"] 162 + ns = NodeStore(bs) 163 + wrangler = NodeWrangler(ns) 164 + #print(wrangler) 165 + #enumerate_mst(ns, mst_root) 166 + enumerate_mst_range(ns, mst_root, "app.bsky.feed.generator/", "app.bsky.feed.generator/\xff") 167 + 168 + root2 = wrangler.del_record(mst_root, "app.bsky.feed.generator/alttext") 169 + root2 = wrangler.del_record(root2, "app.bsky.feed.like/3kas3fyvkti22") 170 + root2 = wrangler.put_record(root2, "app.bsky.feed.like/3kc3brpic2z2p", hash_to_cid(b"blah")) 171 + 172 + c, d = mst_diff(ns, mst_root, root2) 173 + print("CREATED:") 174 + for x in c: 175 + print("created", x.encode("base32")) 176 + print("DELETED:") 177 + for x in d: 178 + print("deleted", x.encode("base32")) 179 + 180 + for op in record_diff(ns, c, d): 181 + print(op) 182 + 183 + e, f = very_slow_mst_diff(ns, mst_root, root2) 184 + assert(e == c) 185 + assert(f == d) 186 + else: 187 + bs = MemoryBlockStore() 188 + ns = NodeStore(bs) 189 + wrangler = NodeWrangler(ns) 190 + root = ns.get_node(None).cid 191 + print(ns.pretty(root)) 192 + root = wrangler.put_record(root, "hello", hash_to_cid(b"blah")) 193 + print(ns.pretty(root)) 194 + root = wrangler.put_record(root, "foo", hash_to_cid(b"bar")) 195 + print(ns.pretty(root)) 196 + root_a = root 197 + root = wrangler.put_record(root, "bar", hash_to_cid(b"bat")) 198 + root = wrangler.put_record(root, "xyzz", hash_to_cid(b"bat")) 199 + root = wrangler.del_record(root, "foo") 200 + print("=============") 201 + print(ns.pretty(root_a)) 202 + print("=============") 203 + print(ns.pretty(root)) 204 + #exit() 205 + print("=============") 206 + enumerate_mst(ns, root) 207 + c, d = mst_diff(ns, root_a, root) 208 + print("CREATED:") 209 + for x in c: 210 + print("created", x.encode("base32")) 211 + print("DELETED:") 212 + for x in d: 213 + print("deleted", x.encode("base32")) 214 + 215 + e, f = very_slow_mst_diff(ns, root_a, root) 216 + assert(e == c) 217 + assert(f == d) 218 + 219 + exit() 220 + root = wrangler.delete(root, "foo") 221 + root = wrangler.delete(root, "hello") 222 + print(ns.pretty(root)) 223 + root = wrangler.delete(root, "bar") 224 + print(ns.pretty(root)) 225 + root = wrangler.delete(root, "bar") 226 + print(ns.pretty(root)) 227 + """
+124
src/atmst/mst/diff.py
··· 1 + import operator 2 + from typing import Tuple, Set, Iterable 3 + from functools import reduce 4 + 5 + from multiformats import CID 6 + 7 + from . import MSTNode 8 + from .node_store import NodeStore 9 + from .node_walker import NodeWalker 10 + 11 + 12 + def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]) -> Iterable[tuple]: 13 + """ 14 + Given two sets of MST nodes (for example, the result of `mst_diff`), this 15 + returns an iterator of record changes, in one of 3 formats: 16 + 17 + ("created", key, value) 18 + ("updated", key, old_value, new_value) 19 + ("deleted", key, value) 20 + """ 21 + created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, created)), {}) 22 + deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, deleted)), {}) 23 + for created_key in created_kv.keys() - deleted_kv.keys(): 24 + yield ("created", created_key, created_kv[created_key].encode("base32")) 25 + for updated_key in created_kv.keys() & deleted_kv.keys(): 26 + v1 = created_kv[updated_key] 27 + v2 = deleted_kv[updated_key] 28 + if v1 != v2: 29 + yield ("updated", updated_key, v1.encode("base32"), v2.encode("base32")) 30 + for deleted_key in deleted_kv.keys() - created_kv.keys(): 31 + yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging 32 + 33 + def very_slow_mst_diff(ns: NodeStore, root_a: CID, root_b: CID): 34 + """ 35 + This should return the same result as mst_diff, but it gets there in a very slow 36 + yet less error-prone way, so it's useful for testing. 37 + 38 + It's actually faster for smaller trees, but it chokes on trees with thousands of nodes (especially if the NodeStore is slow). 39 + """ 40 + a_nodes = set(NodeWalker(ns, root_a).iter_node_cids()) 41 + b_nodes = set(NodeWalker(ns, root_b).iter_node_cids()) 42 + return b_nodes - a_nodes, a_nodes - b_nodes 43 + 44 + EMPTY_NODE_CID = MSTNode.empty_root().cid 45 + 46 + def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created, deleted 47 + created = set() # MST nodes in b but not in a 48 + deleted = set() # MST nodes in a but not in b 49 + mst_diff_recursive(created, deleted, NodeWalker(ns, root_a), NodeWalker(ns, root_b)) 50 + middle = created & deleted # my algorithm has occasional false-positives 51 + #assert(not middle) # this fails 52 + #print("middle", len(middle)) 53 + created -= middle 54 + deleted -= middle 55 + # special case: if one of the root nodes was empty 56 + if root_a == EMPTY_NODE_CID and root_b != EMPTY_NODE_CID: 57 + deleted.add(EMPTY_NODE_CID) 58 + if root_b == EMPTY_NODE_CID and root_a != EMPTY_NODE_CID: 59 + created.add(EMPTY_NODE_CID) 60 + return created, deleted 61 + 62 + def mst_diff_recursive(created: Set[CID], deleted: Set[CID], a: NodeWalker, b: NodeWalker): # created, deleted 63 + # the easiest of all cases 64 + if a.frame.node.cid == b.frame.node.cid: 65 + return # no difference 66 + 67 + # trivial 68 + if a.frame.node.is_empty(): 69 + #mst_deleted.add(a.frame.node.cid) # this doesn't work because it might've been a null subtree node 70 + created |= set(b.iter_node_cids()) 71 + return 72 + 73 + # likewise 74 + if b.frame.node.is_empty(): 75 + #mst_created.add(b.frame.node.cid) 76 + deleted |= set(a.iter_node_cids()) 77 + return 78 + 79 + # now we're onto the hard part 80 + 81 + """ 82 + theory: most trees that get compared will have lots of shared blocks (which we can skip over, due to identical CIDs) 83 + completely different trees will inevitably have to visit every node. 84 + 85 + general idea: 86 + 1. if one cursor is "behind" the other, catch it up 87 + 2. when we're matched up, skip over identical subtrees (and recursively diff non-identical subtrees) 88 + 89 + XXX: this seems to work nicely but I'm not sure if it's necessarily efficient for all tree layouts? 90 + """ 91 + 92 + # NB: these will end up as false-positives if one tree is a subtree of the other 93 + created.add(b.frame.node.cid) 94 + deleted.add(a.frame.node.cid) 95 + 96 + while True: 97 + while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other 98 + # "catch up" cursor a, if it's behind 99 + while a.rkey < b.rkey and not a.is_final: 100 + if a.subtree: # recurse down every subtree 101 + a.down() 102 + deleted.add(a.frame.node.cid) 103 + else: 104 + a.right() 105 + 106 + # catch up cursor b, likewise 107 + while b.rkey < a.rkey and not b.is_final: 108 + if b.subtree: # recurse down every subtree 109 + b.down() 110 + created.add(b.frame.node.cid) 111 + else: 112 + b.right() 113 + 114 + # the rkeys now match, but the subrees below us might not 115 + 116 + mst_diff_recursive(created, deleted, a.subtree_walker(), b.subtree_walker()) 117 + 118 + # check if we can still go right XXX: do we need to care about the case where one can, but the other can't? 119 + # To consider: maybe if I just step a, b will catch up automagically 120 + if a.rkey == a.stack[0].rkey and b.rkey == a.stack[0].rkey: 121 + break 122 + 123 + a.right() 124 + b.right()
+55
src/atmst/mst/node_store.py
··· 1 + from typing import Optional, Dict 2 + 3 + from multiformats import CID 4 + 5 + from ..blockstore import BlockStore 6 + from ..util import indent 7 + from . import MSTNode 8 + 9 + class NodeStore: 10 + """ 11 + NodeStore wraps a BlockStore to provide a more ergonomic interface 12 + for loading and storing MSTNodes 13 + """ 14 + bs: BlockStore 15 + cache: Dict[Optional[CID], MSTNode] # XXX: this cache will grow forever! 16 + #cache_counts: Dict[Optional[CID], int] 17 + 18 + def __init__(self, bs: BlockStore) -> None: 19 + self.bs = bs 20 + self.cache = {} 21 + #self.cache_counts = {} 22 + 23 + # TODO: LRU cache this - this package looks ideal: https://github.com/amitdev/lru-dict 24 + def get_node(self, cid: Optional[CID]) -> MSTNode: 25 + cached = self.cache.get(cid) 26 + if cached: 27 + return cached 28 + """ 29 + if cid is None, returns an empty MST node 30 + """ 31 + if cid is None: 32 + return self.put_node(MSTNode.empty_root()) 33 + 34 + res = MSTNode.deserialise(self.bs.get_block(bytes(cid))) 35 + self.cache[cid] = res 36 + return res 37 + 38 + # TODO: also put in cache 39 + def put_node(self, node: MSTNode) -> MSTNode: 40 + self.cache[node.cid] = node 41 + self.bs.put_block(bytes(node.cid), node.serialised) 42 + return node # this is convenient 43 + 44 + # MST pretty-printing 45 + # this should maybe not be implemented here 46 + def pretty(self, node_cid: Optional[CID]) -> str: 47 + if node_cid is None: 48 + return "<empty>" 49 + node = self.get_node(node_cid) 50 + res = f"MSTNode<cid={node.cid.encode("base32")}>(\n{indent(self.pretty(node.subtrees[0]))},\n" 51 + for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]): 52 + res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode("base32")},\n" 53 + res += indent(self.pretty(t)) + ",\n" 54 + res += ")" 55 + return res
+142
src/atmst/mst/node_walker.py
··· 1 + from dataclasses import dataclass 2 + from typing import Tuple, Self, Optional, List 3 + 4 + from multiformats import CID 5 + 6 + from . import MSTNode 7 + from .node_store import NodeStore 8 + 9 + class NodeWalker: 10 + """ 11 + NodeWalker makes implementing tree diffing and other MST query ops more 12 + convenient (but it does not, itself, implement them). 13 + 14 + A NodeWalker starts off at the root of a tree, and can walk along or recurse 15 + down into subtrees. 16 + 17 + Walking "off the end" of a subtree brings you back up to its next non-empty parent. 18 + 19 + Recall MSTNode layout: :: 20 + 21 + keys: (lkey) (0, 1, 2, 3) (rkey) 22 + vals: (0, 1, 2, 3) 23 + subtrees: (0, 1, 2, 3, 4) 24 + 25 + """ 26 + KEY_MIN = "" # string that compares less than all legal key strings 27 + KEY_MAX = "\xff" # string that compares greater than all legal key strings 28 + 29 + @dataclass 30 + class StackFrame: 31 + node: MSTNode # could store CIDs only to save memory, in theory, but not much point 32 + lkey: str 33 + rkey: str 34 + idx: int 35 + 36 + ns: NodeStore 37 + stack: List[StackFrame] 38 + 39 + def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None: 40 + self.ns = ns 41 + self.stack = [self.StackFrame( 42 + node=self.ns.get_node(root_cid), 43 + lkey=lkey, 44 + rkey=rkey, 45 + idx=0 46 + )] 47 + 48 + def subtree_walker(self) -> Self: 49 + return NodeWalker(self.ns, self.subtree, self.lkey, self.rkey) 50 + 51 + @property 52 + def frame(self) -> StackFrame: 53 + return self.stack[-1] 54 + 55 + @property 56 + def lkey(self) -> str: 57 + return self.frame.lkey if self.frame.idx == 0 else self.frame.node.keys[self.frame.idx - 1] 58 + 59 + @property 60 + def lval(self) -> Optional[CID]: 61 + return None if self.frame.idx == 0 else self.frame.node.vals[self.frame.idx - 1] 62 + 63 + @property 64 + def subtree(self) -> Optional[CID]: 65 + return self.frame.node.subtrees[self.frame.idx] 66 + 67 + # hmmmm rkey is overloaded here... "right key" not "record key"... 68 + @property 69 + def rkey(self) -> str: 70 + return self.frame.rkey if self.frame.idx == len(self.frame.node.keys) else self.frame.node.keys[self.frame.idx] 71 + 72 + @property 73 + def rval(self) -> Optional[CID]: 74 + return None if self.frame.idx == len(self.frame.node.vals) else self.frame.node.vals[self.frame.idx] 75 + 76 + @property 77 + def is_final(self) -> bool: 78 + return (not self.stack) or (self.subtree is None and self.rkey == self.stack[0].rkey) 79 + 80 + def right(self) -> None: 81 + if (self.frame.idx + 1) >= len(self.frame.node.subtrees): 82 + # we reached the end of this node, go up a level 83 + self.stack.pop() 84 + if not self.stack: 85 + raise StopIteration # you probably want to check .final instead of hitting this 86 + return self.right() # we need to recurse, to skip over empty intermediates on the way back up 87 + self.frame.idx += 1 88 + 89 + def down(self) -> None: 90 + subtree = self.frame.node.subtrees[self.frame.idx] 91 + if subtree is None: 92 + raise Exception("oi, you can't recurse here mate") 93 + 94 + self.stack.append(self.StackFrame( 95 + node=self.ns.get_node(subtree), 96 + lkey=self.lkey, 97 + rkey=self.rkey, 98 + idx=0 99 + )) 100 + 101 + # everything above here is core tree walking logic 102 + # everything below here is helper functions 103 + 104 + def next_kv(self) -> Tuple[str, CID]: 105 + while self.subtree: # recurse down every subtree 106 + self.down() 107 + self.right() 108 + return self.lkey, self.lval # the kv pair we just jumped over 109 + 110 + # iterate over every k/v pair in key-sorted order 111 + def iter_kv(self): 112 + while not self.is_final: 113 + yield self.next_kv() 114 + 115 + # get all mst nodes down and to the right of the current position 116 + def iter_node_cids(self): 117 + yield self.frame.node.cid 118 + while not self.is_final: 119 + while self.subtree: # recurse down every subtree 120 + self.down() 121 + yield self.frame.node.cid 122 + self.right() 123 + 124 + 125 + def enumerate_mst(ns: NodeStore, root_cid: CID): 126 + for k, v in NodeWalker(ns, root_cid).iter_kv(): 127 + print(k, "->", v.encode("base32")) 128 + 129 + # start inclusive, end exclusive 130 + def enumerate_mst_range(ns: NodeStore, root_cid: CID, start: str, end: str): 131 + cur = NodeWalker(ns, root_cid) 132 + while True: 133 + while cur.rkey < start: 134 + cur.right() 135 + if not cur.subtree: 136 + break 137 + cur.down() 138 + 139 + for k, v, in cur.iter_kv(): 140 + if k >= end: 141 + break 142 + print(k, "->", v.encode("base32"))
+173
src/atmst/mst/wrangler.py
··· 1 + from typing import Tuple, Optional, Any 2 + 3 + from multiformats import CID 4 + 5 + from . import MSTNode 6 + from .node_store import NodeStore 7 + 8 + # tuple helpers 9 + def tuple_replace_at(original: tuple, i: int, value: Any) -> tuple: 10 + return original[:i] + (value,) + original[i + 1:] 11 + 12 + def tuple_insert_at(original: tuple, i: int, value: Any) -> tuple: 13 + return original[:i] + (value,) + original[i:] 14 + 15 + def tuple_remove_at(original: tuple, i: int) -> tuple: 16 + return original[:i] + original[i + 1:] 17 + 18 + 19 + class NodeWrangler: 20 + """ 21 + NodeWrangler is where core MST transformation ops are implemented, backed 22 + by a NodeStore 23 + 24 + The external APIs take a CID (the MST root) and return a CID (the new root), 25 + while storing any newly created nodes in the NodeStore. 26 + 27 + Neither method should ever fail - deleting a node that doesn't exist is a nop, 28 + and adding the same node twice with the same value is also a nop. Callers 29 + can detect these cases by seeing if the initial and final CIDs changed. 30 + """ 31 + ns: NodeStore 32 + 33 + def __init__(self, ns: NodeStore) -> None: 34 + self.ns = ns 35 + 36 + def put_record(self, root_cid: CID, key: str, val: CID) -> CID: 37 + root = self.ns.get_node(root_cid) 38 + if root.is_empty(): # special case for empty tree 39 + return self._put_here(root, key, val).cid 40 + return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid 41 + 42 + def del_record(self, root_cid: CID, key: str) -> CID: 43 + root = self.ns.get_node(root_cid) 44 + 45 + # Note: the seemingly redundant outer .get().cid is required to transform 46 + # a None cid into the cid representing an empty node (we could maybe find a more elegant 47 + # way of doing this...) 48 + return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid 49 + 50 + 51 + 52 + def _put_here(self, node: MSTNode, key: str, val: CID) -> MSTNode: 53 + i = node.gte_index(key) 54 + 55 + # the key is already present! 56 + if i < len(node.keys) and node.keys[i] == key: 57 + if node.vals[i] == val: 58 + return node # we can return our old self if there is no change 59 + return self.ns.put_node(MSTNode( 60 + keys=node.keys, 61 + vals=tuple_replace_at(node.vals, i, val), 62 + subtrees=node.subtrees 63 + )) 64 + 65 + return self.ns.put_node(MSTNode( 66 + keys=tuple_insert_at(node.keys, i, key), 67 + vals=tuple_insert_at(node.vals, i, val), 68 + subtrees = node.subtrees[:i] + \ 69 + self._split_on_key(node.subtrees[i], key) + \ 70 + node.subtrees[i + 1:], 71 + )) 72 + 73 + def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> MSTNode: 74 + if key_height > tree_height: # we need to grow the tree 75 + return self.ns.put_node(self._put_recursive( 76 + MSTNode.empty_root(), 77 + key, val, key_height, tree_height + 1 78 + )) 79 + 80 + if key_height < tree_height: # we need to look below 81 + i = node.gte_index(key) 82 + return self.ns.put_node(MSTNode( 83 + keys=node.keys, 84 + vals=node.vals, 85 + subtrees=tuple_replace_at( 86 + node.subtrees, i, 87 + self._put_recursive( 88 + self.ns.get_node(node.subtrees[i]), 89 + key, val, key_height, tree_height - 1 90 + ).cid 91 + ) 92 + )) 93 + 94 + # we can insert here 95 + assert(key_height == tree_height) 96 + return self._put_here(node, key, val) 97 + 98 + def _split_on_key(self, node_cid: Optional[CID], key: str) -> Tuple[Optional[CID], Optional[CID]]: 99 + if node_cid is None: 100 + return None, None 101 + node = self.ns.get_node(node_cid) 102 + i = node.gte_index(key) 103 + lsub, rsub = self._split_on_key(node.subtrees[i], key) 104 + return self.ns.put_node(MSTNode( 105 + keys=node.keys[:i], 106 + vals=node.vals[:i], 107 + subtrees=node.subtrees[:i] + (lsub,) 108 + ))._to_optional(), self.ns.put_node(MSTNode( 109 + keys=node.keys[i:], 110 + vals=node.vals[i:], 111 + subtrees=(rsub,) + node.subtrees[i + 1:], 112 + ))._to_optional() 113 + 114 + def _squash_top(self, node_cid: Optional[CID]) -> Optional[CID]: 115 + """ 116 + strip empty nodes from the top of the tree 117 + """ 118 + node = self.ns.get_node(node_cid) 119 + if node.keys: 120 + return node_cid 121 + if node.subtrees[0] is None: 122 + return node_cid 123 + return self._squash_top(node.subtrees[0]) 124 + 125 + def _delete_recursive(self, node: MSTNode, key: str, key_height: int, tree_height: int) -> Optional[CID]: 126 + if key_height > tree_height: # the key cannot possibly be in this tree, no change needed 127 + return node._to_optional() 128 + 129 + i = node.gte_index(key) 130 + if key_height < tree_height: # the key must be deleted from a subtree 131 + if node.subtrees[i] is None: 132 + return node._to_optional() # the key cannot be in this subtree, no change needed 133 + return self.ns.put_node(MSTNode( 134 + keys=node.keys, 135 + vals=node.vals, 136 + subtrees=tuple_replace_at( 137 + node.subtrees, 138 + i, 139 + self._delete_recursive(self.ns.get_node(node.subtrees[i]), key, key_height, tree_height - 1) 140 + ) 141 + ))._to_optional() 142 + 143 + i = node.gte_index(key) 144 + if i == len(node.keys) or node.keys[i] != key: 145 + return node._to_optional() # key already not present 146 + 147 + assert(node.keys[i] == key) # sanity check (should always be true) 148 + 149 + return self.ns.put_node(MSTNode( 150 + keys=tuple_remove_at(node.keys, i), 151 + vals=tuple_remove_at(node.vals, i), 152 + subtrees=node.subtrees[:i] + ( 153 + self._merge(node.subtrees[i], node.subtrees[i + 1]), 154 + ) + node.subtrees[i + 2:] 155 + ))._to_optional() 156 + 157 + def _merge(self, left_cid: Optional[CID], right_cid: Optional[CID]) -> Optional[CID]: 158 + if left_cid is None: 159 + return right_cid # includes the case where left == right == None 160 + if right_cid is None: 161 + return left_cid 162 + left = self.ns.get_node(left_cid) 163 + right = self.ns.get_node(right_cid) 164 + return self.ns.put_node(MSTNode( 165 + keys=left.keys + right.keys, 166 + vals=left.vals + right.vals, 167 + subtrees=left.subtrees[:-1] + ( 168 + self._merge( 169 + left.subtrees[-1], 170 + right.subtrees[0] 171 + ), 172 + ) + right.subtrees[1:] 173 + ))._to_optional()