{-# OPTIONS_GHC -O2 -fglasgow-exts #-}
{-# LANGUAGE BangPatterns, TemplateHaskell #-}

{- |
  A clusterfrolick at the moment. Among other things,
  the @Exists@ con in the @Type@ type isn't handled
  correctly (or at all really, it's just pretended
  that it's @Forall@ because I don't know how to
  handle it).
-}

module Language.Haskell.Derive.Gadt.Unify where

import Data.List
import Data.Tree
import Control.Monad
import Control.Applicative
import Data.Map (Map)
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Monoid(Monoid(..))
import Data.Either
import Text.PrettyPrint
import Control.Monad
import Control.Monad.Fix
import Data.IORef
import System.IO.Unsafe
        (unsafePerformIO)
import Language.Haskell.Meta.Utils(pretty)
ppHs a = text . pretty $ a

-----------------------------------------------------------------------------

intType = ConT (mkNameG "Int")
doubleType = ConT (mkNameG "Double")

t0 = tupT [varT "x", varT "b"]
t1 = tupT [varT "a", intType]
t2 = tupT [varT "a", doubleType]
t3 = tupT [intType, varT "a"]
t4 = tupT [doubleType, varT "a"]
t5 = tupT [varT "b", intType, t1]
t6 = tupT [doubleType, varT "c", t0]
t7 = tupT [doubleType, varT "c", varT "d"]



t8 = tupT [varT "x", varT "x"]
t9 = tupT [varT "x", t8]


{-
ghci> pprType t5
(((,,) b) Int) (((,) a) Int)
ghci> pprType t6
(((,,) Double) c) (((,) x) b)
ghci> ppHs <$> runQ (unify t5 t6)
Right
  (fromList
     [(NameL "a", [VarT (NameU "a" 2)]),
      (NameL "b", [ConT (NameG "Double")])],
   fromList
     [(NameL "b", [ConT (NameG "Int")]),
      (NameL "c", [ConT (NameG "Int")]),
      (NameL "x", [VarT (NameU "a" 2)])])
-}

-----------------------------------------------------------------------------



-- | .
unify :: Type -> Type -> Q (Either String (Substs,Substs))
unify a b = extractSubsts <$> splitSubsts (matchTypes a b)

type Substs = Map Name [Type]
type SubstMap = Map Name [Name]
type UnifyMap = Map Name [Type]
data UnifyEnv = UnifyEnv
  {substMaps  :: (SubstMap, SubstMap)
  ,isubstMaps :: (SubstMap, SubstMap)
  ,noDupsMap  :: Either String (Map Name (Name,Name))
  ,finalMap   :: Either String (Map Name Name)
  ,unifyMap   :: (UnifyMap, UnifyMap)
  ,stragglers :: [(Type, Type)]}
  deriving(Eq,Show)


varT :: String -> Type
varT s = VarT (mkNameL s)

tupT :: [Type] -> Type
tupT []  = ConT (mkNameG "()")
tupT [t] = t
tupT ts = let n = length ts
          in foldl AppT (tupCon n) ts


listT :: Type -> Type
listT = (listCon `AppT`)



invertSubstMap :: SubstMap -> SubstMap
invertSubstMap = foldl' (\m (a,b) -> M.insertWith' (++) a [b] m) mempty
  . concatMap (\(x,xs) -> fmap (flip (,) x) xs) . M.toList


invertSubsts :: SubstMap -> SubstMap -> Map Name ([Name],[Name])
invertSubsts lmap rmap =
  let lmap2 = fmap (\xs->(xs,[])) (invertSubstMap lmap)
      rmap2 = fmap (\xs->([],xs)) (invertSubstMap rmap)
  in M.unionWith mappend lmap2 rmap2


checkForDups :: Map Name ([Name],[Name]) -> Either String (Map Name (Name,Name))
checkForDups = go [] . M.toList
  where go acc    []                = Right (M.fromList (reverse acc))
        go acc ((x,([a],[b])):rest) = go ((x,(a,b)):acc) rest
        go _   (err          :_)    = Left (show err)

extractSubsts :: UnifyEnv -> Either String (Substs, Substs)
extractSubsts (UnifyEnv {substMaps  = (lmap,rmap)
                        ,isubstMaps = _
                        ,noDupsMap  = _
                        ,finalMap   = fin
                        ,unifyMap   = (lumap,rumap)
                        ,stragglers = ss})
  | (not . null) ss = Left (show ss)
  | otherwise = case fin of
      Left e -> Left e
      Right fin -> let f = fmap VarT fin
                       l = lumap `M.union` fmap (:[]) (f |.| fmap head lmap) -- XXX
                       r = rumap `M.union` fmap (:[]) (f |.| fmap head rmap) -- XXX
                    in Right (l,r)

(|.|) :: (Ord a, Ord b) => Map b c -> Map a b -> Map a c
g |.| f = foldl' (\m (a,b) -> maybe m
                  (flip (M.insert a) m)
                        (M.lookup b g)) mempty
                                        (M.toList f)

