{- |
    Module      :  $Header$
    Description :  Lifting of lambda-expressions and local functions
    Copyright   :  (c) 2001 - 2003 Wolfgang Lux
                       2011 - 2015 Björn Peemöller
                       2016 - 2017 Finn Teegen
    License     :  BSD-3-clause

    Maintainer  :  bjp@informatik.uni-kiel.de
    Stability   :  experimental
    Portability :  portable

   After desugaring and simplifying the code, the compiler lifts all local
   function declarations to the top-level keeping only local variable
   declarations. The algorithm used here is similar to Johnsson's, consisting
   of two phases. First, we abstract each local function declaration,
   adding its free variables as initial parameters and update all calls
   to take these variables into account. Second, all local function
   declarations are collected and lifted to the top-level.
-}
{-# LANGUAGE CPP #-}
module Transformations.Lift (lift) where

#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative        ((<$>), (<*>))
#endif
import           Control.Arrow              (first)
import qualified Control.Monad.State as S   (State, runState, gets, modify)
import           Data.List
import qualified Data.Map            as Map (Map, empty, insert, lookup)
import           Data.Maybe                 (mapMaybe, fromJust)
import qualified Data.Set            as Set (fromList, toList, unions)

import Curry.Base.Ident
import Curry.Base.SpanInfo
import Curry.Syntax

import Base.AnnotExpr
import Base.Expr
import Base.Messages                        (internalError)
import Base.SCC
import Base.Types
import Base.TypeSubst
import Base.Typing
import Base.Utils

import Env.Value

lift :: ValueEnv -> Module Type -> (Module Type, ValueEnv)
lift vEnv (Module spi ps m es is ds) = (lifted, valueEnv s')
  where
  (ds', s') = S.runState (mapM (absDecl "" []) ds) initState
  initState = LiftState m vEnv Map.empty
  lifted    = Module spi ps m es is $ concatMap liftFunDecl ds'

-- -----------------------------------------------------------------------------
-- Abstraction
-- -----------------------------------------------------------------------------

-- Besides adding the free variables to every (local) function, the
-- abstraction pass also has to update the type environment in order to
-- reflect the new types of the abstracted functions. As usual, we use a
-- state monad transformer in order to pass the type environment
-- through. The environment constructed in the abstraction phase maps
-- each local function declaration onto its replacement expression,
-- i.e. the function applied to its free variables. In order to generate
-- correct type annotations for an inserted replacement expression, we also
-- save a function's original type. The original type is later unified with
-- the concrete type of the replaced expression to obtain a type substitution
-- which is then applied to the replacement expression.

type AbstractEnv = Map.Map Ident (Expression Type, Type)

data LiftState = LiftState
  { moduleIdent :: ModuleIdent
  , valueEnv    :: ValueEnv
  , abstractEnv :: AbstractEnv
  }

type LiftM a = S.State LiftState a

getModuleIdent :: LiftM ModuleIdent
getModuleIdent = S.gets moduleIdent

getValueEnv :: LiftM ValueEnv
getValueEnv = S.gets valueEnv

modifyValueEnv :: (ValueEnv -> ValueEnv) -> LiftM ()
modifyValueEnv f = S.modify $ \s -> s { valueEnv = f $ valueEnv s }

getAbstractEnv :: LiftM AbstractEnv
getAbstractEnv = S.gets abstractEnv

withLocalAbstractEnv :: AbstractEnv -> LiftM a -> LiftM a
withLocalAbstractEnv ae act = do
  old <- getAbstractEnv
  S.modify $ \s -> s { abstractEnv = ae }
  res <- act
  S.modify $ \s -> s { abstractEnv = old }
  return res

absDecl :: String -> [Ident] -> Decl Type -> LiftM (Decl Type)
absDecl _   lvs (FunctionDecl p ty f eqs) = FunctionDecl p ty f
                                            <$> mapM (absEquation lvs) eqs
absDecl pre lvs (PatternDecl     p t rhs) = PatternDecl p t
                                            <$> absRhs pre lvs rhs
absDecl _   _   d                         = return d

absEquation :: [Ident] -> Equation Type -> LiftM (Equation Type)
absEquation lvs (Equation p lhs@(FunLhs _ f ts) rhs) =
  Equation p lhs <$> absRhs (idName f ++ ".") lvs' rhs
  where lvs' = lvs ++ bv ts
absEquation _ _ = error "Lift.absEquation: no pattern match"

absRhs :: String -> [Ident] -> Rhs Type -> LiftM (Rhs Type)
absRhs pre lvs (SimpleRhs p e _) = simpleRhs p <$> absExpr pre lvs e
absRhs _   _   _                 = error "Lift.absRhs: no simple RHS"

-- Within a declaration group we have to split the list of declarations
-- into the function and value declarations. Only the function
-- declarations are affected by the abstraction algorithm; the value
-- declarations are left unchanged except for abstracting their right
-- hand sides.

-- The abstraction of a recursive declaration group is complicated by the
-- fact that not all functions need to call each in a recursive
-- declaration group. E.g., in the following example neither 'g' nor 'h'
-- call each other.
--
--   f = g True
--     where x = h 1
--           h z = y + z
--           y = g False
--           g z = if z then x else 0
--
-- Because of this fact, 'g' and 'h' can be abstracted separately by adding
-- only 'y' to 'h' and 'x' to 'g'. On the other hand, in the following example
--
--   f x y = g 4
--     where g p = h p + x
--           h q = k + y + q
--           k = g x
--
-- the local function 'g' uses 'h', so the free variables
-- of 'h' have to be added to 'g' as well. However, because
-- 'h' does not call 'g' it is sufficient to add only
-- 'k' and 'y' (and not 'x') to its definition. We handle this by computing
-- the dependency graph between the functions and splitting this graph into
-- its strongly connected components. Each component is then processed
-- separately, adding the free variables in the group to its functions.

-- We have to be careful with local declarations within desugared case
-- expressions. If some of the cases have guards, e.g.,
--
--   case e of
--     x | x < 1 -> 1
--     x -> let double y = y * y in double x
--
-- the desugarer at present may duplicate code. While there is no problem
-- with local variable declaration being duplicated, we must avoid to
-- lift local function declarations more than once. Therefore
-- 'absFunDecls' transforms only those function declarations
-- that have not been lifted and discards the other declarations. Note
-- that it is easy to check whether a function has been lifted by
-- checking whether an entry for its transformed name is present
-- in the value environment.

absDeclGroup :: String -> [Ident] -> [Decl Type] -> Expression Type
             -> LiftM (Expression Type)
absDeclGroup pre lvs ds e = do
  m <- getModuleIdent
  absFunDecls pre lvs' (scc bv (qfv m) fds) vds e
  where lvs' = lvs ++ bv vds
        (fds, vds) = partition isFunDecl ds

absFunDecls :: String -> [Ident] -> [[Decl Type]] -> [Decl Type]
            -> Expression Type -> LiftM (Expression Type)
absFunDecls pre lvs []         vds e = do
  vds' <- mapM (absDecl pre lvs) vds
  e' <- absExpr pre lvs e
  return (Let NoSpanInfo vds' e')
absFunDecls pre lvs (fds:fdss) vds e = do
  m <- getModuleIdent
  env <- getAbstractEnv
  vEnv <- getValueEnv
  let -- defined functions
      fs      = bv fds
      -- function types
      ftys    = map extractFty fds
      extractFty (FunctionDecl _ _ f (Equation _ (FunLhs _ _ ts) rhs : _)) =
        (f, foldr TypeArrow (typeOf rhs) $ map typeOf ts)
      extractFty _                                                         =
        internalError "Lift.absFunDecls.extractFty"
      -- typed free variables on the right-hand sides
      fvsRhs  = Set.unions
                  [ Set.fromList (filter (not . isDummyType . fst)
                                         (maybe [(ty, v)]
                                                (qafv' ty)
                                                (Map.lookup v env)))
                  | (ty, v) <- concatMap (qafv m) fds ]
      -- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
      -- !!! HACK: When calculating the typed free variables on the     !!!
      -- !!! right-hand side, we have to filter out the ones annotated  !!!
      -- !!! with dummy types (see below). Additionally, we have to be  !!!
      -- !!! careful when we calculate the typed free variables in a    !!!
      -- !!! replacement expression: We have to unify the original      !!!
      -- !!! function type with the instantiated function type in order !!!
      -- !!! to obtain a type substitution that can then be applied to  !!!
      -- !!! the typed free variables in the replacement expression.    !!!
      -- !!! This is analogous to the procedure when inserting a        !!!
      -- !!! replacement expression with a correct type annotation      !!!
      -- !!! (see 'absType' in 'absExpr' below).                        !!!
      -- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
      qafv' ty (re, fty) =
        let unifier = matchType fty ty idSubst
        in  map (\(ty', v) -> (subst unifier ty', v)) $ qafv m re
      -- free variables that are local
      fvs     = filter ((`elem` lvs) . snd) (Set.toList fvsRhs)
      -- extended abstraction environment
      env'    = foldr bindF env fs
      -- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
      -- !!! HACK: Since we do not know how to annotate the function    !!!
      -- !!! call within the replacement expression until the replace-  !!!                          !!!
      -- !!! ment expression is actually inserted (see 'absType' in     !!!
      -- !!! 'absExpr' below), we use a dummy type for this. In turn,   !!!
      -- !!! this dummy type has to be filtered out when calculating    !!!
      -- !!! the typed free variables on right-hand sides (see above).  !!!                                             !!!
      -- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
      bindF f =
        Map.insert f ( apply (mkFun m pre dummyType f) (map (uncurry mkVar) fvs)
                     , fromJust $ lookup f ftys )
      -- newly abstracted functions
      fs'     = filter (\f -> null $ lookupValue (liftIdent pre f) vEnv) fs
  withLocalAbstractEnv env' $ do
    -- add variables to functions
    fds' <- mapM (absFunDecl pre fvs lvs) [d | d <- fds, any (`elem` fs') (bv d)]
    -- abstract remaining declarations
    e'   <- absFunDecls pre lvs fdss vds e
    return (Let NoSpanInfo fds' e')

-- When the free variables of a function are abstracted, the type of the
-- function must be changed as well.

absFunDecl :: String -> [(Type, Ident)] -> [Ident] -> Decl Type
           -> LiftM (Decl Type)
absFunDecl pre fvs lvs (FunctionDecl p _ f eqs) = do
  m <- getModuleIdent
  d <- absDecl pre lvs $ FunctionDecl p undefined f' eqs'
  let FunctionDecl _ _ _ eqs'' = d
  modifyValueEnv $ bindGlobalInfo
    (\qf tySc -> Value qf False (eqnArity $ head eqs') tySc) m f' $ polyType ty''
  return $ FunctionDecl p ty'' f' eqs''
  where f' = liftIdent pre f
        ty' = foldr TypeArrow (typeOf rhs') (map typeOf ts')
          where Equation _ (FunLhs _ _ ts') rhs' = head eqs'
        ty'' = genType ty'
        eqs' = map addVars eqs
        genType ty''' = subst (foldr2 bindSubst idSubst tvs tvs') ty'''
          where tvs = nub (typeVars ty''')
                tvs' = map TypeVariable [0 ..]
        addVars (Equation p' (FunLhs _ _ ts) rhs) =
          Equation p' (FunLhs NoSpanInfo
            f' (map (uncurry (VariablePattern NoSpanInfo)) fvs ++ ts)) rhs
        addVars _ = error "Lift.absFunDecl.addVars: no pattern match"
absFunDecl pre _ _ (ExternalDecl p vs) = ExternalDecl p <$> mapM (absVar pre) vs
absFunDecl _ _ _ _ = error "Lift.absFunDecl: no pattern match"

absVar :: String -> Var Type -> LiftM (Var Type)
absVar pre (Var ty f) = do
  m <- getModuleIdent
  modifyValueEnv $ bindGlobalInfo
    (\qf tySc -> Value qf False (arrowArity ty) tySc) m f' $ polyType ty
  return $ Var ty f'
  where f' = liftIdent pre f

absExpr :: String -> [Ident] -> Expression Type -> LiftM (Expression Type)
absExpr _   _   l@(Literal     _ _ _) = return l
absExpr pre lvs var@(Variable _ ty v)
  | isQualified v = return var
  | otherwise     = do
    getAbstractEnv >>= \env -> case Map.lookup (unqualify v) env of
      Nothing       -> return var
      Just (e, fty) -> let unifier = matchType fty ty idSubst
                       in  absExpr pre lvs $ fmap (subst unifier) $ absType ty e
  where -- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        -- !!! HACK: When inserting the replacement expression for an     !!!
        -- !!! abstracted function, we have to unify the original         !!!
        -- !!! function type with the instantiated function type in order !!!
        -- !!! to obtain a type substitution that can then be applied to  !!!
        -- !!! the type annotations in the replacement expression.        !!!
        -- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        absType ty' (Variable spi _ v') = Variable spi ty' v'
        absType ty' (Apply   spi e1 e2) =
          Apply spi (absType (TypeArrow (typeOf e2) ty') e1) e2
        absType _ _ = internalError "Lift.absExpr.absType"
absExpr _   _   c@(Constructor _ _ _) = return c
absExpr pre lvs (Apply       spi e1 e2) = Apply spi <$> absExpr pre lvs e1
                                                    <*> absExpr pre lvs e2
absExpr pre lvs (Let            _ ds e) = absDeclGroup pre lvs ds e
absExpr pre lvs (Case      spi ct e bs) =
  Case spi ct <$> absExpr pre lvs e <*> mapM (absAlt pre lvs) bs
absExpr pre lvs (Typed        spi e ty) =
  flip (Typed spi) ty <$> absExpr pre lvs e
absExpr _   _   e                   = internalError $ "Lift.absExpr: " ++ show e

absAlt :: String -> [Ident] -> Alt Type -> LiftM (Alt Type)
absAlt pre lvs (Alt p t rhs) = Alt p t <$> absRhs pre lvs' rhs
  where lvs' = lvs ++ bv t

-- -----------------------------------------------------------------------------
-- Lifting
-- -----------------------------------------------------------------------------

-- After the abstraction pass, all local function declarations are lifted
-- to the top-level.

liftFunDecl :: Eq a => Decl a -> [Decl a]
liftFunDecl (FunctionDecl p a f eqs) =
  FunctionDecl p a f eqs' : map renameFunDecl (concat dss')
  where (eqs', dss') = unzip $ map liftEquation eqs
liftFunDecl d                        = [d]

liftVarDecl :: Eq a => Decl a -> (Decl a, [Decl a])
liftVarDecl (PatternDecl   p t rhs) = (PatternDecl p t rhs', ds')
  where (rhs', ds') = liftRhs rhs
liftVarDecl ex@(FreeDecl       _ _) = (ex, [])
liftVarDecl _ = error "Lift.liftVarDecl: no pattern match"

liftEquation :: Eq a => Equation a -> (Equation a, [Decl a])
liftEquation (Equation p lhs rhs) = (Equation p lhs rhs', ds')
  where (rhs', ds') = liftRhs rhs

liftRhs :: Eq a => Rhs a -> (Rhs a, [Decl a])
liftRhs (SimpleRhs p e _) = first (simpleRhs p) (liftExpr e)
liftRhs _                 = error "Lift.liftRhs: no pattern match"

liftDeclGroup :: Eq a => [Decl a] -> ([Decl a], [Decl a])
liftDeclGroup ds = (vds', concat (map liftFunDecl fds ++ dss'))
  where (fds , vds ) = partition isFunDecl ds
        (vds', dss') = unzip $ map liftVarDecl vds

liftExpr :: Eq a => Expression a -> (Expression a, [Decl a])
liftExpr l@(Literal     _ _ _) = (l, [])
liftExpr v@(Variable    _ _ _) = (v, [])
liftExpr c@(Constructor _ _ _) = (c, [])
liftExpr (Apply       spi e1 e2) = (Apply spi e1' e2', ds1 ++ ds2)
  where (e1', ds1) = liftExpr e1
        (e2', ds2) = liftExpr e2
liftExpr (Let          _ ds e) = (mkLet ds' e', ds1 ++ ds2)
  where (ds', ds1) = liftDeclGroup ds
        (e' , ds2) = liftExpr e
liftExpr (Case    spi ct e alts) = (Case spi ct e' alts', concat $ ds' : dss')
  where (e'   , ds' ) = liftExpr e
        (alts', dss') = unzip $ map liftAlt alts
liftExpr (Typed        spi e ty) =
  (Typed spi e' ty, ds) where (e', ds) = liftExpr e
liftExpr _ = internalError "Lift.liftExpr"

liftAlt :: Eq a => Alt a -> (Alt a, [Decl a])
liftAlt (Alt p t rhs) = (Alt p t rhs', ds') where (rhs', ds') = liftRhs rhs

-- -----------------------------------------------------------------------------
-- Renaming
-- -----------------------------------------------------------------------------

-- After all local function declarations have been lifted to top-level, we
-- may have to rename duplicate function arguments. Due to polymorphic let
-- declarations it could happen that an argument was added multiple times
-- instantiated with different types during the abstraction pass beforehand.

type RenameMap a = [((a, Ident), Ident)]

renameFunDecl :: Eq a => Decl a -> Decl a
renameFunDecl (FunctionDecl p a f eqs) =
  FunctionDecl p a f (map renameEquation eqs)
renameFunDecl d                        = d

renameEquation :: Eq a => Equation a -> Equation a
renameEquation (Equation p lhs rhs) = Equation p lhs' (renameRhs rm rhs)
  where (rm, lhs') = renameLhs lhs

renameLhs :: Eq a => Lhs a -> (RenameMap a, Lhs a)
renameLhs (FunLhs spi f ts) = (rm, FunLhs spi f ts')
  where (rm, ts') = foldr renamePattern ([], []) ts
renameLhs _             = error "Lift.renameLhs"

renamePattern :: Eq a => Pattern a -> (RenameMap a, [Pattern a])
              -> (RenameMap a, [Pattern a])
renamePattern (VariablePattern spi a v) (rm, ts)
  | v `elem` varPatNames ts =
    let v' = updIdentName (++ ("." ++ show (length rm))) v
    in  (((a, v), v') : rm, VariablePattern spi a v' : ts)
renamePattern t                     (rm, ts) = (rm, t : ts)

renameRhs :: Eq a => RenameMap a -> Rhs a -> Rhs a
renameRhs rm (SimpleRhs p e _) = simpleRhs p (renameExpr rm e)
renameRhs _  _                 = error "Lift.renameRhs"

renameExpr :: Eq a => RenameMap a -> Expression a -> Expression a
renameExpr _  l@(Literal       _ _ _) = l
renameExpr rm v@(Variable   spi a v')
  | isQualified v' = v
  | otherwise      = case lookup (a, unqualify v') rm of
                       Just v'' -> Variable spi a (qualify v'')
                       _        -> v
renameExpr _  c@(Constructor _ _ _) = c
renameExpr rm (Typed       spi e ty) = Typed spi (renameExpr rm e) ty
renameExpr rm (Apply       spi e1 e2) =
  Apply spi (renameExpr rm e1) (renameExpr rm e2)
renameExpr rm (Let         spi ds e) =
  Let spi (map (renameDecl rm) ds) (renameExpr rm e)
renameExpr rm (Case    spi ct e alts) =
  Case spi ct (renameExpr rm e) (map (renameAlt rm) alts)
renameExpr _  _                   = error "Lift.renameExpr"

renameDecl :: Eq a => RenameMap a -> Decl a -> Decl a
renameDecl rm (PatternDecl p t rhs) = PatternDecl p t (renameRhs rm rhs)
renameDecl _  d                     = d

renameAlt :: Eq a => RenameMap a -> Alt a -> Alt a
renameAlt rm (Alt p t rhs) = Alt p t (renameRhs rm rhs)

-- ---------------------------------------------------------------------------
-- Auxiliary definitions
-- ---------------------------------------------------------------------------

isFunDecl :: Decl a -> Bool
isFunDecl (FunctionDecl _ _ _ _) = True
isFunDecl (ExternalDecl _ _    ) = True
isFunDecl _                      = False

mkFun :: ModuleIdent -> String -> a -> Ident -> Expression a
mkFun m pre a = Variable NoSpanInfo a . qualifyWith m . liftIdent pre

liftIdent :: String -> Ident -> Ident
liftIdent prefix x = renameIdent (mkIdent $ prefix ++ showIdent x) $ idUnique x

varPatNames :: [Pattern a] -> [Ident]
varPatNames = mapMaybe varPatName

varPatName :: Pattern a -> Maybe Ident
varPatName (VariablePattern _ _ i) = Just i
varPatName _                     = Nothing

dummyType :: Type
dummyType = TypeForall [] undefined

isDummyType :: Type -> Bool
isDummyType (TypeForall [] _) = True
isDummyType _                 = False