{-# LANGUAGE CPP #-} module Agda.Compiler.Epic.Forcing where import Control.Applicative import Control.Arrow (first, second) import Control.Monad import Control.Monad.State import Control.Monad.Trans import Data.List import qualified Data.Map as M import Data.Maybe import Agda.Syntax.Common import Agda.Syntax.Internal import Agda.TypeChecking.CompiledClause import Agda.TypeChecking.Monad import Agda.TypeChecking.Rules.LHS.Unify import Agda.TypeChecking.Substitute import Agda.TypeChecking.Pretty import Agda.TypeChecking.Reduce import Agda.TypeChecking.Telescope import Agda.Utils.List import Agda.Utils.Size import Agda.Compiler.Epic.CompileState hiding (conPars) import Agda.Compiler.Epic.AuxAST(pairwiseFilter) #include "../../undefined.h" import Agda.Utils.Impossible -- | Replace the uses of forced variables in a CompiledClauses with the function -- arguments that they correspond to. -- Note that this works on CompiledClauses where the term's variable indexes -- have been reversed, which means that the case variables match the variables -- in the term. removeForced :: MonadTCM m => CompiledClauses -> Type -> Compile m CompiledClauses removeForced cc typ = do TelV tele _ <- lift $ telView typ remForced cc tele -- | Returns the type of a constructor given its name constrType :: MonadTCM m => QName -> Compile m Type constrType q = do map <- lift (gets (sigDefinitions . stImports)) return $ maybe __IMPOSSIBLE__ defType (M.lookup q map) -- | Returns how many parameters a datatype has dataParameters :: MonadTCM m => QName -> Compile m Nat dataParameters name = do m <- lift (gets (sigDefinitions . stImports)) return $ maybe __IMPOSSIBLE__ (defnPars . theDef) (M.lookup name m) where defnPars :: Defn -> Nat defnPars (Datatype {dataPars = p}) = p defnPars (Record {recPars = p}) = p defnPars _ = 0 -- Not so sure about this. -- | Is variable n used in a CompiledClause? isIn :: MonadTCM m => Nat -> CompiledClauses -> Compile m Bool n `isIn` Case i brs | n == fromIntegral i = return True | otherwise = n `isInCase` (fromIntegral i, brs) n `isIn` Done _ t = return $ n `isInTerm` t n `isIn` Fail = return $ False isInCase :: MonadTCM m => Nat -> (Nat, Case CompiledClauses) -> Compile m Bool n `isInCase` (i, Branches { conBranches = cbrs , litBranches = lbrs , catchAllBranch = cabr}) = do cbrs' <- (or <$>) $ forM (M.toList cbrs) $ \ (constr, cc) -> do if i < n then do par <- fromIntegral <$> getConPar constr (n + par - 1) `isIn` cc else n `isIn` cc lbrs' <- (or <$>) $ forM (M.toList lbrs) $ \ (_, cc) -> (if i < n then (n - 1) else n) `isIn` cc cabr' <- case cabr of Nothing -> return False Just cc -> n `isIn` cc return (cbrs' || lbrs' || cabr') isInTerm :: Nat -> Term -> Bool n `isInTerm` term = let recs = any (isInTerm n . unArg) in case term of Var i as -> i == n || recs as Lam _ ab -> (n+1) `isInTerm` absBody ab Lit _ -> False Def _ as -> recs as Con _ as -> recs as Pi a b -> n `isInTerm` unEl (unArg a) || (n+1) `isInTerm` unEl (absBody b) Fun a b -> n `isInTerm` unEl (unArg a) || n `isInTerm` unEl b Sort sor -> False -- ? MetaV meta as -> False -- can't occur? DontCare -> False {- | insertTele i xs t tele tpos tele := Gamma ; (i : T as) ; Delta n := parameters T xs' := xs `apply` (take n as) becomes tpos ( Gamma ; xs' ; Delta[i := t] --note that Delta still reference Gamma correctly , T as ^ (size xs') ) we raise the type since we have added xs' new bindings before Gamma, and as can only bind to Gamma. -} insertTele :: MonadTCM m => Int -- ^ ABS `pos` in tele -> Maybe Type -- ^ If Just, it is the type to insert patterns from -- is nothing if we only want to delete a binding. -> Term -- ^ Term to replace at pos -> Telescope -- ^ The telescope `tele` where everything is at -> Compile m ( Telescope , ( Type , Type ) ) -- ^ Returns (resulting telescope, the type at pos in tele, the -- return type of the inserted type). insertTele 0 ins term (ExtendTel t to) = do t' <- lift $ normalise t let Def st arg = unEl . unArg $ t' -- Apply the parameters of the type of t -- Because: parameters occurs in the type of constructors but are not bound by it. pars <- dataParameters st TelV ctele ctyp <- lift $ telView $ maybe (unArg t') (`apply` take (fromIntegral pars) arg) ins () <- if length (take (fromIntegral pars) arg) == fromIntegral pars then return () else __IMPOSSIBLE__ -- we deal with absBody to directly since we remove t return ( ctele +:+ (subst term $ raiseFrom 1 (size ctele) (absBody to)) , (raise (size ctele) $ unArg t , ctyp) ) where -- Append the telescope, we raise since we add a new binding and all the previous -- bindings need to be preserved (+:+) :: Telescope -> Telescope -> Telescope EmptyTel +:+ t2 = t2 ExtendTel t t1 +:+ t2 = ExtendTel t t1 {absBody = absBody t1 +:+ {-raise 1-} t2 } -- This case is impossible since we are trying to split a variable outside the tele insertTele n ins term EmptyTel = __IMPOSSIBLE__ insertTele n ins term (ExtendTel x xs) = do (xs', typ) <- insertTele (n - 1) ins term (absBody xs) return (ExtendTel x xs {absBody = xs'} , typ) mkCon c n = Con c [ defaultArg $ Var (fromIntegral i) [] | i <- [n - 1, n - 2 .. 0] ] unifyI :: MonadTCM m => Telescope -> [Nat] -> Type -> Args -> Args -> Compile m [Maybe Term] unifyI tele flex typ a1 a2 = lift $ addCtxTel tele $ unifyIndices_ flex typ a1 a2 takeTele 0 _ = EmptyTel takeTele n (ExtendTel t ts) = ExtendTel t ts {absBody = takeTele (n-1) (absBody ts) } takeTele _ _ = __IMPOSSIBLE__ -- | Remove forced variables cased on in the current top-level case in the CompiledClauses remForced :: MonadTCM m => CompiledClauses -- ^ Remove cases on forced variables in this -> Telescope -- ^ The current context we are in -> Compile m CompiledClauses remForced ccOrig tele = case ccOrig of Case n brs -> do -- Get all constructor branches cbs <- forM (M.toList $ conBranches brs) $ \(constr, cc) -> do par <- getConPar constr typ <- constrType constr -- Update tele with the telescope from the constructor's type (tele', (ntyp, ctyp)) <- insertTele n (Just typ) (mkCon constr par) tele ntyp <- lift $ reduce ntyp ctyp <- lift $ reduce ctyp notForced <- getIrrFilter constr -- Get the variables that are forced, relative to the position after constr forcedVars <- filterM ((`isIn` cc) . (flip subtract (fromIntegral $ n + par - 1))) $ pairwiseFilter (map not notForced) $ map fromIntegral [par-1,par-2..0] if null forcedVars then (,) constr <$> remForced cc tele' else do unif <- case (unEl ntyp, unEl ctyp) of (Def st a1, Def st' a2) | st == st' -> do typPars <- fromIntegral <$> dataParameters st setType <- constrType st {- We are splitting on C xs we know that C : ts -> T ss ; for some T we also know from tele that we are splitting on T as we want to unify ss with as, but not taking into account the Data parameters to T. -} unifyI (takeTele (n + par) tele') (map fromIntegral [0 .. n + par]) -- Don't unify the constructor arguments (setType `apply` take typPars a1) (drop typPars a1) (drop typPars a2) x -> __IMPOSSIBLE__ -- we calculate the new tpos from n (the old one) by adding -- how many more bindings we have (,) constr <$> replaceForced (fromIntegral $ n + par, tele') forcedVars (cc, unif) lbs <- forM (M.toList $ litBranches brs) $ \(lit, cc) -> do -- We have one less binding (newTele, _) <- insertTele n Nothing (Lit lit) tele (,) lit <$> remForced cc newTele cabs <- case catchAllBranch brs of Nothing -> return Nothing Just cc -> Just <$> remForced cc tele return $ Case n brs { conBranches = M.fromList cbs , litBranches = M.fromList lbs , catchAllBranch = cabs } Done n t -> return $ Done n t Fail -> return Fail data FoldState = FoldState { clauseToFix :: CompiledClauses , clausesAbove :: CompiledClauses -> CompiledClauses , unification :: [Maybe Term] , theTelescope :: Telescope , telePos :: Nat } deriving Show -- Some utility functions foldM' :: Monad m => a -> [b] -> (a -> b -> m a) -> m a foldM' z xs f = foldM f z xs lift2 :: (MonadTrans t, Monad (t1 m), MonadTrans t1, Monad m) => m a -> t (t1 m) a lift2 = lift . lift modifyM :: (MonadState a m) => (a -> m a) -> m () modifyM f = get >>= f >>= put -- (>>= put) . (get >>=) -- | replaceForced (tpos, tele) forcedVars (cc, unification) -- For each forceVar dig out the corresponding case and continue to remForced. replaceForced :: MonadTCM m => (Nat, Telescope) -> [Nat] -> (CompiledClauses, [Maybe Term]) -> Compile m CompiledClauses replaceForced (telPos, tele) forcedVars (cc, unif) = do let origSt = FoldState { clauseToFix = cc , clausesAbove = id , unification = unif , theTelescope = tele , telePos = telPos } st <- flip execStateT origSt $ forM forcedVars $ \ forcedVar -> do unif <- gets unification let (caseVar, caseTerm) = findPosition forcedVar unif telPos <- gets telePos termToBranch (telPos - caseVar - 1) caseTerm forcedVar clausesAbove st <$> remForced (clauseToFix st) (theTelescope st) where {- In this function the following de Bruijn is: forcedVar : Relative caseVar : Absolute telePos : Absolute -} termToBranch :: MonadTCM m => Nat -> Term -> Nat -> StateT FoldState (Compile m) () termToBranch caseVar caseTerm forcedVar = case caseTerm of Var i _ | i == forcedVar -> do telPos <- gets telePos let sub = [0..telPos - forcedVar - 2] ++ [caseVar] ++ [telPos - forcedVar..] modifyM $ \ st -> do newClauseToFix <- substCC sub (clauseToFix st) return st { clauseToFix = newClauseToFix , unification = substs (map (flip Var []) sub) (unification st) } -- This is impossible since we have already looked and it should -- be the correct Var | otherwise -> __IMPOSSIBLE__ Con c args -> do telPos <- gets telePos let (nextCaseVarInCon, nextCaseTerm) = findPosition forcedVar (map (Just . unArg) args) nextCaseVar = nextCaseVarInCon + caseVar newBinds = fromIntegral $ length args - 1 -- we have added newBinds new bindings and removed one before telePos nextTelePos = telPos + newBinds ctyp <- lift (constrType c) modifyM $ \ st -> do (newTele , _) <- lift $ insertTele (fromIntegral caseVar) (Just ctyp) (mkCon c (length args)) (theTelescope st) -- We have to update the unifications-list so that we don't try -- to dig out the same again later. let newUnif = raiseFrom (telPos - caseVar) newBinds $ replaceAt (fromIntegral $ telPos - caseVar - 1) (unification st) (reverse $ map (Just . unArg) args) -- The variables in the unification-list is -- relative so we need to reverse the args -- so they get in the right place. return st { clauseToFix = raiseFromCC caseVar newBinds (substCCBody caseVar (Con c $ map (defaultArg . flip Var []) [caseVar .. caseVar + newBinds]) (clauseToFix st)) , theTelescope = newTele , unification = newUnif , telePos = nextTelePos } st <- get termToBranch nextCaseVar nextCaseTerm forcedVar modify $ \ st -> st { clausesAbove = Case (fromIntegral caseVar) . conCase c . (clausesAbove st) } _ -> __IMPOSSIBLE__ -- Note: Absolute positions raiseFromCC :: Nat -> Nat -> CompiledClauses -> CompiledClauses raiseFromCC from add cc = case cc of Case n (Branches cbr lbr cabr) -> Case (fromIntegral $ raiseN from add (fromIntegral n)) $ Branches (M.map rec cbr) (M.map rec lbr) (fmap rec cabr) Done i t -> Done (i + fromIntegral add) $ raiseFrom from add t Fail -> Fail where rec = raiseFromCC from add raiseN :: Nat -> Nat -> Nat -> Nat raiseN from add n | from <= n = n + add | otherwise = n -- | Substitute with the Substitution, this will adjust with the new bindings in the -- CompiledClauses substCC :: MonadTCM m => [Nat] -> CompiledClauses -> StateT FoldState (Compile m) CompiledClauses substCC ss cc = case cc of Done i t -> do return $ Done i (substs (map (flip Var []) ({-reverse $ take i -} ss)) t) Fail -> return Fail Case n brs -> do {- In a Case split, if we should change n to m, then all the binders in this pattern should also change from being based on n to be based on m. -} cbs <- forM (M.toList $ conBranches brs) $ \ (c, br) -> do nargs <- lift2 $ constructorArity c let delta = (ss !! n) - fi n ss' = take n ss ++ [fi n + delta .. fi n + delta + nargs - 1] ++ map (+ (nargs - 1)) (drop (n+1) ss) (,) c <$> substCC ss' br lbs <- forM (M.toList $ litBranches brs) $ \ (l, br) -> do -- We have one less binder here (,) l <$> substCC (replaceAt n ss []) br cabs <- case catchAllBranch brs of Nothing -> return Nothing Just br -> Just <$> substCC ss br return $ Case (fromIntegral (ss !! n)) Branches { conBranches = M.fromList cbs , litBranches = M.fromList lbs , catchAllBranch = cabs } where fi = fromIntegral -- | Substitute variable n for term t in the body of cc substCCBody :: Nat -> Term -> CompiledClauses -> CompiledClauses substCCBody n t cc = substsCCBody (vs [0..n - 1] ++ [t] ++ vs [n + 1..]) cc where vs = map (flip Var []) -- | Perform a substitution in the body of cc substsCCBody :: [Term] -> CompiledClauses -> CompiledClauses substsCCBody ss cc = case cc of Case n brs -> Case n (substsCCBody ss <$> brs) Done i t -> Done i (substs ss t) Fail -> Fail -- | Find the location where a certain Variable index is by searching the constructors -- aswell. i.e find a term that can be transformed into a pattern that contains the -- same value the index. This fails if no such term is present. findPosition :: Nat -> [Maybe Term] -> (Nat, Term) findPosition var ts = let Just n = findIndex (maybe False pred) ts in (fromIntegral n , fromJust $ ts !! n) where pred :: Term -> Bool pred t = case t of Var i _ | var == i -> True Con _ args -> any (pred . unArg) args _ -> False