module Agda.TypeChecking.RecordPatterns
( translateRecordPatterns
, translateCompiledClauses
, translateSplitTree
, recordPatternToProjections
) where
import Control.Applicative
import Control.Arrow ((***))
import Control.Monad.Fix
import Control.Monad.Reader
import Control.Monad.State
import Data.List
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Traversable as Trav
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Coverage.SplitTree
import Agda.TypeChecking.Datatypes
import Agda.TypeChecking.Errors
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Records
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.Utils.Either
import Agda.Utils.List
import qualified Agda.Utils.Map as Map
import Agda.Utils.Maybe
import Agda.Utils.Permutation
import Agda.Utils.Size
#include "../undefined.h"
import Agda.Utils.Impossible
recordPatternToProjections :: Pattern -> TCM [Term -> Term]
recordPatternToProjections p =
case p of
VarP{} -> return [ \ x -> x ]
LitP{} -> typeError $ ShouldBeRecordPattern p
DotP{} -> typeError $ ShouldBeRecordPattern p
ConP c Nothing ps -> typeError $ ShouldBeRecordPattern p
ConP c (Just t) ps -> do
t <- reduce t
fields <- getRecordTypeFields (unArg t)
concat <$> zipWithM comb (map proj fields) (map unArg ps)
where
proj p = \ x -> Def (unArg p) [defaultArg x]
comb :: (Term -> Term) -> Pattern -> TCM [Term -> Term]
comb prj p = map (prj .) <$> recordPatternToProjections p
conjColumns :: [[Bool]] -> [Bool]
conjColumns = foldl1 (zipWith (&&))
insertColumn :: Int -> a -> [[a]] -> [[a]]
insertColumn i a rows = map ins rows where
ins row = let (init, last) = splitAt i row in init ++ a : last
cutSublist :: Int -> Int -> [a] -> ([a], [a], [a])
cutSublist i n row =
let (init, rest) = splitAt i row
(mid , last) = splitAt n rest
in (init, mid, last)
translateCompiledClauses :: CompiledClauses -> TCM CompiledClauses
translateCompiledClauses cc = snd <$> loop cc
where
loop :: CompiledClauses -> TCM ([Bool], CompiledClauses)
loop cc = case cc of
Fail -> return (repeat True, cc)
Done xs t -> return (map (const True) xs, cc)
Case i cs -> loops i cs
loops :: Int
-> Case CompiledClauses
-> TCM ([Bool], CompiledClauses)
loops i cs@(Branches { conBranches = conMap
, litBranches = litMap
, catchAllBranch = catchAll }) = do
(xssa, catchAll) <- unzipMaybe <$> Trav.mapM loop catchAll
let xsa = maybe (repeat True) id xssa
(xssl, litMap) <- Map.unzip <$> Trav.mapM loop litMap
let xsl = conjColumns (xsa : insertColumn i False (Map.elems xssl))
(ccs, xssc, conMap) <- Map.unzip3 <$> do
Trav.forM (Map.mapWithKey (,) conMap) $ \ (c, WithArity ar cc) -> do
(xs, cc) <- loop cc
dataOrRecCon <- getConstructorArity c
let (isRC, n) = either (False,) ((True,) . size) dataOrRecCon
(xs0, rest) = genericSplitAt i xs
(xs1, xs2 ) = genericSplitAt n rest
x = isRC && and xs1
xs' = xs0 ++ x : xs2
fs = either __IMPOSSIBLE__ id dataOrRecCon
mcc = if x then [replaceByProjections i (map unArg fs) cc] else []
when (n /= ar) __IMPOSSIBLE__
return (mcc, xs', WithArity ar cc)
let xs = conjColumns (xsl : Map.elems xssc)
case concat $ Map.elems ccs of
[] -> return (xs, Case i $ Branches
{ conBranches = conMap
, litBranches = litMap
, catchAllBranch = catchAll })
[cc] -> return (xs, cc)
_ -> __IMPOSSIBLE__
replaceByProjections :: Int -> [QName] -> CompiledClauses -> CompiledClauses
replaceByProjections i projs cc =
let n = length projs
loop :: Int -> CompiledClauses -> CompiledClauses
loop i cc = case cc of
Case j cs
| j < i -> Case j $ loops i cs
| otherwise -> Case (j (n1)) $ fmap (loop i) cs
Done xs v ->
let (xs0,xs1,xs2) = cutSublist i n xs
names | null xs1 = ["r"]
| otherwise = map unArg xs1
x = defaultArg $ foldr1 (++) names
xs' = xs0 ++ x : xs2
us = map (\ p -> Def p [defaultArg $ var 0]) (reverse projs)
index = length xs (i + n)
in Done xs' $ applySubst (liftS (length xs2) $ us ++# raiseS 1) v
Fail -> Fail
loops :: Int -> Case CompiledClauses -> Case CompiledClauses
loops i Branches{ conBranches = conMap
, litBranches = litMap
, catchAllBranch = catchAll } =
Branches{ conBranches = fmap (\ (WithArity n c) -> WithArity n $ loop (i + n 1) c) conMap
, litBranches = fmap (loop (i 1)) litMap
, catchAllBranch = fmap (loop i) catchAll
}
in loop i cc
isRecordCase :: Case c -> TCM (Maybe ([QName], c))
isRecordCase (Branches { conBranches = conMap
, litBranches = litMap
, catchAllBranch = Nothing })
| Map.null litMap
, [(con, WithArity _ br)] <- Map.toList conMap = do
isRC <- isRecordConstructor con
case isRC of
Just (r, Record { recFields = fs }) -> return $ Just (map unArg fs, br)
Just (r, _) -> __IMPOSSIBLE__
Nothing -> return Nothing
isRecordCase _ = return Nothing
data RecordSplitNode = RecordSplitNode
{ splitCon :: QName
, splitArity :: Int
, splitRecordPattern :: Bool
}
type RecordSplitTree = SplitTree' RecordSplitNode
type RecordSplitTrees = SplitTrees' RecordSplitNode
recordSplitTree :: SplitTree -> TCM RecordSplitTree
recordSplitTree t = snd <$> loop t
where
loop :: SplitTree -> TCM ([Bool], RecordSplitTree)
loop t = case t of
SplittingDone n -> return (replicate n True, SplittingDone n)
SplitAt i ts -> do
(xs, ts) <- loops i ts
return (xs, SplitAt i ts)
loops :: Int -> SplitTrees -> TCM ([Bool], RecordSplitTrees)
loops i ts = do
(xss, ts) <- unzip <$> do
forM ts $ \ (c, t) -> do
(xs, t) <- loop t
(isRC, n) <- either (False,) ((True,) . size) <$> getConstructorArity c
let (xs0, rest) = genericSplitAt i xs
(xs1, xs2) = genericSplitAt n rest
x = isRC && and xs1
xs' = xs0 ++ x : xs2
return (xs, (RecordSplitNode c n x, t))
return (foldl1 (zipWith (&&)) xss, ts)
translateSplitTree :: SplitTree -> TCM SplitTree
translateSplitTree t = snd <$> loop t
where
loop :: SplitTree -> TCM ([Bool], SplitTree)
loop t = case t of
SplittingDone n ->
return (replicate n True, SplittingDone n)
SplitAt i ts -> do
(x, xs, ts) <- loops i ts
let t' = if x then
case ts of
[(c,t)] -> t
_ -> __IMPOSSIBLE__
else SplitAt i ts
return (xs, t')
loops :: Int -> SplitTrees -> TCM (Bool, [Bool], SplitTrees)
loops i ts = do
(rs, xss, ts) <- unzip3 <$> do
forM ts $ \ (c, t) -> do
(xs, t) <- loop t
(isRC, n) <- either (False,) ((True,) . size) <$> getConstructorArity c
let (xs0, rest) = genericSplitAt i xs
(xs1, xs2) = genericSplitAt n rest
x = isRC && and xs1
xs' = xs0 ++ x : xs2
t' = if x then dropFrom i (n 1) t else t
return (x, xs', (c, t'))
let x = and rs
if x then unless (rs == [True]) __IMPOSSIBLE__
else unless (or rs == False) __IMPOSSIBLE__
return (x, conjColumns xss, ts)
class DropFrom a where
dropFrom :: Int -> Int -> a -> a
instance DropFrom (SplitTree' c) where
dropFrom i n t = case t of
SplittingDone m -> SplittingDone (m n)
SplitAt j ts
| j >= i + n -> SplitAt (j n) $ dropFrom i n ts
| j < i -> SplitAt j $ dropFrom i n ts
| otherwise -> __IMPOSSIBLE__
instance DropFrom (c, SplitTree' c) where
dropFrom i n (c, t) = (c, dropFrom i n t)
instance DropFrom a => DropFrom [a] where
dropFrom i n ts = map (dropFrom i n) ts
translateRecordPatterns :: Clause -> TCM Clause
translateRecordPatterns clause = do
(ps, s, cs) <- runRecPatM $ translatePatterns $ clausePats clause
let
noNewPatternVars = size cs
s' = reverse s
mkSub s = s ++# raiseS noNewPatternVars
rhsSubst = mkSub s'
rhsSubst' = mkSub $ permute (reverseP $ clausePerm clause) s'
flattenedOldTel =
permute (invertP $ compactP $ clausePerm clause) $
zip (teleNames $ clauseTel clause) $
flattenTel $
clauseTel clause
newTel' =
map (fmap (id *** applySubst rhsSubst')) $
translateTel cs $
flattenedOldTel
newPerm = adjustForDotPatterns $
reorderTel_ $ map (maybe dummy snd) newTel'
where
dummy = dummyDom
isDotP n = case genericIndex cs n of
Left DotP{} -> True
_ -> False
adjustForDotPatterns (Perm n is) =
Perm n (filter (not . isDotP) is)
lhsSubst' = permToSubst (reverseP newPerm)
lhsSubst = applySubst lhsSubst' rhsSubst'
newTel =
uncurry unflattenTel . unzip $
map (maybe __IMPOSSIBLE__ id) $
permute newPerm $
map (fmap (id *** applySubst lhsSubst')) $
newTel'
c = clause
{ clauseTel = newTel
, clausePerm = newPerm
, clausePats = applySubst lhsSubst ps
, clauseBody = translateBody cs rhsSubst $ clauseBody clause
}
reportSDoc "tc.lhs.recpat" 10 $
escapeContext (size $ clauseTel clause) $ vcat
[ text "Translated clause:"
, nest 2 $ vcat
[ text "delta =" <+> prettyTCM (clauseTel c)
, text "perm =" <+> text (show $ clausePerm c)
, text "ps =" <+> text (show $ clausePats c)
, text "body =" <+> text (show $ clauseBody c)
, text "body =" <+> prettyTCM (clauseBody c)
]
]
return c
newtype RecPatM a = RecPatM (TCMT (ReaderT Nat (StateT Nat IO)) a)
deriving (Functor, Applicative, Monad,
MonadIO, MonadTCM,
MonadReader TCEnv, MonadState TCState)
runRecPatM :: RecPatM a -> TCM a
runRecPatM (RecPatM m) =
mapTCMT (\m -> do
(x, noVars) <- mfix $ \ ~(_, noVars) ->
runStateT (runReaderT m noVars) 0
return x)
m
nextVar :: RecPatM (Pattern, Term)
nextVar = RecPatM $ do
n <- lift get
lift $ put $ succ n
noVars <- lift ask
return (VarP "r", Var (noVars n 1) [])
data Kind = VarPat | DotPat
deriving Eq
type Changes = [Either Pattern (Kind -> Nat, String, Dom Type)]
data RecordTree
= Leaf Pattern
| RecCon (Arg Type) [(Term -> Term, RecordTree)]
projections :: RecordTree -> [(Term -> Term, Kind)]
projections (Leaf (DotP{})) = [(id, DotPat)]
projections (Leaf (VarP{})) = [(id, VarPat)]
projections (Leaf _) = __IMPOSSIBLE__
projections (RecCon _ args) =
concatMap (\(p, t) -> map (\(p', k) -> (p' . p, k))
(projections t))
args
removeTree :: RecordTree -> RecPatM (Pattern, [Term], Changes)
removeTree tree = do
(pat, x) <- nextVar
let ps = projections tree
s = map (\(p, _) -> p x) ps
count k = genericLength $ filter ((== k) . snd) ps
return $ case tree of
Leaf p -> (p, s, [Left p])
RecCon t _ -> (pat, s, [Right (count, "r", domFromArg t)])
translatePattern :: Pattern -> RecPatM (Pattern, [Term], Changes)
translatePattern (ConP c Nothing ps) = do
(ps, s, cs) <- translatePatterns ps
return (ConP c Nothing ps, s, cs)
translatePattern p@(ConP _ (Just _) _) = do
r <- recordTree p
case r of
Left r -> r
Right t -> removeTree t
translatePattern p@VarP{} = removeTree (Leaf p)
translatePattern p@DotP{} = removeTree (Leaf p)
translatePattern p@LitP{} = return (p, [], [])
translatePatterns ::
[Arg Pattern] -> RecPatM ([Arg Pattern], [Term], Changes)
translatePatterns ps = do
(ps', ss, cs) <- unzip3 <$> mapM (translatePattern . unArg) ps
return (ps' `withArgsFrom` ps, concat ss, concat cs)
recordTree ::
Pattern ->
RecPatM (Either (RecPatM (Pattern, [Term], Changes)) RecordTree)
recordTree p@(ConP _ Nothing _) = return $ Left $ translatePattern p
recordTree (ConP c (Just t) ps) = do
rs <- mapM (recordTree . unArg) ps
case allRight rs of
Left rs ->
return $ Left $ do
(ps', ss, cs) <- unzip3 <$> mapM (either id removeTree) rs
return (ConP c (Just t) (ps' `withArgsFrom` ps),
concat ss, concat cs)
Right ts -> liftTCM $ do
t <- reduce t
fields <- getRecordTypeFields (unArg t)
let proj p = \x -> Def (unArg p) [defaultArg x]
return $ Right $ RecCon t $ zip (map proj fields) ts
recordTree p@VarP{} = return (Right (Leaf p))
recordTree p@DotP{} = return (Right (Leaf p))
recordTree p@LitP{} = return $ Left $ translatePattern p
translateTel
:: Changes
-> [(String, Dom Type)]
-> [Maybe (String, Dom Type)]
translateTel (Left (DotP{}) : rest) tel = Nothing : translateTel rest tel
translateTel (Right (n, x, t) : rest) tel = Just (x, t) :
translateTel rest
(genericDrop (n VarPat) tel)
translateTel (Left _ : rest) (t : tel) = Just t : translateTel rest tel
translateTel [] [] = []
translateTel (Left _ : _) [] = __IMPOSSIBLE__
translateTel [] (_ : _) = __IMPOSSIBLE__
translateBody :: Changes -> Substitution -> ClauseBody -> ClauseBody
translateBody _ s NoBody = NoBody
translateBody (Right (n, x, _) : rest) s b =
Bind $ Abs x $ translateBody rest s $ dropBinds n' b
where n' = sum $ map n [VarPat, DotPat]
translateBody (Left _ : rest) s (Bind b) = Bind $ fmap (translateBody rest s) b
translateBody [] s (Body t) = Body $ applySubst s t
translateBody _ _ _ = __IMPOSSIBLE__
permToSubst :: Permutation -> Substitution
permToSubst (Perm n is) =
[ makeVar i | i <- [0..n 1] ] ++# raiseS (size is)
where
makeVar i = case genericElemIndex i is of
Nothing -> __IMPOSSIBLE__
Just k -> var k
dropBinds :: Nat -> ClauseBody -> ClauseBody
dropBinds n b | n == 0 = b
dropBinds n (Bind b) | n > 0 = dropBinds (pred n) (absBody b)
dropBinds _ _ = __IMPOSSIBLE__