{-# LANGUAGE FlexibleContexts, DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
{-# LANGUAGE ScopedTypeVariables, TypeApplications, AllowAmbiguousTypes, TemplateHaskell #-}
module Overload (overload) where
import Language.Haskell.TH
import Data.Char
import Data.List (lookup, nub)
import Control.Arrow
import Data.Functor.Identity
import Data.Function
import Control.Effects.State
data TypeTree name = Var name | Concrete Name | App (TypeTree name) (TypeTree name)
deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
type FreshSource a = ([(a, Int)], Int)
lookupName :: (MonadEffectState (FreshSource a) m, Eq a) => a -> m Int
lookupName name = do
(table, top) <- getState
case lookup name table of
Nothing -> do
setState ((name, top) : table, top + 1)
return top
Just n -> return n
freshVar :: forall a m. MonadEffectState (FreshSource a) m => m Int
freshVar = do
(table, top) :: FreshSource a <- getState
setState (table, top + 1)
return top
typeToTypeTree :: Type -> TypeTree Name
typeToTypeTree (ConT n) = Concrete n
typeToTypeTree (AppT t1 t2) = App (typeToTypeTree t1) (typeToTypeTree t2)
typeToTypeTree (VarT n) = Var n
typeToTypeTree (InfixT t1 n t2) =
App (App (Concrete n) (typeToTypeTree t1)) (typeToTypeTree t2)
typeToTypeTree ArrowT = Concrete (''(->))
typeToTypeTree t = error ("Non supported type " ++ show t)
allGeneralizations :: TypeTree a -> [TypeTree (Maybe a)]
allGeneralizations (Var n) = [Var Nothing, Var (Just n)]
allGeneralizations (Concrete n) = [Var Nothing, Concrete n]
allGeneralizations (App t1 t2) =
Var Nothing : (App <$> allGeneralizations t1 <*> allGeneralizations t2)
normalizeTypeTree :: forall a. Eq a => TypeTree (Maybe a) -> TypeTree Int
normalizeTypeTree =
runIdentity . handleStateT (([], 0) :: FreshSource a) . traverse (maybe (freshVar @a) lookupName)
type VariableMapping a b = [(a, TypeTree b)]
trySetVar :: (MonadEffectState (VariableMapping a b) m, Eq a, Eq b) => a -> TypeTree b -> m Bool
trySetVar name typ = do
mapping <- getState
case lookup name mapping of
Just typ' | typ == typ' -> return True
| otherwise -> return False
Nothing -> do
setState ((name, typ) : mapping)
return True
isMoreGeneralThan :: forall a b. (Eq a, Eq b) => TypeTree a -> TypeTree b -> Bool
isMoreGeneralThan t1 t2 =
runIdentity (handleStateT ([] :: VariableMapping a b) (isMoreGeneralThan' t1 t2))
isMoreGeneralThan' :: (MonadEffectState (VariableMapping a b) m, Eq a, Eq b)
=> TypeTree a -> TypeTree b -> m Bool
isMoreGeneralThan' (Var n) t = trySetVar n t
isMoreGeneralThan' (Concrete n1) (Concrete n2) | n1 == n2 = return True
isMoreGeneralThan' (App t1 t2) (App t3 t4) =
(&&) <$> t1 `isMoreGeneralThan'` t3 <*> t2 `isMoreGeneralThan'` t4
isMoreGeneralThan' _ _ = return False
withouts :: [a] -> [(a, [a])]
withouts [] = []
withouts (x : xs) = (x, xs) : map (second (x :)) (withouts xs)
minimize :: [TypeTree Int] -> [TypeTree Int]
minimize types = types & withouts
& filter (\(t, rest) -> not (any (`isMoreGeneralThan` t) rest))
& map fst
findDeciders :: Eq a => [TypeTree a] -> [[TypeTree Int]]
findDeciders types = fmap minimize viableInstances
where normalized = fmap (normalizeTypeTree . fmap Just) types
viableInstances =
normalized & withouts
& map (\(t, rest) ->
t & allGeneralizations
& fmap normalizeTypeTree
& nub
& filter (\g -> not (any (g `isMoreGeneralThan`) rest))
)
typeTreeWithNames :: Show a => TypeTree a -> TypeTree Name
typeTreeWithNames = fmap (\a -> mkName ("t" ++ show a))
getEqualities :: forall a b. (Eq a, Eq b) => TypeTree a -> TypeTree b -> [(b, TypeTree a)]
getEqualities specific general = runIdentity $ handleStateT ([] :: VariableMapping b a) $ do
res <- general `isMoreGeneralThan'` specific
if res then getState
else error "Can't get equalities because the second type isn't more general than the first"
typeTreeToType :: TypeTree Name -> Type
typeTreeToType (Var n) = VarT n
typeTreeToType (Concrete n) | n == ''(->) = ArrowT
typeTreeToType (Concrete n) = ConT n
typeTreeToType (App t1 t2) = AppT (typeTreeToType t1) (typeTreeToType t2)
equalityToCxt :: [(Name, TypeTree Name)] -> Cxt
equalityToCxt =
fmap (\(n, t) -> AppT (AppT EqualityT (VarT n)) (typeTreeToType t))
deciders :: [(Name, Type)] -> [(Name, Cxt, Type)]
deciders cases =
concat (zipWith insts
(findDeciders (map (\(_, _, t) -> t) triplets))
triplets)
where triplets =
cases & map (\(n, t) -> case t of
ForallT _ c t' -> (n, c, typeToTypeTree t')
t' -> (n, [], typeToTypeTree t'))
insts decs t = fmap (inst t . typeTreeWithNames) decs
inst (n, c, t) dec = (n, c ++ equalityToCxt eqs, typeTreeToType dec)
where eqs = getEqualities t dec
makeInstance :: Name -> Name -> (Name, Cxt, Type) -> Dec
makeInstance className methodName (overloadName, c, t) =
InstanceD (Just Incoherent) c (AppT (ConT className) t)
[FunD methodName [Clause [] (NormalB (VarE overloadName)) []]]
-- | Generates a new function with the given name that can behave like multiple functions.
overload :: String -> [Name] -> Q [Dec]
overload functionName overloadNames = do
infos <- mapM reify overloadNames
let overloads = [(n, t) | VarI n t _ <- infos]
className = toUpper (head functionName) : tail functionName
classDec = ClassD [] (mkName className) [PlainTV (mkName "t")] []
[SigD (mkName functionName) (VarT (mkName "t"))]
instances = fmap (makeInstance (mkName className) (mkName functionName))
(deciders overloads)
return (classDec : instances)