module Agda.Compiler.Epic.Injection where
import Control.Monad.State
import Data.Function
import Data.Ix
import Data.List
import Data.Map(Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Set(Set)
import qualified Data.Set as S
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Literal
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Monad hiding ((!!!))
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.Utils.Monad
import Agda.Utils.Size
import Agda.Compiler.Epic.CompileState
import qualified Agda.Compiler.Epic.FromAgda as FA
import Agda.Compiler.Epic.Interface as Interface
#include "../../undefined.h"
import Agda.Utils.Impossible
findInjection :: [(QName, Definition)] -> Compile TCM [(QName, Definition)]
findInjection defs = do
funs <- forM defs $ \(name, def) -> case theDef def of
f@(Function{}) -> isInjective name (funClauses f)
_ -> return Nothing
newNames <- M.keys <$> gets (Interface.conArity . curModule)
injFuns <- solve newNames (catMaybes funs)
defs' <- forM defs $ \(q, def) -> case q `isIn` injFuns of
Nothing -> return (q, def)
Just inj@(InjectiveFun var arity) -> case theDef def of
f@(Function{}) -> do
modifyEI $ \s -> s { injectiveFuns = M.insert q inj (injectiveFuns s) }
let ns = replicate (fromIntegral arity) (Arg NotHidden Relevant "")
return $ (,) q $ def {theDef = f { funCompiled = Done ns
(Var (arity var 1) []) } }
_ -> __IMPOSSIBLE__
lift $ reportSLn "epic.injection" 10 $ "injfuns: " ++ show injFuns
return defs'
where
q `isIn` funs = case filter (\(nam, _) -> q == nam) funs of
[] -> Nothing
(_,x):_ -> Just x
replaceFunCC :: QName -> CompiledClauses -> Compile TCM ()
replaceFunCC name cc = do
lift $ modify $ \s ->
s { stSignature = (stSignature s) { sigDefinitions = M.adjust replaceDef name (sigDefinitions (stSignature s)) }
, stImports = (stImports s) { sigDefinitions = M.adjust replaceDef name (sigDefinitions (stImports s)) }
}
where
replaceDef :: Definition -> Definition
replaceDef def = case theDef def of
f@(Function{}) -> def {theDef = f { funCompiled = cc } }
x -> __IMPOSSIBLE__
type InjConstraints = Maybe [(QName,QName)]
isInjective :: QName
-> [Clause]
-> Compile TCM (Maybe ((QName, InjectiveFun)
, [(QName, QName)]
))
isInjective nam [] = return Nothing
isInjective nam cls@(cl : _) = do
let total = genericLength . clausePats $ cl
(listToMaybe . catMaybes <$>) . forM [0 .. total 1] $ \i -> do
cli <- forM cls $ \ cl -> isInjectiveHere nam i cl
let cli' = catMaybes cli
return $ if length cli == length cli'
then Just ((nam, InjectiveFun (fromIntegral i) (fromIntegral total)), concat cli')
else Nothing
remAbs :: ClauseBody -> Term
remAbs b = case b of
Body t -> t
Bind ab -> remAbs $ absBody ab
NoBody -> __IMPOSSIBLE__
isNoBody :: ClauseBody -> Bool
isNoBody b = case b of
Body t -> False
Bind ab -> isNoBody $ absBody ab
NoBody -> True
patternToTerm :: Nat -> Pattern -> Term
patternToTerm n p = case p of
VarP v -> Var n []
DotP t -> t
ConP c typ args -> Con c $ zipWith (\ arg t -> arg {unArg = t}) args
$ snd
$ foldr (\ arg (n, ts) -> (n + nrBinds arg, patternToTerm n arg : ts))
(n , [])
$ map unArg args
LitP l -> Lit l
nrBinds :: Num i => Pattern -> i
nrBinds p = case p of
VarP v -> 1
DotP t -> 0
ConP c typ args -> sum $ map (nrBinds . unArg) args
LitP l -> 0
substForDot :: [Arg Pattern] -> Substitution
substForDot ps = map (flip Var []) (makeSubst 0 0 $ reverse $ calcDots ps)
where
makeSubst i accum [] = [i + accum ..]
makeSubst i accum (True : ps) = makeSubst i (accum +1) ps
makeSubst i accum (False : ps) = i + accum : makeSubst (i+1) accum ps
calcDots = concatMap calcDots' . map unArg
calcDots' p = case p of
VarP v -> [False]
DotP t -> [True]
ConP c typ args -> calcDots args
LitP l -> [False]
isInjectiveHere :: QName
-> Int
-> Clause
-> Compile TCM InjConstraints
isInjectiveHere nam idx Clause {clauseBody = body} | isNoBody body = return emptyC
isInjectiveHere nam idx clause = do
let t = patternToTerm idxR $ unArg $ clausePats clause !! idx
t' = substs (substForDot $ clausePats clause) t
idxR = sum . map (nrBinds . unArg) . genericDrop (idx + 1) $ clausePats clause
body = remAbs $ clauseBody clause
body' <- lift $ reduce body
injFs <- gets (injectiveFuns . importedModules)
res <- (t' <: body') (M.insert nam (InjectiveFun (fromIntegral idx)
(genericLength (clausePats clause))) injFs)
lift $ reportSDoc "epic.injection" 20 $ vcat
[ text "isInjective:" <+> text (show nam)
, text "at Index :" <+> text (show idx)
, nest 2 $ vcat
[ text "clause :" <+> text (show clause)
, text "t :" <+> prettyTCM t
, text "idxR :" <+> (text . show) idxR
, text "body' :" <+> (text . show) body'
]
, text "res :" <+> text (show res)
]
return res
litToCon :: Literal -> TCM Term
litToCon l = case l of
LitInt r n | n > 0 -> do
inner <- litToCon (LitInt r (n 1))
suc <- primSuc
return $ suc `apply` [defaultArg inner]
| otherwise -> primZero
lit -> return $ Lit lit
litCon :: Literal -> Bool
litCon (LitInt _ _) = True
litCon _ = False
insertAt :: (Nat,Term) -> Term -> Term
insertAt (index, ins) = substs [if i == index then ins else Var i [] | i <- [0 .. ]]
solve :: [QName] -> [((QName, InjectiveFun), [(QName,QName)])] -> Compile TCM [(QName, InjectiveFun)]
solve newNames xs = do
conGraph <- M.union <$> gets (constrTags . curModule) <*> gets (constrTags . importedModules)
(funs, mconstr) <- ($ xs) $ flip foldM ([] , Just $ initialTags conGraph newNames) $ \ (xs , prev) (fun , con) -> do
m <- foldM solvable prev con
return $ case m of
Nothing -> (xs, prev)
Just next -> (fun : xs, m)
case mconstr of
Nothing -> __IMPOSSIBLE__
Just constr -> updateTags constr
return funs
where
solvable :: Maybe Tags -> (QName, QName)
-> Compile TCM (Maybe Tags)
solvable Nothing _ = return Nothing
solvable (Just st) (c1, c2) = unify c1 c2 st
updateTags :: Tags -> Compile TCM ()
updateTags tags = do
let (hasTags, eqs) = M.partition isTag (constrGroup tags)
isTag (IsTag _) = True
isTag _ = False
forM (M.toList hasTags) $ \ (c, tagged) -> case tagged of
IsTag tag -> putCon c tag
_ -> __IMPOSSIBLE__
case M.toList eqs of
(c, Same n) : _ -> do
let grp = eqGroups tags !!! n
tag <- assignConstrTag' c (S.toList grp)
updateTags . fromMaybe __IMPOSSIBLE__ =<< setTag n tag tags { constrGroup = eqs }
_ -> return ()
putCon :: QName -> Tag -> Compile TCM ()
putCon con tag = do
m <- gets (constrTags . importedModules)
case M.lookup con m of
Nothing -> putConstrTag con tag
Just _ -> return ()
emptyC :: InjConstraints
emptyC = Just []
addConstraint :: QName -> QName -> InjConstraints -> InjConstraints
addConstraint q1 q2 Nothing = Nothing
addConstraint q1 q2 (Just xs) = Just (if q1 == q2 then xs else (q1,q2) : xs)
unionConstraints :: [InjConstraints] -> InjConstraints
unionConstraints [] = Just []
unionConstraints (Nothing : _) = Nothing
unionConstraints (Just c : cs) = do
cs' <- unionConstraints cs
return (c ++ cs')
(<:) :: Term -> Term -> (QName :-> InjectiveFun) -> Compile TCM InjConstraints
(Lit l <: t1) injs | litCon l = do
l' <- lift $ litToCon l
(l' <: t1) injs
(t1 <: Lit l) injs | litCon l = do
l' <- lift $ litToCon l
(t1 <: l') injs
(t1 <: Def n2 args2) injs | Just (InjectiveFun argn arit) <- M.lookup n2 injs =
if genericLength args2 /= arit
then return Nothing
else do
arg <- lift $ reduce $ unArg $ args2 !! fromIntegral argn
(t1 <: arg) injs
(Var n1 args1 <: Var n2 args2) injs | n1 == n2 && length args1 == length args2 = do
args1' <- map unArg <$> mapM (lift . reduce) args1
args2' <- map unArg <$> mapM (lift . reduce) args2
unionConstraints <$> zipWithM (\a b -> (a <: b) injs) args1' args2'
(Def q1 args1 <: Def q2 args2) injs | q1 == q2 && length args1 == length args2 = do
args1' <- map unArg <$> mapM (lift . reduce) args1
args2' <- map unArg <$> mapM (lift . reduce) args2
unionConstraints <$> zipWithM (\a b -> (a <: b) injs) args1' args2'
(Con c1 args1 <: Con c2 args2) injs = do
args1' <- map unArg <$> flip notForced args1 <$> getForcedArgs c1
args2' <- map unArg <$> (mapM (lift . reduce) =<< flip notForced args2 <$> getForcedArgs c2)
if length args1' == length args2'
then addConstraint c1 c2 <$> unionConstraints <$> zipWithM (\a b -> (a <: b) injs) args1' args2'
else return Nothing
(_ <: _) _ = return Nothing
data TagEq
= Same Int
| IsTag Tag
deriving Eq
data Tags = Tags
{ eqGroups :: Int :-> Set QName
, constrGroup :: QName :-> TagEq
}
initialTags :: Map QName Tag -> [QName] -> Tags
initialTags setTags newNames = Tags
{ eqGroups = M.fromList $ zip [0..] (map S.singleton newNames)
, constrGroup = M.map IsTag setTags `M.union` M.fromList (zip newNames (map Same [0..]))
}
unify :: QName -> QName -> Tags -> Compile TCM (Maybe Tags)
unify c1 c2 ts = do
let g1 = constrGroup ts !!! c1
g2 = constrGroup ts !!! c2
case (g1, g2) of
(Same n1, Same n2) | n1 == n2 -> return $ Just ts
(IsTag t1, IsTag t2) | t1 == t2 -> return $ Just ts
(Same n1, Same n2) -> mergeGroups n1 n2 ts
(Same n1, IsTag t2) -> setTag n1 t2 ts
(IsTag t1 , Same n2) -> setTag n2 t1 ts
_ -> return $ Nothing
setTag :: Int -> Tag -> Tags -> Compile TCM (Maybe Tags)
setTag gid tag ts = return $ Just $ ts
{ constrGroup = foldr (\c -> M.insert c (IsTag tag)) (constrGroup ts) (S.toList $ eqGroups ts !!! gid)}
mergeGroups :: Int -> Int -> Tags -> Compile TCM (Maybe Tags)
mergeGroups n1 n2 ts = do
let g1s = eqGroups ts !!! n1
g2s = eqGroups ts !!! n2
gs = S.union g1s g2s
ifM (not . and <$> sequence [unifiable e1 e2 | e1 <- S.toList g1s, e2 <- S.toList g2s])
(return Nothing) $
return $ Just $ ts
{ eqGroups = M.delete n2 $ M.insert n1 gs (eqGroups ts)
, constrGroup = M.fromList [(e2, Same n1) | e2 <- S.toList g2s] `M.union` constrGroup ts
}
unifiable :: QName -> QName -> Compile TCM Bool
unifiable c1 c2 = do
d1 <- getConData c1
d2 <- getConData c2
return $ d1 /= d2
(!!!) :: Ord k => k :-> v -> k -> v
m !!! k = case M.lookup k m of
Nothing -> __IMPOSSIBLE__
Just x -> x