{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveGeneric #-}

-- |
-- Data types for types
--
module Language.PureScript.Types where

import Prelude.Compat
import Protolude (ordNub)

import Control.Arrow (first)
import Control.DeepSeq (NFData)
import Control.Monad ((<=<))
import qualified Data.Aeson as A
import qualified Data.Aeson.TH as A
import Data.List (sortBy)
import Data.Ord (comparing)
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Data.Text (Text)
import qualified Data.Text as T
import GHC.Generics (Generic)

import Language.PureScript.AST.SourcePos
import Language.PureScript.Kinds
import Language.PureScript.Names
import Language.PureScript.Label (Label)
import Language.PureScript.PSString (PSString)

-- |
-- An identifier for the scope of a skolem variable
--
newtype SkolemScope = SkolemScope { runSkolemScope :: Int }
  deriving (Show, Eq, Ord, A.ToJSON, A.FromJSON, Generic)

instance NFData SkolemScope

-- |
-- The type of types
--
data Type
  -- | A unification variable of type Type
  = TUnknown Int
  -- | A named type variable
  | TypeVar Text
  -- | A type-level string
  | TypeLevelString PSString
  -- | A type wildcard, as would appear in a partial type synonym
  | TypeWildcard SourceSpan
  -- | A type constructor
  | TypeConstructor (Qualified (ProperName 'TypeName))
  -- | A type operator. This will be desugared into a type constructor during the
  -- "operators" phase of desugaring.
  | TypeOp (Qualified (OpName 'TypeOpName))
  -- | A type application
  | TypeApp Type Type
  -- | Forall quantifier
  | ForAll Text Type (Maybe SkolemScope)
  -- | A type with a set of type class constraints
  | ConstrainedType Constraint Type
  -- | A skolem constant
  | Skolem Text Int SkolemScope (Maybe SourceSpan)
  -- | An empty row
  | REmpty
  -- | A non-empty row
  | RCons Label Type Type
  -- | A type with a kind annotation
  | KindedType Type Kind
  -- | A placeholder used in pretty printing
  | PrettyPrintFunction Type Type
  -- | A placeholder used in pretty printing
  | PrettyPrintObject Type
  -- | A placeholder used in pretty printing
  | PrettyPrintForAll [Text] Type
  -- | Binary operator application. During the rebracketing phase of desugaring,
  -- this data constructor will be removed.
  | BinaryNoParensType Type Type Type
  -- | Explicit parentheses. During the rebracketing phase of desugaring, this
  -- data constructor will be removed.
  --
  -- Note: although it seems this constructor is not used, it _is_ useful,
  -- since it prevents certain traversals from matching.
  | ParensInType Type
  deriving (Show, Eq, Ord, Generic)

instance NFData Type

-- | Additional data relevant to type class constraints
data ConstraintData
  = PartialConstraintData [[Text]] Bool
  -- ^ Data to accompany a Partial constraint generated by the exhaustivity checker.
  -- It contains (rendered) binder information for those binders which were
  -- not matched, and a flag indicating whether the list was truncated or not.
  -- Note: we use 'Text' here because using 'Binder' would introduce a cyclic
  -- dependency in the module graph.
  deriving (Show, Eq, Ord, Generic)

instance NFData ConstraintData

-- | A typeclass constraint
data Constraint = Constraint
  { constraintClass :: Qualified (ProperName 'ClassName)
  -- ^ constraint class name
  , constraintArgs  :: [Type]
  -- ^ type arguments
  , constraintData  :: Maybe ConstraintData
  -- ^ additional data relevant to this constraint
  } deriving (Show, Eq, Ord, Generic)

instance NFData Constraint

mapConstraintArgs :: ([Type] -> [Type]) -> Constraint -> Constraint
mapConstraintArgs f c = c { constraintArgs = f (constraintArgs c) }

overConstraintArgs :: Functor f => ([Type] -> f [Type]) -> Constraint -> f Constraint
overConstraintArgs f c = (\args -> c { constraintArgs = args }) <$> f (constraintArgs c)

$(A.deriveJSON A.defaultOptions ''Type)
$(A.deriveJSON A.defaultOptions ''Constraint)
$(A.deriveJSON A.defaultOptions ''ConstraintData)

-- | Convert a row to a list of pairs of labels and types
rowToList :: Type -> ([(Label, Type)], Type)
rowToList = go where
  go (RCons name ty row) =
    first ((name, ty) :) (rowToList row)
  go r = ([], r)

-- | Convert a row to a list of pairs of labels and types, sorted by the labels.
rowToSortedList :: Type -> ([(Label, Type)], Type)
rowToSortedList = first (sortBy (comparing fst)) . rowToList

-- | Convert a list of labels and types to a row
rowFromList :: ([(Label, Type)], Type) -> Type
rowFromList (xs, r) = foldr (uncurry RCons) r xs

-- | Check whether a type is a monotype
isMonoType :: Type -> Bool
isMonoType ForAll{} = False
isMonoType (ParensInType t) = isMonoType t
isMonoType (KindedType t _) = isMonoType t
isMonoType _        = True

-- | Universally quantify a type
mkForAll :: [Text] -> Type -> Type
mkForAll args ty = foldl (\t arg -> ForAll arg t Nothing) ty args

-- | Replace a type variable, taking into account variable shadowing
replaceTypeVars :: Text -> Type -> Type -> Type
replaceTypeVars v r = replaceAllTypeVars [(v, r)]

-- | Replace named type variables with types
replaceAllTypeVars :: [(Text, Type)] -> Type -> Type
replaceAllTypeVars = go [] where
  go :: [Text] -> [(Text, Type)] -> Type -> Type
  go _  m (TypeVar v) = fromMaybe (TypeVar v) (v `lookup` m)
  go bs m (TypeApp t1 t2) = TypeApp (go bs m t1) (go bs m t2)
  go bs m f@(ForAll v t sco) | v `elem` keys = go bs (filter ((/= v) . fst) m) f
                             | v `elem` usedVars =
                               let v' = genName v (keys ++ bs ++ usedVars)
                                   t' = go bs [(v, TypeVar v')] t
                               in ForAll v' (go (v' : bs) m t') sco
                             | otherwise = ForAll v (go (v : bs) m t) sco
    where
      keys = map fst m
      usedVars = concatMap (usedTypeVariables . snd) m
  go bs m (ConstrainedType c t) = ConstrainedType (mapConstraintArgs (map (go bs m)) c) (go bs m t)
  go bs m (RCons name' t r) = RCons name' (go bs m t) (go bs m r)
  go bs m (KindedType t k) = KindedType (go bs m t) k
  go bs m (BinaryNoParensType t1 t2 t3) = BinaryNoParensType (go bs m t1) (go bs m t2) (go bs m t3)
  go bs m (ParensInType t) = ParensInType (go bs m t)
  go _  _ ty = ty

  genName orig inUse = try' 0 where
    try' :: Integer -> Text
    try' n | (orig <> T.pack (show n)) `elem` inUse = try' (n + 1)
           | otherwise = orig <> T.pack (show n)

-- | Collect all type variables appearing in a type
usedTypeVariables :: Type -> [Text]
usedTypeVariables = ordNub . everythingOnTypes (++) go where
  go (TypeVar v) = [v]
  go _ = []

-- | Collect all free type variables appearing in a type
freeTypeVariables :: Type -> [Text]
freeTypeVariables = ordNub . go [] where
  go :: [Text] -> Type -> [Text]
  go bound (TypeVar v) | v `notElem` bound = [v]
  go bound (TypeApp t1 t2) = go bound t1 ++ go bound t2
  go bound (ForAll v t _) = go (v : bound) t
  go bound (ConstrainedType c t) = concatMap (go bound) (constraintArgs c) ++ go bound t
  go bound (RCons _ t r) = go bound t ++ go bound r
  go bound (KindedType t _) = go bound t
  go bound (BinaryNoParensType t1 t2 t3) = go bound t1 ++ go bound t2 ++ go bound t3
  go bound (ParensInType t) = go bound t
  go _ _ = []

-- | Universally quantify over all type variables appearing free in a type
quantify :: Type -> Type
quantify ty = foldr (\arg t -> ForAll arg t Nothing) ty $ freeTypeVariables ty

-- | Move all universal quantifiers to the front of a type
moveQuantifiersToFront :: Type -> Type
moveQuantifiersToFront = go [] [] where
  go qs cs (ForAll q ty sco) = go ((q, sco) : qs) cs ty
  go qs cs (ConstrainedType c ty) = go qs (c : cs) ty
  go qs cs ty = foldl (\ty' (q, sco) -> ForAll q ty' sco) (foldl (flip ConstrainedType) ty cs) qs

-- | Check if a type contains wildcards
containsWildcards :: Type -> Bool
containsWildcards = everythingOnTypes (||) go where
  go :: Type -> Bool
  go TypeWildcard{} = True
  go _ = False

-- | Check if a type contains `forall`
containsForAll :: Type -> Bool
containsForAll = everythingOnTypes (||) go where
  go :: Type -> Bool
  go ForAll{} = True
  go _ = False

everywhereOnTypes :: (Type -> Type) -> Type -> Type
everywhereOnTypes f = go where
  go (TypeApp t1 t2) = f (TypeApp (go t1) (go t2))
  go (ForAll arg ty sco) = f (ForAll arg (go ty) sco)
  go (ConstrainedType c ty) = f (ConstrainedType (mapConstraintArgs (map go) c) (go ty))
  go (RCons name ty rest) = f (RCons name (go ty) (go rest))
  go (KindedType ty k) = f (KindedType (go ty) k)
  go (PrettyPrintFunction t1 t2) = f (PrettyPrintFunction (go t1) (go t2))
  go (PrettyPrintObject t) = f (PrettyPrintObject (go t))
  go (PrettyPrintForAll args t) = f (PrettyPrintForAll args (go t))
  go (BinaryNoParensType t1 t2 t3) = f (BinaryNoParensType (go t1) (go t2) (go t3))
  go (ParensInType t) = f (ParensInType (go t))
  go other = f other

everywhereOnTypesTopDown :: (Type -> Type) -> Type -> Type
everywhereOnTypesTopDown f = go . f where
  go (TypeApp t1 t2) = TypeApp (go (f t1)) (go (f t2))
  go (ForAll arg ty sco) = ForAll arg (go (f ty)) sco
  go (ConstrainedType c ty) = ConstrainedType (mapConstraintArgs (map (go . f)) c) (go (f ty))
  go (RCons name ty rest) = RCons name (go (f ty)) (go (f rest))
  go (KindedType ty k) = KindedType (go (f ty)) k
  go (PrettyPrintFunction t1 t2) = PrettyPrintFunction (go (f t1)) (go (f t2))
  go (PrettyPrintObject t) = PrettyPrintObject (go (f t))
  go (PrettyPrintForAll args t) = PrettyPrintForAll args (go (f t))
  go (BinaryNoParensType t1 t2 t3) = BinaryNoParensType (go (f t1)) (go (f t2)) (go (f t3))
  go (ParensInType t) = ParensInType (go (f t))
  go other = f other

everywhereOnTypesM :: Monad m => (Type -> m Type) -> Type -> m Type
everywhereOnTypesM f = go where
  go (TypeApp t1 t2) = (TypeApp <$> go t1 <*> go t2) >>= f
  go (ForAll arg ty sco) = (ForAll arg <$> go ty <*> pure sco) >>= f
  go (ConstrainedType c ty) = (ConstrainedType <$> overConstraintArgs (mapM go) c <*> go ty) >>= f
  go (RCons name ty rest) = (RCons name <$> go ty <*> go rest) >>= f
  go (KindedType ty k) = (KindedType <$> go ty <*> pure k) >>= f
  go (PrettyPrintFunction t1 t2) = (PrettyPrintFunction <$> go t1 <*> go t2) >>= f
  go (PrettyPrintObject t) = (PrettyPrintObject <$> go t) >>= f
  go (PrettyPrintForAll args t) = (PrettyPrintForAll args <$> go t) >>= f
  go (BinaryNoParensType t1 t2 t3) = (BinaryNoParensType <$> go t1 <*> go t2 <*> go t3) >>= f
  go (ParensInType t) = (ParensInType <$> go t) >>= f
  go other = f other

everywhereOnTypesTopDownM :: Monad m => (Type -> m Type) -> Type -> m Type
everywhereOnTypesTopDownM f = go <=< f where
  go (TypeApp t1 t2) = TypeApp <$> (f t1 >>= go) <*> (f t2 >>= go)
  go (ForAll arg ty sco) = ForAll arg <$> (f ty >>= go) <*> pure sco
  go (ConstrainedType c ty) = ConstrainedType <$> overConstraintArgs (mapM (go <=< f)) c <*> (f ty >>= go)
  go (RCons name ty rest) = RCons name <$> (f ty >>= go) <*> (f rest >>= go)
  go (KindedType ty k) = KindedType <$> (f ty >>= go) <*> pure k
  go (PrettyPrintFunction t1 t2) = PrettyPrintFunction <$> (f t1 >>= go) <*> (f t2 >>= go)
  go (PrettyPrintObject t) = PrettyPrintObject <$> (f t >>= go)
  go (PrettyPrintForAll args t) = PrettyPrintForAll args <$> (f t >>= go)
  go (BinaryNoParensType t1 t2 t3) = BinaryNoParensType <$> (f t1 >>= go) <*> (f t2 >>= go) <*> (f t3 >>= go)
  go (ParensInType t) = ParensInType <$> (f t >>= go)
  go other = f other

everythingOnTypes :: (r -> r -> r) -> (Type -> r) -> Type -> r
everythingOnTypes (<+>) f = go where
  go t@(TypeApp t1 t2) = f t <+> go t1 <+> go t2
  go t@(ForAll _ ty _) = f t <+> go ty
  go t@(ConstrainedType c ty) = foldl (<+>) (f t) (map go (constraintArgs c)) <+> go ty
  go t@(RCons _ ty rest) = f t <+> go ty <+> go rest
  go t@(KindedType ty _) = f t <+> go ty
  go t@(PrettyPrintFunction t1 t2) = f t <+> go t1 <+> go t2
  go t@(PrettyPrintObject t1) = f t <+> go t1
  go t@(PrettyPrintForAll _ t1) = f t <+> go t1
  go t@(BinaryNoParensType t1 t2 t3) = f t <+> go t1 <+> go t2 <+> go t3
  go t@(ParensInType t1) = f t <+> go t1
  go other = f other

everythingWithContextOnTypes :: s -> r -> (r -> r -> r) -> (s -> Type -> (s, r)) -> Type -> r
everythingWithContextOnTypes s0 r0 (<+>) f = go' s0 where
  go' s t = let (s', r) = f s t in r <+> go s' t
  go s (TypeApp t1 t2) = go' s t1 <+> go' s t2
  go s (ForAll _ ty _) = go' s ty
  go s (ConstrainedType c ty) = foldl (<+>) r0 (map (go' s) (constraintArgs c)) <+> go' s ty
  go s (RCons _ ty rest) = go' s ty <+> go' s rest
  go s (KindedType ty _) = go' s ty
  go s (PrettyPrintFunction t1 t2) = go' s t1 <+> go' s t2
  go s (PrettyPrintObject t1) = go' s t1
  go s (PrettyPrintForAll _ t1) = go' s t1
  go s (BinaryNoParensType t1 t2 t3) = go' s t1 <+> go' s t2 <+> go' s t3
  go s (ParensInType t1) = go' s t1
  go _ _ = r0