module Agda.Compiler.Treeless.Unused
  ( usedArguments
  , stripUnusedArguments
  ) where

import Control.Arrow (first)
import Control.Applicative
import qualified Data.Set as Set
import Data.Maybe

import Agda.Syntax.Treeless
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Substitute

import Agda.Compiler.Treeless.Subst
import Agda.Compiler.Treeless.Pretty

usedArguments :: QName -> TTerm -> TCM [Bool]
usedArguments q t = computeUnused q b (replicate n False)
  where (n, b) = lamView t

computeUnused :: QName -> TTerm -> [Bool] -> TCM [Bool]
computeUnused q t used = do
  reportSLn "treeless.opt.unused" 50 $ "Unused approximation for " ++ show q ++ ": " ++
                                       unwords [ if u then [x] else "_" | (x, u) <- zip ['a'..] used ]
  setCompiledArgUse q used
  fv <- go t
  let used' = [ Set.member i fv | (i, _) <- reverse $ zip [0..] used ]
  if used == used' then return used'
                   else computeUnused q t used'
  where
    go t = case t of
      TVar x    -> pure $ Set.singleton x
      TPrim{}   -> pure Set.empty
      TDef{}    -> pure Set.empty
      TLit{}    -> pure Set.empty
      TCon{}    -> pure Set.empty

      TApp (TDef f) ts -> do
        used <- getCompiledArgUse f
        Set.unions <$> sequence [ go t | (t, True) <- zip ts $ used ++ repeat True ]

      TApp f ts -> Set.unions <$> mapM go (f : ts)
      TLam b    -> underBinder <$> go b
      TLet e b  -> Set.union <$> go e <*> (underBinder <$> go b)
      TCase x _ d bs -> Set.insert x . Set.unions <$> ((:) <$> go d <*> mapM goAlt bs)
      TUnit{}   -> pure Set.empty
      TSort{}   -> pure Set.empty
      TErased{} -> pure Set.empty
      TError{}  -> pure Set.empty

    goAlt (TALit _   b) = go b
    goAlt (TAGuard g b) = Set.union <$> go g <*> go b
    goAlt (TACon _ a b) = underBinders a <$> go b

    underBinder = underBinders 1
    underBinders 0 = id
    underBinders n = Set.filter (>= 0) . Set.mapMonotonic (subtract n)

stripUnusedArguments :: [Bool] -> TTerm -> TTerm
stripUnusedArguments used t = unlamView m $ applySubst rho b
  where
    (n, b) = lamView t
    m      = length $ filter id used'
    used'  = reverse $ take n $ used ++ repeat True
    rho = computeSubst used'
    computeSubst (False : bs) = TErased :# computeSubst bs
    computeSubst (True  : bs) = liftS 1 $ computeSubst bs
    computeSubst []           = idS

lamView :: TTerm -> (Int, TTerm)
lamView (TLam b) = first succ $ lamView b
lamView t        = (0, t)

unlamView :: Int -> TTerm -> TTerm
unlamView 0 t = t
unlamView n t = TLam $ unlamView (n - 1) t