renameT :: Type -> Type
renameT = unQ . go
  where go (ForallT ns t) = do
          xs <- replicateM (length ns) (newName "a")
          ForallT xs <$> go (substT (zip ns (fmap VarT xs)) t)
        go (ExistsT ns t) = do
          xs <- replicateM (length ns) (newName "a")
          ExistsT xs <$> go (substT (zip ns (fmap VarT xs)) t)
        go (AppT a b) = AppT <$> go a <*> go b
        go  t = return t

{- UNUSED CURRENTLY
windFunT :: Type -> [Type] -> Type
windFunT = foldr (.->.)

prepFunApp :: Type -> Type -> Maybe ((Type,Type),(Type,[Type]))
prepFunApp f x
  = case unwindFunT (openT f) of
      (_,[]) -> Nothing
      (res,t:args) -> Just ((x,t),(res,args))
-}

noDupsMapToFinalSubsts :: Map Name (Name,Name) -> Either String (Map Name Name)
noDupsMapToFinalSubsts = go [] . M.toList
  where go acc    []         = Right (M.fromList (reverse acc))
        go acc ((x,ns):rest) = check x ns (\n -> go ((x,n):acc) rest)
        check x (NameL _, NameL _) k = k x
        check _ (NameL a, NameG b) k = k (NameG b)
        check _ (NameG a, NameL b) k = k (NameG a)
        check _ (NameG a, NameG b) k
          | a==b = k (NameG a)
          | otherwise = Left (show (NameG a) ++" /= "++ show (NameG b))

buildUMap :: [(Type, Type)] -> ((UnifyMap,UnifyMap),[(Type, Type)])
buildUMap ts = let (ls,xs) = partition varOnLeft ts
                   (rs,ys) = partition varOnRight xs
                   lmap = M.fromList (fmap (\(VarT a,t)->(a,[t])) ls)
                   rmap = M.fromList (fmap (\(t,VarT a)->(a,[t])) rs)
                in ((lmap,rmap),ys) -- wlog, since xs and ys are the same

splitSubsts :: [(Type,Type)] -> Q UnifyEnv
splitSubsts xs = do
  let (vars,tys) = partition bothVars xs
      names = fmap (\(VarT a,VarT b)->(a,b)) vars
      (lnames,rnames) = unzip names --(filter (not . uncurry (==)) names)
  ns <- fmap (:[]) <$> replicateM (length vars) (newName "a")
  let lmap = M.fromListWith (++) (zip lnames ns)
      rmap = M.fromListWith (++) (zip rnames ns)
      (lts,rts) = unzip (filter (not . uncurry (==)) tys)
      rntys = zip (fmap (substT (firstNameMap lmap)) lts)
                  (fmap (substT (firstNameMap rmap)) rts)
      (umaps,others) = buildUMap rntys
      ndm = checkForDups (invertSubsts lmap rmap)
      fsm = either Left noDupsMapToFinalSubsts ndm
  return (UnifyEnv {substMaps  = (lmap,rmap)
                    ,isubstMaps = (invertSubstMap lmap, invertSubstMap rmap)
                    ,noDupsMap  = ndm
                    ,finalMap   = fsm
                    ,unifyMap   = umaps
                    ,stragglers = others})

matchTypes :: Type -> Type -> [(Type, Type)]
matchTypes a b = let c = renameT a
                     d = renameT b
                  in c `seq` d `seq`
                      match typeViaT typeViaT (,)
                                  (openT c)
                                  (openT d)
typeViaT :: ViaT Type Type Type
typeViaT = ViaT
            typeToT
            typeFromT
