-----------------------------------------------------------------------------
--
-- Module      :  PureScript.TypeChecker.Kinds
-- Copyright   :  (c) Phil Freeman 2013
-- License     :  MIT
--
-- Maintainer  :  Phil Freeman <paf31@cantab.net>
-- Stability   :  experimental
-- Portability :
--
-- |
--
-----------------------------------------------------------------------------

{-# LANGUAGE DeriveDataTypeable #-}

module PureScript.TypeChecker.Kinds (
    KindConstraint(..),
    KindSolution(..),
    kindsOf,
    kindOf
) where

import Data.List
import Data.Maybe (fromMaybe)
import Data.Function
import Data.Data

import PureScript.Types
import PureScript.Kinds
import PureScript.Declarations
import PureScript.TypeChecker.Monad
import PureScript.CodeGen.Pretty.Kinds
import PureScript.CodeGen.Pretty.Types

import Control.Monad.State
import Control.Monad.Error

import Control.Applicative
import Control.Arrow (Kleisli(..), (***))
import qualified Control.Category as C

import qualified Data.Map as M

data KindConstraintOrigin
  = DataDeclOrigin
  | TypeOrigin Type
  | RowOrigin Row deriving (Show, Data, Typeable)

prettyPrintKindConstraintOrigin :: KindConstraintOrigin -> String
prettyPrintKindConstraintOrigin (DataDeclOrigin) = "data declaration"
prettyPrintKindConstraintOrigin (TypeOrigin ty) = prettyPrintType ty
prettyPrintKindConstraintOrigin (RowOrigin row) = prettyPrintRow row

data KindConstraint = KindConstraint Int Kind KindConstraintOrigin deriving (Show, Data, Typeable)

newtype KindSolution = KindSolution { runKindSolution :: Int -> Kind }

emptyKindSolution :: KindSolution
emptyKindSolution = KindSolution KUnknown

kindOf :: PolyType -> Check Kind
kindOf (PolyType idents ty) = do
  ns <- replicateM (length idents) fresh
  (cs, n, m) <- kindConstraints (M.fromList (zip idents ns)) ty
  solution <- solveKindConstraints cs emptyKindSolution
  return $ starIfUnknown $ runKindSolution solution n

kindsOf :: Maybe String -> [String] -> [Type] -> Check Kind
kindsOf name args ts = do
  tyCon <- fresh
  nargs <- replicateM (length args) fresh
  (cs, ns, m) <- kindConstraintsAll (maybe id (`M.insert` tyCon) name $ M.fromList (zip args nargs)) ts
  let extraConstraints =
        KindConstraint tyCon (foldr (FunKind . KUnknown) Star nargs) DataDeclOrigin
        : zipWith (\n arg -> KindConstraint n Star (TypeOrigin arg)) ns ts
  solution <- solveKindConstraints (extraConstraints ++ cs) emptyKindSolution
  return $ starIfUnknown $ runKindSolution solution tyCon

starIfUnknown :: Kind -> Kind
starIfUnknown (KUnknown _) = Star
starIfUnknown (FunKind k1 k2) = FunKind (starIfUnknown k1) (starIfUnknown k2)
starIfUnknown k = k

kindConstraintsAll :: M.Map String Int -> [Type] -> Check ([KindConstraint], [Int], M.Map String Int)
kindConstraintsAll m [] = return ([], [], m)
kindConstraintsAll m (t:ts) = do
  (cs, n1, m') <- kindConstraints m t
  (cs', ns, m'') <- kindConstraintsAll m' ts
  return (KindConstraint n1 Star (TypeOrigin t) : cs ++ cs', n1:ns, m'')

kindConstraints :: M.Map String Int -> Type -> Check ([KindConstraint], Int, M.Map String Int)
kindConstraints m a@(Array t) = do
  me <- fresh
  (cs, n1, m') <- kindConstraints m t
  return (KindConstraint n1 Star (TypeOrigin t) : KindConstraint me Star (TypeOrigin a) : cs, me, m')
kindConstraints m o@(Object row) = do
  me <- fresh
  (cs, r, m') <- kindConstraintsForRow m row
  return (KindConstraint me Star (TypeOrigin o) : KindConstraint r Row (RowOrigin row) : cs, me, m')
kindConstraints m f@(Function args ret) = do
  me <- fresh
  (cs, ns, m') <- kindConstraintsAll m args
  (cs', retN, m'') <- kindConstraints m' ret
  return (KindConstraint retN Star (TypeOrigin ret) : KindConstraint me Star (TypeOrigin f) : zipWith (\n arg -> KindConstraint n Star (TypeOrigin arg)) ns args ++ cs ++ cs', me, m'')
kindConstraints m (TypeVar v) =
  case M.lookup v m of
    Just u -> return ([], u, m)
    Nothing -> throwError $ "Unbound type variable " ++ v
kindConstraints m c@(TypeConstructor v) = do
  env <- getEnv
  me <- fresh
  case M.lookup v m of
    Nothing -> case M.lookup v (types env) of
      Nothing -> throwError $ "Unknown type constructor '" ++ v ++ "'"
      Just (kind, _) -> return ([KindConstraint me kind (TypeOrigin c)], me, m)
    Just u -> return ([KindConstraint me (KUnknown u) (TypeOrigin c)], me, m)
kindConstraints m a@(TypeApp t1 t2) = do
  me <- fresh
  (cs1, n1, m1) <- kindConstraints m t1
  (cs2, n2, m2) <- kindConstraints m1 t2
  return (KindConstraint n1 (FunKind (KUnknown n2) (KUnknown me)) (TypeOrigin a) : cs1 ++ cs2, me, m2)
kindConstraints m t = do
  me <- fresh
  return ([KindConstraint me Star (TypeOrigin t)], me, m)

kindConstraintsForRow :: M.Map String Int -> Row -> Check ([KindConstraint], Int, M.Map String Int)
kindConstraintsForRow m r@(RowVar v) = do
  me <- case M.lookup v m of
    Just u -> return u
    Nothing -> fresh
  return ([KindConstraint me Row (RowOrigin r)], me, M.insert v me m)
kindConstraintsForRow m r@REmpty = do
  me <- fresh
  return ([KindConstraint me Row (RowOrigin r)], me, m)
kindConstraintsForRow m r@(RCons _ ty row) = do
  me <- fresh
  (cs1, n1, m1) <- kindConstraints m ty
  (cs2, n2, m2) <- kindConstraintsForRow m1 row
  return (KindConstraint me Row (RowOrigin r) : KindConstraint n1 Star (TypeOrigin ty) : KindConstraint n2 Row (RowOrigin r) : cs1 ++ cs2, me, m2)

solveKindConstraints :: [KindConstraint] -> KindSolution -> Check KindSolution
solveKindConstraints [] s = return s
solveKindConstraints all@(KindConstraint n k t : cs) s = do
  (cs', s') <- rethrow (\err -> "Error in " ++ prettyPrintKindConstraintOrigin t ++ ": " ++ err) $ do
    guardWith "Occurs check failed" $ not $ kindOccursCheck False n k
    let s' = KindSolution $ replaceUnknownKind n k . runKindSolution s
    cs' <- fmap concat $ mapM (substituteKindConstraint n k) cs
    return (cs', s')
  solveKindConstraints cs' s'

substituteKindConstraint :: Int -> Kind -> KindConstraint -> Check [KindConstraint]
substituteKindConstraint n k (KindConstraint m l t)
  | n == m = unifyKinds t k l
  | otherwise = return [KindConstraint m (replaceUnknownKind n k l) t]

replaceUnknownKind :: Int -> Kind -> Kind -> Kind
replaceUnknownKind n k = f
  where
  f (KUnknown m) | m == n = k
  f (FunKind k1 k2) = FunKind (f k2) (f k2)
  f other = other

unifyKinds :: KindConstraintOrigin -> Kind -> Kind -> Check [KindConstraint]
unifyKinds _ (KUnknown u1) (KUnknown u2) | u1 == u2 = return []
unifyKinds t (KUnknown u) k = do
  guardWith "Occurs check failed" $ not $ kindOccursCheck False u k
  return [KindConstraint u k t]
unifyKinds t k (KUnknown u) = do
  guardWith "Occurs check failed" $ not $ kindOccursCheck False u k
  return [KindConstraint u k t]
unifyKinds _ Star Star = return []
unifyKinds _ Row Row = return []
unifyKinds t (FunKind k1 k2) (FunKind k3 k4) = do
  cs1 <- unifyKinds t k1 k3
  cs2 <- unifyKinds t k2 k4
  return $ cs1 ++ cs2
unifyKinds _ k1 k2 = throwError $ "Cannot unify " ++ prettyPrintKind k1 ++ " with " ++ prettyPrintKind k2 ++ "."

kindOccursCheck :: Bool -> Int -> Kind -> Bool
kindOccursCheck b u (KUnknown u') | u == u' = b
kindOccursCheck _ u (FunKind k1 k2) = kindOccursCheck True u k1 || kindOccursCheck True u k2
kindOccursCheck _ _ _ = False