diff --git a/trees/redblacktree/redblacktree.go b/trees/redblacktree/redblacktree.go index 7ea2f5b..d4d3dbc 100644 --- a/trees/redblacktree/redblacktree.go +++ b/trees/redblacktree/redblacktree.go @@ -23,6 +23,7 @@ package redblacktree import ( "github.com/emirpasic/gods/utils" + "log" ) type Color bool @@ -45,17 +46,17 @@ type Node struct { parent *Node } -// Instantiates a red-black tree with the custom comparator +// Instantiates a red-black tree with the custom comparator. func NewWith(comparator utils.Comparator) *Tree { return &Tree{comparator: comparator} } -// Instantiates a red-black tree with the IntComparator, i.e. keys are of type int +// Instantiates a red-black tree with the IntComparator, i.e. keys are of type int. func NewWithIntComparator() *Tree { return &Tree{comparator: utils.IntComparator} } -// Instantiates a red-black tree with the StringComparator, i.e. keys are of type string +// Instantiates a red-black tree with the StringComparator, i.e. keys are of type string. func NewWithStringComparator() *Tree { return &Tree{comparator: utils.StringComparator} } @@ -98,7 +99,7 @@ func (tree *Tree) Put(key interface{}, value interface{}) { // Searches the node in the tree by key and returns its value or nil if key is not found in tree. // Second return parameter is true if key was found, otherwise false. -// Key should adhere to the comparator's type assertion, otherwise method panics +// Key should adhere to the comparator's type assertion, otherwise method panics. func (tree *Tree) Get(key interface{}) (interface{}, bool) { node := tree.lookup(key) if node != nil { @@ -107,8 +108,33 @@ func (tree *Tree) Get(key interface{}) (interface{}, bool) { return nil, false } +// Remove the node from the tree by key. +// Key should adhere to the comparator's type assertion, otherwise method panics. func (tree *Tree) Remove(key interface{}) { - + var child *Node + node := tree.lookup(key) + if node == nil { + return + } + if node.left != nil && node.right != nil { + pred := node.left.maximumNode() + node.key = pred.key + node.value = pred.value + node = pred + } + if node.right == nil { + child = node.left + } else { + child = node.right + } + if node.color == BLACK { + node.color = child.color + tree.deleteCase1(node) + } + tree.replaceNode(node, child) + if node.parent == nil && child != nil { + child.color = BLACK + } } // Returns true if tree does not contain any nodes @@ -204,6 +230,7 @@ func (tree *Tree) insertCase2(node *Node) { } func (tree *Tree) insertCase3(node *Node) { + log.Printf("%#v\n", node) if node.uncle().color == RED { node.parent.color = BLACK node.uncle().color = BLACK @@ -234,3 +261,94 @@ func (tree *Tree) insertCase5(node *Node) { tree.rotateLeft(node.grandparent()) } } + +func (node *Node) maximumNode() *Node { + for node.right != nil { + node = node.right + } + return node +} + +func (node *Node) sibling() *Node { + if node == node.parent.left { + return node.parent.right + } else { + return node.parent.left + } +} + +func (tree *Tree) deleteCase1(node *Node) { + if node.parent == nil { + return + } else { + tree.deleteCase2(node) + } +} + +func (tree *Tree) deleteCase2(node *Node) { + if node.sibling().color == RED { + node.parent.color = RED + node.sibling().color = BLACK + if node == node.parent.left { + tree.rotateLeft(node.parent) + } else { + tree.rotateRight(node.parent) + } + } + tree.deleteCase3(node) +} + +func (tree *Tree) deleteCase3(node *Node) { + if node.parent.color == BLACK && + node.sibling().color == BLACK && + node.sibling().left.color == BLACK && + node.sibling().right.color == BLACK { + node.sibling().color = RED + tree.deleteCase1(node.parent) + } else { + tree.deleteCase4(node) + } +} + +func (tree *Tree) deleteCase4(node *Node) { + if node.parent.color == RED && + node.sibling().color == BLACK && + node.sibling().left.color == BLACK && + node.sibling().right.color == BLACK { + node.sibling().color = RED + node.parent.color = BLACK + } else { + tree.deleteCase5(node) + } +} + +func (tree *Tree) deleteCase5(node *Node) { + if node == node.parent.left && + node.sibling().color == BLACK && + node.sibling().left.color == RED && + node.sibling().right.color == BLACK { + node.sibling().color = RED + node.sibling().left.color = BLACK + tree.rotateRight(node.sibling()) + } else if node == node.parent.right && + node.sibling().color == BLACK && + node.sibling().right.color == RED && + node.sibling().left.color == BLACK { + node.sibling().color = RED + node.sibling().right.color = BLACK + tree.rotateLeft(node.sibling()) + } + tree.deleteCase6(node) +} + +func (tree *Tree) deleteCase6(node *Node) { + node.sibling().color = node.parent.color + node.parent.color = BLACK + if node == node.parent.left { + node.sibling().right.color = BLACK + tree.rotateLeft(node.parent) + } else { + node.sibling().left.color = BLACK + tree.rotateRight(node.parent) + } +} diff --git a/trees/redblacktree/redblacktree_test.go b/trees/redblacktree/redblacktree_test.go index 031c762..75e2389 100644 --- a/trees/redblacktree/redblacktree_test.go +++ b/trees/redblacktree/redblacktree_test.go @@ -8,6 +8,7 @@ func TestPutGet(t *testing.T) { tree := NewWithIntComparator() + tree.Put(5, "e") tree.Put(3, "c") tree.Put(4, "d") tree.Put(1, "x")