{-# LANGUAGE PatternGuards #-} module Idris.PartialEval(partial_eval, getSpecApps, specType, mkPE_TyDecl, mkPE_TermDecl, PEArgType(..)) where import Idris.AbsSyntax import Idris.Delaborate import Idris.Core.TT import Idris.Core.Evaluate import Control.Monad.State import Debug.Trace partial_eval :: Context -> [(Name, Maybe Int)] -> [Either Term (Term, Term)] -> [Either Term (Term, Term)] partial_eval ctxt ns tms = map peClause tms where peClause (Left t) = Left t peClause (Right (lhs, rhs)) = let rhs' = specialise ctxt [] (map toLimit ns) rhs in Right (lhs, rhs') toLimit (n, Nothing) = (n, 65536) -- somewhat arbitrary reduction limit toLimit (n, Just l) = (n, l) specType :: [(PEArgType, Term)] -> Type -> (Type, [(PEArgType, Term)]) specType args ty = let (t, args') = runState (unifyEq args ty) [] in (st (map fst args') t, map fst args') where st ((ExplicitS, v) : xs) (Bind n (Pi t) sc) = Bind n (Let t v) (st xs sc) st ((ImplicitS, v) : xs) (Bind n (Pi t) sc) = Bind n (Let t v) (st xs sc) st ((UnifiedD, _) : xs) (Bind n (Pi t) sc) = st xs sc st (_ : xs) (Bind n (Pi t) sc) = Bind n (Pi t) (st xs sc) st _ t = t unifyEq (imp@(ImplicitD, v) : xs) (Bind n (Pi t) sc) = do amap <- get case lookup imp amap of Just n' -> do put (amap ++ [((UnifiedD, Erased), n)]) sc' <- unifyEq xs (subst n (P Bound n' Erased) sc) return (Bind n (Pi t) sc') -- erase later _ -> do put (amap ++ [(imp, n)]) sc' <- unifyEq xs sc return (Bind n (Pi t) sc') unifyEq (x : xs) (Bind n (Pi t) sc) = do args <- get put (args ++ [(x, n)]) sc' <- unifyEq xs sc return (Bind n (Pi t) sc') unifyEq xs t = do args <- get put (args ++ (zip xs (repeat (sUN "_")))) return t mkPE_TyDecl :: IState -> [(PEArgType, Term)] -> Type -> PTerm mkPE_TyDecl ist args ty = mkty args ty where mkty ((ExplicitD, v) : xs) (Bind n (Pi t) sc) = PPi expl n (delab ist t) (mkty xs sc) mkty ((ImplicitD, v) : xs) (Bind n (Pi t) sc) | concreteClass ist t = mkty xs sc | classConstraint ist t = PPi constraint n (delab ist t) (mkty xs sc) | otherwise = PPi impl n (delab ist t) (mkty xs sc) mkty (_ : xs) t = mkty xs t mkty [] t = delab ist t classConstraint ist v | (P _ c _, args) <- unApply v = case lookupCtxt c (idris_classes ist) of [_] -> True _ -> False | otherwise = False concreteClass ist v | not (classConstraint ist v) = False | (P _ c _, args) <- unApply v = all concrete args | otherwise = False where concrete (Constant _) = True concrete tm | (P _ n _, args) <- unApply tm = case lookupTy n (tt_ctxt ist) of [_] -> all concrete args _ -> False | otherwise = False mkPE_TermDecl :: IState -> Name -> Name -> [(PEArgType, Term)] -> [(PTerm, PTerm)] mkPE_TermDecl ist newname sname ns = let lhs = PApp emptyFC (PRef emptyFC newname) (map pexp (mkp ns)) rhs = eraseImps $ delab ist (mkApp (P Ref sname Erased) (map snd ns)) in [(lhs, rhs)] where mkp [] = [] mkp ((ExplicitD, tm) : tms) = delab ist tm : mkp tms mkp (_ : tms) = mkp tms eraseImps tm = mapPT deImp tm deImp (PApp fc t as) = PApp fc t (map deImpArg as) deImp t = t deImpArg a@(PImp _ _ _ _ _) = a { getTm = Placeholder } deImpArg a = a data PEArgType = ImplicitS | ImplicitD | ExplicitS | ExplicitD | UnifiedD deriving (Eq, Show) getSpecApps :: IState -> [Name] -> Term -> [(Name, [(PEArgType, Term)])] getSpecApps ist env tm = ga env (explicitNames tm) where -- staticArg env True _ tm@(P _ n _) _ | n `elem` env = Just (True, tm) -- staticArg env True _ tm@(App f a) _ | (P _ n _, args) <- unApply tm, -- n `elem` env = Just (True, tm) staticArg env x imp tm n | x && imparg imp = (ImplicitS, tm) | x = (ExplicitS, tm) | imparg imp = (ImplicitD, tm) | otherwise = (ExplicitD, (P Ref (sUN (show n ++ "arg")) Erased)) imparg (PExp _ _ _ _) = False imparg _ = True buildApp env [] [] _ _ = [] buildApp env (s:ss) (i:is) (a:as) (n:ns) = let s' = staticArg env s i a n ss' = buildApp env ss is as ns in (s' : ss') ga env tm@(App f a) | (P _ n _, args) <- unApply tm = ga env f ++ ga env a ++ case (lookupCtxt n (idris_statics ist), lookupCtxt n (idris_implicits ist)) of ([statics], [imps]) -> if (length statics == length args && or statics) then case buildApp env statics imps args [0..] of args -> [(n, args)] -- _ -> [] else [] _ -> [] ga env (Bind n t sc) = ga (n : env) sc ga env t = []