Generalizing Dijkstra's Algorithm

Earlier this week, I wrote a simplified implementation of Dijkstra’s algorithm. You can read the article for details or just look at the full code implementation here on GitHub. This implementation is fine to use within a particular project for a particular purpose, but it doesn’t generalize very well. Today we’ll explore how to make this idea more general.

I chose to do this without looking at any existing implementations of Dijkstra’s algorithm in Haskell libraries to see how my approach would be different. So at the end of this series I’ll also spend some time comparing my approach to some other ideas that exist in the Haskell world.

Parameterizing the Graph Type

So why doesn’t this approach generalize? Well, for the obvious reason that my module defines a specific type for the graph:

data Graph = Graph
  { graphEdges :: HashMap String [(String, Int)] }

So, for someone else to re-use this code from the perspective of a different project, they would have to take whatever graph information they had, and turn it into this specific type. And their data might not map very well into String values for the nodes, and they might also have a different cost type in mind than the simple Int value, with Double being the most obvious example.

So we could parameterize the graph type and allow more customization of the underlying values.

data Graph node cost = Graph
  { graphEdges :: HashMap node [(node, cost)] }

The function signature would have to change to reflect this, and we would have to impose some additional constraints on these types:

findShortestDistance :: (Hashable node, Eq node, Num cost, Ord cost) =>
  Graph node cost  -> node -> node -> Distance cost

Graph Enumeration

But this would still leave us with an important problem. Sometimes you don’t want to have to enumerate the whole graph. As is, the expression you submit as the "graph" to the function must have every edge enumerated, or it won’t give you the right answer. But many times, you won’t want to list every edge because they are so numerous. Rather, you want to be able to list every edge simply from a particular node. For example:

edgesForNode :: Graph node cost -> node -> [(node, cost)]

How can we capture this behavior more generally in Haskell?

Using a Typeclass

Well one of the tools we typically turn to for this task is a typeclass. We might want to define something like this:

class DijkstraGraph graph where
  dijkstraEdges :: graph node cost -> node -> [(node, cost)]

However, it quickly gets a bit strange to try to do this with a simple typeclass because of the node and cost parameters. It’s difficult to resolve the constraints we end up needing because these parameters aren’t really part of our class.

Using a Multi-Param Typeclass

We could instead try having a multi-param typeclass like this:

{-# LANGUAGE MultiParamTypeClasses #-}

class DijkstraGraph graph node cost where
  dijkstraEdges :: graph -> node -> [(node, cost)]

This actually works more smoothly than the previous try. We can construct an instance (if we allow flexible instances).

{-# LANGUAGE FlexibleInstances #-}

import qualified Data.HashMap as HM
import Data.Maybe (fromMaybe)

instance DijkstraGraph (Graph String Int) String Int where
  dijkstraEdges g n = fromMaybe [] (HM.lookup n (edges g))

And we can actually use this class in our function now! It mainly requires changing around a few of our type signatures. We can start with our DijkstraState type, which must now be parameterized by the node and cost:

data DijkstraState node cost = DijkstraState
  { visitedSet :: HashSet node
  , distanceMap :: HashMap node (Distance cost)
  , nodeQueue :: MinPrioHeap (Distance cost) node
  }

And, of course, we would also like to generalize the type signature of our findShortestDistance function. In its simplest form, we would like use this:

findShortestDistance :: graph -> node -> node -> Distance cost

However, a couple extra items are necessary to make this work. First, as above, our function is the correct place to assign constraints to the node and cost types. The node type must fit into our hashing structures, so it should fulfill Eq and Hashable. The cost type must be Ord and Num in order for us to perform our addition operations and use it for the heap. And last of course, we have to add the constraint regarding the DijkstraGraph itself:

findShortestDistance ::
  (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph) =>
  graph -> node -> node -> Distance cost

Now, if we want to use the graph, node, and cost types within the “inner” type signatures of our function, we need one more thing. We need a forall specifier on the function so that the compiler knows we are referring to the same types.

{-# LANGUAGE ScopedTypeVariables #-}

findShortestDistance :: forall graph node cost.
  (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph) =>
  graph -> node -> node -> Distance cost

We can now make one change to our function so that it works with our class.

processQueue :: DijkstraState node cost -> HashMap node (Distance cost)
processQueue = ...
  -- Previously
  -- allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
  -- Updated
  allNeighbors = dijkstraEdges graph node

And now we’re done! We can again, verify the behavior. However, we do run into some difficulties in that we need some extra type specifiers to help the compiler figure everything out.

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

...

>> :set -XFlexibleContexts
>> findShortestDistance graph1 :: Distance Int
Dist 40

Conclusion

Below in the appendix is the full code for this part. You can also take a look at it on Github here.

For various reasons, I don’t love this attempt at generalizing. I especially don't like the "re-statement" of the parameter types in the instance. The parameters are part of the Graph type and are separately parameters of the class. This is what leads to the necessity of specifying the Distance Int type in the GHCI session above. We could avoid this if we don't parameterize our Graph type, which is definitely an option.

In the next part of this series, we'll make a second attempt at generalizing this algorithm!

Appendix

{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Dijkstra2 where

import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

newtype Graph node cost = Graph
   { edges :: HashMap node [(node, cost)] }

class DijkstraGraph graph node cost where
    dijkstraEdges :: graph -> node -> [(node, cost)]

instance DijkstraGraph (Graph String Int) String Int where
    dijkstraEdges g n = fromMaybe [] (HM.lookup n (edges g))

data DijkstraState node cost = DijkstraState
  { visitedSet :: HashSet node
  , distanceMap :: HashMap node (Distance cost)
  , nodeQueue :: MinPrioHeap (Distance cost) node
  }

findShortestDistance :: forall graph node cost. (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph node cost) => graph -> node -> node -> Distance cost
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue

    processQueue :: DijkstraState node cost -> HashMap node (Distance cost)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = dijkstraEdges graph node
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
    foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
      let altDistance = addDist (d0 !?? current) (Dist cost)
      in  if altDistance < d0 !?? neighborNode
            then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
            else ds

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]
Previous
Previous

Dijkstra with Type Families

Next
Next

Dijkstra's Algorithm in Haskell