{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Agda.TypeChecking.RecordPatterns
( translateRecordPatterns
, translateCompiledClauses
, translateSplitTree
, recordPatternToProjections
) where
import Control.Arrow (first, second)
import Control.Monad.Fix
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.List as List
import Data.Maybe
import qualified Data.Map as Map
import qualified Data.Traversable
import Agda.Syntax.Common
import Agda.Syntax.Internal as I
import Agda.Syntax.Internal.Pattern as I
import Agda.Syntax.Literal
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Coverage.SplitTree
import Agda.TypeChecking.EtaContract
import Agda.TypeChecking.Datatypes
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Pretty hiding (pretty)
import Agda.TypeChecking.Records
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.Interaction.Options
import Agda.Utils.Either
import Agda.Utils.Functor
import Agda.Utils.List
import qualified Agda.Utils.Map as Map
import Agda.Utils.Maybe
import Agda.Utils.Permutation hiding (dropFrom)
import Agda.Utils.Pretty (Pretty(..))
import qualified Agda.Utils.Pretty as P
import Agda.Utils.Size
#include "undefined.h"
import Agda.Utils.Impossible
recordPatternToProjections :: DeBruijnPattern -> TCM [Term -> Term]
recordPatternToProjections p =
case p of
VarP{} -> return [ \ x -> x ]
LitP{} -> typeError $ ShouldBeRecordPattern p
DotP{} -> typeError $ ShouldBeRecordPattern p
ConP c ci ps -> do
whenNothing (conPRecord ci) $
typeError $ ShouldBeRecordPattern p
let t = unArg $ fromMaybe __IMPOSSIBLE__ $ conPType ci
reportSDoc "tc.rec" 45 $ vcat
[ "recordPatternToProjections: "
, nest 2 $ "constructor pattern " <+> prettyTCM p <+> " has type " <+> prettyTCM t
]
reportSLn "tc.rec" 70 $ " type raw: " ++ show t
fields <- getRecordTypeFields t
concat <$> zipWithM comb (map proj fields) (map namedArg ps)
ProjP{} -> __IMPOSSIBLE__
IApplyP{} -> typeError $ ShouldBeRecordPattern p
DefP{} -> typeError $ ShouldBeRecordPattern p
where
proj p = (`applyE` [Proj ProjSystem $ unArg p])
comb :: (Term -> Term) -> DeBruijnPattern -> TCM [Term -> Term]
comb prj p = map (\ f -> f . prj) <$> recordPatternToProjections p
conjColumns :: [[Bool]] -> [Bool]
conjColumns = foldl1 (zipWith (&&))
insertColumn :: Int -> a -> [[a]] -> [[a]]
insertColumn i a rows = map ins rows where
ins row = let (init, last) = splitAt i row in init ++ a : last
cutSublist :: Int -> Int -> [a] -> ([a], [a], [a])
cutSublist i n row =
let (init, rest) = splitAt i row
(mid , last) = splitAt n rest
in (init, mid, last)
getEtaAndArity :: SplitTag -> TCM (Bool, Nat)
getEtaAndArity (SplitCon c) =
for (getConstructorInfo c) $ \case
DataCon n -> (False, n)
RecordCon eta fs -> (eta == YesEta, size fs)
getEtaAndArity (SplitLit l) = return (False, 0)
getEtaAndArity SplitCatchall = return (False, 1)
translateCompiledClauses :: CompiledClauses -> TCM CompiledClauses
translateCompiledClauses cc = do
reportSDoc "tc.cc.record" 20 $ vcat
[ "translate record patterns in compiled clauses"
, nest 2 $ return $ pretty cc
]
cc <- loop cc
reportSDoc "tc.cc.record" 20 $ vcat
[ "translated compiled clauses (no eta record patterns):"
, nest 2 $ return $ pretty cc
]
cc <- recordExpressionsToCopatterns cc
reportSDoc "tc.cc.record" 20 $ vcat
[ "translated compiled clauses (record expressions to copatterns):"
, nest 2 $ return $ pretty cc
]
return cc
where
loop :: CompiledClauses -> TCM (CompiledClauses)
loop cc = case cc of
Fail -> return cc
Done{} -> return cc
Case i cs -> loops i cs
loops :: Arg Int
-> Case CompiledClauses
-> TCM CompiledClauses
loops i cs@Branches{ projPatterns = comatch
, conBranches = conMap
, etaBranch = eta
, litBranches = litMap
, fallThrough = fT
, catchAllBranch = catchAll
, lazyMatch = lazy } = do
catchAll <- traverse loop catchAll
litMap <- traverse loop litMap
(conMap, eta) <- do
let noEtaCase = (, Nothing) <$> (traverse . traverse) loop conMap
yesEtaCase ch b = (Map.empty,) . Just . (ch,) <$> traverse loop b
case Map.toList conMap of
_ | Just (ch, b) <- eta -> yesEtaCase ch b
[(c, b)] | not comatch ->
getConstructorInfo c >>= \ case
RecordCon YesEta fs ->
let ch = ConHead c Inductive fs in
yesEtaCase ch b
_ -> noEtaCase
_ -> noEtaCase
return $ Case i cs{ conBranches = conMap
, etaBranch = eta
, litBranches = litMap
, fallThrough = fT
, catchAllBranch = catchAll
}
recordExpressionsToCopatterns :: CompiledClauses -> TCM CompiledClauses
recordExpressionsToCopatterns cc =
case cc of
Case i bs -> Case i <$> traverse recordExpressionsToCopatterns bs
Fail -> return cc
Done xs (Con c i es) | i == ConORec -> do
Constructor{conData = d, conArity = ar} <- theDef <$> getConstInfo (conName c)
ddef <- theDef <$> getConstInfo d
irrProj <- optIrrelevantProjections <$> pragmaOptions
getConstructorInfo (conName c) >>= \ case
RecordCon YesEta fs
| ar <- length fs, ar > 0,
length es == ar,
irrProj || not (any isIrrelevant fs) -> do
let body (Apply v) = WithArity 0 $ Done xs (unArg v)
body _ = __IMPOSSIBLE__
bs = Branches True (Map.fromList $ zip (map unArg fs) (map body es))
Nothing Map.empty Nothing Nothing False
Case (defaultArg $ length xs) <$> traverse recordExpressionsToCopatterns bs
_ -> return cc
Done{} -> return cc
replaceByProjections :: Arg Int -> [QName] -> CompiledClauses -> CompiledClauses
replaceByProjections (Arg ai i) projs cc =
let n = length projs
loop :: Int -> CompiledClauses -> CompiledClauses
loop i cc = case cc of
Case j cs
| unArg j < i -> Case j $ loops i cs
| otherwise -> Case (j <&> \ k -> k - (n-1)) $ fmap (loop i) cs
Done xs v ->
let (xs0,xs1,xs2) = cutSublist i n xs
names | null xs1 = ["r"]
| otherwise = map unArg xs1
x = Arg ai $ foldr1 appendArgNames names
xs' = xs0 ++ x : xs2
us = map (\ p -> Var 0 [Proj ProjSystem p]) (reverse projs)
index = length xs - (i + n)
in Done xs' $ applySubst (liftS (length xs2) $ us ++# raiseS 1) v
Fail -> Fail
loops :: Int -> Case CompiledClauses -> Case CompiledClauses
loops i bs@Branches{ conBranches = conMap
, litBranches = litMap
, catchAllBranch = catchAll } =
bs{ conBranches = fmap (\ (WithArity n c) -> WithArity n $ loop (i + n - 1) c) conMap
, litBranches = fmap (loop (i - 1)) litMap
, catchAllBranch = fmap (loop i) catchAll
}
in loop i cc
isRecordCase :: Case c -> TCM (Maybe ([QName], c))
isRecordCase (Branches { conBranches = conMap
, litBranches = litMap
, catchAllBranch = Nothing })
| Map.null litMap
, [(con, WithArity _ br)] <- Map.toList conMap = do
isRC <- isRecordConstructor con
case isRC of
Just (r, Record { recFields = fs }) -> return $ Just (map unArg fs, br)
Just (r, _) -> __IMPOSSIBLE__
Nothing -> return Nothing
isRecordCase _ = return Nothing
data RecordSplitNode = RecordSplitNode
{ splitTag :: SplitTag
, splitArity :: Int
, splitRecordPattern :: Bool
}
type RecordSplitTree = SplitTree' RecordSplitNode
type RecordSplitTrees = SplitTrees' RecordSplitNode
recordSplitTree :: SplitTree -> TCM RecordSplitTree
recordSplitTree t = snd <$> loop t
where
loop :: SplitTree -> TCM ([Bool], RecordSplitTree)
loop t = case t of
SplittingDone n -> return (replicate n True, SplittingDone n)
SplitAt i ts -> do
(xs, ts) <- loops (unArg i) ts
return (xs, SplitAt i ts)
loops :: Int -> SplitTrees -> TCM ([Bool], RecordSplitTrees)
loops i ts = do
(xss, ts) <- unzip <$> do
forM ts $ \ (c, t) -> do
(xs, t) <- loop t
(isRC, n) <- getEtaAndArity c
let (xs0, rest) = splitAt i xs
(xs1, xs2) = splitAt n rest
x = isRC && and xs1
xs' = xs0 ++ x : xs2
return (xs, (RecordSplitNode c n x, t))
return (foldl1 (zipWith (&&)) xss, ts)
translateSplitTree :: SplitTree -> TCM SplitTree
translateSplitTree t = snd <$> loop t
where
loop :: SplitTree -> TCM ([Bool], SplitTree)
loop t = case t of
SplittingDone n ->
return (replicate n True, SplittingDone n)
SplitAt i ts -> do
(x, xs, ts) <- loops (unArg i) ts
let t' = if x then
case ts of
[(c,t)] -> t
_ -> __IMPOSSIBLE__
else SplitAt i ts
return (xs, t')
loops :: Int -> SplitTrees -> TCM (Bool, [Bool], SplitTrees)
loops i ts = do
(rs, xss, ts) <- unzip3 <$> do
forM ts $ \ (c, t) -> do
(xs, t) <- loop t
(isRC, n) <- getEtaAndArity c
let (xs0, rest) = splitAt i xs
(xs1, xs2) = splitAt n rest
x = isRC && and xs1
xs' = xs0 ++ x : xs2
t' = if x then dropFrom i (n - 1) t else t
return (x, xs', (c, t'))
let x = and rs
if x then unless (rs == [True]) __IMPOSSIBLE__
else unless (or rs == False) __IMPOSSIBLE__
return (x, conjColumns xss, ts)
class DropFrom a where
dropFrom :: Int -> Int -> a -> a
instance DropFrom (SplitTree' c) where
dropFrom i n t = case t of
SplittingDone m -> SplittingDone (m - n)
SplitAt x@(Arg ai j) ts
| j >= i + n -> SplitAt (Arg ai $ j - n) $ dropFrom i n ts
| j < i -> SplitAt x $ dropFrom i n ts
| otherwise -> __IMPOSSIBLE__
instance DropFrom (c, SplitTree' c) where
dropFrom i n (c, t) = (c, dropFrom i n t)
instance DropFrom a => DropFrom [a] where
dropFrom i n ts = map (dropFrom i n) ts
translateRecordPatterns :: Clause -> TCM Clause
translateRecordPatterns clause = do
(ps, s, cs) <- runRecPatM $ translatePatterns $ unnumberPatVars $ namedClausePats clause
let
noNewPatternVars = size cs
s' = reverse s
mkSub s = s ++# raiseS noNewPatternVars
rhsSubst = mkSub s'
perm = fromMaybe __IMPOSSIBLE__ $ clausePerm clause
rhsSubst' = mkSub $ permute (reverseP perm) s'
flattenedOldTel =
permute (invertP __IMPOSSIBLE__ $ compactP perm) $
zip (teleNames $ clauseTel clause) $
flattenTel $
clauseTel clause
substTel = map . fmap . second . applySubst
newTel' =
substTel rhsSubst' $
translateTel cs $
flattenedOldTel
newPerm = adjustForDotPatterns $
reorderTel_ $ map (maybe __DUMMY_DOM__ snd) newTel'
where
isDotP n = case List.genericIndex cs n of
Left DotP{} -> True
_ -> False
adjustForDotPatterns (Perm n is) =
Perm n (filter (not . isDotP) is)
lhsSubst' = renaming __IMPOSSIBLE__ (reverseP newPerm)
lhsSubst = applySubst lhsSubst' rhsSubst'
newTel =
uncurry unflattenTel . unzip $
map (fromMaybe __IMPOSSIBLE__) $
permute newPerm $
substTel lhsSubst' $
newTel'
c = clause
{ clauseTel = newTel
, namedClausePats = numberPatVars __IMPOSSIBLE__ newPerm $ applySubst lhsSubst ps
, clauseBody = applySubst lhsSubst $ clauseBody clause
}
reportSDoc "tc.lhs.recpat" 20 $ vcat
[ "Original clause:"
, nest 2 $ inTopContext $ vcat
[ "delta =" <+> prettyTCM (clauseTel clause)
, "pats =" <+> text (show $ clausePats clause)
]
, "Intermediate results:"
, nest 2 $ vcat
[ "ps =" <+> text (show ps)
, "s =" <+> prettyTCM s
, "cs =" <+> prettyTCM cs
, "flattenedOldTel =" <+> (text . show) flattenedOldTel
, "newTel' =" <+> (text . show) newTel'
, "newPerm =" <+> prettyTCM newPerm
]
]
reportSDoc "tc.lhs.recpat" 20 $ vcat
[ "lhsSubst' =" <+> (text . show) lhsSubst'
, "lhsSubst =" <+> (text . show) lhsSubst
, "newTel =" <+> prettyTCM newTel
]
reportSDoc "tc.lhs.recpat" 10 $
escapeContext (size $ clauseTel clause) $ vcat
[ "Translated clause:"
, nest 2 $ vcat
[ "delta =" <+> prettyTCM (clauseTel c)
, "ps =" <+> text (show $ clausePats c)
, "body =" <+> text (show $ clauseBody c)
, "body =" <+> addContext (clauseTel c) (maybe "_|_" prettyTCM (clauseBody c))
]
]
return c
newtype RecPatM a = RecPatM (TCMT (ReaderT Nat (StateT Nat IO)) a)
deriving (Functor, Applicative, Monad,
MonadIO, MonadTCM, HasOptions, MonadDebug,
MonadTCEnv, MonadTCState)
runRecPatM :: RecPatM a -> TCM a
runRecPatM (RecPatM m) =
mapTCMT (\m -> do
(x, noVars) <- mfix $ \ ~(_, noVars) ->
runStateT (runReaderT m noVars) 0
return x)
m
nextVar :: RecPatM (Pattern, Term)
nextVar = RecPatM $ do
n <- lift get
lift $ put $ succ n
noVars <- lift ask
return (varP "r", var $ noVars - n - 1)
data Kind = VarPat | DotPat
deriving Eq
type Change = Either Pattern (Kind -> Nat, ArgName, Dom Type)
type Changes = [Change]
instance Pretty (Kind -> Nat) where
pretty f =
"(VarPat:" P.<+> P.text (show $ f VarPat) P.<+>
"DotPat:" P.<+> P.text (show $ f DotPat) P.<> ")"
instance PrettyTCM (Kind -> Nat) where
prettyTCM = return . pretty
instance PrettyTCM Change where
prettyTCM (Left p) = prettyTCM p
prettyTCM (Right (f, x, t)) = "Change" <+> prettyTCM f <+> text x <+> prettyTCM t
data RecordTree
= Leaf Pattern
| RecCon (Arg Type) [(Term -> Term, RecordTree)]
projections :: RecordTree -> [(Term -> Term, Kind)]
projections (Leaf (DotP{})) = [(id, DotPat)]
projections (Leaf (VarP{})) = [(id, VarPat)]
projections (Leaf _) = __IMPOSSIBLE__
projections (RecCon _ args) =
concatMap (\ (p, t) -> map (first (. p)) $ projections t)
args
removeTree :: RecordTree -> RecPatM (Pattern, [Term], Changes)
removeTree tree = do
(pat, x) <- nextVar
let ps = projections tree
s = map (\(p, _) -> p x) ps
count k = length $ filter ((== k) . snd) ps
return $ case tree of
Leaf p -> (p, s, [Left p])
RecCon t _ -> (pat, s, [Right (count, "r", domFromArg t)])
translatePattern :: Pattern -> RecPatM (Pattern, [Term], Changes)
translatePattern p@(ConP c ci ps)
| Just PatOSystem <- conPRecord ci = do
r <- recordTree p
case r of
Left r -> r
Right t -> removeTree t
| otherwise = do
(ps, s, cs) <- translatePatterns ps
return (ConP c ci ps, s, cs)
translatePattern p@(DefP o q ps) = do
(ps, s, cs) <- translatePatterns ps
return (DefP o q ps, s, cs)
translatePattern p@VarP{} = removeTree (Leaf p)
translatePattern p@DotP{} = removeTree (Leaf p)
translatePattern p@LitP{} = return (p, [], [])
translatePattern p@ProjP{}= return (p, [], [])
translatePattern p@IApplyP{}= return (p, [], [])
translatePatterns :: [NamedArg Pattern] -> RecPatM ([NamedArg Pattern], [Term], Changes)
translatePatterns ps = do
(ps', ss, cs) <- unzip3 <$> mapM (translatePattern . namedArg) ps
return (zipWith (\p -> fmap (p <$)) ps' ps, concat ss, concat cs)
recordTree ::
Pattern ->
RecPatM (Either (RecPatM (Pattern, [Term], Changes)) RecordTree)
recordTree p@(ConP c ci ps) | Just PatOSystem <- conPRecord ci = do
let t = fromMaybe __IMPOSSIBLE__ $ conPType ci
rs <- mapM (recordTree . namedArg) ps
case allRight rs of
Nothing ->
return $ Left $ do
(ps', ss, cs) <- unzip3 <$> mapM (either id removeTree) rs
return (ConP c ci (ps' `withNamedArgsFrom` ps),
concat ss, concat cs)
Just ts -> liftTCM $ do
t <- reduce t
reportSDoc "tc.rec" 45 $ vcat
[ "recordTree: "
, nest 2 $ "constructor pattern " <+> prettyTCM p <+> " has type " <+> prettyTCM t
]
fields <- getRecordTypeFields =<< reduce (unArg t)
let proj p = (`applyE` [Proj ProjSystem $ unArg p])
return $ Right $ RecCon t $ zip (map proj fields) ts
recordTree p@(ConP _ ci _) = return $ Left $ translatePattern p
recordTree p@DefP{} = return $ Left $ translatePattern p
recordTree p@VarP{} = return (Right (Leaf p))
recordTree p@DotP{} = return (Right (Leaf p))
recordTree p@LitP{} = return $ Left $ translatePattern p
recordTree p@ProjP{}= return $ Left $ translatePattern p
recordTree p@IApplyP{}= return $ Left $ translatePattern p
translateTel
:: Changes
-> [(ArgName, Dom Type)]
-> [Maybe (ArgName, Dom Type)]
translateTel (Left (DotP{}) : rest) tel = Nothing : translateTel rest tel
translateTel (Right (n, x, t) : rest) tel = Just (x, t) :
translateTel rest
(drop (n VarPat) tel)
translateTel (Left _ : rest) (t : tel) = Just t : translateTel rest tel
translateTel [] [] = []
translateTel (Left _ : _) [] = __IMPOSSIBLE__
translateTel [] (_ : _) = __IMPOSSIBLE__