module Agda.TypeChecking.CompiledClause.Compile where
import Prelude hiding (null)
import Data.Maybe
import Data.Monoid
import qualified Data.Map as Map
import Data.List (genericReplicate, nubBy, findIndex)
import Data.Function
import Debug.Trace
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Pattern
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Coverage
import Agda.TypeChecking.Coverage.SplitTree
import Agda.TypeChecking.Monad
import Agda.TypeChecking.RecordPatterns
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Pretty (prettyTCM, nest, sep, text)
import Agda.Utils.Functor
import Agda.Utils.Maybe
import Agda.Utils.Null
import Agda.Utils.List
import Agda.Utils.Pretty (Pretty(..), prettyShow)
import qualified Agda.Utils.Pretty as P
#include "undefined.h"
import Agda.Utils.Impossible
compileClauses ::
Maybe (QName, Type)
-> [Clause] -> TCM CompiledClauses
compileClauses mt cs = do
let cls = [ Cl (unnumberPatVars $ clausePats c) (clauseBody c) | c <- cs ]
shared <- sharedFun
case mt of
Nothing -> return $ compile shared cls
Just (q, t) -> do
splitTree <- coverageCheck q t cs
reportSDoc "tc.cc" 30 $ sep $ do
(text "clauses patterns before compilation") : do
map (prettyTCM . map unArg . clPats) cls
reportSDoc "tc.cc" 50 $ do
sep [ text "clauses before compilation"
, (nest 2 . text . show) cs
]
let cc = compileWithSplitTree shared splitTree cls
reportSDoc "tc.cc" 12 $ sep
[ text "compiled clauses (still containing record splits)"
, nest 2 $ text (show cc)
]
cc <- translateCompiledClauses cc
return cc
data Cl = Cl
{ clPats :: [Arg Pattern]
, clBody :: ClauseBody
} deriving (Show)
instance Pretty Cl where
pretty (Cl ps b) = P.prettyList ps P.<+> P.text "->" P.<+> pretty b
type Cls = [Cl]
compileWithSplitTree :: (Term -> Term) -> SplitTree -> Cls -> CompiledClauses
compileWithSplitTree shared t cs = case t of
SplitAt i ts -> Case i $ compiles ts $ splitOn (length ts == 1) (unArg i) cs
SplittingDone n -> compile shared cs
where
compiles :: SplitTrees -> Case Cls -> Case CompiledClauses
compiles ts br@Branches{ projPatterns = cop
, conBranches = cons
, litBranches = lits
, catchAllBranch = catchAll }
= Branches
{ projPatterns = cop
, conBranches = updCons cons
, litBranches = compile shared <$> lits
, catchAllBranch = compile shared <$> catchAll
}
where
updCons = Map.mapWithKey $ \ c cl ->
caseMaybe (lookup c ts) (compile shared) (compileWithSplitTree shared) <$> cl
compile :: (Term -> Term) -> Cls -> CompiledClauses
compile shared cs = case nextSplit cs of
Just (isRecP, n)-> Case n $ fmap (compile shared) $ splitOn isRecP (unArg n) cs
Nothing -> case map (getBody . clBody) cs of
Just t : _ -> Done (map (fmap name) $ clPats $ head cs) (shared t)
Nothing : _ -> Fail
[] -> __IMPOSSIBLE__
where
name (VarP x) = x
name (DotP _) = underscore
name ConP{} = __IMPOSSIBLE__
name LitP{} = __IMPOSSIBLE__
name ProjP{} = __IMPOSSIBLE__
nextSplit :: Cls -> Maybe (Bool, Arg Int)
nextSplit [] = __IMPOSSIBLE__
nextSplit (Cl ps _ : _) = headMaybe $ catMaybes $
zipWith (\ (Arg ai p) n -> (, Arg ai n) <$> properSplit p) ps [0..]
properSplit :: Pattern -> Maybe Bool
properSplit (ConP _ cpi _) = Just $ isJust $ conPRecord cpi
properSplit LitP{} = Just False
properSplit ProjP{} = Just False
properSplit VarP{} = Nothing
properSplit DotP{} = Nothing
isVar :: Pattern -> Bool
isVar VarP{} = True
isVar DotP{} = True
isVar ConP{} = False
isVar LitP{} = False
isVar ProjP{} = False
splitOn :: Bool -> Int -> Cls -> Case Cls
splitOn single n cs = mconcat $ map (fmap (:[]) . splitC n) $
expandCatchAlls single n cs
splitC :: Int -> Cl -> Case Cl
splitC n (Cl ps b) = case unArg p of
ProjP d -> projCase d $ Cl (ps0 ++ ps1) b
ConP c _ qs -> conCase (conName c) $ WithArity (length qs) $
Cl (ps0 ++ map (fmap namedThing) qs ++ ps1) b
LitP l -> litCase l $ Cl (ps0 ++ ps1) b
VarP{} -> catchAll $ Cl ps b
DotP{} -> catchAll $ Cl ps b
where
(ps0, p, ps1) = extractNthElement' n ps
expandCatchAlls :: Bool -> Int -> Cls -> Cls
expandCatchAlls single n cs =
if single then doExpand =<< cs else
case cs of
_ | all (isCatchAllNth . clPats) cs -> cs
Cl ps b : cs | not (isCatchAllNth ps) -> Cl ps b : expandCatchAlls False n cs
| otherwise -> map (expand ps b) expansions ++ Cl ps b : expandCatchAlls False n cs
_ -> __IMPOSSIBLE__
where
doExpand c@(Cl ps b)
| isVar $ unArg $ nth ps = map (expand ps b) expansions ++ [c]
| otherwise = [c]
isCatchAllNth ps = all (isVar . unArg) $ take 1 $ drop n ps
nth qs = headWithDefault __IMPOSSIBLE__ $ drop n qs
classify (LitP l) = Left l
classify (ConP c _ _) = Right c
classify _ = __IMPOSSIBLE__
expansions = nubBy ((==) `on` (classify . unArg))
. filter (not . isVar . unArg)
. map (nth . clPats)
$ cs
expand ps b q =
case unArg q of
ConP c mt qs' -> Cl (ps0 ++ [q $> ConP c mt conPArgs] ++ ps1)
(substBody n' m (Con c conArgs) b)
where
m = length qs'
conPArgs = map (fmap ($> VarP underscore)) qs'
conArgs = zipWith (\ q n -> q $> var n) qs' $ downFrom m
LitP l -> Cl (ps0 ++ [q $> LitP l] ++ ps1) (substBody n' 0 (Lit l) b)
_ -> __IMPOSSIBLE__
where
(ps0, rest) = splitAt n ps
ps1 = maybe __IMPOSSIBLE__ snd $ uncons rest
n' = countVars ps0
countVars = sum . map (count . unArg)
count VarP{} = 1
count (ConP _ _ ps) = countVars $ map (fmap namedThing) ps
count DotP{} = 1
count _ = 0
substBody :: Int -> Int -> Term -> ClauseBody -> ClauseBody
substBody _ _ _ NoBody = NoBody
substBody 0 m v b = case b of
Bind b -> foldr (.) id (replicate m (Bind . Abs underscore)) $ subst 0 v (absBody $ raise m b)
_ -> __IMPOSSIBLE__
substBody n m v b = case b of
Bind b -> Bind $ fmap (substBody (n 1) m v) b
_ -> __IMPOSSIBLE__