-- | This library provides a mechanism for overloading an indentifier with mutliple definitions.
--   The number of overloads is finite and have to be defined at once.
--
--   The advantage of this library over the regular typeclass approach is that it behaves very well
--   with type inference.
--
-- @
--   {-\# LANGUAGE TemplateHaskell, TypeFamilies, FlexibleInstances \#-}
--   module Overload.Example where
--
--   import Data.Maybe
--   import Overload
--
--   f1 :: Bool
--   f1 = True
--
--   f2 :: Int -> Int
--   f2 x = x + 1
--
--   f3 :: Num a = Maybe a
--   f3 = Just 0
--
--   'overload' "f" [\'f1, \'f2, \'f3]
--
--   test :: IO ()
--   test = do
--       print (f 1)
--       print (f && True)
--       print (fromMaybe 10 f)
-- @
--
--   Notice that we didn't have to annotate anything. For the function case it was enough to use
--   'f' as a function. Since there's only one overload that's a function, the argument and
--   the return value are inferred as 'Int's.
{-# LANGUAGE TupleSections, LambdaCase #-}
module Overload (overload) where

import Language.Haskell.TH
import Data.Char
import Control.Arrow
import Data.Function
import Control.Monad
import Language.Haskell.TH.ExpandSyns
import Data.List (tails)
import Data.Semigroup ((<>))

import Overload.Normal
import Overload.TypeTree
import Overload.General
import qualified Overload.Diff as Diff

-- diff :: TypeTree a -> TypeTree a ->

withouts :: [a] -> [(a, [a])]
withouts [] = []
withouts (x : xs) = (x, xs) : map (second (x :)) (withouts xs)

allDeciders :: Eq a => [TypeTree a] -> [[TypeTree Normal]]
allDeciders types =
    normalized & withouts
               & map (\(t, rest) ->
                   t & (`Diff.deciders` rest)
                     & fmap normalizeTypeTree
                     & minimize
                   )
    where normalized = fmap (normalizeTypeTree . fmap Just) types

equalityToCxt :: [(Name, TypeTree Name)] -> Cxt
equalityToCxt =
    fmap (\(n, t) -> AppT (AppT EqualityT (VarT n)) (typeTreeToType t))

deciders :: [(Name, Type)] -> Q [(Name, Cxt, Type)]
deciders cases = do
    let triplets = cases & map (\(n, t) -> case t of
            ForallT _ c t' -> (n, c, t')
            t'             -> (n, [], t'))
    seeds <- forM triplets $ \(_, _, t) ->
        normalizeTypeTree <$> typeToTypeTree (const (Var Nothing)) Just t
    let pairs = concatMap ( \case 
            x : xs -> concatMap (\y -> [(x, y), (y, x)]) xs
            []     -> [] ) (tails seeds)
    let errs  = filter (uncurry isMoreGeneralThan) pairs
    unless (null errs) $ do
        forM_ errs $ \(x, y) ->
            reportError
                $ "Type " <> showTypeTree x <> " is more general than " <> showTypeTree y <> ".\n"
               <> "There is no way to construct a set of non-overlapping instances.\n"
               <> "To fix this, make the first type more specific."
        fail "Not all types are distinguishable."
    concat <$> zipWithM insts
                        (allDeciders seeds)
                        triplets
    where insts decs t = mapM (inst t . typeTreeWithNames) decs
          inst (n, c, t) dec = do
                  tree <- typeToTypeTree Concrete id t
                  let eqs = getEqualities tree dec
                  return (n, c ++ equalityToCxt eqs, typeTreeToType dec)


makeInstance :: Name -> Name -> (Name, Cxt, Type) -> Dec
makeInstance className methodName (overloadName, c, t) =
    InstanceD (Just Incoherent) c (AppT (ConT className) t)
              [FunD methodName [Clause [] (NormalB (VarE overloadName)) []]]

-- | Generates a new function with the given name that can behave like multiple functions.
overload :: String -> [Name] -> Q [Dec]
overload functionName overloadNames = do
    infos <- mapM reify overloadNames
    overloads <- sequence [(n,) <$> expandSynsWith noWarnTypeFamilies t | VarI n t _ <- infos]
    let className = toUpper (head functionName) : tail functionName
        classDec = ClassD [] (mkName className) [PlainTV (mkName "t")] []
                          [SigD (mkName functionName) (VarT (mkName "t"))]
    instances <- fmap (makeInstance (mkName className) (mkName functionName)) <$> deciders overloads
    return (classDec : instances)