this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

working MST diffing

+126 -61
+1 -1
carfile.py
··· 71 71 72 72 from mst import NodeStore 73 73 ns = NodeStore(bs) 74 - print(ns.get(mst_root)) 74 + print(ns.get_node(mst_root))
+78 -60
mst.py
··· 167 167 for loading and storing MSTNodes 168 168 """ 169 169 bs: BlockStore 170 - #cache: Dict[Optional[CID], MSTNode] 170 + cache: Dict[Optional[CID], MSTNode] # XXX: this cache will grow forever! 171 171 #cache_counts: Dict[Optional[CID], int] 172 172 173 173 def __init__(self, bs: BlockStore) -> None: 174 174 self.bs = bs 175 - #self.cache = {} 175 + self.cache = {} 176 176 #self.cache_counts = {} 177 177 178 178 # TODO: LRU cache this - this package looks ideal: https://github.com/amitdev/lru-dict 179 - def get(self, cid: Optional[CID]) -> MSTNode: 179 + def get_node(self, cid: Optional[CID]) -> MSTNode: 180 + cached = self.cache.get(cid) 181 + if cached: 182 + return cached 180 183 """ 181 184 if cid is None, returns an empty MST node 182 185 """ 183 186 if cid is None: 184 - return self.put(MSTNode.empty_root()) 187 + return self.put_node(MSTNode.empty_root()) 185 188 186 - return MSTNode.deserialise(self.bs.get(bytes(cid))) 189 + res = MSTNode.deserialise(self.bs.get(bytes(cid))) 190 + self.cache[cid] = res 191 + return res 187 192 188 193 # TODO: also put in cache 189 - def put(self, node: MSTNode) -> MSTNode: 194 + def put_node(self, node: MSTNode) -> MSTNode: 195 + self.cache[node.cid] = node 190 196 self.bs.put(bytes(node.cid), node.serialised) 191 197 return node # this is convenient 192 198 ··· 195 201 def pretty(self, node_cid: Optional[CID]) -> str: 196 202 if node_cid is None: 197 203 return "<empty>" 198 - node = self.get(node_cid) 204 + node = self.get_node(node_cid) 199 205 res = f"MSTNode<cid={node.cid.encode("base32")}>(\n{indent(self.pretty(node.subtrees[0]))},\n" 200 206 for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]): 201 207 res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode("base32")},\n" ··· 222 228 self.ns = ns 223 229 224 230 def put(self, root_cid: CID, key: str, val: CID) -> CID: 225 - root = ns.get(root_cid) 231 + root = self.ns.get_node(root_cid) 226 232 if root.is_empty(): # special case for empty tree 227 - return self._put_here(root, key, val) 228 - return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height) 233 + return self._put_here(root, key, val).cid 234 + return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid 229 235 230 236 def delete(self, root_cid: CID, key: str) -> CID: 231 - root = ns.get(root_cid) 237 + root = self.ns.get_node(root_cid) 232 238 233 239 # Note: the seemingly redundant outer .get().cid is required to transform 234 240 # a None cid into the cid representing an empty node (we could maybe find a more elegant 235 241 # way of doing this...) 236 - return self.ns.get(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid 242 + return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid 237 243 238 244 239 245 240 - def _put_here(self, node: MSTNode, key: str, val: CID) -> CID: 246 + def _put_here(self, node: MSTNode, key: str, val: CID) -> MSTNode: 241 247 i = node.gte_index(key) 242 248 243 249 # the key is already present! 244 250 if i < len(node.keys) and node.keys[i] == key: 245 251 if node.vals[i] == val: 246 - return node.cid # we can return our old self if there is no change 247 - return self.ns.put(MSTNode( 252 + return node # we can return our old self if there is no change 253 + return self.ns.put_node(MSTNode( 248 254 keys=node.keys, 249 255 vals=tuple_replace_at(node.vals, i, val), 250 256 subtrees=node.subtrees 251 - )).cid 257 + )) 252 258 253 - return self.ns.put(MSTNode( 259 + return self.ns.put_node(MSTNode( 254 260 keys=tuple_insert_at(node.keys, i, key), 255 261 vals=tuple_insert_at(node.vals, i, val), 256 262 subtrees = node.subtrees[:i] + \ 257 263 self._split_on_key(node.subtrees[i], key) + \ 258 264 node.subtrees[i + 1:], 259 - )).cid 265 + )) 260 266 261 - def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> CID: 267 + def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> MSTNode: 262 268 if key_height > tree_height: # we need to grow the tree 263 - return self.ns.put(self._put_recursive( 269 + return self.ns.put_node(self._put_recursive( 264 270 MSTNode.empty_root(), 265 271 key, val, key_height, tree_height + 1 266 - )).cid 272 + )) 267 273 268 274 if key_height < tree_height: # we need to look below 269 275 i = node.gte_index(key) 270 - return self.ns.put(MSTNode( 276 + return self.ns.put_node(MSTNode( 271 277 keys=node.keys, 272 278 vals=node.vals, 273 279 subtrees=tuple_replace_at( 274 280 node.subtrees, i, 275 281 self._put_recursive( 276 - self.ns.get(node.subtrees[i]), 282 + self.ns.get_node(node.subtrees[i]), 277 283 key, val, key_height, tree_height - 1 278 - ) 284 + ).cid 279 285 ) 280 - )).cid 286 + )) 281 287 282 288 # we can insert here 283 289 assert(key_height == tree_height) ··· 286 292 def _split_on_key(self, node_cid: Optional[CID], key: str) -> Tuple[Optional[CID], Optional[CID]]: 287 293 if node_cid is None: 288 294 return None, None 289 - node = ns.get(node_cid) 295 + node = self.ns.get_node(node_cid) 290 296 i = node.gte_index(key) 291 297 lsub, rsub = self._split_on_key(node.subtrees[i], key) 292 - return self.ns.put(MSTNode( 298 + return self.ns.put_node(MSTNode( 293 299 keys=node.keys[:i], 294 300 vals=node.vals[:i], 295 301 subtrees=node.subtrees[:i] + (lsub,) 296 - ))._to_optional(), self.ns.put(MSTNode( 302 + ))._to_optional(), self.ns.put_node(MSTNode( 297 303 keys=node.keys[i:], 298 304 vals=node.vals[i:], 299 305 subtrees=(rsub,) + node.subtrees[i + 1:], ··· 303 309 """ 304 310 strip empty nodes from the top of the tree 305 311 """ 306 - node = self.ns.get(node_cid) 312 + node = self.ns.get_node(node_cid) 307 313 if node.keys: 308 314 return node_cid 309 315 if node.subtrees[0] is None: ··· 318 324 if key_height < tree_height: # the key must be deleted from a subtree 319 325 if node.subtrees[i] is None: 320 326 return node._to_optional() # the key cannot be in this subtree, no change needed 321 - return self.ns.put(MSTNode( 327 + return self.ns.put_node(MSTNode( 322 328 keys=node.keys, 323 329 vals=node.vals, 324 330 subtrees=tuple_replace_at( 325 331 node.subtrees, 326 332 i, 327 - self._delete_recursive(self.ns.get(node.subtrees[i]), key, key_height, tree_height - 1) 333 + self._delete_recursive(self.ns.get_node(node.subtrees[i]), key, key_height, tree_height - 1) 328 334 ) 329 335 ))._to_optional() 330 336 ··· 334 340 335 341 assert(node.keys[i] == key) # sanity check (should always be true) 336 342 337 - return self.ns.put(MSTNode( 343 + return self.ns.put_node(MSTNode( 338 344 keys=tuple_remove_at(node.keys, i), 339 345 vals=tuple_remove_at(node.vals, i), 340 346 subtrees=node.subtrees[:i] + ( ··· 347 353 return right_cid # includes the case where left == right == None 348 354 if right_cid is None: 349 355 return left_cid 350 - left = self.ns.get(left_cid) 351 - right = self.ns.get(right_cid) 352 - return self.ns.put(MSTNode( 356 + left = self.ns.get_node(left_cid) 357 + right = self.ns.get_node(right_cid) 358 + return self.ns.put_node(MSTNode( 353 359 keys=left.keys + right.keys, 354 360 vals=left.vals + right.vals, 355 361 subtrees=left.subtrees[:-1] + ( ··· 393 399 def __init__(self, ns: NodeStore, root_cid: CID, lkey: Optional[str]=KEY_MIN, rkey: Optional[str]=KEY_MAX) -> None: 394 400 self.ns = ns 395 401 self.stack = [self.StackFrame( 396 - node=self.ns.get(root_cid), 402 + node=self.ns.get_node(root_cid), 397 403 lkey=lkey, 398 404 rkey=rkey, 399 405 idx=0 ··· 446 452 raise Exception("oi, you can't recurse here mate") 447 453 448 454 self.stack.append(self.StackFrame( 449 - node=self.ns.get(subtree), 455 + node=self.ns.get_node(subtree), 450 456 lkey=self.lkey, 451 457 rkey=self.rkey, 452 458 idx=0 ··· 496 502 print(k, "->", v.encode("base32")) 497 503 498 504 def record_diff(ns: NodeStore, created: set[CID], deleted: set[CID]): 499 - created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get, created)), {}) 500 - deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get, deleted)), {}) 505 + created_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, created)), {}) 506 + deleted_kv = reduce(operator.__or__, ({ k: v for k, v in zip(node.keys, node.vals)} for node in map(ns.get_node, deleted)), {}) 501 507 for created_key in created_kv.keys() - deleted_kv.keys(): 502 508 yield ("created", created_key, created_kv[created_key].encode("base32")) 503 509 for updated_key in created_kv.keys() & deleted_kv.keys(): ··· 508 514 for deleted_key in deleted_kv.keys() - created_kv.keys(): 509 515 yield ("deleted", deleted_key, deleted_kv[deleted_key].encode("base32")) #XXX: encode() is just for debugging 510 516 517 + EMPTY_NODE_CID = MSTNode.empty_root().cid 518 + 511 519 def mst_diff(ns: NodeStore, root_a: CID, root_b: CID) -> Tuple[Set[CID], Set[CID]]: # created_deleted 512 520 created, deleted = mst_diff_recursive(NodeWalker(ns, root_a), NodeWalker(ns, root_b)) 513 - middle = created & deleted 514 - #assert(not middle) # should be no intersection!!! 515 - return created - middle, deleted - middle 521 + middle = created & deleted # my algorithm has occasional false-positives 522 + #assert(not middle) # this fails 523 + #print("middle", len(middle)) 524 + created -= middle 525 + deleted -= middle 526 + # special case: if one of the root nodes was empty 527 + if root_a == EMPTY_NODE_CID and root_b != EMPTY_NODE_CID: 528 + deleted.add(EMPTY_NODE_CID) 529 + if root_b == EMPTY_NODE_CID and root_a != EMPTY_NODE_CID: 530 + created.add(EMPTY_NODE_CID) 531 + return created, deleted 516 532 517 533 def very_slow_mst_diff(ns, root_a: CID, root_b: CID): 518 534 """ ··· 528 544 mst_deleted = set() # MST nodes in a but not in b 529 545 530 546 # the easiest of all cases 531 - if a.frame.node.cid == b.frame.node.cid: # includes the case where they're both None 547 + if a.frame.node.cid == b.frame.node.cid: 532 548 return mst_created, mst_deleted # no difference 533 549 534 550 # trivial ··· 559 575 mst_deleted.add(a.frame.node.cid) 560 576 561 577 while True: 562 - # "catch up" cursor a, if it's behind 563 - while a.rkey < b.rkey and not a.is_final: 564 - if a.subtree: # recurse down every subtree 565 - a.down() 566 - mst_deleted.add(a.frame.node.cid) 567 - else: 568 - a.right() 569 - 570 - # catch up cursor b, likewise 571 - while b.rkey < a.rkey and not b.is_final: 572 - if b.subtree: # recurse down every subtree 573 - b.down() 574 - mst_created.add(b.frame.node.cid) 575 - else: 576 - b.right() 578 + while a.rkey != b.rkey: # we need a loop because they might "leapfrog" each other 579 + # "catch up" cursor a, if it's behind 580 + while a.rkey < b.rkey and not a.is_final: 581 + if a.subtree: # recurse down every subtree 582 + a.down() 583 + mst_deleted.add(a.frame.node.cid) 584 + else: 585 + a.right() 586 + 587 + # catch up cursor b, likewise 588 + while b.rkey < a.rkey and not b.is_final: 589 + if b.subtree: # recurse down every subtree 590 + b.down() 591 + mst_created.add(b.frame.node.cid) 592 + else: 593 + b.right() 577 594 578 - assert(b.rkey == a.rkey) 595 + #print(a.rkey, a.stack[0].rkey, b.rkey, a.stack[0].rkey) 596 + #assert(b.rkey == a.rkey) 579 597 # the rkeys match, but the subrees below us might not 580 598 581 599 c, d = mst_diff_recursive(a.subtree_walker(), b.subtree_walker()) ··· 629 647 bs = MemoryBlockStore() 630 648 ns = NodeStore(bs) 631 649 wrangler = NodeWrangler(ns) 632 - root = ns.get(None).cid 650 + root = ns.get_node(None).cid 633 651 print(ns.pretty(root)) 634 652 root = wrangler.put(root, "hello", hash_to_cid(b"blah")) 635 653 print(ns.pretty(root))
+47
mst_test.py
··· 1 + import random 2 + from mst import mst_diff, very_slow_mst_diff, MemoryBlockStore, NodeStore, NodeWrangler, hash_to_cid 3 + import time 4 + 5 + PERF_BENCH = False 6 + 7 + def random_test(): 8 + bs = MemoryBlockStore() 9 + ns = NodeStore(bs) 10 + nw = NodeWrangler(ns) 11 + root = ns.get_node(None).cid 12 + keys = [] 13 + for _ in range(10240 if PERF_BENCH else random.randrange(0, 32)): 14 + k = random.randbytes(8).hex() 15 + keys.append(k) 16 + root = nw.put(root, k, hash_to_cid(random.randbytes(8))) 17 + root_a = root 18 + for _ in range(8 if PERF_BENCH else random.randrange(0, 8)): 19 + # some random additions 20 + root = nw.put(root, random.randbytes(8).hex(), hash_to_cid(random.randbytes(8))) 21 + if keys: 22 + # some random modifications 23 + for _ in range(4 if PERF_BENCH else random.randrange(0, 4)): 24 + for k in random.choice(keys): 25 + root = nw.put(root, k, hash_to_cid(random.randbytes(8))) 26 + # some random deletions 27 + for _ in range(4 if PERF_BENCH else random.randrange(0, 4)): 28 + for k in random.choice(keys): 29 + root = nw.delete(root, k) 30 + 31 + diff_start = time.time() 32 + c, d = mst_diff(ns, root_a, root) 33 + #c, d = very_slow_mst_diff(ns, root_a, root) 34 + diff_duration = time.time()-diff_start 35 + e, f = mst_diff(ns, root, root_a) 36 + assert(c == f) # compare with reverse 37 + assert(e == d) # compare with reverse 38 + g, h = very_slow_mst_diff(ns, root_a, root) 39 + assert(c == g) # compare with known-good 40 + assert(d == h) # compare with known-good 41 + return diff_duration 42 + 43 + if __name__ == "__main__": 44 + duration = 0 45 + for _ in range(1 if PERF_BENCH else 200): 46 + duration += random_test() 47 + print("time spent diffing (ms):", duration*1000)