{-# LANGUAGE PatternGuards #-}

module Idris.DataOpts where

-- Forcing, detagging and collapsing

import Idris.AbsSyntax
import Idris.AbsSyntaxTree
import Idris.Core.TT

import Control.Applicative
import qualified Data.IntMap as M
import Data.List
import Data.Maybe
import Debug.Trace

type ForceMap = M.IntMap Forceability

-- Calculate the forceable arguments to a constructor
-- and update the set of optimisations.
forceArgs :: Name -> Name -> [Int] -> Type -> Idris ()
forceArgs typeName n expforce t = do
    ist <- getIState
    let fargs = getForcedArgs ist typeName t
        copt = case lookupCtxt n (idris_optimisation ist) of
          []   -> Optimise False False [] []
          op:_ -> op
        opts = addDef n (copt { forceable = M.toList fargs ++
                                            zip expforce (repeat Unconditional) }) 
                        (idris_optimisation ist)
    putIState (ist { idris_optimisation = opts })
    addIBC (IBCOpt n)
    iLOG $ "Forced: " ++ show n ++ " " ++ show fargs ++ "\n   from " ++ show t

getForcedArgs :: IState -> Name -> Type -> ForceMap
getForcedArgs ist typeName t = addCollapsibleArgs 0 t $ forcedInTarget 0 t
  where
    maxUnion = M.unionWith max

    -- Label all occurrences of the variable bound in Pi in the rest of
    -- the term with the number i so that we can recognize them anytime later.
    label i = instantiate $ P Bound (sMN i "ctor_arg") Erased

    addCollapsibleArgs :: Int -> Type -> ForceMap -> ForceMap
    addCollapsibleArgs i (Bind vn (Pi ty) rest) alreadyForceable
        = addCollapsibleArgs (i+1) (label i rest) (forceable $ unApply ty)
      where
        -- forceable takes an un-applied type of a ctor argument
        forceable (P _ tn _, args)
            -- if `ty' is collapsible, the argument is unconditionally forceable
            | isCollapsible tn
            = M.insert i Unconditional alreadyForceable

            -- a recursive occurrence with known indices is conditionally forceable
            | tn == typeName
            = M.insertWith max i Conditional alreadyForceable

        forceable _ = alreadyForceable

        isCollapsible :: Name -> Bool
        isCollapsible n = case lookupCtxt n (idris_optimisation ist) of
            [oi] -> collapsible oi
            _    -> False

    addCollapsibleArgs _ _ fs = fs

    forcedInTarget :: Int -> Type -> ForceMap
    forcedInTarget i (Bind _ (Pi _) rest) = forcedInTarget (i+1) (label i rest)
    forcedInTarget i t@(App f a) | (_, as) <- unApply t = unionMap guardedArgs as
    forcedInTarget _ _ = M.empty

    guardedArgs :: Term -> ForceMap
    guardedArgs t@(App f a) | (P (DCon _ _) _ _, args) <- unApply t
        = unionMap bareArg args `maxUnion` unionMap guardedArgs args
    guardedArgs t = bareArg t

    bareArg :: Term -> ForceMap
    bareArg (P _ (MN i ctor_arg) _) 
         | ctor_arg == txt "ctor_arg" = M.singleton i Unconditional
    bareArg  _                        = M.empty

    unionMap :: (a -> ForceMap) -> [a] -> ForceMap
    unionMap f = M.unionsWith max . map f

-- Calculate whether a collection of constructors is collapsible
-- and update the state accordingly.
collapseCons :: Name -> [(Name, Type)] -> Idris ()
collapseCons tn ctors = do
    ist <- getIState
    case ctors of
        _
          | all (ctorCollapsible ist) ctors
          , disjointTerms ctorTargetArgs
            -> mapM_ setCollapsible (tn : map fst ctors)

        [(cn, ct)]
            -> checkNewType ist cn ct

        _ -> return () -- nothing can be done
  where
    ctorTargetArgs = map (snd . unApply . getRetTy . snd) ctors

    ctorArity :: Type -> Int
    ctorArity = length . getArgTys

    ctorCollapsible :: IState -> (Name, Type) -> Bool
    ctorCollapsible ist (n, t) = all (`M.member` forceMap) [0 .. ctorArity t - 1]
      where
        forceMap = case lookupCtxt n (idris_optimisation ist) of
            oi:_ -> M.fromList $ forceable oi
            _    -> M.empty

    -- one constructor; if one remaining argument, treat as newtype
    checkNewType :: IState -> Name -> Type -> Idris ()
    checkNewType ist cn ct
        | oi:_ <- lookupCtxt cn opt
        , length (getArgTys ct) == 1 + forcedCnt (M.fromList $ forceable oi)
            = putIState ist{ idris_optimisation = opt' oi }
        | otherwise = return ()
      where
        opt = idris_optimisation ist
        opt' oi = addDef cn oi{ isnewtype = True } opt

    setCollapsible :: Name -> Idris ()
    setCollapsible n
       = do i <- getIState
            iLOG $ show n ++ " collapsible"
            case lookupCtxt n (idris_optimisation i) of
               (oi:_) -> do let oi' = oi { collapsible = True }
                            let opts = addDef n oi' (idris_optimisation i)
                            putIState (i { idris_optimisation = opts })
               [] -> do let oi = Optimise True False [] []
                        let opts = addDef n oi (idris_optimisation i)
                        putIState (i { idris_optimisation = opts })
                        addIBC (IBCOpt n)

    disjointTerms :: [[Term]] -> Bool
    disjointTerms []         = True
    disjointTerms [xs]       = True
    disjointTerms (xs : xss) =
        -- xs is disjoint with every pattern from xss
        all (or . zipWith disjoint xs) xss
        -- and xss is pairwise disjoint, too
        && disjointTerms xss

    -- Return True  if the two patterns are provably disjoint.
    -- Return False if they're not or if unsure.
    disjoint :: Term -> Term -> Bool
    disjoint x y = case (cx, cy) of
        -- data constructors -> compare their names
        (P (DCon _ _) nx _, P (DCon _ _) ny _)
            | nx /= ny  -> True
            | otherwise -> or $ zipWith disjoint xargs yargs
        _ -> False
      where
        (cx, xargs) = unApply x
        (cy, yargs) = unApply y

class Optimisable term where
    applyOpts :: term -> Idris term
    stripCollapsed :: term -> Idris term

instance (Optimisable a, Optimisable b) => Optimisable (a, b) where
    applyOpts (x, y) = (,) <$> applyOpts x <*> applyOpts y
    stripCollapsed (x, y) = (,) <$> stripCollapsed x <*> stripCollapsed y

instance (Optimisable a, Optimisable b) => Optimisable (vs, a, b) where
    applyOpts (v, x, y) = (,,) v <$> applyOpts x <*> applyOpts y
    stripCollapsed (v, x, y) = (,,) v <$> stripCollapsed x <*> stripCollapsed y

instance Optimisable a => Optimisable [a] where
    applyOpts = mapM applyOpts
    stripCollapsed = mapM stripCollapsed

instance Optimisable a => Optimisable (Either a (a, a)) where
    applyOpts (Left  t) = Left  <$> applyOpts t
    applyOpts (Right t) = Right <$> applyOpts t
    stripCollapsed (Left  t) = Left  <$> stripCollapsed t
    stripCollapsed (Right t) = Right <$> stripCollapsed t

-- Raw is for compile time optimisation (before type checking)
-- Term is for run time optimisation (after type checking, collapsing allowed)

-- Compile time: no collapsing

instance Optimisable Raw where
    applyOpts t@(RApp f a)
        | (Var n, args) <- raw_unapply t -- MAGIC HERE
            = do args' <- mapM applyOpts args
                 i <- getIState
                 return $ case lookupCtxt n (idris_optimisation i) of
                    oi:_ -> applyDataOpt oi n args'
                    _    -> raw_apply (Var n) args'
        | otherwise = RApp <$> applyOpts f <*> applyOpts a

    applyOpts (RBind n b t) = RBind n <$> applyOpts b <*> applyOpts t
    applyOpts (RForce t)    = applyOpts t
    applyOpts t = return t

    stripCollapsed t = return t

-- Erase types (makes ibc smaller, and we don't need them)
instance Optimisable (Binder (TT Name)) where
    applyOpts (Let t v) = Let <$> return Erased <*> applyOpts v
    applyOpts b = return (b { binderTy = Erased })
    stripCollapsed (Let t v) = Let <$> return Erased <*> stripCollapsed v
    stripCollapsed b = return (b { binderTy = Erased })

instance Optimisable (Binder Raw) where
    applyOpts b = do t' <- applyOpts (binderTy b)
                     return (b { binderTy = t' })
    stripCollapsed (Let t v) = Let <$> stripCollapsed t <*> stripCollapsed v
    stripCollapsed b = do t' <- stripCollapsed (binderTy b)
                          return (b { binderTy = t' })

forcedArgSeq :: OptInfo -> [Maybe Forceability]
forcedArgSeq oi = map (\i -> M.lookup i forceMap) [0..]
  where
    forceMap = M.fromList $ forceable oi

forcedCnt :: ForceMap -> Int
forcedCnt = length . filter (== Unconditional) . M.elems

applyDataOpt :: OptInfo -> Name -> [Raw] -> Raw
applyDataOpt oi n args 
    = raw_apply (Var n) $ zipWith doForce (forcedArgSeq oi) args
  where
    doForce (Just Unconditional) a = RForce a
    doForce _ a = a

-- Run-time: do everything

prel = [txt "Nat", txt "Prelude"]

instance Optimisable (TT Name) where
    applyOpts (P _ (NS (UN fn) mod) _)
       | fn == txt "plus" && mod == prel
        = return (P Ref (sUN "prim__addBigInt") Erased)
    applyOpts (P _ (NS (UN fn) mod) _)
       | fn == txt "mult" && mod == prel
        = return (P Ref (sUN "prim__mulBigInt") Erased)
    applyOpts (App (P _ (NS (UN fn) mod) _) x)
       | fn == txt "fromIntegerNat" && mod == prel
        = applyOpts x
    applyOpts (P _ (NS (UN fn) mod) _)
       | fn == txt "fromIntegerNat" && mod == prel
        = return (App (P Ref (sNS (sUN "id") ["Basics","Prelude"]) Erased) Erased)
    applyOpts (P _ (NS (UN fn) mod) _)
       | fn == txt "toIntegerNat" && mod == prel
         = return (App (P Ref (sNS (sUN "id") ["Basics","Prelude"]) Erased) Erased)
    applyOpts c@(P (DCon t arity) n _)
        = do i <- getIState
             case lookupCtxt n (idris_optimisation i) of
                 (oi:_) -> return $ applyDataOptRT oi n t arity []
                 _ -> return c
    applyOpts t@(App f a)
        | (c@(P (DCon t arity) n _), args) <- unApply t
            = do args' <- mapM applyOpts args
                 i <- getIState
                 case lookupCtxt n (idris_optimisation i) of
                      (oi:_) -> do return $ applyDataOptRT oi n t arity args'
                      _ -> return (mkApp c args')
        | otherwise = do f' <- applyOpts f
                         a' <- applyOpts a
                         return (App f' a')
    applyOpts (Bind n b t) = do b' <- applyOpts b
                                t' <- applyOpts t
                                return (Bind n b' t')
    applyOpts (Proj t i) = do t' <- applyOpts t
                              return (Proj t' i)
    applyOpts t = return t

    stripCollapsed (Bind n (PVar x) t) | (P _ ty _, _) <- unApply x
           = do i <- getIState
                -- NOTE: This assumes that 'ty' is in normal form, which it
                -- has to be before now because we're not keeping track of
                -- an environment so we can't do it here.
                case lookupCtxt ty (idris_optimisation i) of
                  [oi] -> if collapsible oi
                             then do t' <- stripCollapsed t
                                     return (Bind n (PVar x) (instantiate Erased t'))
                             else do t' <- stripCollapsed t
                                     return (Bind n (PVar x) t')
                  _ -> do t' <- stripCollapsed t
                          return (Bind n (PVar x) t')
    stripCollapsed (Bind n (PVar x) t)
                  = do t' <- stripCollapsed t
                       return (Bind n (PVar x) t')
    stripCollapsed t = return t

-- Need to saturate arguments first to ensure that erasure happens uniformly

applyDataOptRT :: OptInfo -> Name -> Int -> Int -> [Term] -> Term
applyDataOptRT oi n tag arity args
    | length args == arity = doOpts n args (collapsible oi) (M.fromList $ forceable oi)
    | otherwise = let extra = satArgs (arity - length args)
                      tm = doOpts n (args ++ map (\n -> P Bound n Erased) extra)
                                    (collapsible oi) (M.fromList $ forceable oi) in
                      bind extra tm
  where
    satArgs n = map (\i -> sMN i "sat") [1..n]

    bind [] tm = tm
    bind (n:ns) tm = Bind n (Lam Erased) (pToV n (bind ns tm))

    -- Nat special cases
    -- TODO: Would be nice if this was configurable in idris source!
    doOpts (NS (UN z) [nat, prelude]) [] _ _ 
        | z == txt "Z" && nat == txt "Nat" && prelude == txt "Prelude"
          = Constant (BI 0)
    doOpts (NS (UN s) [nat, prelude]) [k] _ _
        | s == txt "S" && nat == txt "Nat" && prelude == txt "Prelude"
          = App (App (P Ref (sUN "prim__addBigInt") Erased) k) (Constant (BI 1))

    doOpts n args True _  = Erased
    doOpts n args _ forceMap
        | isnewtype oi = case args' of
            [val] -> val
            _     -> error $ "Can't happen (newtype not a singleton): " ++ show args'
        | otherwise = mkApp ctor' args'
      where
        ctor' = (P (DCon tag (arity - forcedCnt forceMap)) n Erased)
        args' = [t | (f, t) <- zip (forcedArgSeq oi) args, f /= Just Unconditional]