Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions src/rbtree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Implementing Red-Black Tree in Python
# Adapted from https://www.programiz.com/dsa/red-black-tree

import sys
from typing import Type, TypeVar, Iterator


Expand Down Expand Up @@ -81,6 +80,18 @@ def __getitem__(self: T, key: int) -> int:
def __setitem__(self: T, key: int, value: int) -> None:
self.search(key).value = value

def __str__(self: T) -> str:
node = self.root
output = ""
s_color = "RED" if node.is_red() else "BLACK"
output += str(node.get_key()) + "(" + s_color + ")\n"
output += self.__print_helper(node.left, " ", False)
output += self.__print_helper(node.right, " ", True)
return output

def __len__(self: T) -> int:
return self.size

# Setters and Getters #
def get_root(self: T) -> Node:
return self.root
Expand Down Expand Up @@ -295,20 +306,22 @@ def fix_insert(self: T, node: Node) -> None:
self.root.set_color("black")

# Printing the tree
def __print_helper(self: T, node: Node, indent: str, last: bool) -> None:
def __print_helper(self: T, node: Node, indent: str, last: bool) -> str:
output = ""
if not node.is_null():
sys.stdout.write(indent)
output += indent
if last:
sys.stdout.write("R---- ")
output += "R---- "
indent += " "
else:
sys.stdout.write("L---- ")
output += "L---- "
indent += "| "

s_color = "RED" if node.is_red() else "BLACK"
print(str(node.get_key()) + "(" + s_color + ")")
self.__print_helper(node.left, indent, False)
self.__print_helper(node.right, indent, True)
output += str(node.get_key()) + "(" + s_color + ")" + '\n'
output += self.__print_helper(node.left, indent, False)
output += self.__print_helper(node.right, indent, True)
return output

def search(self: T, key: int) -> Node:
return self.search_tree_helper(self.root, key)
Expand Down Expand Up @@ -423,4 +436,4 @@ def delete(self: T, key: int) -> None:
self.delete_node_helper(self.root, key)

def print_tree(self: T) -> None:
self.__print_helper(self.root, "", True)
print(str(self))
11 changes: 11 additions & 0 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ def test_null_node() -> None:
null = Node.null()
assert null.is_null()
assert null.is_black()


def test_repr() -> None:
"""
Test:
Key only
Key/Value
Null
"""
node = Node(2)
assert repr(node) == "Key: 2 Value: None"
76 changes: 46 additions & 30 deletions tests/test_rbtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_insert() -> None:
bst.insert(109)
bst.insert(102)

assert bst.size == 23
assert len(bst) == 23

check_valid(bst)

Expand Down Expand Up @@ -112,24 +112,25 @@ def test_delete() -> None:
bst.insert(58)
bst.insert(42)

assert bst.size == 12
assert len(bst) == 12

bst.delete(48)
assert bst.size == 11
assert len(bst) == 11
bst.delete(42)
assert bst.size == 10
assert len(bst) == 10
bst.delete(42)
assert bst.size == 9
assert len(bst) == 9

assert bst.search(42).get_key() == 42
bst.delete(42)
assert bst.search(42).is_null()
assert bst.size == 8
assert len(bst) == 8
bst.delete(100)
assert bst.size == 7
assert len(bst) == 7

bst.delete(100)

assert bst.size == 7
assert len(bst) == 7
check_valid(bst)


Expand Down Expand Up @@ -192,14 +193,13 @@ def test_accessors() -> None:

bst.insert(57)
assert bst.predecessor(bst.search(57)).get_key() == 55


def test_preorder() -> None:
bst = RedBlackTree()
bst.insert(1)
bst.insert(2)
bst.insert(3)

nodes = bst.preorder()
keys = []
for node in nodes:
Expand Down Expand Up @@ -257,25 +257,6 @@ def test_iterator_exception() -> None:
bst.set_iteration_style("spam")


def test_print() -> None:
bst = RedBlackTree()
bst.insert(73)
print(bst.get_root())
bst.insert(48)
bst.insert(100)
bst.insert(42)
bst.insert(55)
bst.insert(40)
bst.insert(58)
bst.insert(42)
bst.insert(55)
bst.insert(40)
bst.insert(58)
bst.insert(42)

bst.print_tree()


def test_elaborate_delete() -> None:
bst = RedBlackTree()
bst.insert(55)
Expand Down Expand Up @@ -355,3 +336,38 @@ def test_duplicates() -> None:
bst.delete(42)
bst.delete(42)
check_valid(bst)


print_data = [
[
[1, 2, 3],
(""
+ "2(BLACK)\n"
+ " L---- 1(RED)\n"
+ " R---- 3(RED)\n")
],
[
[73, 48, 100, 42, 55, 40, 58, 42, 55, 40, 58, 42],
(""
+ "48(BLACK)\n"
+ " L---- 42(BLACK)\n"
+ " | L---- 40(BLACK)\n"
+ " | | R---- 40(RED)\n"
+ " | R---- 42(BLACK)\n"
+ " | R---- 42(RED)\n"
+ " R---- 73(BLACK)\n"
+ " L---- 55(RED)\n"
+ " | L---- 55(BLACK)\n"
+ " | R---- 58(BLACK)\n"
+ " | R---- 58(RED)\n"
+ " R---- 100(BLACK)\n")
]
]


@pytest.mark.parametrize("input, expected", print_data)
def test_print(input: list, expected: str) -> None:
bst = RedBlackTree()
for key in input:
bst.insert(key)
assert str(bst) == expected