module Agda.Compiler.Epic.Smashing where
import Control.Arrow((&&&))
import Control.Monad
import Control.Monad.State
import Control.Monad.Trans
import Data.List
import qualified Data.Map as M
import Data.Map (Map)
import Data.Maybe
import qualified Data.Set as S
import Data.Set (Set)
import Agda.Syntax.Common
import Agda.Syntax.Internal as SI
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Rules.LHS.Unify
import Agda.Compiler.Epic.AuxAST as AA
import Agda.Compiler.Epic.CompileState
import Agda.Compiler.Epic.Interface
import Agda.Utils.Monad
import Agda.Utils.Size
import qualified Agda.Utils.HashMap as HM
#include "../../undefined.h"
import Agda.Utils.Impossible
defnPars :: Integral n => Defn -> n
defnPars (Record {recPars = p}) = fromIntegral p
defnPars (Constructor {conPars = p}) = fromIntegral p
defnPars d = 0
smash'em :: [Fun] -> Compile TCM [Fun]
smash'em funs = do
defs <- lift (gets (sigDefinitions . stImports))
funs' <- forM funs $ \f -> case f of
AA.Fun{} -> case funQName f >>= flip HM.lookup defs of
Nothing -> do
lift $ reportSDoc "epic.smashing" 10 $ vcat
[ (text . show) f <+> text " was not found"]
return f
Just def -> do
lift $ reportSLn "epic.smashing" 10 $ "running on:" ++ (show (funQName f))
minfered <- smashable (length (funArgs f) + defnPars (theDef def)) (defType def)
case minfered of
Just infered -> do
lift $ reportSDoc "smashing" 5 $ vcat
[ prettyTCM (defName def) <+> text "is smashable"]
return f { funExpr = infered
, funInline = True
, funComment = funComment f ++ " [SMASHED]"
}
Nothing -> return f
_ -> do
lift $ reportSLn "epic.smashing" 10 $ "smashing!"
return f
return funs'
(+++) :: Telescope -> Telescope -> Telescope
xs +++ ys = unflattenTel names $ map (raise (size ys)) (flattenTel xs) ++ flattenTel ys
where names = teleNames xs ++ teleNames ys
inferable :: Set QName -> QName -> [Arg Term] -> Compile TCM (Maybe Expr)
inferable visited dat args | dat `S.member` visited = return Nothing
inferable visited dat args = do
lift $ reportSLn "epic.smashing" 10 $ " inferring:" ++ (show dat)
defs <- lift (gets (sigDefinitions . stImports))
let def = fromMaybe __IMPOSSIBLE__ $ HM.lookup dat defs
case theDef def of
d@Datatype{} -> do
case dataCons d of
[c] -> inferableArgs c (dataPars d)
_ -> return Nothing
r@Record{} -> inferableArgs (recCon r) (recPars r)
f@Function{} -> do
term <- lift $ normalise $ Def dat args
inferableTerm visited' term
d -> do
lift $ reportSLn "epic.smashing" 10 $ " failed (inferable): " ++ (show d)
return Nothing
where
inferableArgs c pars = do
defs <- lift (gets (sigDefinitions . stImports))
let def = fromMaybe __IMPOSSIBLE__ $ HM.lookup c defs
forc <- getForcedArgs c
TelV tel _ <- lift $ telView (defType def `apply` genericTake pars args)
tag <- getConstrTag c
lift $ reportSDoc "epic.smashing" 10 $ nest 2 $ vcat
[ text "inferableArgs!"
, text "tele" <+> prettyTCM tel
, text "constr:" <+> prettyTCM c
]
(AA.Con tag c <$>) <$> sequence <$> forM (notForced forc $ flattenTel tel) (inferableTerm visited' . unEl . unDom)
visited' = S.insert dat visited
inferableTerm visited t = case t of
Def q as -> inferable visited q as
Pi _ b -> (AA.Lam "_" <$>) <$> inferableTerm visited (unEl $ unAbs b)
Sort {} -> return . return $ AA.UNIT
t -> do
lift $ reportSLn "epic.smashing" 10 $ " failed to infer: " ++ show t
return Nothing
smashable :: Int -> Type -> Compile TCM (Maybe Expr)
smashable origArity typ = do
defs <- lift (gets (sigDefinitions . stImports))
TelV tele retType <- lift $ telView typ
retType' <- return retType
inf <- inferableTerm S.empty (unEl retType')
lift $ reportSDoc "epic.smashing" 10 $ nest 2 $ vcat
[ text "Result is"
, text "inf: " <+> (text . show) inf
, text "type: " <+> prettyTCM retType'
]
return $ buildLambda (size tele origArity) <$> inf
buildLambda :: (Ord n, Num n) => n -> Expr -> Expr
buildLambda n e | n <= 0 = e
buildLambda n e | otherwise = AA.Lam "_" (buildLambda (n 1) e)