Two weeks ago I wrote about Overloaded:Categories
, and mentioned two brief examples. In this post I will walk through a more complete example. The topic is bitonic sorters.
Such domain "networks with N-inputs and M-outputs" are unfairly well suited for category representation. But first let me show few pictures^{1}. We'll see how this pictures are drawn from the same code as computation. The horizontal lines represent signals. The vertical lines represent compare-and-swap operator. I.e. for this two inputs we apply a function
cas :: Ord a => (a,a) -> (a,a)
cas (x,y) = if x < y then (x,y) else (y,x)
For four inputs:
Eight inputs:
Or sixteen inputs:
In the last network, there are 80 comparators. The network is not optimal, there is known network using just 60 comparators, you can find its diagram in The Art Of Computer Programming. Knuth also explains how bitonic sort networks are constructed i.e. why they work.
I should immediately point out that this wouldn't be doable with arrows, as the categories have Functor
s as objects i.e. type constructors of kind Type -> Type
, not just Type
s.
Lets define a super-concise names for well-known functors. We will see soon why this is a good idea.
newtype I a = I a -- Identity
data U a = U -- Proxy
data (f :*: g) a = f a :*: g a -- Product
Next we define one useful example category type:
newtype A x f g = A { runA :: f x -> g x }
Note, this is not a natural transformation, as the x
argument is not closed over (no RankNTypes
). We'll need it "outside", so we can require Ord x
needed for sorting. The type argument is unorthodox, but this way it fits into Category
class. The A x
instances resemble the (->)
instances:
instance Category (A x) where
id = A id
A f . A g = A (f . g)
instance CategoryWith1 (A x) where
type Terminal (A x) = U
terminal = A (\_ -> U)
instance CartesianCategory (A x) where
type Product (A x) = (:*:)
proj1 = A (\(x :*: _) -> x)
proj2 = A (\(_ :*: y) -> y)
fanout (A f) (A g) = A (\x -> f x :*: g x)
We also define a new ad-hoc type-class, SortingNetwork
. It basically says that sorting networks are cartesian categories, where the categorical product is the functor product :*:
. The block
operator is used to place wires on the diagrams, otherwise it does nothing.
class (CartesianCategory cat, Product cat ~ (:*:))
=> SortingNetwork cat
where
block :: cat a b -> cat a b
block = id
Bitonic sorter networks are defined recursively. We'll use Haskell's type-class machinery to synthesize them for us. Before we can dive into actual sorter, we need two helper primitives:
class Pairwise f g where
pairwise :: SortingNetwork cat
=> cat (I :*: I) (I :*: I)
-> cat (f :*: g) (f :*: g)
class Opposite f g where
opposite :: SortingNetwork cat
=> cat (I :*: I) (I :*: I)
-> cat (f :*: g) (f :*: g)
These primitives are quite similar. Given some primitive two-wire function (like compare-and-swap), the pairwise
connects wires in two blocks in order:
and opposite
connects them in opposite order:
Both operations are parallel if the "hardware" supports it. (Unfortunately we cannot take advantage of that in Haskell, at least easily).
The instances are simple to define. The base case is I I
. Then we can return the primitive function directly:
instance Pairwise I I where
pairwise = id
The induction case is for product :*:
. We require that halves of product can be connected pairwise. Because of this definition we essentially can make only networks with size . That is not a real issue, as we will see later.
The instance definition is easy using Overloaded:Categories
. Manual definition wouldn't be as pretty:
instance (Pairwise x u, Pairwise y v) => Pairwise (x :*: y) (u :*: v) where
pairwise f = proc ((x,y), (u,v)) -> do
(x',u') <- pairwise f -< (x,u)
(y',v') <- pairwise f -< (y,v)
identity -< ((x',y'), (u',v'))
The Opposite
instances are defined similarly.
If you have already examined the bitonic sorter networks, you noticed that they are composed from smaller Pairwise
and Opposite
blocks. So this is all we need to define bitonic sorter networks.
Bitonic sorter is kind-of merge-sort. First we sort the two halves, and then we merge them. Thus the type-class for them has two methods:
class BitonicSort f where
bitonicSort :: SortingNetwork cat
=> cat (I :*: I) (I :*: I)
-> cat f f
bitonicMerge :: SortingNetwork cat
=> cat (I :*: I) (I :*: I)
-> cat f f
The base case for I
is easy. With only one wire there is nothing to sort:
instance BitonicSort I where
bitonicSort _ = identity
bitonicMerge _ = identity
The induction step shows the structure.
The sort is recursive:
The merge is also recursive:
instance (Pairwise f g, Opposite f g, BitonicSort f, BitonicSort g)
=> BitonicSort (f :*: g)
where
bitonicSort cas =
block (bitonicSort cas *** bitonicSort cas) >>>
block (opposite cas) >>>
block (bitonicMerge cas *** bitonicMerge cas)
bitonicMerge cas =
block (pairwise cas) >>>
block (bitonicMerge cas *** bitonicMerge cas)
(***) :: CartesianCategory cat
=> cat a b -> cat c d -> cat (Product cat a c) (Product cat b d)
f *** g = fanout (proj1 >>> f) (proj2 >>> g)
Here I haven't used Overloaded:Categories
, as it doesn't offer much help. Here point-free approach is clearer.
As we have have networks defined for all sufficient categories, we can use the already define concrete instance A x
. We simply need to instantiate bitonicSort
, with a primitive compare-and-swap element:
cas :: A a (I :*: I) (I :*: I)
cas = A $ \(I x :*: I y) ->
if x <= y
then I x :*: I y
else I y :*: I x
Then the sort is just plumbing of data constructors. To avoid using tuples in the result I define function in continuation passing style.
sort4 :: Ord a
=> a -> a -> a -> a
-> (a -> a -> a -> a -> r)
-> r
sort4 a b c d kont =
let input = (I a :*: I b ) :*: (I c :*: I d)
(I a' :*: I b') :*: (I c' :*: I d') = runA (bitonicSort cas) input
in kont a' b' c' d'
We can QuickCheck
this function:
quickCheck $ \a b c d ->
sort4 a b c d $ \a' b' c' d' ->
sort [a,b,c,d] === [a',b',c',d' :: Int]
It works perfectly.
Below is the Core for sort4
. There are no allocations, just jump
s around (and six compare-and-swaps as in the diagram). There are few casts as types of some continuations are mentioning I
. The :*:
constructors are however gone. (I was positively surprised how well GHC does job here).
$workerSort4
$workerSort4
= \ @ a
@ r
lessEqual
x1 x2 x3 x4
kont ->
join {
$kont1 y1 y2
= join {
$kont2 z1 z2 z3 z4
= join {
$kont3 w1 w2 w3 w4
= join {
$kont4 q1 q2
= case lessEqual (w2 `cast` <Co:2>) (w4 `cast` <Co:2>)
of {
False ->
kont
(q1 `cast` <Co:2>)
(q2 `cast` <Co:2>)
(w4 `cast` <Co:2>)
(w2 `cast` <Co:2>);
True ->
kont
(q1 `cast` <Co:2>)
(q2 `cast` <Co:2>)
(w2 `cast` <Co:2>)
(w4 `cast` <Co:2>)
} } in
case lessEqual (w3 `cast` <Co:2>) (w1 `cast` <Co:2>) of {
False -> jump $kont4 w1 w3;
True -> jump $kont4 w3 w1
} } in
case lessEqual (z1 `cast` <Co:2>) (z4 `cast` <Co:2>) of {
False ->
case lessEqual (z2 `cast` <Co:2>) (z3 `cast` <Co:2>) of {
False -> jump $kont3 z3 z2 z4 z1;
True -> jump $kont3 z2 z3 z4 z1
};
True ->
case lessEqual (z2 `cast` <Co:2>) (z3 `cast` <Co:2>) of {
False -> jump $kont3 z3 z2 z1 z4;
True -> jump $kont3 z2 z3 z1 z4
}
} } in
case lessEqual x3 x4 of {
False -> jump $kont2 y1 y2 (x4 `cast` <Co:3>) (x3 `cast` <Co:3>);
True -> jump $kont2 y1 y2 (x3 `cast` <Co:3>) (x4 `cast` <Co:3>)
} } in
case lessEqual x1 x2 of {
False -> jump $kont1 (x2 `cast` <Co:3>) (x1 `cast` <Co:3>);
True -> jump $kont1 (x1 `cast` <Co:3>) (x2 `cast` <Co:3>)
}
It's great that GHC optimises the above example. But will it do as great job for network of 16 inputs? You don't know, untill you try and inspect the produced Core.
There is a way to be sure, by constructing good code: Template Haskell. The commonly used Template Haskell is untyped and error prone, but there is also Typed Template Haskell variant. It is quite simple (to the first approximation). Instead of single delimiters like [| Just $x |] :: ExpQ
, we use double delimiters:
[|| Just $$x ||] :: TExpQ (Maybe Int) -- if x :: TExpQ Int
There is one trick to know. Instead of direct result returning as we had with A x
we need to use continuation-passing-style.
type Code a = TExpQ a
newtype Staged x f g = St
{ runStaged :: forall r. f (Code x) -> (g (Code x) -> Code r) -> Code r }
Rewriting Category
and CartesianCategory
instances is an easy exercise. But why we need CPS? That can be illustrated by an implementation of cas
.
cas :: Code (a -> a -> Bool) -> Staged a (I :*: I) (I :*: I)
cas le = St $ \(I x :*: I y) kont ->
[|| let joinP a b = $$(kont (I [|| a ||] :*: I [|| b ||]))
in
if $$le $$x $$y
then joinP $$x $$y
else joinP $$y $$x
||]
The continuation provides the code "which will happen later", so we can generate let
(which is recognized as a join point). If we used direct style, there wouldn't be a way to generate code like that.
Notice that we pass comparator as Code (a -> a -> Bool)
. This is because currently TTH doesn't work well with type-class methods. This is a workaround.
With cas
done and Staged
having category instances we can instantiate bitonicSort
. This code is the same as with non-TH version previously.
sort4code
:: Code (a -> a -> Bool)
-> Code a -> Code a -> Code a -> Code a
-> (Code a -> Code a -> Code a -> Code a -> Code r)
-> Code r
sort4code le a b c d kont =
let input = (I a :*: I b ) :*: (I c :*: I d)
in runStaged (bitonicSort $ cas le) input $
\((I a' :*: I b') :*: (I c' :*: I d')) ->
kont a' b' c' d'
And then we can use it. We generate the sorting code using local names bound by sort4
definition:
sort4 :: (a -> a -> Bool)
-> a -> a -> a -> a
-> (a -> a -> a -> a -> r)
-> r
sort4 le a b c d kont =
$$(sort4code [|| le ||] [|| a ||] [|| b ||] [|| c ||] [|| d ||] $
\a' b' c' d' -> [|| kont $$a' $$b' $$c' $$d' ||]
)
We can examine first what we generate, using -ddump-splices
(The names are cleaned up):
lessThan x1 x2 x3 x4 kont =
let
joinP0 y1 y2
= let
joinP1 y3 y4
= let
joinP2 z1 z2
= let
joinP3 z3 z4
= let
joinP4 w1 w2
= let
joinP5 w3 w4
= kont w1 w2 w3 w4
in
if lessThan z4 z2 then
joinP5 z4 z2
else
joinP5 z2 z4
in
if lessThan z1 z3 then
joinP4 z1 z3
else
joinP4 z3 z1
in
if lessThan y2 y3 then
joinP3 y2 y3
else
joinP3 y3 y2
in
if lessThan y1 y4 then
joinP2 y1 y4
else
joinP2 y4 y1
in
if lessThan x3 x4 then
joinP1 x3 x4
else
joinP1 x4 x3
in
if lessThan x1 x2 then
joinP0 x1 x2
else
joinP0 x2 x1
It looks very nice. The GHC correctly recognizes the jump points, otherwise the generated Core (with -ddump-simpl
) is essentially the same. There weren't left much for GHC to do. The produced Core is similar to the variant GHC optimised from direct definition, except this one is more regular (I'm not sure why the Core above is slightly irregular). But most importantly, we know we will get this good result using staged programming / TTH as it is constructed explicitly. We are not relying on compiler being sufficiently smart, we tell it what to do.
It's good to have smart compiler, but in some specific use-cases general compiler general heuristics might not be well suited or enough. In particular, one might want to use even exponential-time optimization routine in some specific case, but no-one would like that used for everything.
sort4
= \ @ a_ad1P
@ r_ad1Q
lessThan
x1 x2 x3 x4
kont ->
let {
joinP0
= \ y1 y2 ->
let {
joinP1
= \ y3 y4 ->
let {
joinP2
= \ z1 z2 ->
let {
joinP3
= \ z3 z4 ->
let {
joinP4
= \ w1 w2 ->
let {
joinP5
= \ w3 w4 ->
kont w1 w2 w3 w4
} in
case lessThan z4 z2 of {
False -> joinP5 z2 z4;
True -> joinP5 z4 z2
} } in
case lessThan z1 z3 of {
False -> joinP4 z3 z1;
True -> joinP4 z1 z3
} } in
case lessThan y2 y3 of {
False -> joinP3 y3 y2;
True -> joinP3 y2 y3
} } in
case lessThan y1 y4 of {
False -> joinP2 y4 y1;
True -> joinP2 y1 y4
} } in
case lessThan x3 x4 of {
False -> joinP1 x4 x3;
True -> joinP1 x3 x4
} } in
case lessThan x1 x2 of {
False -> joinP0 x2 x1;
True -> joinP0 x1 x2
}
Another important observation is that we didn't need to touch bitonicSort
definition. We only created a new instance of Category
with a cas
primitive. Then the code was generated for us (truly at compile time).
The diagram drawing code is yet another instantiation. The type I used is quite simple:
newtype Diagram f g =
D { runD :: forall r. f Int -> (g Int -> [Piece] -> r) -> r }
data Piece = CAS Int Int
| Pieces [Piece]
We provide f
with distinct integers for each wire, and at the end we'll get which one are compared (and in the proper order). The block
operator of SortingNetwork
will create Pieces
, so we get more nicely aligned diagrams.
There is also a neat trick for generating the initial wire numbers. Recall that we used three Functor
s. They are also Traversable
and Applicative
. Thus we can generate initial f Int
using a traversal with State
, or more conveniently using mapAccumL
.
input :: (Applicative f, Traversable f) => f Int
input = snd $ mapAccumL (\s () -> (s + 1 ,s)) 0 (pure ())
So far we have been limited to networks of size. We can get all other networks by simply omitting top or bottom wires, treating them as always smaller or always bigger values. This is the second reason why we bothered with Functor kind. First one is that staging example wouldn't been possible either.^{2}
For example if we need a sorting network of six inputs, we can instantiate it for a functor
type Six = ((U :*: I) :*: (I :*: I)) :*: ((I :*: I) :*: (I :*: U))
The resulting network is a truncated version of the eight network:
This network has 13 comparators (optimal has 12). Not bad.
Programming with categories is fun. They offer inspectable static-shape computation as does Applicative
s, but they have quite different ergonomics.
Where Applicative
concentrates on "the objects", category (or Arrow
) approach is concentrated on "the arrows". Which one suits your particular domain depends. Categories work very well for sorting networks, sorting network is "an arrow", not the end result. I have hard time imagining how this can be represented with Applicative
s.
We defined a bitonic sort using categories (with a little help of Overloaded:Categories
), which we than instantiated three times:
We can (and maybe should) to continue and add an optimization layer - even the simple - one. If we further remove a wire from six-network above we will get a network with a redundant comparator. I leave the implementation of an optimizer as a tough but fun and rewarding exercise for a reader ;)
Hints: The optimizer usage looks like:
sort5optDiagram :: Diagram Five Five
sort5optDiagram = runOptimize $
bitonicSort (optimizedCas cas) >>>
bitonicSort (optimizedCas cas)
Also https://www.well-typed.com/blog/2014/12/simple-smt-solver/ and and zero-one principle may be useful.
So even we sort twice, we get only nine comparator network:
Acknowledgments I want to thank Andres Löh for introducing me to Typed Template Haskell, and Matthew Pickering for working on it in GHC.