Michael PJ recently wrote a post about Lenses for Tree Traversals. In a r/haskell discussion there is a comment which got my attention.
And here is the problem. With mutually recursive datatypes even with generics we can't write generic type-safe traversal. We have to do with boilerplate.
Challenge accepted.
As pointed out on Twitter, the approach below is similar to what multiplate library by Russell O'Connor gives combinators for. It's usable with lens too.
{-# LANGUAGE RankNTypes #-}
-- For GPlated
{-# LANGUAGE DeriveGeneric, TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
module MutualTraversals where
import Control.Lens (transformOf)
import GHC.GenericsI will use the same example as Michael, simply implemented, simply typed lambda calculus.
type Name = String
data Type = IntegerType | FunType Type Type deriving (Eq, Show)but instead of direct variant
data Term =
Var Name
| Lam Name Type Term
| App Term Term
| Plus Term Term
| Constant Integerlet us have bidirectional version. (I just wrote a post about Bidirectional Pure Type Systems, so an obvious choice).
Bidirectionality forces us to think about Plus. A solution is to add an additional constructor.
data Syn
= Var Name
| App Syn Chk
| Ann Chk Type
deriving (Show, Generic)
data Chk
= Lam Name Chk -- note, no type annotation
| Constant Integer
| UnaryPlus Integer Syn -- additional constructor
| Plus Syn Syn
| Conv Syn
deriving (Show, Generic)Note, how UnaryPlus and Plus are "stuck" on Syn terms.
The goal is fold constants. Interesting stuff happens in checkable terms, Chk.
cfChk :: Chk -> Chk
cfChk t = case t of
UnaryPlus n (Ann (UnaryPlus m p) _) -> UnaryPlus (n + m) p
UnaryPlus n (Ann (Constant m) _) -> Constant (n + m)
Plus (Ann (UnaryPlus n m) ty) p -> UnaryPlus n (Ann (Plus m p) ty)
Plus n (Ann (UnaryPlus m p) ty) -> UnaryPlus m (Ann (Plus n p) ty)
Plus (Ann (Constant n) _) m -> UnaryPlus n m
Plus n (Ann (Constant m) _) -> UnaryPlus m n
_ -> tWe have more rules because we added UnaryPlus, but we can fold more constants, exploiting the commutativity and associativity of addition.
But how to write constantFold, we have two mutually recursive types. The answer is obvious after you hear it. If you have two mutually recursive types, then there are two mutually recursive traversals.
They look like monomorphic Bitraversable.
Let me define helper type-aliases:
type Star f a b = a -> f b
type LensLike' f s a = Star f a a -> Star f s s
type BilensLike' f s a b = Star f a a -> Star f b b -> Star f s s
type Traversal' s a = forall f. Applicative f => LensLike' f s aThen we can define bitraversals:
chkSubterms' :: Applicative f => BilensLike' f Chk Syn Chk
chkSubterms' _syn chk (Lam n x) = Lam n <$> chk x
chkSubterms' _syn _chk t@Constant{} = pure t
chkSubterms' syn _chk (UnaryPlus n x) = UnaryPlus n <$> syn x
chkSubterms' syn _chk (Plus x y) = Plus <$> syn x <*> syn y
chkSubterms' syn _chk (Conv x) = Conv <$> syn x
synSubterms' :: Applicative f => BilensLike' f Syn Syn Chk
synSubterms' syn chk (App f x) = App <$> syn f <*> chk x
synSubterms' _syn chk (Ann x t) = Ann <$> chk x <*> pure t
synSubterms' _syn _chk t@Var {} = pure tAnd using these we can define
chkSubterms :: Traversal' Chk Chk
chkSubterms f = chkSubterms' aux f where aux = synSubterms' aux fThe above definition is slightly complicated. We have to make recursive aux to drill through Syn terms until it finds Chk terms.
But after all the setup, we can define constantFold.
constantFold :: Chk -> Chk
constantFold = transformOf chkSubterms cfChkLet us also try it out. We are going to write a redundant program, there are plenty of type annotations highlighting that.
expr1 :: Chk
expr1
= Plus (annZ (Constant 2))
$ annZ $ Plus (Var "x")
$ annZ (Constant 3)
where
annZ n = Ann n IntegerTypeAfter two iterations of constantFold we get completely simplified result:
*MutualTraversals> constantFold expr1
UnaryPlus 3 (Ann (Plus (Ann (Constant 2) IntegerType) (Var "x")) IntegerType)
*MutualTraversals> constantFold $ constantFold expr1
UnaryPlus 5 (Var "x")It works.
To complete the challenge, we need to write chkSubterms' and synSubterms' generically. If we are allowed to use Template Haskell, that would be as straight forward as writing Template Haskell is. Nor I see any immediate problems generalizing GPlated definitions to generate bitraversals.
EDIT: Later today I added GPlated2 in appendix. It is straight forward generalization of GPlate implementation in lens.
chkSubterms2' :: Applicative f => BilensLike' f Chk Syn Chk
synSubterms2' :: Applicative f => BilensLike' f Syn Syn Chk
chkSubterms2' f g = gplate2 g f
synSubterms2' f g = gplate2 f g
chkSubterms2 :: Traversal' Chk Chk
chkSubterms2 f = chkSubterms2' aux f where aux = synSubterms2' aux f
constantFold2 :: Chk -> Chk
constantFold2 = transformOf chkSubterms2 cfChkIt works.
*MutualTraversals> constantFold2 expr1
UnaryPlus 3 (Ann (Plus (Ann (Constant 2) IntegerType) (Var "x")) IntegerType)
*MutualTraversals> constantFold2 $ constantFold2 expr1
UnaryPlus 5 (Var "x")So you don't even need to define boilerplate by hand.
We can define a Plate type like multiplate library advises.
data Plate f = Plate
{ chkPlate :: Star f Chk Chk
, synPlate :: Star f Syn Syn
}and define a value
synChkPlate :: Applicative f => Plate f -> Plate f
synChkPlate p = Plate
{ chkPlate = chkSubterms' (synPlate p) (chkPlate p)
, synPlate = synSubterms' (synPlate p) (chkPlate p)
}Now it's easy to see how you would add Type traversals to the mix.
I could also refute Michael's comment
recursion-schemes does badly with mutually recursive types. If this is a problem for you, you’ll realize pretty quickly.
The recursion-schemes itself cannot deal with mutually recursive types, but the approach can be generalized. In this post we used stuff beyond lens as well.
I'll leave that for a future post. (Or you can look into https://hackage.haskell.org/package/multirec and paper which explains it).
Note, that Michael slightly cheats in counting nodes with Folds:
-- ... plus the number of nodes in all the subterms
<> foldMapOf termSubterms countTermNodes t
-- ... plus the number of nodes in all the subtypes
<> foldMapOf termSubtypes countTypeNodes there he examines the same t twice. First looking for subterms, and then for subtypes.
With his definition of (unidirectional) Term, he could use bitraversal to look for types and terms simultaneously!
Appendix: GPlated2
-- | Implement 'plate' operation for a type using its 'Generic' instance.
gplate2
:: (Generic a, GPlated2 a b (Rep a), Applicative f)
=> BilensLike' f a a b
gplate2 f g x = GHC.Generics.to <$> gplate2' f g (GHC.Generics.from x)
class GPlated2 a b g where
gplate2' :: Applicative f => BilensLike' f (g p) a b
instance GPlated2 a b f => GPlated2 a b (M1 i c f) where
gplate2' f g (M1 x) = M1 <$> gplate2' f g x
instance (GPlated2 a b f, GPlated2 a b g) => GPlated2 a b (f :+: g) where
gplate2' f g (L1 x) = L1 <$> gplate2' f g x
gplate2' f g (R1 x) = R1 <$> gplate2' f g x
instance (GPlated2 a b f, GPlated2 a b g) => GPlated2 a b (f :*: g) where
gplate2' f g (x :*: y) = (:*:) <$> gplate2' f g x <*> gplate2' f g y
instance {-# OVERLAPPING #-} GPlated2 a b (K1 i a) where
gplate2' f _ (K1 x) = K1 <$> f x
instance {-# OVERLAPPING #-} GPlated2 a b (K1 i b) where
gplate2' _ g (K1 x) = K1 <$> g x
instance GPlated2 a b (K1 i c) where
gplate2' _ _ = pure
instance GPlated2 a b U1 where
gplate2' _ _ = pure
instance GPlated2 a b V1 where
gplate2' _ _ v = v `seq` error "GPlated2/V1"