-- | This library provides a mechanism for overloading an indentifier with mutliple definitions.
--   The number of overloads is finite and have to be defined at once.
--
--   The advantage of this library over the regular typeclass approach is that it behaves very well
--   with type inference.
--
-- @
--   {-\# LANGUAGE TemplateHaskell, TypeFamilies, FlexibleInstances \#-}
--   module Overload.Example where
--
--   import Data.Maybe
--   import Overload
--
--   f1 :: Bool
--   f1 = True
--
--   f2 :: Int -> Int
--   f2 x = x + 1
--
--   f3 :: Num a = Maybe a
--   f3 = Just 0
--
--   'overload' "f" [\'f1, \'f2, \'f3]
--
--   test :: IO ()
--   test = do
--       print (f 1)
--       print (f && True)
--       print (fromMaybe 10 f)
-- @
--
--   Notice that we didn't have to annotate anything. For the function case it was enough to use
--   'f' as a function. Since there's only one overload that's a function, the argument and
--   the return value are inferred as 'Int's.
{-# LANGUAGE FlexibleContexts, DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
{-# LANGUAGE ScopedTypeVariables, TypeApplications, AllowAmbiguousTypes, TemplateHaskell #-}
module Overload (overload) where

import Language.Haskell.TH
import Data.Char
import Data.List (lookup, nub)
import Control.Arrow
import Data.Functor.Identity
import Data.Function

import Control.Effects.State

data TypeTree name = Var name | Concrete Name | App (TypeTree name) (TypeTree name)
                    deriving (Eq, Ord, Show, Functor, Foldable, Traversable)

type FreshSource a = ([(a, Int)], Int)

lookupName :: (MonadEffectState (FreshSource a) m, Eq a) => a -> m Int
lookupName name = do
    (table, top) <- getState
    case lookup name table of
        Nothing -> do
            setState ((name, top) : table, top + 1)
            return top
        Just n  -> return n

freshVar :: forall a m. MonadEffectState (FreshSource a) m => m Int
freshVar = do
    (table, top) :: FreshSource a <- getState
    setState (table, top + 1)
    return top

typeToTypeTree :: Type -> TypeTree Name
typeToTypeTree (ConT n) = Concrete n
typeToTypeTree (AppT t1 t2) = App (typeToTypeTree t1) (typeToTypeTree t2)
typeToTypeTree (VarT n) = Var n
typeToTypeTree (InfixT t1 n t2) =
    App (App (Concrete n) (typeToTypeTree t1)) (typeToTypeTree t2)
typeToTypeTree ArrowT = Concrete (''(->))
typeToTypeTree t = error ("Non supported type " ++ show t)

allGeneralizations :: TypeTree a -> [TypeTree (Maybe a)]
allGeneralizations (Var n) = [Var Nothing, Var (Just n)]
allGeneralizations (Concrete n) = [Var Nothing, Concrete n]
allGeneralizations (App t1 t2) =
    Var Nothing : (App <$> allGeneralizations t1 <*> allGeneralizations t2)

normalizeTypeTree :: forall a. Eq a => TypeTree (Maybe a) -> TypeTree Int
normalizeTypeTree =
    runIdentity . handleStateT (([], 0) :: FreshSource a) . traverse (maybe (freshVar @a) lookupName)

type VariableMapping a b = [(a, TypeTree b)]

trySetVar :: (MonadEffectState (VariableMapping a b) m, Eq a, Eq b) => a -> TypeTree b -> m Bool
trySetVar name typ = do
    mapping <- getState
    case lookup name mapping of
        Just typ' | typ == typ' -> return True
                  | otherwise   -> return False
        Nothing -> do
            setState ((name, typ) : mapping)
            return True

isMoreGeneralThan :: forall a b. (Eq a, Eq b) => TypeTree a -> TypeTree b -> Bool
isMoreGeneralThan t1 t2 =
    runIdentity (handleStateT ([] :: VariableMapping a b) (isMoreGeneralThan' t1 t2))

isMoreGeneralThan' :: (MonadEffectState (VariableMapping a b) m, Eq a, Eq b)
                   => TypeTree a -> TypeTree b -> m Bool
isMoreGeneralThan' (Var n) t = trySetVar n t
isMoreGeneralThan' (Concrete n1) (Concrete n2) | n1 == n2 = return True
isMoreGeneralThan' (App t1 t2) (App t3 t4) =
    (&&) <$> t1 `isMoreGeneralThan'` t3 <*> t2 `isMoreGeneralThan'` t4
isMoreGeneralThan' _ _ = return False

withouts :: [a] -> [(a, [a])]
withouts [] = []
withouts (x : xs) = (x, xs) : map (second (x :)) (withouts xs)

minimize :: [TypeTree Int] -> [TypeTree Int]
minimize types = types & withouts
                       & filter (\(t, rest) -> not (any (`isMoreGeneralThan` t) rest))
                       & map fst

findDeciders :: Eq a => [TypeTree a] -> [[TypeTree Int]]
findDeciders types = fmap minimize viableInstances
    where normalized = fmap (normalizeTypeTree . fmap Just) types
          viableInstances =
              normalized & withouts
                         & map (\(t, rest) ->
                             t & allGeneralizations
                               & fmap normalizeTypeTree
                               & nub
                               & filter (\g -> not (any (g `isMoreGeneralThan`) rest))
                             )

typeTreeWithNames :: Show a => TypeTree a -> TypeTree Name
typeTreeWithNames = fmap (\a -> mkName ("t" ++ show a))

getEqualities :: forall a b. (Eq a, Eq b) => TypeTree a -> TypeTree b -> [(b, TypeTree a)]
getEqualities specific general = runIdentity $ handleStateT ([] :: VariableMapping b a) $ do
    res <- general `isMoreGeneralThan'` specific
    if res then getState
    else error "Can't get equalities because the second type isn't more general than the first"

typeTreeToType :: TypeTree Name -> Type
typeTreeToType (Var n) = VarT n
typeTreeToType (Concrete n) | n == ''(->) = ArrowT
typeTreeToType (Concrete n) = ConT n
typeTreeToType (App t1 t2) = AppT (typeTreeToType t1) (typeTreeToType t2)

equalityToCxt :: [(Name, TypeTree Name)] -> Cxt
equalityToCxt =
    fmap (\(n, t) -> AppT (AppT EqualityT (VarT n)) (typeTreeToType t))

deciders :: [(Name, Type)] -> [(Name, Cxt, Type)]
deciders cases =
    concat (zipWith insts
                    (findDeciders (map (\(_, _, t) -> t) triplets))
                    triplets)
    where triplets =
              cases & map (\(n, t) -> case t of
                  ForallT _ c t' -> (n, c, typeToTypeTree t')
                  t'             -> (n, [], typeToTypeTree t'))
          insts decs t = fmap (inst t . typeTreeWithNames) decs
          inst (n, c, t) dec = (n, c ++ equalityToCxt eqs, typeTreeToType dec)
              where eqs = getEqualities t dec


makeInstance :: Name -> Name -> (Name, Cxt, Type) -> Dec
makeInstance className methodName (overloadName, c, t) =
    InstanceD (Just Incoherent) c (AppT (ConT className) t)
              [FunD methodName [Clause [] (NormalB (VarE overloadName)) []]]

-- | Generates a new function with the given name that can behave like multiple functions.
overload :: String -> [Name] -> Q [Dec]
overload functionName overloadNames = do
    infos <- mapM reify overloadNames
    let overloads = [(n, t) | VarI n t _ <- infos]
        className = toUpper (head functionName) : tail functionName
        classDec = ClassD [] (mkName className) [PlainTV (mkName "t")] []
                          [SigD (mkName functionName) (VarT (mkName "t"))]
        instances = fmap (makeInstance (mkName className) (mkName functionName))
                         (deciders overloads)
    return (classDec : instances)