module Overload (overload) where
import Language.Haskell.TH
import Data.Char
import Control.Arrow
import Data.Function
import Control.Monad
import Language.Haskell.TH.ExpandSyns
import Data.List (tails)
import Data.Semigroup ((<>))
import Overload.Normal
import Overload.TypeTree
import Overload.General
import qualified Overload.Diff as Diff
withouts :: [a] -> [(a, [a])]
withouts [] = []
withouts (x : xs) = (x, xs) : map (second (x :)) (withouts xs)
allDeciders :: Eq a => [TypeTree a] -> [[TypeTree Normal]]
allDeciders types =
normalized & withouts
& map (\(t, rest) ->
t & (`Diff.deciders` rest)
& fmap normalizeTypeTree
& minimize
)
where normalized = fmap (normalizeTypeTree . fmap Just) types
equalityToCxt :: [(Name, TypeTree Name)] -> Cxt
equalityToCxt =
fmap (\(n, t) -> AppT (AppT EqualityT (VarT n)) (typeTreeToType t))
deciders :: [(Name, Type)] -> Q [(Name, Cxt, Type)]
deciders cases = do
let triplets = cases & map (\(n, t) -> case t of
ForallT _ c t' -> (n, c, t')
t' -> (n, [], t'))
seeds <- forM triplets $ \(_, _, t) ->
normalizeTypeTree <$> typeToTypeTree (const (Var Nothing)) Just t
let pairs = concatMap ( \case
x : xs -> concatMap (\y -> [(x, y), (y, x)]) xs
[] -> [] ) (tails seeds)
let errs = filter (uncurry isMoreGeneralThan) pairs
unless (null errs) $ do
forM_ errs $ \(x, y) ->
reportError
$ "Type " <> showTypeTree x <> " is more general than " <> showTypeTree y <> ".\n"
<> "There is no way to construct a set of non-overlapping instances.\n"
<> "To fix this, make the first type more specific."
fail "Not all types are distinguishable."
concat <$> zipWithM insts
(allDeciders seeds)
triplets
where insts decs t = mapM (inst t . typeTreeWithNames) decs
inst (n, c, t) dec = do
tree <- typeToTypeTree Concrete id t
let eqs = getEqualities tree dec
return (n, c ++ equalityToCxt eqs, typeTreeToType dec)
makeInstance :: Name -> Name -> (Name, Cxt, Type) -> Dec
makeInstance className methodName (overloadName, c, t) =
InstanceD (Just Incoherent) c (AppT (ConT className) t)
[FunD methodName [Clause [] (NormalB (VarE overloadName)) []]]
overload :: String -> [Name] -> Q [Dec]
overload functionName overloadNames = do
infos <- mapM reify overloadNames
overloads <- sequence [(n,) <$> expandSynsWith noWarnTypeFamilies t | VarI n t _ <- infos]
let className = toUpper (head functionName) : tail functionName
classDec = ClassD [] (mkName className) [PlainTV (mkName "t")] []
[SigD (mkName functionName) (VarT (mkName "t"))]
instances <- fmap (makeInstance (mkName className) (mkName functionName)) <$> deciders overloads
return (classDec : instances)