diff --git a/trees/binaryheap/binaryheap.go b/trees/binaryheap/binaryheap.go index 9320fc4..c5b2e39 100644 --- a/trees/binaryheap/binaryheap.go +++ b/trees/binaryheap/binaryheap.go @@ -33,6 +33,7 @@ package binaryheap import ( "fmt" + "github.com/emirpasic/gods/containers" "github.com/emirpasic/gods/lists/arraylist" "github.com/emirpasic/gods/trees" "github.com/emirpasic/gods/utils" @@ -41,6 +42,7 @@ import ( func assertInterfaceImplementation() { var _ trees.Tree = (*Heap)(nil) + var _ containers.Iterator = (*Iterator)(nil) } type Heap struct { @@ -109,6 +111,29 @@ func (heap *Heap) Values() []interface{} { return heap.list.Values() } +type Iterator struct { + heap *Heap + index int +} + +func (heap *Heap) Iterator() Iterator { + return Iterator{heap: heap, index: -1} +} + +func (iterator *Iterator) Next() bool { + iterator.index += 1 + return iterator.heap.withinRange(iterator.index) +} + +func (iterator *Iterator) Value() interface{} { + value, _ := iterator.heap.list.Get(iterator.index) + return value +} + +func (iterator *Iterator) Index() interface{} { + return iterator.index +} + func (heap *Heap) String() string { str := "BinaryHeap\n" values := []string{} @@ -158,3 +183,8 @@ func (heap *Heap) bubbleUp() { index = parentIndex } } + +// Check that the index is withing bounds of the list +func (heap *Heap) withinRange(index int) bool { + return index >= 0 && index < heap.list.Size() +} diff --git a/trees/binaryheap/binaryheap_test.go b/trees/binaryheap/binaryheap_test.go index 473526f..66c80e1 100644 --- a/trees/binaryheap/binaryheap_test.go +++ b/trees/binaryheap/binaryheap_test.go @@ -103,7 +103,50 @@ func TestBinaryHeap(t *testing.T) { } prev = curr } +} + +func TestBinaryHeapIterator(t *testing.T) { + heap := NewWithIntComparator() + + if actualValue := heap.Empty(); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } + + // insertions + heap.Push(3) + // [3] + heap.Push(2) + // [2,3] + heap.Push(1) + // [1,3,2](2 swapped with 1, hence last) + // Iterator + it := heap.Iterator() + for it.Next() { + index := it.Index() + value := it.Value() + switch index { + case 0: + if actualValue, expectedValue := value, 1; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + case 1: + if actualValue, expectedValue := value, 3; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + case 2: + if actualValue, expectedValue := value, 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + default: + t.Errorf("Too many") + } + } + heap.Clear() + it = heap.Iterator() + for it.Next() { + t.Errorf("Shouldn't iterate on empty stack") + } } func BenchmarkBinaryHeap(b *testing.B) {