diff --git a/src/rbtree.py b/src/rbtree.py index f910cd4..22e85f6 100644 --- a/src/rbtree.py +++ b/src/rbtree.py @@ -1,7 +1,6 @@ # Implementing Red-Black Tree in Python # Adapted from https://www.programiz.com/dsa/red-black-tree -import sys from typing import Type, TypeVar, Iterator @@ -81,6 +80,18 @@ def __getitem__(self: T, key: int) -> int: def __setitem__(self: T, key: int, value: int) -> None: self.search(key).value = value + def __str__(self: T) -> str: + node = self.root + output = "" + s_color = "RED" if node.is_red() else "BLACK" + output += str(node.get_key()) + "(" + s_color + ")\n" + output += self.__print_helper(node.left, " ", False) + output += self.__print_helper(node.right, " ", True) + return output + + def __len__(self: T) -> int: + return self.size + # Setters and Getters # def get_root(self: T) -> Node: return self.root @@ -295,20 +306,22 @@ def fix_insert(self: T, node: Node) -> None: self.root.set_color("black") # Printing the tree - def __print_helper(self: T, node: Node, indent: str, last: bool) -> None: + def __print_helper(self: T, node: Node, indent: str, last: bool) -> str: + output = "" if not node.is_null(): - sys.stdout.write(indent) + output += indent if last: - sys.stdout.write("R---- ") + output += "R---- " indent += " " else: - sys.stdout.write("L---- ") + output += "L---- " indent += "| " s_color = "RED" if node.is_red() else "BLACK" - print(str(node.get_key()) + "(" + s_color + ")") - self.__print_helper(node.left, indent, False) - self.__print_helper(node.right, indent, True) + output += str(node.get_key()) + "(" + s_color + ")" + '\n' + output += self.__print_helper(node.left, indent, False) + output += self.__print_helper(node.right, indent, True) + return output def search(self: T, key: int) -> Node: return self.search_tree_helper(self.root, key) @@ -423,4 +436,4 @@ def delete(self: T, key: int) -> None: self.delete_node_helper(self.root, key) def print_tree(self: T) -> None: - self.__print_helper(self.root, "", True) + print(str(self)) \ No newline at end of file diff --git a/tests/test_node.py b/tests/test_node.py index 0ad9ffc..f165e5e 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -66,3 +66,14 @@ def test_null_node() -> None: null = Node.null() assert null.is_null() assert null.is_black() + + +def test_repr() -> None: + """ + Test: + Key only + Key/Value + Null + """ + node = Node(2) + assert repr(node) == "Key: 2 Value: None" \ No newline at end of file diff --git a/tests/test_rbtree.py b/tests/test_rbtree.py index 7bdb21f..fbfc66c 100644 --- a/tests/test_rbtree.py +++ b/tests/test_rbtree.py @@ -80,7 +80,7 @@ def test_insert() -> None: bst.insert(109) bst.insert(102) - assert bst.size == 23 + assert len(bst) == 23 check_valid(bst) @@ -112,24 +112,25 @@ def test_delete() -> None: bst.insert(58) bst.insert(42) - assert bst.size == 12 + assert len(bst) == 12 bst.delete(48) - assert bst.size == 11 + assert len(bst) == 11 bst.delete(42) - assert bst.size == 10 + assert len(bst) == 10 bst.delete(42) - assert bst.size == 9 + assert len(bst) == 9 + assert bst.search(42).get_key() == 42 bst.delete(42) assert bst.search(42).is_null() - assert bst.size == 8 + assert len(bst) == 8 bst.delete(100) - assert bst.size == 7 + assert len(bst) == 7 bst.delete(100) - assert bst.size == 7 + assert len(bst) == 7 check_valid(bst) @@ -192,14 +193,13 @@ def test_accessors() -> None: bst.insert(57) assert bst.predecessor(bst.search(57)).get_key() == 55 - - + + def test_preorder() -> None: bst = RedBlackTree() bst.insert(1) bst.insert(2) bst.insert(3) - nodes = bst.preorder() keys = [] for node in nodes: @@ -257,25 +257,6 @@ def test_iterator_exception() -> None: bst.set_iteration_style("spam") -def test_print() -> None: - bst = RedBlackTree() - bst.insert(73) - print(bst.get_root()) - bst.insert(48) - bst.insert(100) - bst.insert(42) - bst.insert(55) - bst.insert(40) - bst.insert(58) - bst.insert(42) - bst.insert(55) - bst.insert(40) - bst.insert(58) - bst.insert(42) - - bst.print_tree() - - def test_elaborate_delete() -> None: bst = RedBlackTree() bst.insert(55) @@ -355,3 +336,38 @@ def test_duplicates() -> None: bst.delete(42) bst.delete(42) check_valid(bst) + + +print_data = [ + [ + [1, 2, 3], + ("" + + "2(BLACK)\n" + + " L---- 1(RED)\n" + + " R---- 3(RED)\n") + ], + [ + [73, 48, 100, 42, 55, 40, 58, 42, 55, 40, 58, 42], + ("" + + "48(BLACK)\n" + + " L---- 42(BLACK)\n" + + " | L---- 40(BLACK)\n" + + " | | R---- 40(RED)\n" + + " | R---- 42(BLACK)\n" + + " | R---- 42(RED)\n" + + " R---- 73(BLACK)\n" + + " L---- 55(RED)\n" + + " | L---- 55(BLACK)\n" + + " | R---- 58(BLACK)\n" + + " | R---- 58(RED)\n" + + " R---- 100(BLACK)\n") + ] +] + + +@pytest.mark.parametrize("input, expected", print_data) +def test_print(input: list, expected: str) -> None: + bst = RedBlackTree() + for key in input: + bst.insert(key) + assert str(bst) == expected