{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternGuards #-}
module Agda.Compiler.Treeless.Simplify (simplifyTTerm) where

import Control.Arrow (first, second, (***))
import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Writer
import Data.Traversable (traverse)
import Data.List

import Agda.Syntax.Treeless
import Agda.Syntax.Internal (Substitution'(..))
import Agda.Syntax.Literal
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Primitive
import Agda.TypeChecking.Substitute
import Agda.Utils.Maybe

import Agda.Compiler.Treeless.Subst
import Agda.Compiler.Treeless.Pretty
import Agda.Compiler.Treeless.Compare
import Agda.Utils.Pretty

import Agda.Utils.Impossible
#include "undefined.h"

data SEnv = SEnv
  { envSubst   :: Substitution' TTerm
  , envRewrite :: [(TTerm, TTerm)] }

type S = Reader SEnv

runS :: S a -> a
runS m = runReader m $ SEnv IdS []

lookupVar :: Int -> S TTerm
lookupVar i = asks $ (`lookupS` i) . envSubst

onSubst :: (Substitution' TTerm -> Substitution' TTerm) -> S a -> S a
onSubst f = local $ \ env -> env { envSubst = f (envSubst env) }

onRewrite :: Substitution' TTerm -> S a -> S a
onRewrite rho = local $ \ env -> env { envRewrite = map (applySubst rho *** applySubst rho) (envRewrite env) }

addRewrite :: TTerm -> TTerm -> S a -> S a
addRewrite lhs rhs = local $ \ env -> env { envRewrite = (lhs, rhs) : envRewrite env }

underLams :: Int -> S a -> S a
underLams i = onRewrite (raiseS i) . onSubst (liftS i)

underLam :: S a -> S a
underLam = underLams 1

underLet :: TTerm -> S a -> S a
underLet u = onRewrite (raiseS 1) . onSubst (\rho -> wkS 1 $ u :# rho)

rewrite :: TTerm -> S TTerm
rewrite t = do
  rules <- asks envRewrite
  case [ rhs | (lhs, rhs) <- rules, equalTerms t lhs ] of
    rhs : _ -> pure rhs
    []      -> pure t

data FunctionKit = FunctionKit
  { modAux, divAux, natMinus, true, false :: Maybe QName }

simplifyTTerm :: TTerm -> TCM TTerm
simplifyTTerm t = do
  kit <- FunctionKit <$> getBuiltinName builtinNatModSucAux
                     <*> getBuiltinName builtinNatDivSucAux
                     <*> getBuiltinName builtinNatMinus
                     <*> getBuiltinName builtinTrue
                     <*> getBuiltinName builtinFalse
  return $ runS $ simplify kit t

simplify :: FunctionKit -> TTerm -> S TTerm
simplify FunctionKit{..} = simpl
  where
    simpl t = rewrite' t >>= \ t -> case t of

      TDef{}         -> pure t
      TPrim{}        -> pure t

      TVar x  -> do
        v <- lookupVar x
        pure $ if isAtomic v then v else t

      TApp (TDef f) [TLit (LitNat _ 0), m, n, m']
        -- div/mod are equivalent to quot/rem on natural numbers.
        | m == m', Just f == divAux -> simpl $ tOp PQuot n (tPlusK 1 m)
        | m == m', Just f == modAux -> simpl $ tOp PRem n (tPlusK 1 m)

      TApp (TPrim _) _ -> pure t  -- taken care of by rewrite'

      TApp f es -> do
        f  <- simpl f
        es <- traverse simpl es
        maybeMinusToPrim f es
      TLam b         -> TLam <$> underLam (simpl b)
      TLit{}         -> pure t
      TCon{}         -> pure t
      TLet e b       -> do
        e <- simpl e
        tLet e <$> underLet e (simpl b)

      TCase x t d bs -> do
        v <- lookupVar x
        let (lets, u) = letView v
        case u of                          -- TODO: also for literals
          _ | Just (c, as) <- conView u -> simpl $ matchCon lets c as d bs
          TCase y t1 d1 bs1 -> simpl $ mkLets lets $ TCase y t1 (distrDef case1 d1) $
                                       map (distrCase case1) bs1
            where
              -- Γ x Δ -> Γ _ Δ Θ y, where x maps to y and Θ are the lets
              n     = length lets
              rho   = liftS (x + n + 1) (raiseS 1)    `composeS`
                      singletonS (x + n + 1) (TVar 0) `composeS`
                      raiseS (n + 1)
              case1 = applySubst rho (TCase x t d bs)

              distrDef v d | isUnreachable d = tUnreachable
                           | otherwise       = tLet d v

              distrCase v (TACon c a b) = TACon c a $ TLet b $ raiseFrom 1 a v
              distrCase v (TALit l b)   = TALit l   $ TLet b v
              distrCase v (TAGuard g b) = TAGuard g $ TLet b v

          _ -> do
            d  <- simpl d
            bs <- traverse (simplAlt x) bs
            tCase x t d bs

      TUnit          -> pure t
      TSort          -> pure t
      TErased        -> pure t
      TError{}       -> pure t

    conView (TCon c)           = Just (c, [])
    conView (TApp (TCon c) as) = Just (c, as)
    conView e                  = Nothing

    letView (TLet e b) = first (e :) $ letView b
    letView e          = ([], e)

    mkLets es b = foldr TLet b es

    matchCon _ _ _ d [] = d
    matchCon lets c as d (TALit{}   : bs) = matchCon lets c as d bs
    matchCon lets c as d (TAGuard{} : bs) = matchCon lets c as d bs
    matchCon lets c as d (TACon c' a b : bs)
      | c == c'        = flip (foldr TLet) lets $ mkLet 0 as (raiseFrom a (length lets) b)
      | otherwise      = matchCon lets c as d bs
      where
        mkLet _ []       b = b
        mkLet i (a : as) b = TLet (raise i a) $ mkLet (i + 1) as b

    simplPrim (TApp f@TPrim{} args) = do
        args    <- mapM simpl args
        inlined <- mapM inline args
        let u = TApp f args
            v = simplPrim' (TApp f inlined)
        pure $ if v `betterThan` u then v else u
      where
        inline (TVar x)                   = do
          v <- lookupVar x
          if v == TVar x then pure v else inline v
        inline (TApp f@TPrim{} args)      = TApp f <$> mapM inline args
        inline u@(TLet _ (TCase 0 _ _ _)) = pure u
        inline (TLet e b)                 = inline (subst 0 e b)
        inline u                          = pure u
    simplPrim t = pure t

    simplPrim' :: TTerm -> TTerm
    simplPrim' (TApp (TPrim PLt) [u, v])
      | Just (PAdd, k, u) <- constArithView u,
        Just (PAdd, j, v) <- constArithView v,
        k == j = tOp PLt u v
    simplPrim' (TApp (TPrim PEq) [u, v])
      | Just (op1, k, u) <- constArithView u,
        Just (op2, j, v) <- constArithView v,
        op1 == op2, k == j,
        elem op1 [PAdd, PSub] = tOp PEq u v
    simplPrim' (TApp (TPrim PMul) [u, v])
      | Just 0 <- intView u = tInt 0
      | Just 0 <- intView v = tInt 0
    simplPrim' (TApp (TPrim op) [u, v])
      | Just u <- negView u,
        Just v <- negView v,
        elem op [PMul, PQuot] = tOp op u v
      | Just u <- negView u,
        elem op [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v)
      | Just v <- negView v,
        elem op [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v)
    simplPrim' (TApp (TPrim PRem) [u, v])
      | Just u <- negView u  = simplArith $ tOp PSub (tInt 0) (tOp PRem u (unNeg v))
      | Just v <- negView v  = tOp PRem u v
    simplPrim' (TApp f@(TPrim op) [u, v]) = simplArith $ TApp f [simplPrim' u, simplPrim' v]
    simplPrim' u = u

    unNeg u | Just v <- negView u = v
            | otherwise           = u

    negView (TApp (TPrim PSub) [a, b])
      | Just 0 <- intView a = Just b
    negView _ = Nothing

    -- Count arithmetic operations
    betterThan u v = operations u <= operations v
      where
        operations (TApp (TPrim _) [a, b]) = 1 + operations a + operations b
        operations TVar{}                  = 0
        operations TLit{}                  = 0
        operations _                       = 1000

    rewrite' t = rewrite =<< simplPrim t

    constArithView :: TTerm -> Maybe (TPrim, Integer, TTerm)
    constArithView (TApp (TPrim op) [TLit (LitNat _ k), u])
      | elem op [PAdd, PSub] = Just (op, k, u)
    constArithView (TApp (TPrim op) [u, TLit (LitNat _ k)])
      | op == PAdd = Just (op, k, u)
      | op == PSub = Just (PAdd, -k, u)
    constArithView _ = Nothing

    simplAlt x (TACon c a b) = TACon c a <$> underLams a (maybeAddRewrite (x + a) conTerm $ simpl b)
      where conTerm = mkTApp (TCon c) [TVar i | i <- reverse $ take a [0..]]
    simplAlt x (TALit l b)   = TALit l   <$> maybeAddRewrite x (TLit l) (simpl b)
    simplAlt x (TAGuard g b) = TAGuard   <$> simpl g <*> simpl b

    maybeAddRewrite x rhs cont = do
      v <- lookupVar x
      case v of
        TVar y | x == y -> cont
        _ -> addRewrite v rhs cont

    isTrue (TCon c) = Just c == true
    isTrue _        = False

    isFalse (TCon c) = Just c == false
    isFalse _        = False

    maybeMinusToPrim f@(TDef minus) es@[a, b]
      | Just minus == natMinus = do
      b_a  <- rewrite' (tOp PLt b a)
      b_sa <- rewrite' (tOp PLt b (tOp PAdd (tInt 1) a))
      a_b  <- rewrite' (tOp PLt a b)
      if isTrue b_a || isTrue b_sa || isFalse b_a && isFalse a_b
        then pure $ tOp PSub a b
        else tApp f es

    maybeMinusToPrim f es = tApp f es

    tLet (TVar x) b = subst 0 (TVar x) b
    tLet e b        = TLet e b

    tCase :: Int -> CaseType -> TTerm -> [TAlt] -> S TTerm
    tCase x t d bs
      | isUnreachable d =
        case reverse bs' of
          [] -> pure d
          TALit _ b   : as  -> pure $ tCase' x t b (reverse as)
          TAGuard _ b : as  -> pure $ tCase' x t b (reverse as)
          TACon c a b : _   -> pure $ tCase' x t d bs'
      | otherwise = pure $ TCase x t d bs'
      where
        bs' = filter (not . isUnreachable) bs

        tCase' x t d [] = d
        tCase' x t d bs = TCase x t d bs

    tApp :: TTerm -> [TTerm] -> S TTerm
    tApp (TLet e b) es = TLet e <$> underLet e (tApp b (raise 1 es))
    tApp (TCase x t d bs) es = do
      d  <- tApp d es
      bs <- mapM (`tAppAlt` es) bs
      simpl $ TCase x t d bs    -- will resimplify branches
    tApp (TVar x) es = do
      v <- lookupVar x
      case v of
        _ | v /= TVar x && isAtomic v -> tApp v es
        TLam{} -> tApp v es   -- could blow up the code
        _      -> pure $ mkTApp (TVar x) es
    tApp f [] = pure f
    tApp (TLam b) (TVar i : es) = tApp (subst 0 (TVar i) b) es
    tApp (TLam b) (e : es) = tApp (TLet e b) es
    tApp f es = pure $ TApp f es

    tAppAlt (TACon c a b) es = TACon c a <$> underLams a (tApp b (raise a es))
    tAppAlt (TALit l b) es   = TALit l   <$> tApp b es
    tAppAlt (TAGuard g b) es = TAGuard g <$> tApp b es

    isAtomic v = case v of
      TVar{}    -> True
      TCon{}    -> True
      TPrim{}   -> True
      TDef{}    -> True
      TLit{}    -> True
      TSort{}   -> True
      TErased{} -> True
      TError{}  -> True
      _         -> False

type Arith = (Integer, [Atom])

data Atom = Pos TTerm | Neg TTerm
  deriving (Show, Eq, Ord)

aNeg :: Atom -> Atom
aNeg (Pos a) = Neg a
aNeg (Neg a) = Pos a

aCancel :: [Atom] -> [Atom]
aCancel (a : as)
  | elem (aNeg a) as = aCancel (delete (aNeg a) as)
  | otherwise        = a : aCancel as
aCancel [] = []

sortR :: Ord a => [a] -> [a]
sortR = sortBy (flip compare)

aAdd :: Arith -> Arith -> Arith
aAdd (a, xs) (b, ys) = (a + b, aCancel $ sortR $ xs ++ ys)

aSub :: Arith -> Arith -> Arith
aSub (a, xs) (b, ys) = (a - b, aCancel $ sortR $ xs ++ map aNeg ys)

fromArith :: Arith -> TTerm
fromArith (n, []) = tInt n
fromArith (0, xs)
  | (ys, Pos a : zs) <- break isPos xs = foldl addAtom a (ys ++ zs)
fromArith (n, xs)
  | n < 0, (ys, Pos a : zs) <- break isPos xs =
    tOp PSub (foldl addAtom a (ys ++ zs)) (tInt (-n))
fromArith (n, xs) = foldl addAtom (tInt n) xs

isPos :: Atom -> Bool
isPos Pos{} = True
isPos Neg{} = False

addAtom :: TTerm -> Atom -> TTerm
addAtom t (Pos a) = tOp PAdd t a
addAtom t (Neg a) = tOp PSub t a

toArith :: TTerm -> Arith
toArith t | Just n <- intView t = (n, [])
toArith (TApp (TPrim PAdd) [a, b]) = aAdd (toArith a) (toArith b)
toArith (TApp (TPrim PSub) [a, b]) = aSub (toArith a) (toArith b)
toArith t = (0, [Pos t])

simplArith :: TTerm -> TTerm
simplArith = fromArith . toArith