{-# LANGUAGE FlexibleContexts, PatternGuards #-}

{-
  Function adjustTypeInfos annotates every declaration,
  identifier, and application with exact type information.

  This information is derived from the more general information
  found in the AST.

  (c) 2009, Holger Siegel.

-}

module Curry.ExtendedFlat.TypeInference
    ( dispType,
      adjustTypeInfo,
      labelVarsWithTypes,
      uniqueTypeIndices,
      genEquations,
      elimFreeTypes
    ) where

import Debug.Trace

import Text.PrettyPrint.HughesPJ
import Control.Monad.State
import Control.Monad.Reader
import Data.Maybe
import qualified  Data.IntMap as IntMap

import Curry.ExtendedFlat.Type
import Curry.ExtendedFlat.Goodies


trace' msg x = x -- trace msg x 

-- | For every identifier that occurs in the right hand side
--   of a declaration, the polymorphic type variables in its
--   type label are replaced by concrete types.
adjustTypeInfo :: Prog -> Prog
adjustTypeInfo = -- elimFreeTypes .
                 genEquations . 
                 uniqueTypeIndices .
                 labelVarsWithTypes

-- | Displays a TypeExpr as a string
dispType :: TypeExpr -> String
dispType = render . prettyType

prettyType :: TypeExpr -> Doc
prettyType (TVar i) = text ('t':show i)
prettyType (FuncType f x) = parens (prettyType f) <+> text "->" <+> prettyType x
prettyType (TCons qn ts) = let  n = let (m,l) = qnOf qn in m ++ '.' : l
                           in text n <+> hsep (map (parens . prettyType) ts)

prettyAllEqns = render . prettyEqns
    where
      prettyEqn ::(TVarIndex, TypeExpr)  -> Doc
      prettyEqn (l, r) = (char 't' <> int l <+> text "->" <+> prettyType r)

      prettyEqns ((m,l), t, eqns)
          = text m <> char '.' <> text l <+> text "::" <+> prettyType t <> char ':'
            $$ (nest 5 (vcat (map prettyEqn eqns)))


postOrderExpr :: Monad m => (Expr -> m Expr) -> Expr -> m Expr
postOrderExpr f = po
    where po e@(Var _) = f e
          po e@(Lit _) = f e
          po (Comb t n es) = do es' <- mapM po es
                                f (Comb t n es')
          po (Free vs e) = do e' <- po e
                              f (Free vs e')
          po (Let bs e) = do bs' <- mapM poBind bs
                             e'  <- po e
                             f (Let bs' e')
          po (Or l r) = liftM2 Or (po l) (po r) >>= f
          po (Case p t e bs) = do e' <- po e
                                  bs' <- mapM poBranch bs
                                  f (Case p t e' bs')
          poBind  (v, rhs) = do rhs' <- po rhs
                                return (v, rhs')
          poBranch (Branch p rhs) = do rhs' <- po rhs
                                       return (Branch p rhs')




postOrderType :: Monad m => (TypeExpr -> m TypeExpr) -> TypeExpr -> m TypeExpr
postOrderType f = po
    where po e@(TVar _) = f e
          po (FuncType t1 t2) = do t1' <- po t1
                                   t2' <- po t2
                                   f (FuncType t1' t2')
          po (TCons qn ts) = do ts' <- mapM po ts
                                f (TCons qn ts')


visitTVars :: Monad m => (TVarIndex -> m TypeExpr) -> TypeExpr -> m TypeExpr
visitTVars f = postOrderType f'
    where f' (TVar i) = f i
          f' t = return t


-- ----------------------------------------------------------------------
-- ----------------------------------------------------------------------

-- | All identifiers that do not have type annotations are
--   labelled with new type variables
labelVarsWithTypes :: Prog -> Prog
labelVarsWithTypes = updProgFuncs updateFunc
    where 
      updateFunc = map (\func -> let maxtvi = maxFuncTV func + 1 
                                 in trFunc (foo maxtvi) func)
      foo maxtv qn arity visty te r@(External _) = Func qn arity visty te r
      foo maxtv qn arity visty te r@(Rule vs expr) 
          = let expr' = evalState (runReaderT (withVS vs (po expr)) typeMap) maxtv
                typeMap = trace' (show argTypes) $ IntMap.fromList argTypes
                argTypes = [ (vi, t) | VarIndex (Just t) vi <- vs ]
            in Func qn arity visty te (Rule vs expr')

      po :: Expr -> ReaderT TypeMap (State Int) Expr
      -- type information from vi is superseded by type information
      -- from the map. This is okay in the current context, but for
      -- general type inference this would result in loss of information.
      -- (Fix by unifying both types in a later version)
      po e@(Var vi)
          = do vt <- asks (IntMap.lookup $ idxOf vi)
               case vt of
                 Just t -> return (Var vi { typeofVar = Just t })
                 Nothing -> case typeofVar vi of
                              Nothing -> error $ "no type for var " ++ show e
                              _ -> liftM Var (poVarIndex vi)
      po e@(Lit _)
          = return e
      po (Comb t n es)
          = do es' <- mapM po es
               n' <- poQName n
               return (Comb t n' es')
      po (Free vs e) 
          = do vs' <- mapM poVarIndex vs
               e' <- po e
               return (Free vs' e')
      po (Let bs e)
          = do let (vs, es) = unzip bs
               vs' <- mapM poVarIndex vs
               withVS vs' (do es' <- mapM po es
                              e'  <- po e
                              return (Let (zip vs' es') e'))
      po (Or l r)
          = liftM2 Or (po l) (po r)
      po (Case p t e bs)
          = do e' <- po e
               bs' <- mapM poBranch bs
               return (Case p t e' bs')
      poBranch (Branch (Pattern qn vs) rhs) 
          = do qn' <- poQName qn
               vs' <- mapM poVarIndex vs
               withVS vs' (do rhs' <- po rhs
                              return (Branch (Pattern qn' vs') rhs'))
      poBranch (Branch (LPattern l) e) 
          = do rhs' <- po e
               return (Branch (LPattern l) e)
      poVarIndex vi
          = do t <- maybe (lift$freshTVar) return . typeofVar $ vi
               return vi{typeofVar = Just t }

      poQName qn
          = do t <- maybe (lift$freshTVar) 
                        return . typeofQName $ qn
               return qn{typeofQName = Just t }

      withVS :: MonadReader TypeMap m => [VarIndex] -> m a -> m a
      withVS vs action = local (\ m -> foldr (\ v -> IntMap.insert (idxOf v) (fromJust $ typeofVar v)) m vs) action

-- ----------------------------------------------------------------------
-- ----------------------------------------------------------------------

-- | Type variables that occur in the type annotations of QNames
--   are replaced by newly introduced type variables, so that further
--   unification steps will not interfere with parametric polymorphism
uniqueTypeIndices :: Prog -> Prog
uniqueTypeIndices = updProgFuncs (map updateFunc)
    where
      updateFunc func = let firstfree = maxFuncTV func + 1
                        in (updFuncRule (trRule (ruleFoo firstfree) External)) func
      ruleFoo firstfree args expr
          = let expr' = evalState (postOrderExpr relabelTypes expr) firstfree
            in  Rule args expr'

relabelTypes :: Expr ->  State TVarIndex Expr
relabelTypes (Comb ct qname args)
    = do t' <- case typeofQName qname of
                 Just lt -> relabelType lt
                 Nothing -> freshTVar
         return (Comb ct qname {typeofQName = Just t'} args)
relabelTypes (Var v)
    | typeofVar v == Nothing
    = do t <- freshTVar
         return (Var v{typeofVar = Just t})
relabelTypes (Case p t e bs)
    = do bs' <- mapM relabelPatType bs
         return (Case p t e bs')
    where relabelPatType (Branch (Pattern qn vis) e)
              = do t' <- case typeofQName qn of
                           Just lt -> relabelType lt
                           Nothing -> freshTVar
                   return (Branch (Pattern qn {typeofQName = Just t'} vis) e)
          relabelPatType be = return be
relabelTypes t = return t

relabelType :: TypeExpr -> State TVarIndex TypeExpr
relabelType t = evalStateT (visitTVars typeFoo t) IntMap.empty
    where typeFoo i = do m <- get
                         case IntMap.lookup i m of
                           Just v -> return v
                           Nothing -> do v <- lift freshTVar 
                                         modify (IntMap.insert i v)
                                         return v


-- ----------------------------------------------------------------------
-- ----------------------------------------------------------------------                                

type TypeMap =  IntMap.IntMap TypeExpr

type EqnMonad = StateT TypeMap (State TVarIndex)


-- | Specialises all type variables (part of adjustTypeInfo)
genEquations  :: Prog -> Prog
genEquations = updProgFuncs updateFunc
    where 
      updateFunc = map (\func -> let maxtvi = maxFuncTV func + 1 
                                 in trFunc (foo maxtvi) func)
      foo maxtv qn arity visty te r@(External _) = Func qn arity visty te r
      foo maxtv qn arity visty te r@(Rule vs expr) 
          = let h = evalState (execStateT (do argTypes <- mapM varIndexType vs
                                              etype <- equations expr
                                              qnt <- qnType qn
                                              qnt =:= foldr FuncType etype argTypes
                                              return()
                                          ) IntMap.empty) maxtv
            in trace' (prettyAllEqns (qnOf qn,te,IntMap.toList h)) Func qn arity visty (specialiseType h te) (specInRule h (Rule vs expr))
          

equations :: Expr -> EqnMonad TypeExpr
equations = trExpr varIndexType (return . typeofLiteral) combEqn letEqn frEqn orEqn casEqn branchEqn
    where
      combEqn :: (CombType -> QName -> [EqnMonad TypeExpr] -> EqnMonad TypeExpr)
      combEqn _ qn args
          = do resultType <- lift$freshTVar
               argTypes <- sequence args
               tqn <- qnType qn
               tqn =:= foldr FuncType resultType argTypes
               return resultType

      letEqn _ e = e

      frEqn _ e = e

      orEqn l r = do l' <- l
                     r' <- r
                     l' =:= r'

      casEqn :: SrcRef -> CaseType -> EqnMonad TypeExpr -> [(Pattern, EqnMonad TypeExpr)] -> EqnMonad TypeExpr
      casEqn _ _ scr [] = scr >> (lift$freshTVar)
      casEqn _ _ scr ps = do scrt <- scr
                             -- unify patterns with scrutinee
                             mapM_ (unifLhs scrt) ps
                             -- unify right hand sides
                             (p:ps') <- sequence $ map snd ps
                             foldM (=:=) p ps'

      unifLhs scrt (LPattern lit, _)
          = typeofLiteral lit =:= scrt
      unifLhs scrt (Pattern qn vs, _)
          = do qnt <- qnType qn
               argTypes <- mapM varIndexType vs
               qnt =:= foldr FuncType scrt argTypes


      branchEqn :: Pattern -> EqnMonad TypeExpr -> (Pattern, EqnMonad TypeExpr)
      branchEqn p e = (p, e)


unify :: TypeExpr -> TypeExpr -> TypeMap -> TypeMap
-- t =:= u = return t

unify (TVar i) t tm
    | Just s <- IntMap.lookup i tm 
    = unify s t tm
unify s (TVar j) tm
    | Just t <- IntMap.lookup j tm
    = unify s t tm
unify s@(TVar i) t@(TVar j) tm
    | i == j    = tm
    | i < j     = IntMap.insert j s tm
    | i > j     = IntMap.insert i t tm
unify (TVar i) t tm
    = IntMap.insert i t tm
unify s (TVar j) tm
    = IntMap.insert j s tm

unify (FuncType f x) (FuncType g y) tm
    = unify x y (unify f g tm)
unify (TCons m as) (TCons n bs) tm
    | m == n  = foldr ($) tm (zipWith unify as bs)
unify s t _
    = error . render $
      text "Types differ: " <+> prettyType s <+> text "/=" <+> prettyType t


(=:=) :: TypeExpr -> TypeExpr -> EqnMonad TypeExpr
a =:= b = modify (unify a b) >> return a


varIndexType :: VarIndex -> EqnMonad TypeExpr
varIndexType = maybe (lift$freshTVar) return . typeofVar


qnType :: QName -> EqnMonad TypeExpr
qnType = maybe (lift$freshTVar) return . typeofQName

      
freshTVar :: MonadState Int m => m TypeExpr
freshTVar = do nextIdx <- get
               modify succ
               return (TVar nextIdx)

---------------------------------------------------------------------

-- | Type variables that occur in the right hand side of a declaration
--   but not in its type signature are replaced by the unit type ().
--   This function requires that proper type information has been made
--   available by function @adjustTypeInfo@

elimFreeTypes :: Prog -> Prog
elimFreeTypes = updProgFuncs updateFunc
    where 
      updateFunc = map (trFunc foo)
      foo qn arity visty te r@(External _) = Func qn arity visty te r
      foo qn arity visty te r@(Rule vs expr) 
          = let tvs = tvars te
                tvars (TVar vi) = [vi]
                tvars (FuncType t1 t2) = tvars t1 ++ tvars t2
                tvars (TCons _ ts) = concatMap tvars ts
                tfoo t@(TVar vi)
                    | vi `elem` tvs = t
                    | otherwise = TCons (mkQName ("Prelude", "()")) []
                tfoo (FuncType t1 t2) = FuncType (tfoo t1) (tfoo t2)
                tfoo (TCons qn ts) = TCons qn (map tfoo ts)
            in Func qn arity visty te (modifyType tfoo (Rule vs expr))


---------------------------------------------------------------------

maxFuncTV = trFunc (\qn _ _ te r -> max (maxQNameTV qn) (max (maxTypeTV te) (maxRuleTV r)))
    where 
      maxRuleTV = trRule (\vis e -> maximum (maxExprTV e : map maxVarIndexTV vis)) (const (-1))

      maxExprTV :: Expr -> Int
      maxExprTV = trExpr var lit comb lt fr max cas branch
          where var  = maxVarIndexTV
                lit  = const (-1)
                comb _ qn ms = maximum (maxQNameTV qn : ms)
                lt bs e = maximum (e : map maxBindTV bs)
                fr vs e = maximum (e : map maxVarIndexTV vs)
                cas _ _ e ps = maximum (e : ps)
                branch p e = max e (maxPatternTV p)

      maxQNameTV = maybe (-1) maxTypeTV . typeofQName

      maxVarIndexTV = maybe (-1) maxTypeTV . typeofVar

      maxBindTV (vi, e) = max e (maxVarIndexTV vi)

      maxPatternTV (Pattern qn vis) = maximum (maxQNameTV qn : map maxVarIndexTV vis)
      maxPatternTV (LPattern _) = -1

      maxTypeTV = trTypeExpr id tapp max
          where tapp _ args = maximum (-1:args)

--------------------


specialiseType :: TypeMap -> TypeExpr -> TypeExpr
specialiseType m t = trTypeExpr (foo m) TCons FuncType t
    where foo m i = maybe (TVar i) (specialiseType m) (IntMap.lookup i m)


-- boilerplate
specInRule :: TypeMap -> Rule -> Rule
specInRule tm = modifyType (specialiseType tm)



-- boilerplate
modifyType :: (TypeExpr -> TypeExpr) -> Rule -> Rule
modifyType f = updRule (map specInVarIndex) specInExpr id
    where specInExpr
              = trExpr var Lit comb letexp free Or Case alt
          var vi
              = Var (specInVarIndex vi)
          comb ct qn as
              = Comb ct (specInQName qn) as
          letexp bs e
              = Let (map specInBind bs) e
          free vis e
              = Free (map specInVarIndex vis) e
          alt p e
              = Branch (specInPattern p) e

          specInBind (vi, e)
              = (specInVarIndex vi, e)

          specInPattern (Pattern qn vis)
              = Pattern (specInQName qn) (map specInVarIndex vis)
          specInPattern p = p

          specInVarIndex vi
              = vi { typeofVar = fmap f (typeofVar vi)}

          specInQName qn
              = qn { typeofQName = fmap f (typeofQName qn)}