{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module KleenePlugin.TypeEq where
import Control.Monad.EitherK (EitherKT, runEitherKT)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT (..), get, put)
import Control.Unification (UTerm (..), applyBindingsAll, unify)
import Control.Unification.IntVar
(IntBindingT, IntVar (..), evalIntBindingT)
import Control.Unification.Types
(BindingMonad (..), UFailure (..), Unifiable (..))
import Data.Bifunctor (first)
import Data.Either (isRight)
import Data.Functor.Identity (Identity, runIdentity)
import Data.Traversable (for)
import qualified Data.Map.Strict as Map
import qualified GhcPlugins as GHC
import qualified Outputable as PP
import KleenePlugin.Names
maybeEqType
:: KleNames
-> Map.Map GHC.TyVar GHC.FastString
-> GHC.Type -> GHC.Type
-> Either UnifResult Bool
maybeEqType kle labels x y
| GHC.eqType x y = Right True
| mayUnify x' y' = Left (MayUnify x' y')
| otherwise = Right False
where
x' = elaborateType kle labels x
y' = elaborateType kle labels y
data UnifResult = MayUnify (UTerm Mono GHC.TyVar) (UTerm Mono GHC.TyVar)
elaborateType :: KleNames -> Map.Map GHC.TyVar GHC.FastString -> GHC.Type -> UTerm Mono GHC.TyVar
elaborateType kle labels = go where
go t
| Just (tycon, [x]) <- GHC.splitTyConApp_maybe t
, tycon == kleKey kle
, Just sym <- GHC.isStrLitTy x =
UTerm (MonoSym sym)
| Just (f, x) <- GHC.splitAppTy_maybe t =
UTerm $ MonoApp (go f) (go x)
| Just tyvar <- GHC.getTyVar_maybe t =
case Map.lookup tyvar labels of
Nothing -> UVar tyvar
Just sym -> UTerm (MonoSym sym)
| otherwise = UTerm (MonoC t)
data Mono a
= MonoC GHC.Type
| MonoApp a a
| MonoSym GHC.FastString
deriving (Functor, Foldable, Traversable)
instance Unifiable Mono where
zipMatch (MonoC a) (MonoC b) | GHC.eqType a b = Just $ MonoC a
zipMatch (MonoApp f x ) (MonoApp f' x') = Just $ MonoApp
(Right (f, f'))
(Right (x, x'))
zipMatch (MonoSym sym) (MonoSym sym') | sym == sym' = Just (MonoSym sym)
zipMatch _ _ = Nothing
type M = EitherKT (UFailure Mono IntVar) (IntBindingT Mono Identity)
newtype UF = UF (UFailure Mono IntVar)
manyMayUnify :: [UnifResult] -> Either UF ()
manyMayUnify unifResults
= first UF
$ runIdentity
$ evalIntBindingT
$ runEitherKT action
where
action :: M ()
action = do
(xs, _vars) <- flip runStateT Map.empty $ for unifResults $ \ur ->
case ur of
MayUnify x y -> do
x' <- traverse makeVar x
y' <- traverse makeVar y
return (x', y')
unified <- for xs $ \(x, y) -> unify x y
_ <- applyBindingsAll unified
return ()
makeVar :: Ord a => a -> StateT (Map.Map a IntVar) M IntVar
makeVar var = do
m <- get
case Map.lookup var m of
Just iv -> return iv
Nothing -> do
iv <- lift (lift freeVar)
put (Map.insert var iv m)
return iv
mayUnify :: UTerm Mono GHC.TyVar -> UTerm Mono GHC.TyVar -> Bool
mayUnify x y = isRight $ manyMayUnify [MayUnify x y]
instance PP.Outputable UF where
ppr (UF (OccursFailure _v _t)) = PP.text "Occurs failure: "
ppr (UF (MismatchFailure a b)) =
PP.text "Couldn't match types" PP.<+>
pprMono 0 a PP.<+>
PP.text "and" PP.<+>
pprMono 0 b
pprMono :: Rational -> Mono (UTerm Mono IntVar) -> PP.SDoc
pprMono d (MonoC t) = PP.pprPrec d t
pprMono _ (MonoSym sym) = PP.char '#' PP.<> PP.ppr sym
pprMono d (MonoApp f x) = PP.cparen (d > 10) $ pprMono' 10 f PP.<+> pprMono' 11 x
pprMono' :: Rational -> UTerm Mono IntVar -> PP.SDoc
pprMono' _ (UVar (IntVar v)) = PP.char '?' PP.<> PP.int v
pprMono' d (UTerm t) = pprMono d t