module Agda.Compiler.Epic.Forcing where
import Control.Applicative
import Control.Arrow (first, second)
import Control.Monad
import Control.Monad.State
import Control.Monad.Trans
import Data.List
import qualified Data.Map as M
import Data.Maybe
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Rules.LHS.Unify
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Telescope
import Agda.Utils.List
import Agda.Utils.Size
import Agda.Compiler.Epic.CompileState hiding (conPars)
import Agda.Compiler.Epic.AuxAST(pairwiseFilter)
#include "../../undefined.h"
import Agda.Utils.Impossible
removeForced :: MonadTCM m => CompiledClauses -> Type -> Compile m CompiledClauses
removeForced cc typ = do
TelV tele _ <- lift $ telView typ
remForced cc tele
constrType :: MonadTCM m => QName -> Compile m Type
constrType q = do
map <- lift (gets (sigDefinitions . stImports))
return $ maybe __IMPOSSIBLE__ defType (M.lookup q map)
dataParameters :: MonadTCM m => QName -> Compile m Nat
dataParameters name = do
m <- lift (gets (sigDefinitions . stImports))
return $ maybe __IMPOSSIBLE__ (defnPars . theDef) (M.lookup name m)
where
defnPars :: Defn -> Nat
defnPars (Datatype {dataPars = p}) = p
defnPars (Record {recPars = p}) = p
defnPars _ = 0
isIn :: MonadTCM m => Nat -> CompiledClauses -> Compile m Bool
n `isIn` Case i brs | n == fromIntegral i = return True
| otherwise = n `isInCase` (fromIntegral i, brs)
n `isIn` Done _ t = return $ n `isInTerm` t
n `isIn` Fail = return $ False
isInCase :: MonadTCM m => Nat -> (Nat, Case CompiledClauses) -> Compile m Bool
n `isInCase` (i, Branches { conBranches = cbrs
, litBranches = lbrs
, catchAllBranch = cabr}) = do
cbrs' <- (or <$>) $ forM (M.toList cbrs) $ \ (constr, cc) -> do
if i < n
then do
par <- fromIntegral <$> getConPar constr
(n + par 1) `isIn` cc
else n `isIn` cc
lbrs' <- (or <$>) $ forM (M.toList lbrs) $ \ (_, cc) ->
(if i < n
then (n 1)
else n) `isIn` cc
cabr' <- case cabr of
Nothing -> return False
Just cc -> n `isIn` cc
return (cbrs' || lbrs' || cabr')
isInTerm :: Nat -> Term -> Bool
n `isInTerm` term = let recs = any (isInTerm n . unArg) in case term of
Var i as -> i == n || recs as
Lam _ ab -> (n+1) `isInTerm` absBody ab
Lit _ -> False
Def _ as -> recs as
Con _ as -> recs as
Pi a b -> n `isInTerm` unEl (unArg a) || (n+1) `isInTerm` unEl (absBody b)
Fun a b -> n `isInTerm` unEl (unArg a) || n `isInTerm` unEl b
Sort sor -> False
MetaV meta as -> False
DontCare -> False
insertTele :: MonadTCM m
=> Int
-> Maybe Type
-> Term
-> Telescope
-> Compile m ( Telescope
, ( Type
, Type
)
)
insertTele 0 ins term (ExtendTel t to) = do
t' <- lift $ normalise t
let Def st arg = unEl . unArg $ t'
pars <- dataParameters st
TelV ctele ctyp <- lift $ telView $ maybe (unArg t')
(`apply` take (fromIntegral pars) arg) ins
() <- if length (take (fromIntegral pars) arg) == fromIntegral pars
then return ()
else __IMPOSSIBLE__
return ( ctele +:+ (subst term $ raiseFrom 1 (size ctele) (absBody to))
, (raise (size ctele) $ unArg t , ctyp)
)
where
(+:+) :: Telescope -> Telescope -> Telescope
EmptyTel +:+ t2 = t2
ExtendTel t t1 +:+ t2 = ExtendTel t t1 {absBody = absBody t1 +:+ t2 }
insertTele n ins term EmptyTel = __IMPOSSIBLE__
insertTele n ins term (ExtendTel x xs) = do
(xs', typ) <- insertTele (n 1) ins term (absBody xs)
return (ExtendTel x xs {absBody = xs'} , typ)
mkCon c n = Con c [ defaultArg $ Var (fromIntegral i) [] | i <- [n 1, n 2 .. 0] ]
unifyI :: MonadTCM m => Telescope -> [Nat] -> Type -> Args -> Args -> Compile m [Maybe Term]
unifyI tele flex typ a1 a2 = lift $ addCtxTel tele $ unifyIndices_ flex typ a1 a2
takeTele 0 _ = EmptyTel
takeTele n (ExtendTel t ts) = ExtendTel t ts {absBody = takeTele (n1) (absBody ts) }
takeTele _ _ = __IMPOSSIBLE__
remForced :: MonadTCM m
=> CompiledClauses
-> Telescope
-> Compile m CompiledClauses
remForced ccOrig tele = case ccOrig of
Case n brs -> do
cbs <- forM (M.toList $ conBranches brs) $ \(constr, cc) -> do
par <- getConPar constr
typ <- constrType constr
(tele', (ntyp, ctyp)) <- insertTele n (Just typ) (mkCon constr par) tele
ntyp <- lift $ reduce ntyp
ctyp <- lift $ reduce ctyp
notForced <- getIrrFilter constr
forcedVars <- filterM ((`isIn` cc) . (flip subtract (fromIntegral $ n + par 1)))
$ pairwiseFilter (map not notForced)
$ map fromIntegral [par1,par2..0]
if null forcedVars
then (,) constr <$> remForced cc tele'
else do
unif <- case (unEl ntyp, unEl ctyp) of
(Def st a1, Def st' a2) | st == st' -> do
typPars <- fromIntegral <$> dataParameters st
setType <- constrType st
unifyI (takeTele (n + par) tele')
(map fromIntegral [0 .. n + par])
(setType `apply` take typPars a1)
(drop typPars a1)
(drop typPars a2)
x -> __IMPOSSIBLE__
(,) constr <$> replaceForced (fromIntegral $ n + par, tele')
forcedVars
(cc, unif)
lbs <- forM (M.toList $ litBranches brs) $ \(lit, cc) -> do
(newTele, _) <- insertTele n Nothing (Lit lit) tele
(,) lit <$> remForced cc newTele
cabs <- case catchAllBranch brs of
Nothing -> return Nothing
Just cc -> Just <$> remForced cc tele
return $ Case n brs { conBranches = M.fromList cbs
, litBranches = M.fromList lbs
, catchAllBranch = cabs }
Done n t -> return $ Done n t
Fail -> return Fail
data FoldState = FoldState
{ clauseToFix :: CompiledClauses
, clausesAbove :: CompiledClauses -> CompiledClauses
, unification :: [Maybe Term]
, theTelescope :: Telescope
, telePos :: Nat
} deriving Show
foldM' :: Monad m => a -> [b] -> (a -> b -> m a) -> m a
foldM' z xs f = foldM f z xs
lift2 :: (MonadTrans t, Monad (t1 m), MonadTrans t1, Monad m) => m a -> t (t1 m) a
lift2 = lift . lift
modifyM :: (MonadState a m) => (a -> m a) -> m ()
modifyM f = get >>= f >>= put
replaceForced :: MonadTCM m
=> (Nat, Telescope) -> [Nat] -> (CompiledClauses, [Maybe Term])
-> Compile m CompiledClauses
replaceForced (telPos, tele) forcedVars (cc, unif) = do
let origSt = FoldState
{ clauseToFix = cc
, clausesAbove = id
, unification = unif
, theTelescope = tele
, telePos = telPos
}
st <- flip execStateT origSt $ forM forcedVars $ \ forcedVar -> do
unif <- gets unification
let (caseVar, caseTerm) = findPosition forcedVar unif
telPos <- gets telePos
termToBranch (telPos caseVar 1) caseTerm forcedVar
clausesAbove st <$> remForced (clauseToFix st) (theTelescope st)
where
termToBranch :: MonadTCM m => Nat -> Term -> Nat -> StateT FoldState (Compile m) ()
termToBranch caseVar caseTerm forcedVar = case caseTerm of
Var i _ | i == forcedVar -> do
telPos <- gets telePos
let sub = [0..telPos forcedVar 2] ++ [caseVar] ++ [telPos forcedVar..]
modifyM $ \ st -> do
newClauseToFix <- substCC sub (clauseToFix st)
return st
{ clauseToFix = newClauseToFix
, unification = substs (map (flip Var []) sub) (unification st)
}
| otherwise -> __IMPOSSIBLE__
Con c args -> do
telPos <- gets telePos
let (nextCaseVarInCon, nextCaseTerm) = findPosition forcedVar (map (Just . unArg) args)
nextCaseVar = nextCaseVarInCon + caseVar
newBinds = fromIntegral $ length args 1
nextTelePos = telPos + newBinds
ctyp <- lift (constrType c)
modifyM $ \ st -> do
(newTele , _) <- lift $ insertTele (fromIntegral caseVar) (Just ctyp)
(mkCon c (length args)) (theTelescope st)
let newUnif = raiseFrom (telPos caseVar) newBinds $
replaceAt (fromIntegral $ telPos caseVar 1)
(unification st)
(reverse $ map (Just . unArg) args)
return st
{ clauseToFix = raiseFromCC caseVar newBinds
(substCCBody caseVar
(Con c $ map (defaultArg . flip Var [])
[caseVar .. caseVar + newBinds])
(clauseToFix st))
, theTelescope = newTele
, unification = newUnif
, telePos = nextTelePos
}
st <- get
termToBranch nextCaseVar nextCaseTerm forcedVar
modify $ \ st -> st
{ clausesAbove = Case (fromIntegral caseVar) . conCase c . (clausesAbove st)
}
_ -> __IMPOSSIBLE__
raiseFromCC :: Nat -> Nat -> CompiledClauses -> CompiledClauses
raiseFromCC from add cc = case cc of
Case n (Branches cbr lbr cabr) -> Case (fromIntegral $ raiseN from add (fromIntegral n)) $
Branches (M.map rec cbr)
(M.map rec lbr)
(fmap rec cabr)
Done i t -> Done (i + fromIntegral add) $ raiseFrom from add t
Fail -> Fail
where
rec = raiseFromCC from add
raiseN :: Nat -> Nat -> Nat -> Nat
raiseN from add n | from <= n = n + add
| otherwise = n
substCC :: MonadTCM m => [Nat] -> CompiledClauses -> StateT FoldState (Compile m) CompiledClauses
substCC ss cc = case cc of
Done i t -> do
return $ Done i (substs (map (flip Var []) ( ss)) t)
Fail -> return Fail
Case n brs -> do
cbs <- forM (M.toList $ conBranches brs) $ \ (c, br) -> do
nargs <- lift2 $ constructorArity c
let delta = (ss !! n) fi n
ss' = take n ss
++ [fi n + delta .. fi n + delta + nargs 1]
++ map (+ (nargs 1)) (drop (n+1) ss)
(,) c <$> substCC ss' br
lbs <- forM (M.toList $ litBranches brs) $ \ (l, br) -> do
(,) l <$> substCC (replaceAt n ss []) br
cabs <- case catchAllBranch brs of
Nothing -> return Nothing
Just br -> Just <$> substCC ss br
return $ Case (fromIntegral (ss !! n))
Branches { conBranches = M.fromList cbs
, litBranches = M.fromList lbs
, catchAllBranch = cabs
}
where
fi = fromIntegral
substCCBody :: Nat -> Term -> CompiledClauses -> CompiledClauses
substCCBody n t cc = substsCCBody (vs [0..n 1] ++ [t] ++ vs [n + 1..]) cc
where vs = map (flip Var [])
substsCCBody :: [Term] -> CompiledClauses -> CompiledClauses
substsCCBody ss cc = case cc of
Case n brs -> Case n (substsCCBody ss <$> brs)
Done i t -> Done i (substs ss t)
Fail -> Fail
findPosition :: Nat -> [Maybe Term] -> (Nat, Term)
findPosition var ts = let Just n = findIndex (maybe False pred) ts
in (fromIntegral n , fromJust $ ts !! n)
where
pred :: Term -> Bool
pred t = case t of
Var i _ | var == i -> True
Con _ args -> any (pred . unArg) args
_ -> False