module Idris.Erasure (performUsageAnalysis, mkFieldName) where
import Idris.AbsSyntax
import Idris.ASTUtils
import Idris.Core.CaseTree
import Idris.Core.Evaluate
import Idris.Core.TT
import Idris.Error
import Idris.Options
import Idris.Primitives
import Prelude hiding (id, (.))
import Control.Arrow
import Control.Category
import Control.Monad.State
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.IntSet (IntSet)
import qualified Data.IntSet as IS
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (pack)
import qualified Data.Text as T
type UseMap = Map Name (IntMap (Set Reason))
data Arg = Arg Int | Result deriving (Eq, Ord)
instance Show Arg where
    show (Arg i) = show i
    show Result  = "*"
type Node = (Name, Arg)
type Deps = Map Cond DepSet
type Reason = (Name, Int)  
type DepSet = Map Node (Set Reason)
type Cond = Set Node
data VarInfo = VI
    { viDeps   :: DepSet      
    , viFunArg :: Maybe Int   
    , viMethod :: Maybe Name  
    }
    deriving Show
type Vars = Map Name VarInfo
performUsageAnalysis :: [Name] -> Idris [Name]
performUsageAnalysis startNames = do
    ctx <- tt_ctxt <$> getIState
    case startNames of
      [] -> return []  
      main  -> do
        ci  <- idris_interfaces <$> getIState
        cg  <- idris_callgraph <$> getIState
        opt <- idris_optimisation <$> getIState
        used <- idris_erasureUsed <$> getIState
        externs <- idris_externs <$> getIState
        
        let depMap = buildDepMap ci used (S.toList externs) ctx main
        
        let (residDeps, (reachableNames, minUse)) = minimalUsage depMap
            usage = M.toList minUse
        
        logErasure 5 $ "Original deps:\n" ++ unlines (map fmtItem . M.toList $ depMap)
        logErasure 3 $ "Reachable names:\n" ++ unlines (map (indent . show) . S.toList $ reachableNames)
        logErasure 4 $ "Minimal usage:\n" ++ fmtUseMap usage
        logErasure 5 $ "Residual deps:\n" ++ unlines (map fmtItem . M.toList $ residDeps)
        
        checkEnabled <- (WarnReach `elem`) . opt_cmdline . idris_options <$> getIState
        when checkEnabled $
            mapM_ (checkAccessibility opt) usage
        
        reachablePostulates <- S.intersection reachableNames . idris_postulates <$> getIState
        when (not . S.null $ reachablePostulates)
            $ ifail ("reachable postulates:\n" ++ intercalate "\n" ["  " ++ show n | n <- S.toList reachablePostulates])
        
        mapM_ storeUsage usage
        return $ S.toList reachableNames
  where
    indent = ("  " ++)
    fmtItem :: (Cond, DepSet) -> String
    fmtItem (cond, deps) = indent $ show (S.toList cond) ++ " -> " ++ show (M.toList deps)
    fmtUseMap :: [(Name, IntMap (Set Reason))] -> String
    fmtUseMap = unlines . map (\(n,is) -> indent $ show n ++ " -> " ++ fmtIxs is)
    fmtIxs :: IntMap (Set Reason) -> String
    fmtIxs = intercalate ", " . map fmtArg . IM.toList
      where
        fmtArg (i, rs)
            | S.null rs = show i
            | otherwise = show i ++ " from " ++ intercalate ", " (map show $ S.toList rs)
    storeUsage :: (Name, IntMap (Set Reason)) -> Idris ()
    storeUsage (n, args) = fputState (cg_usedpos . ist_callgraph n) flat
      where
        flat = [(i, S.toList rs) | (i,rs) <- IM.toList args]
    checkAccessibility :: Ctxt OptInfo -> (Name, IntMap (Set Reason)) -> Idris ()
    checkAccessibility opt (n, reachable)
        | Just (Optimise inaccessible dt force) <- lookupCtxtExact n opt
        , eargs@(_:_) <- [fmt n (S.toList rs) | (i,n) <- inaccessible, rs <- maybeToList $ IM.lookup i reachable]
        = warn $ show n ++ ": inaccessible arguments reachable:\n  " ++ intercalate "\n  " eargs
        | otherwise = return ()
      where
        fmt n [] = show n ++ " (no more information available)"
        fmt n rs = show n ++ " from " ++ intercalate ", " [show rn ++ " arg# " ++ show ri | (rn,ri) <- rs]
        warn = logErasure 0
