{-|
Module      : Idris.PartialEval
Description : Implementation of a partial evaluator.
Copyright   :
License     : BSD3
Maintainer  : The Idris Community.
-}
{-# LANGUAGE PatternGuards #-}

module Idris.PartialEval(
    partial_eval, getSpecApps, specType
  , mkPE_TyDecl, mkPE_TermDecl, PEArgType(..)
  , pe_app, pe_def, pe_clauses, pe_simple
  ) where

import Idris.AbsSyntax
import Idris.Delaborate

import Idris.Core.TT
import Idris.Core.CaseTree
import Idris.Core.Evaluate

import Control.Monad.State
import Control.Applicative
import Data.Maybe
import Debug.Trace

-- | Data type representing binding-time annotations for partial evaluation of arguments
data PEArgType = ImplicitS -- ^ Implicit static argument
               | ImplicitD -- ^ Implicit dynamic argument
               | ExplicitS -- ^ Explicit static argument
               | ExplicitD -- ^ Explicit dynamic argument
               | UnifiedD  -- ^ Erasable dynamic argument (found under unification)
  deriving (Eq, Show)

-- | A partially evaluated function. pe_app captures the lhs of the
-- new definition, pe_def captures the rhs, and pe_clauses is the
-- specialised implementation.
--
-- pe_simple is set if the result is always reducible, because in such
-- a case we'll also need to reduce the static argument
data PEDecl = PEDecl { pe_app :: PTerm, -- new application
                       pe_def :: PTerm, -- old application
                       pe_clauses :: [(PTerm, PTerm)], -- clauses of new application
                       pe_simple :: Bool -- if just one reducible clause
                     }

-- | Partially evaluates given terms under the given context.
-- It is an error if partial evaluation fails to make any progress.
-- Making progress is defined as: all of the names given with explicit
-- reduction limits (in practice, the function being specialised)
-- must have reduced at least once.
-- If we don't do this, we might end up making an infinite function after
-- applying the transformation.
partial_eval :: Context
            -> [(Name, Maybe Int)]
            -> [Either Term (Term, Term)]
            -> Maybe [Either Term (Term, Term)]
partial_eval ctxt ns_in tms = mapM peClause tms where
   ns = squash ns_in
   squash ((n, Just x) : ns)
       | Just (Just y) <- lookup n ns
                   = squash ((n, Just (x + y)) : drop n ns)
       | otherwise = (n, Just x) : squash ns
   squash (n : ns) = n : squash ns
   squash [] = []

   drop n ((m, _) : ns) | n == m = ns
   drop n (x : ns) = x : drop n ns
   drop n [] = []

   -- If the term is not a clause, it is simply kept as is
   peClause (Left t) = Just $ Left t
   -- If the term is a clause, specialise the right hand side
   peClause (Right (lhs, rhs))
       = let (rhs', reductions) = specialise ctxt [] (map toLimit ns) rhs in
             do when (length tms == 1) $ checkProgress ns reductions
                return (Right (lhs, rhs'))

   -- TMP HACK until I do PE by WHNF rather than using main evaluator
   toLimit (n, Nothing) | isTCDict n ctxt = (n, 2)
   toLimit (n, Nothing) = (n, 65536) -- somewhat arbitrary reduction limit
   toLimit (n, Just l) = (n, l)

   checkProgress ns [] = return ()
   checkProgress ns ((n, r) : rs)
      | Just (Just start) <- lookup n ns
             = if start <= 1 || r < start then checkProgress ns rs else Nothing
      | otherwise = checkProgress ns rs

-- | Specialises the type of a partially evaluated TT function returning
-- a pair of the specialised type and the types of expected arguments.
specType :: [(PEArgType, Term)] -> Type -> (Type, [(PEArgType, Term)])
specType args ty = let (t, args') = runState (unifyEq args ty) [] in
                       (st (map fst args') t, map fst args')
  where
    -- Specialise static argument in type by let-binding provided value instead
    -- of expecting it as a function argument
    st ((ExplicitS, v) : xs) (Bind n (Pi _ t _) sc)
         = Bind n (Let t v) (st xs sc)
    st ((ImplicitS, v) : xs) (Bind n (Pi _ t _) sc)
         = Bind n (Let t v) (st xs sc)
    -- Erase argument from function type
    st ((UnifiedD, _) : xs) (Bind n (Pi _ t _) sc)
         = st xs sc
    -- Keep types as is
    st (_ : xs) (Bind n (Pi i t k) sc)
         = Bind n (Pi i t k) (st xs sc)
    st _ t = t

    -- Erase implicit dynamic argument if existing argument shares it value,
    -- by substituting the value of previous argument
    unifyEq (imp@(ImplicitD, v) : xs) (Bind n (Pi i t k) sc)
         = do amap <- get
              case lookup imp amap of
                   Just n' ->
                        do put (amap ++ [((UnifiedD, Erased), n)])
                           sc' <- unifyEq xs (subst n (P Bound n' Erased) sc)
                           return (Bind n (Pi i t k) sc') -- erase later
                   _ -> do put (amap ++ [(imp, n)])
                           sc' <- unifyEq xs sc
                           return (Bind n (Pi i t k) sc')
    unifyEq (x : xs) (Bind n (Pi i t k) sc)
         = do args <- get
              put (args ++ [(x, n)])
              sc' <- unifyEq xs sc
              return (Bind n (Pi i t k) sc')
    unifyEq xs t = do args <- get
                      put (args ++ (zip xs (repeat (sUN "_"))))
                      return t

-- | Creates an Idris type declaration given current state and a
-- specialised TT function application type.
-- Can be used in combination with the output of 'specType'.
--
-- This should: specialise any static argument position, then generalise
-- over any function applications in the result.
mkPE_TyDecl :: IState -> [(PEArgType, Term)] -> Type -> PTerm
mkPE_TyDecl ist args ty = mkty args ty
  where
    mkty ((ExplicitD, v) : xs) (Bind n (Pi _ t k) sc)
       = PPi expl n NoFC (delab ist (generaliseIn t)) (mkty xs sc)
    mkty ((ImplicitD, v) : xs) (Bind n (Pi _ t k) sc)
         | concreteInterface ist t = mkty xs sc
         | interfaceConstraint ist t
             = PPi constraint n NoFC (delab ist (generaliseIn t)) (mkty xs sc)
         | otherwise = PPi impl n NoFC (delab ist (generaliseIn t)) (mkty xs sc)

    mkty (_ : xs) t
       = mkty xs t
    mkty [] t = delab ist t

    generaliseIn tm = evalState (gen tm) 0

    gen tm | (P _ fn _, args) <- unApply tm,
             isFnName fn (tt_ctxt ist)
        = do nm <- get
             put (nm + 1)
             return (P Bound (sMN nm "spec") Erased)
    gen (App s f a) = App s <$> gen f <*> gen a
    gen tm = return tm

-- | Checks if a given argument is an interface constraint argument
interfaceConstraint :: Idris.AbsSyntax.IState -> TT Name -> Bool
interfaceConstraint ist v
    | (P _ c _, args) <- unApply v = case lookupCtxt c (idris_interfaces ist) of
                                          [_] -> True
                                          _ -> False
    | otherwise = False

-- | Checks if the given arguments of an interface constraint are all either constants
-- or references (i.e. that it doesn't contain any complex terms).
concreteInterface :: IState -> TT Name -> Bool
concreteInterface ist v
    | not (interfaceConstraint ist v) = False
    | (P _ c _, args) <- unApply v = all concrete args
    | otherwise = False
  where concrete (Constant _) = True
        concrete tm | (P _ n _, args) <- unApply tm
                         = case lookupTy n (tt_ctxt ist) of
                                 [_] -> all concrete args
                                 _ -> False
                    | otherwise = False

mkNewPats :: IState
          -> [(Term, Term)]      -- ^ definition to specialise
          -> [(PEArgType, Term)] -- ^ arguments to specialise with
          -> Name                -- ^ New name
          -> Name                -- ^ Specialised function name
          -> PTerm               -- ^ Default lhs
          -> PTerm               -- ^ Default rhs
          -> PEDecl
-- If all of the dynamic positions on the lhs are variables (rather than
-- patterns or constants) then we can just make a simple definition
-- directly applying the specialised function, since we know the
-- definition isn't going to block on any of the dynamic arguments
-- in this case
mkNewPats ist d ns newname sname lhs rhs | all dynVar (map fst d)
     = PEDecl lhs rhs [(lhs, rhs)] True
  where dynVar ap = case unApply ap of
                         (_, args) -> dynArgs ns args
        dynArgs _ [] = True -- can definitely reduce from here
        -- if Static, doesn't matter what the argument is
        dynArgs ((ImplicitS, _) : ns) (a : as) = dynArgs ns as
        dynArgs ((ExplicitS, _) : ns) (a : as) = dynArgs ns as
        -- if Dynamic, it had better be a variable or we'll need to
        -- do some more work
        dynArgs (_ : ns) (V _     : as) = dynArgs ns as
        dynArgs (_ : ns) (P _ _ _ : as) = dynArgs ns as
        dynArgs _ _ = False -- and now we'll get stuck

mkNewPats ist d ns newname sname lhs rhs =
    PEDecl lhs rhs (map mkClause d) False
  where
    mkClause :: (Term, Term) -> (PTerm, PTerm)
    mkClause (oldlhs, oldrhs)
         = let (_, as) = unApply oldlhs
               lhsargs = mkLHSargs [] ns as
               lhs = PApp emptyFC (PRef emptyFC [] newname) lhsargs
               rhs = PApp emptyFC (PRef emptyFC [] sname)
                                  (mkRHSargs ns lhsargs) in
                     (lhs, rhs)

    mkLHSargs _ [] _ = []
    -- dynamics don't appear if they're implicit
    mkLHSargs sub ((ExplicitD, t) : ns) (a : as)
         = pexp (delab ist (substNames sub a)) : mkLHSargs sub ns as
    mkLHSargs sub ((ImplicitD, _) : ns) (a : as)
         = mkLHSargs sub ns as
    mkLHSargs sub ((UnifiedD, _) : ns) (a : as)
         = mkLHSargs sub ns as
    -- statics get dropped in any case
    mkLHSargs sub ((ImplicitS, t) : ns) (a : as)
         = mkLHSargs (extend a t sub) ns as
    mkLHSargs sub ((ExplicitS, t) : ns) (a : as)
         = mkLHSargs (extend a t sub) ns as
    mkLHSargs sub _ [] = [] -- no more LHS

    extend (P _ n _) t sub = (n, t) : sub
    extend _ _ sub = sub

    mkRHSargs ((ExplicitS, t) : ns) as = pexp (delab ist t) : mkRHSargs ns as
    mkRHSargs ((ExplicitD, t) : ns) (a : as) = a : mkRHSargs ns as
    mkRHSargs (_ : ns) as = mkRHSargs ns as
    mkRHSargs _ _ = []

    mkSubst :: (Term, Term) -> Maybe (Name, Term)
    mkSubst (P _ n _, t) = Just (n, t)
    mkSubst _ = Nothing

-- | Creates a new declaration for a specialised function application.
-- Simple version at the moment: just create a version which is a direct
-- application of the function to be specialised.
-- More complex version to do: specialise the definition clause by clause
mkPE_TermDecl :: IState
              -> Name
              -> Name
              -> [(PEArgType, Term)]
              -> PEDecl
mkPE_TermDecl ist newname sname ns
    = let lhs = PApp emptyFC (PRef emptyFC [] newname) (map pexp (mkp ns))
          rhs = eraseImps $ delab ist (mkApp (P Ref sname Erased) (map snd ns))
          patdef = lookupCtxtExact sname (idris_patdefs ist)
          newpats = case patdef of
                         Nothing -> PEDecl lhs rhs [(lhs, rhs)] True
                         Just d -> mkNewPats ist (getPats d) ns
                                             newname sname lhs rhs in
          newpats where

  getPats (ps, _) = map (\(_, lhs, rhs) -> (lhs, rhs)) ps

  mkp [] = []
  mkp ((ExplicitD, tm) : tms) = delab ist tm : mkp tms
  mkp (_ : tms) = mkp tms

  eraseImps tm = mapPT deImp tm

  deImp (PApp fc t as) = PApp fc t (map deImpArg as)
  deImp t = t

  deImpArg a@(PImp _ _ _ _ _) = a { getTm = Placeholder }
  deImpArg a = a

-- | Get specialised applications for a given function
getSpecApps :: IState
            -> [Name]
            -> Term
            -> [(Name, [(PEArgType, Term)])]
getSpecApps ist env tm = ga env (explicitNames tm) where

--     staticArg env True _ tm@(P _ n _) _ | n `elem` env = Just (True, tm)
--     staticArg env True _ tm@(App f a) _ | (P _ n _, args) <- unApply tm,
--                                            n `elem` env = Just (True, tm)
    staticArg env x imp tm n
         | x && imparg imp = (ImplicitS, tm)
         | x = (ExplicitS, tm)
         | imparg imp = (ImplicitD, tm)
         | otherwise = (ExplicitD, (P Ref (sUN (show n ++ "arg")) Erased))

    imparg (PExp _ _ _ _) = False
    imparg _ = True

    buildApp env [] [] _ _ = []
    buildApp env (s:ss) (i:is) (a:as) (n:ns)
        = let s' = staticArg env s i a n
              ss' = buildApp env ss is as ns in
              (s' : ss')

    -- if we have a *defined* function that has static arguments,
    -- it will become a specialised application
    ga env tm@(App _ f a) | (P _ n _, args) <- unApply tm,
                          n `notElem` map fst (idris_metavars ist) =
        ga env f ++ ga env a ++
          case (lookupCtxtExact n (idris_statics ist),
                  lookupCtxtExact n (idris_implicits ist)) of
               (Just statics, Just imps) ->
                   if (length statics == length args && or statics
                          && specialisable (tt_ctxt ist) n) then
                      case buildApp env statics imps args [0..] of
                           args -> [(n, args)]
--                            _ -> []
                      else []
               _ -> []
    ga env (Bind n (Let t v) sc) = ga env v ++ ga (n : env) sc
    ga env (Bind n t sc) = ga (n : env) sc
    ga env t = []

    -- A function is only specialisable if there are no overlapping
    -- cases in the case tree (otherwise the partial evaluation could
    -- easily get stuck)
    specialisable :: Context -> Name -> Bool
    specialisable ctxt n = case lookupDefExact n ctxt of
                                Just (CaseOp _ _ _ _ _ cds) ->
                                     noOverlap (snd (cases_compiletime cds))
                                _ -> False

    noOverlap :: SC -> Bool
    noOverlap (Case _ _ [DefaultCase sc]) = noOverlap sc
    noOverlap (Case _ _ alts) = noOverlapAlts alts
    noOverlap _ = True

    -- There's an overlap if the case tree has a default case along with
    -- some other cases. It's fine if there's a default case on its own.
    noOverlapAlts (ConCase _ _ _ sc : rest)
        = noOverlapAlts rest && noOverlap sc
    noOverlapAlts (FnCase _ _ sc : rest) = noOverlapAlts rest
    noOverlapAlts (ConstCase _ sc : rest)
        = noOverlapAlts rest && noOverlap sc
    noOverlapAlts (SucCase _ sc : rest)
        = noOverlapAlts rest && noOverlap sc
    noOverlapAlts (DefaultCase _ : _) = False
    noOverlapAlts _ = True