typeToT :: Type -> T Type
typeToT (t `AppT` t')
  = typeToT t `T` typeToT t'
typeToT t = Tip t
typeFromT :: T Type -> Type
typeFromT (l `T` r)
  = typeFromT l `AppT` typeFromT r
typeFromT (Tip t) = t

firstNameMap :: Map Name [Name] -> [(Name, Type)]
firstNameMap m =
  (flip concatMap (M.toList m)
    (\(a,as)-> case as of
                [] -> []
                n:_ -> [(a,VarT n)]))

isInf :: Name -> Type -> Bool
isInf a t = a `S.member` ftvs t

bothVars :: (Type, Type) -> Bool
bothVars (VarT{},VarT{}) = True
bothVars    _            = False

varOnLeft :: (Type, Type) -> Bool
varOnLeft (VarT{},_) = True
varOnLeft     _      = False

varOnRight :: (Type, Type) -> Bool
varOnRight (_,VarT{}) = True
varOnRight    _       = False

{-
ghci> Right (a,b) <- runQ (unifyTop' (expType[|(,)|]) (expType[|\x y->([x],y)|]))
ghci> pprType a
b -> a -> (b, a)
ghci> fmap (\(a,b)->(a,pprType b)) b
[(b,[] b)]
ghci> ftvs a
fromList [a,b]
-}

-----------------------------------------------------------------------------

data Name
  = NameG String
  | NameL String
  | NameU String !Int
  deriving(Eq,Ord,Read,Show)

data Type
  = ArrowT
  | VarT Name
  | ConT Name
  | AppT Type Type
  | ForallT [Name] Type
  | ExistsT [Name] Type
  deriving(Eq,Ord,Read,Show)




ftvs :: Type -> Set Name
ftvs (VarT n) = S.singleton n
-- ftvs (ConT n) = S.singleton n
ftvs (AppT t t') = ftvs t `S.union` ftvs t'
ftvs (ForallT ns t) = ftvs t `S.difference` S.fromList ns
ftvs (ExistsT ns t) = ftvs t `S.difference` S.fromList ns
ftvs  _ = mempty

btvs :: Type -> Set Name
btvs (ForallT ns t) = S.fromList ns `S.union` btvs t
btvs (ExistsT ns t) = S.fromList ns `S.union` btvs t
btvs _ = mempty


closeT :: Type -> Type
closeT t = case S.toList (ftvs t) of
            [] -> t
            ns -> case t of
                    ForallT ms t -> ForallT (ms++ns) t
                    _ -> ForallT ns t

openT :: Type -> Type
openT (ForallT _ t) = t
openT    t          = t

-- ghci> let t = let a = mkName "a" in AppT (AppT ArrowT (VarT a)) (VarT a)
--
-- ghci> pprType (foldl (.->.) t (replicate 4 t))
-- ((((a -> a) -> a -> a) -> a -> a) -> a -> a) -> a -> a
--
-- ghci> pprType (foldr (.->.) t (replicate 4 t))
-- (a -> a) -> (a -> a) -> (a -> a) -> (a -> a) -> a -> a
(.->.) :: Type -> Type -> Type
a .->. b = (ArrowT `AppT` a) `AppT` b

unwindFunT :: Type -> (Type,[Type])
unwindFunT = go []
  where go acc ((ArrowT `AppT` a)
                        `AppT` b)
                  = go (a:acc) b
        go acc  t = (t, reverse acc)

unwindAppT :: Type -> (Type,[Type])
unwindAppT = go []
  where go acc (AppT a b) = go (b:acc) a
        go acc    t       = (t, acc)

substT :: [(Name, Type)] -> Type -> Type
substT env t = runSubstM (go t) (initSubstEnv env)
  where go (VarT a) = substM a
        go (AppT a b) = AppT <$> localM (go a)
                             <*> localM (go b)
        go (ForallT ns t) = do mapM_ bindM ns
                               ForallT ns <$> go t
        go (ExistsT ns t) = do mapM_ bindM ns
                               ExistsT ns <$> go t
        go  t             = return t

-----------------------------------------------------------------------------

listName :: Name
listName = NameG "[]"

tupName :: Int -> Name
tupName n = NameG
  ("(" ++ replicate (n-1) ',' ++ ")")

listCon :: Type
listCon = ConT listName

tupCon :: Int -> Type
tupCon = ConT . tupName

-----------------------------------------------------------------------------

testType0 :: Type
testType0 = let a = mkName "a"
            in ForallT [a] (AppT (AppT ArrowT (VarT a)) (VarT a))


idType :: Type -> Type
idType t = closeT (AppT (AppT ArrowT t) t)



testType1 :: Type
testType1 = let b = mkName "b"
            in substT [(mkName "a", VarT b)] testType0

testType2 :: Type
testType2 = let a = mkName "a"
                b = mkName "b"
                ta = AppT (AppT ArrowT (VarT a)) (VarT a)
                tb = substT [(a, VarT b)] ta
            in ForallT [a]
                (foldr (.->.)
                       (foldl (.->.)
                              tb
                              (replicate 4 ta))
                       (replicate 2 tb))

-----------------------------------------------------------------------------

pprName :: Name -> Doc
pprName (NameL a) = text a
pprName (NameG a) = text a
pprName (NameU a u) = text a <> char '_' <> int u

pprType :: Type -> Doc
pprType t
  = case unwindFunT t of
      (_,[]) -> go t
      (x,xs) -> hsep . punctuate (space <> text "->") . fmap pprParenType $ xs++[x]
    where go (VarT a)       = pprName a
          go (ConT a)       = pprName a
          go  ArrowT        = text "->" -- ummm
          go (AppT a b)     = pprParenType a <+> pprParenType b
          go (ForallT ns t) = text "forall"
                                <+> hsep (fmap pprName ns)
                                <> char '.'
                                <+> pprType t
          go (ExistsT ns t) = text "exists"
                                <+> hsep (fmap pprName ns)
                                <> char '.'
                                <+> pprType t

pprParenType :: Type -> Doc
pprParenType t
  = case unwindFunT t of
      (_,[]) -> go t
      (x,xs) -> parens . hsep . punctuate (space <> text "->") . fmap pprParenType $ xs++[x]
    where go t@(VarT{}) = pprType t
          go t@(ConT{}) = pprType t
          go t          = parens (pprType t)

-----------------------------------------------------------------------------

(***) f g = \(a,b) -> (f a,g b)
(&&&) f g = \a -> (f a, g a)
mapfst f = \(a,b) -> (f a, b)
mapsnd f = \(a,b) -> (a, f b)

-----------------------------------------------------------------------------

newtype Q a = Q (IO a)
runQ :: Q a -> IO a
runQ (Q io) = io
runIO :: IO a -> Q a
runIO = Q
unQ :: Q a -> a
unQ = unsafePerformIO . runQ

newName :: String -> Q Name
newName s = do
  u <- tick gensymQ
  return (NameU s u)
mkName :: String -> Name
mkName = mkNameL
mkNameG :: String -> Name
mkNameL :: String -> Name
mkNameU :: String -> Int -> Name
mkNameG = NameG
mkNameL = NameL
mkNameU = NameU

newUniq :: Q Int
newUniq = tick gensymQ

instance Functor Q where
  fmap f (Q io) = Q (fmap f io)
instance Monad Q where
  return a = Q (return a)
  Q io >>= k = Q (runQ . k =<< io)
instance Applicative Q where
  pure = return
  (<*>) = ap
gensymQ :: IORef Int
{-# NOINLINE gensymQ #-}
gensymQ = unsafePerformIO (newIORef 0)
tick :: IORef Int -> Q Int
tick ref = runIO (atomicModifyIORef ref (\n -> (n+1,n)))
reset :: IORef Int -> Q ()
reset ref = runIO (writeIORef ref 0)
resetQ :: Q ()
resetQ = reset gensymQ

-----------------------------------------------------------------------------




newtype S s a = S {unS :: forall o. (a -> s -> o) -> s -> o}
instance Functor (S s) where
  fmap f (S g) = S (\k -> g (k . f))
instance Monad (S s) where
  return a = S (\k -> k a)
  S g >>= f = S (\k -> g (\a -> unS (f a) k))
instance MonadFix (S s) where
  mfix f = S (\k s -> let (a,s') = unS (f a) (,) s in k a s')
instance Applicative (S s) where
  pure = return
  (<*>) = ap
get :: S s s
get = S (\k s -> k s s)
gets :: (s -> a) -> S s a
gets f = S (\k s -> k (f s) s)
set :: s -> S s ()
set s = S (\k _ -> k () s)
modify :: (s -> s) -> S s ()
modify f = S (\k -> k () . f)
runS :: S s a -> s -> (a, s)
runS (S g) = g (,)
evalS :: S s a -> s -> a
evalS (S g) = g const
execS :: S s a -> s -> s
execS (S g) = g (flip const)






type SubstM a = S SubstEnv a
data SubstEnv = SubstEnv
  {boundSet :: Set Name
  ,substMap :: Map Name Type}
  deriving(Eq,Ord,Read,Show)
initSubstEnv :: [(Name, Type)] -> SubstEnv
initSubstEnv = SubstEnv mempty . M.fromList
runSubstM :: SubstM a -> SubstEnv -> a
runSubstM m env = fst (runS m env)
bindM :: Name -> SubstM ()
bindM n = do
  bs <- gets boundSet
  modify(\e->e{boundSet=n`S.insert`bs})
substM :: Name -> SubstM Type
substM n = do
  bs <- gets boundSet
  case n `S.member` bs of
    True -> return (VarT n)
    False -> do o <- gets (M.lookup n . substMap)
                case o of
                  Nothing -> return (VarT n)
                  Just t  -> return t
localM :: SubstM a -> SubstM a
localM m = do s <- get
              a <- m
              set s
              return a

-----------------------------------------------------------------------------

data T a
  = Tip a
  | T (T a) (T a)
  deriving (Eq,Ord,Show,Read)

instance Functor T where
  fmap f (Tip a) = Tip (f a)
  fmap f (T t t') = T (fmap f t) (fmap f t')

-----------------------------------------------------------------------------

toListT :: T a -> [a]
toListT = foldrT (:) []

fromListT :: [a] -> [T a]
fromListT [] = []
fromListT [a] = [Tip a]
fromListT [a,b] = [T (Tip a) (Tip b)]
fromListT xs = let (ys, zs) = splitAt (length xs `div` 2) xs
                   [y] = fromListT ys
                   [z] = fromListT zs
                in [T y z]

toTreeT :: (Maybe a -> b) -> T a -> Tree b
toTreeT f (Tip a) = Node (f (Just a)) []
toTreeT f (T l r) = Node (f Nothing) (fmap (toTreeT f) [l,r])

-----------------------------------------------------------------------------

foldrT :: (a -> b -> b) -> b -> T a -> b
foldrT (<>) b t = go (<>) b t id
  where go (<>) b (Tip a)  k = a <> k b
        go (<>) b (T t t') k = go (<>) b t (\b ->
                                go (<>) b t' k)

foldlT :: (a -> b -> a) -> a -> T b -> a
foldlT (<>) a t = go (<>) a t id
  where go (<>) a (Tip b)  k = k (a <> b)
        go (<>) a (T t t') k = go (<>) a t (\a ->
                                go (<>) a t' k)

foldl'T :: (a -> b -> a) -> a -> T b -> a
foldl'T (<>) !a t = go (<>) a t id
  where go (<>) !a (Tip b)  k = k (a <> b)
        go (<>) !a (T t t') k = go (<>) a t (\a ->
                                  go (<>) a t' k)

sumT :: (Num a) => T a -> a
sumT = foldl'T (+) 0

prodT :: (Num a) => T a -> a
prodT = foldl'T (*) 1

andT :: T Bool -> Bool
andT = foldrT (&&) True

orT :: T Bool -> Bool
orT = foldrT (||) False

-----------------------------------------------------------------------------

unifyT :: (Either a (T a) -> c)
       -> (Either b (T b) -> d)
       -> (c -> d -> e)
       -> (T a -> T b -> T e)
unifyT f g (<>) (Tip a) (Tip b)   = Tip (f (Left a) <> g (Left b))
unifyT f g (<>) (Tip a) b@(T _ _) = Tip (f (Left a) <> g (Right b))
unifyT f g (<>) a@(T _ _) (Tip b) = Tip (f (Right a) <> g (Left b))
unifyT f g (<>) (T a a') (T b b') = let (><) = unifyT f g (<>)
                                    in T (a >< b) (a' >< b')

zipT :: (T a -> T b -> c)
     -> (T a -> T b -> T c)
zipT = unifyT (either Tip id)
              (either Tip id)

matchT :: (a -> T x)
       -> (b -> T y)
       -> (T x -> c)
       -> (T y -> d)
       -> (c -> d -> e)
       -> (a -> b -> T e)
matchT ia ib px py (<>) a b = unifyT (either (px . Tip) px)
                                     (either (py . Tip) py)
                                     (<>)
                                     (ia a)
                                     (ib b)

-----------------------------------------------------------------------------

data ViaT a b c
  = ViaT {toT   :: a -> T b
         ,fromT :: T b -> c}

match :: ViaT a x c
      -> ViaT b y d
      -> (c -> d -> e)
      -> (a -> b -> [e])
match da db = ((toListT .) .)
  . matchT (toT da)
           (toT db)
           (fromT da)
           (fromT db)

-----------------------------------------------------------------------------