type Constraint = (Cond, DepSet)
minimalUsage :: Deps -> (Deps, (Set Name, UseMap))
minimalUsage deps
    = fromNumbered *** gather
    $ forwardChain (index numbered) seedDeps seedDeps numbered
  where
    numbered = toNumbered deps
    
    
    seedDeps :: DepSet
    seedDeps = M.unionsWith S.union [ds | (cond, ds) <- IM.elems numbered, S.null cond]
    toNumbered :: Deps -> IntMap Constraint
    toNumbered = IM.fromList . zip [0..] . M.toList
    fromNumbered :: IntMap Constraint -> Deps
    fromNumbered = IM.foldr addConstraint M.empty
      where
        addConstraint (ns, vs) = M.insertWith (M.unionWith S.union) ns vs
    
    
    index :: IntMap Constraint -> Map Node IntSet
    index = IM.foldrWithKey (
            
            \i (ns, _ds) ix -> foldr (
                
                \n ix' -> M.insertWith IS.union n (IS.singleton i) ix'
              ) ix (S.toList ns)
        ) M.empty
    
    
    
    gather :: DepSet -> (Set Name, UseMap)
    gather = foldr ins (S.empty, M.empty) . M.toList
       where
        ins :: (Node, Set Reason) -> (Set Name, UseMap) -> (Set Name, UseMap)
        ins ((n, Result), rs) (ns, umap) = (S.insert n ns, umap)
        ins ((n, Arg i ), rs) (ns, umap) = (ns, M.insertWith (IM.unionWith S.union) n (IM.singleton i rs) umap)
forwardChain
    :: Map Node IntSet   
    -> DepSet            
    -> DepSet            
    -> IntMap Constraint 
    -> (IntMap Constraint, DepSet)
