Binary Search Tree

Binary search tree implementation from interactivepython.org

from typing import Optional from nagini_contracts.contracts import * class TreeNode: def __init__(self, key: int, val: str, left:'TreeNode'=None, right:'TreeNode'=None, parent:'TreeNode'=None) -> None: self.key = key self.payload = val self.leftChild = left self.rightChild = right self.parent = parent Ensures(Acc(self.key) and self.key is key and Acc(self.payload) and self.payload is val and Acc(self.leftChild) and self.leftChild is left and Acc(self.rightChild) and self.rightChild is right and Acc(self.parent) and self.parent is parent) @Pure def hasLeftChild(self) -> Optional['TreeNode']: Requires(Acc(self.leftChild)) return self.leftChild @Pure def hasRightChild(self) -> Optional['TreeNode']: Requires(Acc(self.rightChild)) return self.rightChild @Pure def isRoot(self) -> bool: Requires(tree(self)) return Unfolding(tree(self), not self.parent) @Pure def isLeaf(self) -> bool: Requires(tree(self)) return Unfolding(tree(self), not (self.rightChild or self.leftChild)) @Pure def hasAnyChildren(self) -> Optional['TreeNode']: Requires(tree(self)) return Unfolding(tree(self), self.rightChild or self.leftChild) @Pure def hasBothChildren(self) -> Optional['TreeNode']: Requires(tree(self)) return Unfolding(tree(self), self.rightChild and self.leftChild) @Predicate def tree(n : TreeNode) -> bool: return (Acc(n.key) and Acc(n.payload) and Acc(n.leftChild) and Acc(n.rightChild) and Acc(n.parent) and Implies(n.leftChild is not None, tree(n.leftChild) and getParent(n.leftChild) is n) and Implies(n.rightChild is not None, tree(n.rightChild) and getParent(n.rightChild) is n)) @Pure def sorted(n: TreeNode, upper: Optional[int], lower: Optional[int]) -> bool: Requires(tree(n)) return (Unfolding(tree(n), Implies(upper is not None, n.key < upper) and Implies(lower is not None, n.key > lower) and Implies(n.leftChild is not None, sorted(n.leftChild, n.key, lower)) and Implies(n.rightChild is not None, sorted(n.rightChild, upper, n.key)))) @Pure def getParent(node: TreeNode) -> Optional['TreeNode']: Requires(tree(node)) return Unfolding(tree(node), node.parent) class BinarySearchTree: def __init__(self) -> None: self.root = None # type: Optional[TreeNode] self.size = 0 Fold(bst(self)) Ensures(bst(self)) def put(self, key: int, val: str) -> None: Requires(bst(self)) Ensures(bst(self)) Unfold(bst(self)) if self.root: increased_size = self._put(key, val, self.root, None, None) else: self.root = TreeNode(key,val) Fold(tree(self.root)) increased_size = True if increased_size: self.size = self.size + 1 Fold(bst(self)) def _put(self, key: int, val: str, currentNode: TreeNode, upper: Optional[int], lower: Optional[int]) -> bool: Requires(tree(currentNode) and sorted(currentNode, upper, lower)) Requires(Implies(upper is not None, key < upper)) Requires(Implies(lower is not None, key > lower)) Ensures(tree(currentNode) and sorted(currentNode, upper, lower)) Ensures(getParent(currentNode) is Old(getParent(currentNode))) Unfold(tree(currentNode)) res = True if key < currentNode.key: if currentNode.hasLeftChild(): res = self._put(key, val, currentNode.leftChild, currentNode.key, lower) else: currentNode.leftChild = TreeNode(key, val, parent=currentNode) Fold(tree(currentNode.leftChild)) elif key > currentNode.key: if currentNode.hasRightChild(): res = self._put(key, val, currentNode.rightChild, upper, currentNode.key) else: currentNode.rightChild = TreeNode(key, val, parent=currentNode) Fold(tree(currentNode.rightChild)) else: currentNode.payload = val res = False Fold(tree(currentNode)) return res def __setitem__(self, k: int, v: str) -> None: Requires(bst(self)) Ensures(bst(self)) self.put(k,v) def get(self, key: int) -> Optional[str]: Requires(Acc(bst(self))) Ensures(Acc(bst(self))) Unfold(bst(self)) if self.root: res = self._get(key, self.root, 2) Fold(bst(self)) return res else: Fold(bst(self)) return None def _get(self, key: int, currentNode: Optional[TreeNode], perm: int) -> Optional[str]: Requires(perm > 0) Requires(Implies(currentNode is not None, Acc(tree(currentNode), 1/perm))) Ensures(Implies(currentNode is not None, Acc(tree(currentNode), 1/perm))) if not currentNode: return None Unfold(Acc(tree(currentNode), 1/perm)) if currentNode.key == key: res = currentNode.payload elif key < currentNode.key: res = self._get(key, currentNode.leftChild, perm * 2) else: res = self._get(key, currentNode.rightChild, perm * 2) Fold(Acc(tree(currentNode), 1/perm)) return res def __getitem__(self, key: int) -> Optional[str]: Requires(Acc(bst(self))) Ensures(Acc(bst(self))) return self.get(key) @Predicate def bst(t: BinarySearchTree) -> bool: return (Acc(t.root) and Acc(t.size) and Implies(t.root is not None, tree(t.root) and sorted(t.root, None, None))) def print(o: object) -> None: pass mytree = BinarySearchTree() mytree[3]="red" mytree[4]="blue" mytree[6]="yellow" mytree[2]="at" print(mytree[6]) print(mytree[2])