DEV Community

Alessandro Diaferia
Alessandro Diaferia

Posted on

A concatenated iterator example in Kotlin

Some times you might need to iterate through multiple lists. This happens, for example, when implementing a tree traversal algorithm. Thanks to tail recursion and Kotlin, a trivial implementation can be as follows:

tailrec fun countAll(nodes: List<Node>, acc: Int): Int {
    return when {
        nodes.isEmpty() -> acc
        else -> countAll(nodes.first().children + nodes.subList(1, nodes.size), acc + 1)
    }
}

You start by invoking countAll on the root node (e.g. countAll(listOf(root), 1)) and the function will accumulate all the children iteratively until all of them are traversed and counted.

The problem with this approach is that it spends most of its time producing the result of the concatenation by copying all the elements of the operands into a new list.

Flamegraph - naive countAll

As you can see from the flame-graph, the majority of the time is spent on constructing the ArrayList instance and invoking addAll.

The reason we are concatenating lists here is only to be able to keep iterating through our collection of nodes as more items get added.

Let's make it better

An alternative approach we could use in order to avoid the overhead of concatenating the lists is by taking advantage of an iterator that can cross the list boundaries.

If we had one, we could rewrite our countAll function as follows:

tailrec fun countAll(nodes: Iterator<Node>, acc: Int): Int {
    return when {
        !nodes.hasNext() -> acc
        else -> {
            val node = nodes.next()
            countAll(nodes + node.children.iterator(), acc + 1)
        }
    }
}

🏃‍♂️ Time to code!

In our ideal implementation, the plus operator will take care of storing the iterator for the children list without preemptively copying all the items in a new list.

Let's take a look at the iterator implementation:

class ConcatIterator<T>(iterator: Iterator<T>) : Iterator<T> {
    private val store = ArrayDeque<Iterator<T>>()

    init {
        if (iterator.hasNext())
            store.add(iterator)
    }

    override fun hasNext(): Boolean = when {
        store.isEmpty() -> false
        else -> store.first.hasNext()
    }

    override fun next(): T {
        val t = store.first.next()

        if (!store.first.hasNext())
            store.removeFirst()

        return t
    }

    operator fun plus(iterator: Iterator<T>): ConcatIterator<T> {
        if (iterator.hasNext())
            store.add(iterator)
        return this
    }
}

As the iterators get concatenated they get enqueued in the store. In this case, the store uses an ArrayDeque which is a lightweight non-concurrent implementation of a Queue that performs the majority of its operations in amortized constant time.

The last touch

The final addition to make working with the iterator a little more comfortable is achieved by implementing the plus operator on Iterator.

operator fun <T> Iterator<T>.plus(iterator: Iterator<T>): ConcatIterator<T> =
    when {
        this is ConcatIterator<T> -> this.plus(iterator)
        iterator is ConcatIterator<T> -> iterator.plus(this)
        else -> ConcatIterator(this).plus(iterator)
    }

This extension function helps us easily concatenate two existing iterators into a new one by producing a ConcatIterator. The nice thing about this function is that it reuses an existing ConcatIterator instance if available among the two instances.

Let the numbers speak

Now that we have implemented our alternative version of the countAll function let's see how it performs.

I've tested my assumptions using a little Kotlin playground that I've created to experiment with trees. You can find it here.

The following results come from testing the two implementations against a tree with 65201277 nodes.

Total count (countAll without ConcatIterator): 65201277 nodes (9227 ms)
Total count (countAll with ConcatIterator): 65201277 nodes (1288 ms)

As you can see, the ConcatIterator version is almost 10 times faster. We're not incurring in the overhead of concatenating lists anymore so the majority of the computation is spent performing the counting.

Conclusion

I hope you enjoyed. Let me know how you approach this kind of problem and if there are better ways of achieving this in Kotlin.

Cheers!

Top comments (2)

Collapse
 
nanodeath profile image
Max

Hm...what about

fun countAll(nodes: List<Node>): Int
    = nodes.size + nodes.sumBy { countAll(it.children) }

?

I'm more curious about how you came up with that flamegraph :) Looks useful.

Collapse
 
alediaferia profile image
Alessandro Diaferia

Hi Max A. That recursive version is definitely more efficient than my tail recursive one as its stack size is smaller and doesn't have the overhead of the accumulation. What I wanted to focus on is the iteration across multiple lists and couldn't come up with a better example. Regarding the profiling tool, it's a pretty useful one and it's here: github.com/jvm-profiling-tools/asy...

Enjoy :)