forwardChain index solution previouslyNew constrs
    
    | M.null currentlyNew
    = (constrs, solution)
    
    | otherwise
    = forwardChain index
        (M.unionWith S.union solution currentlyNew)
        currentlyNew
        constrs'
  where
    
    
    affectedIxs = IS.unions [
        M.findWithDefault IS.empty n index
        | n <- M.keys previouslyNew
      ]
    
    
    
    
    (currentlyNew, constrs')
        = IS.foldr
            (reduceConstraint $ M.keysSet previouslyNew)
            (M.empty, constrs)
            affectedIxs
    
    
    reduceConstraint
        :: Set Node  
        -> Int       
        -> (DepSet, IntMap (Cond, DepSet))
        -> (DepSet, IntMap (Cond, DepSet))
    reduceConstraint previouslyNew i (news, constrs)
        | Just (cond, deps) <- IM.lookup i constrs
        = case cond S.\\ previouslyNew of
            cond'
                
                
                
                | S.null cond'
                -> (M.unionWith S.union news deps, IM.delete i constrs)
                
                
                
                | S.size cond' < S.size cond
                -> (news, IM.insert i (cond', deps) constrs)
                
                
                | otherwise
                -> (news, constrs)
        
        
        
        | otherwise = (news, constrs)
buildDepMap :: Ctxt InterfaceInfo -> [(Name, Int)] -> [(Name, Int)] ->
               Context -> [Name] -> Deps
buildDepMap ci used externs ctx startNames
    = addPostulates used $ dfs S.empty M.empty startNames
  where
    
    addPostulates :: [(Name, Int)] -> Deps -> Deps
    addPostulates used deps = foldr (\(ds, rs) -> M.insertWith (M.unionWith S.union) ds rs) deps (postulates used)
      where
        
        (==>) ds rs = (S.fromList ds, M.fromList [(r, S.empty) | r <- rs])
        it n is = [(sUN n, Arg i) | i <- is]
        
        specialPrims = S.fromList [sUN "prim__believe_me"]
        usedNames = allNames deps S.\\ specialPrims
        usedPrims = [(p_name p, p_arity p) | p <- primitives, p_name p `S.member` usedNames]
        postulates used =
            [ [] ==> concat
                
                
                [(map (\n -> (n, Result)) startNames)
                ,[(sUN "run__IO", Result), (sUN "run__IO", Arg 1)]
                ,[(sUN "call__IO", Result), (sUN "call__IO", Arg 2)]
                
                , map (\(n, i) -> (n, Arg i)) used
                
                
                , it "MkIO"         [2]
                , it "prim__IO"     [1]
                
                
                , [(pairCon, Arg 2),
                   (pairCon, Arg 3)] 
                
                
                
                , it "prim_fork"    [0]
                , it "unsafePerformPrimIO"  [1]
                
                
                , it "prim__believe_me" [2]
                
                , [(n, Arg i) | (n,arity) <- usedPrims, i <- [0..arity1]]
                
                , [(n, Arg i) | (n,arity) <- externs, i <- [0..arity1]]
                
                ]
            ]
    
    
    
    dfs :: Set Name -> Deps -> [Name] -> Deps
    dfs visited deps [] = deps
    dfs visited deps (n : ns)
        | n `S.member` visited = dfs visited deps ns
        | otherwise = dfs (S.insert n visited) (M.unionWith (M.unionWith S.union) deps' deps) (next ++ ns)
      where
        next = [n | n <- S.toList depn, n `S.notMember` visited]
        depn = S.delete n $ allNames deps'
        deps' = getDeps n
    
    
    allNames :: Deps -> Set Name
    allNames = S.unions . map names . M.toList
        where
        names (cs, ns) = S.map fst cs `S.union` S.map fst (M.keysSet ns)
    
    getDeps :: Name -> Deps
    getDeps (SN (WhereN i (SN (ImplementationCtorN interfaceN)) (MN i' field)))
        = M.empty  
    getDeps n = case lookupDefExact n ctx of
        Just def -> getDepsDef n def
        Nothing  -> error $ "erasure checker: unknown reference: " ++ show n
    getDepsDef :: Name -> Def -> Deps
    getDepsDef fn (Function ty t) = error "a function encountered"  
    getDepsDef fn (TyDecl   ty t) = M.empty
    getDepsDef fn (Operator ty n' f) = M.empty  
    getDepsDef fn (CaseOp ci ty tys def tot cdefs)
        = getDepsSC fn etaVars (etaMap `M.union` varMap) sc
      where
        
        
        etaIdx = [length vars .. length tys  1]
        etaVars = [eta i | i <- etaIdx]
        etaMap = M.fromList [varPair (eta i) i | i <- etaIdx]
        eta i = MN i (pack "eta")
        
        varMap = M.fromList [varPair v i | (v,i) <- zip vars [0..]]
        varPair n argNo = (n, VI
            { viDeps   = M.singleton (fn, Arg argNo) S.empty
            , viFunArg = Just argNo
            , viMethod = Nothing
            })
        (vars, sc) = cases_runtime cdefs
            
            
    etaExpand :: [Name] -> Term -> Term
    etaExpand []       t = t
    etaExpand (n : ns) t = etaExpand ns (App Complete t (P Ref n Erased))
    getDepsSC :: Name -> [Name] -> Vars -> SC -> Deps
    getDepsSC fn es vs  ImpossibleCase     = M.empty
    getDepsSC fn es vs (UnmatchedCase msg) = M.empty
    
    getDepsSC fn es vs (ProjCase (Proj t i) alts) = getDepsSC fn es vs (ProjCase t alts)  
    getDepsSC fn es vs (ProjCase (P  _ n _) alts) = getDepsSC fn es vs (Case Shared n alts)  
    
    getDepsSC fn es vs (ProjCase t alts)   = error $ "ProjCase not supported:\n" ++ show (ProjCase t alts)
    getDepsSC fn es vs (STerm    t)        = getDepsTerm vs [] (S.singleton (fn, Result)) (etaExpand es t)
    getDepsSC fn es vs (Case sh n alts)
        
        
        
        
        = addTagDep $ unionMap (getDepsAlt fn es vs casedVar) alts  
      where
        addTagDep = case alts of
            [_] -> id  
            _   -> M.insertWith (M.unionWith S.union) (S.singleton (fn, Result)) (viDeps casedVar)
        casedVar  = fromMaybe (error $ "nonpatvar in case: " ++ show n) (M.lookup n vs)
    getDepsAlt :: Name -> [Name] -> Vars -> VarInfo -> CaseAlt -> Deps
    getDepsAlt fn es vs var (FnCase n ns sc) = M.empty 
    getDepsAlt fn es vs var (ConstCase c sc) = getDepsSC fn es vs sc
    getDepsAlt fn es vs var (DefaultCase sc) = getDepsSC fn es vs sc
    getDepsAlt fn es vs var (SucCase   n sc)
        = getDepsSC fn es (M.insert n var vs) sc 
    
    getDepsAlt fn es vs var (ConCase n cnt ns sc)
        = getDepsSC fn es (vs' `M.union` vs) sc  
      where
        
        
        vs' = M.fromList [(v, VI
            { viDeps   = M.insertWith S.union (n, Arg j) (S.singleton (fn, varIdx)) (viDeps var)
            , viFunArg = viFunArg var
            , viMethod = meth j
            })
          | (v, j) <- zip ns [0..]]
        
        varIdx = fromJust (viFunArg var)
        
        meth :: Int -> Maybe Name
        meth | SN (ImplementationCtorN interfaceName) <- n = \j -> Just (mkFieldName n j)
             | otherwise = \j -> Nothing
    
    getDepsTerm :: Vars -> [(Name, Cond -> Deps)] -> Cond -> Term -> Deps
    
    getDepsTerm vs bs cd (P _ n _)
        
        | Just deps <- lookup n bs
        = deps cd
        
        | Just var <- M.lookup n vs
        = M.singleton cd (viDeps var)
        
        | MN _ _ <- n
        = error $ "erasure analysis: variable " ++ show n ++ " unbound in " ++ show (S.toList cd)
        
        | otherwise = M.singleton cd (M.singleton (n, Result) S.empty)
    
    getDepsTerm vs bs cd (V i) = snd (bs !! i) cd
    getDepsTerm vs bs cd (Bind n bdr body)
        
        
        | Lam _ ty <- bdr = getDepsTerm vs ((n, const M.empty) : bs) cd body
        | Pi _ _ ty _ <- bdr = getDepsTerm vs ((n, const M.empty) : bs) cd body
        
        
        | Let rig ty t <- bdr = var t cd `union` getDepsTerm vs ((n, const M.empty) : bs) cd body
        | NLet ty t <- bdr = var t cd `union` getDepsTerm vs ((n, const M.empty) : bs) cd body
      where
        var t cd = getDepsTerm vs bs cd t
    
    getDepsTerm vs bs cd app@(App _ _ _)
        | (fun, args) <- unApply app = case fun of
            
            P (DCon _ _ _) ctorName@(SN (ImplementationCtorN interfaceName)) _
                -> conditionalDeps ctorName args  
                    `union` unionMap (methodDeps ctorName) (zip [0..] args)  
            
            P (TCon _ _) n _ -> unconditionalDeps args  
            P (DCon _ _ _) n _ -> conditionalDeps n args  
            
            
            
            P _ (UN n) _
                | n == T.pack "mkForeignPrim"
                -> unconditionalDeps $ drop 4 args
            
            
            P _ n _
                
                | Just deps <- lookup n bs
                    -> deps cd `union` unconditionalDeps args
                
                | Just var  <- M.lookup n vs
                , Just meth <- viMethod var
                    -> viDeps var `ins` conditionalDeps meth args  
                
                | Just var <- M.lookup n vs
                    
                    -> viDeps var `ins` unconditionalDeps args
                
                | otherwise
                    
                    -> conditionalDeps n args
            
            V i -> snd (bs !! i) cd `union` unconditionalDeps args
            
            Bind n (Lam _ ty) t -> getDepsTerm vs bs cd (lamToLet app)
            
            Bind n (Let _ ty t') t -> getDepsTerm vs bs cd (App Complete (Bind n (Lam RigW ty) t) t')
            Bind n (NLet ty t') t -> getDepsTerm vs bs cd (App Complete (Bind n (Lam RigW ty) t) t')
            Proj t i
                -> error $ "cannot[0] analyse projection !" ++ show i ++ " of " ++ show t
            Erased -> M.empty
            _ -> error $ "cannot analyse application of " ++ show fun ++ " to " ++ show args
      where
        union = M.unionWith $ M.unionWith S.union
        ins = M.insertWith (M.unionWith S.union) cd
        unconditionalDeps :: [Term] -> Deps
        unconditionalDeps = unionMap (getDepsTerm vs bs cd)
        conditionalDeps :: Name -> [Term] -> Deps
        conditionalDeps n
            = ins (M.singleton (n, Result) S.empty) . unionMap (getDepsArgs n) . zip indices
          where
            indices = map Just [0 .. getArity n  1] ++ repeat Nothing
            getDepsArgs n (Just i,  t) = getDepsTerm vs bs (S.insert (n, Arg i) cd) t  
            getDepsArgs n (Nothing, t) = getDepsTerm vs bs cd t                        
        methodDeps :: Name -> (Int, Term) -> Deps
        methodDeps ctorName (methNo, t)
            = getDepsTerm (vars `M.union` vs) (bruijns ++ bs) cond body
          where
            vars = M.fromList [(v, VI
                { viDeps   = deps i
                , viFunArg = Just i
                , viMethod = Nothing
                }) | (v, i) <- zip args [0..]]
            deps i   = M.singleton (metameth, Arg i) S.empty
            bruijns  = reverse [(n, \cd -> M.singleton cd (deps i)) | (i, n) <- zip [0..] args]
            cond     = S.singleton (metameth, Result)
            metameth = mkFieldName ctorName methNo
            (args, body) = unfoldLams t
    
    getDepsTerm vs bs cd (Proj t (1)) = getDepsTerm vs bs cd t  
    getDepsTerm vs bs cd (Proj t i) = error $ "cannot[1] analyse projection !" ++ show i ++ " of " ++ show t
    
    getDepsTerm vs bs cd (Constant _) = M.empty
    getDepsTerm vs bs cd (TType    _) = M.empty
    getDepsTerm vs bs cd (UType    _) = M.empty
    getDepsTerm vs bs cd  Erased      = M.empty
    getDepsTerm vs bs cd  Impossible  = M.empty
    getDepsTerm vs bs cd t = error $ "cannot get deps of: " ++ show t
    
    getArity :: Name -> Int
    getArity (SN (WhereN i' ctorName (MN i field)))
        | Just (TyDecl (DCon _ _ _) ty) <- lookupDefExact ctorName ctx
        = let argTys = map snd $ getArgTys ty
            in if i <= length argTys
                then length $ getArgTys (argTys !! i)
                else error $ "invalid field number " ++ show i ++ " for " ++ show ctorName
        | otherwise = error $ "unknown implementation constructor: " ++ show ctorName
    getArity n = case lookupDefExact n ctx of
        Just (CaseOp ci ty tys def tot cdefs) -> length tys
        Just (TyDecl (DCon tag arity _) _)    -> arity
        Just (TyDecl (Ref) ty)                -> length $ getArgTys ty
        Just (Operator ty arity op)           -> arity
        Just df -> error $ "Erasure/getArity: unrecognised entity '"
                             ++ show n ++ "' with definition: "  ++ show df
        Nothing -> error $ "Erasure/getArity: definition not found for " ++ show n
    
    
    lamToLet :: Term -> Term
    lamToLet tm = lamToLet' args f
      where
        (f, args) = unApply tm
    lamToLet' :: [Term] -> Term -> Term
    lamToLet' (v:vs) (Bind n (Lam rig ty) tm) = Bind n (Let rig ty v) $ lamToLet' vs tm
    lamToLet'    []  tm = tm
    lamToLet'    vs  tm = error $
        "Erasure.hs:lamToLet': unexpected input: "
            ++ "vs = " ++ show vs ++ ", tm = " ++ show tm
    
    unfoldLams :: Term -> ([Name], Term)
    unfoldLams (Bind n (Lam _ ty) t) = let (ns,t') = unfoldLams t in (n:ns, t')
    unfoldLams t = ([], t)
    union :: Deps -> Deps -> Deps
    union = M.unionWith (M.unionWith S.union)
    unionMap :: (a -> Deps) -> [a] -> Deps
    unionMap f = M.unionsWith (M.unionWith S.union) . map f
mkFieldName :: Name -> Int -> Name
mkFieldName ctorName fieldNo = SN (WhereN fieldNo ctorName $ sMN fieldNo "field")