module Agda.TypeChecking.CompiledClause.Compile where
import Prelude hiding (null)
import Data.Maybe
import Data.Monoid
import qualified Data.Map as Map
import Data.List (genericReplicate, nubBy, findIndex)
import Data.Function
import Debug.Trace
import Agda.Syntax.Common
import Agda.Syntax.Internal as I
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Coverage
import Agda.TypeChecking.Coverage.SplitTree
import Agda.TypeChecking.Monad
import Agda.TypeChecking.RecordPatterns
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Pretty
import Agda.Utils.Functor
import Agda.Utils.Null
import Agda.Utils.List
#include "undefined.h"
import Agda.Utils.Impossible
compileClauses ::
Maybe (QName, Type)
-> [Clause] -> TCM CompiledClauses
compileClauses mt cs = do
let cls = [(clausePats c, clauseBody c) | c <- cs]
case mt of
Nothing -> return $ compile cls
Just (q, t) -> do
splitTree <- coverageCheck q t cs
reportSDoc "tc.cc" 30 $ sep $ do
(text "clauses patterns before compilation") : do
map (prettyTCM . map unArg . fst) cls
reportSDoc "tc.cc" 50 $ do
sep [ text "clauses before compilation"
, (nest 2 . text . show) cs
]
let cc = compileWithSplitTree splitTree cls
reportSDoc "tc.cc" 12 $ sep
[ text "compiled clauses (still containing record splits)"
, nest 2 $ text (show cc)
]
cc <- translateCompiledClauses cc
return cc
type Cl = ([I.Arg Pattern], ClauseBody)
type Cls = [Cl]
compileWithSplitTree :: SplitTree -> Cls -> CompiledClauses
compileWithSplitTree t cs = case t of
SplitAt i ts ->
let n = i
in Case n $ compiles ts $ splitOn (length ts == 1) n cs
SplittingDone n -> compile cs
where
compiles :: SplitTrees -> Case Cls -> Case CompiledClauses
compiles ts br@Branches{ projPatterns = cop
, conBranches = cons
, litBranches = lits
, catchAllBranch = Nothing }
| Map.null lits = empty { projPatterns = cop, conBranches = updCons cons }
where
updCons = Map.mapWithKey $ \ c cl ->
let t = fromMaybe __IMPOSSIBLE__ $ lookup c ts
in compileWithSplitTree t <$> cl
compiles ts br = compile <$> br
countInDotPatterns :: Int -> [Cl] -> Int
countInDotPatterns i [] = __IMPOSSIBLE__
countInDotPatterns i ((ps, _) : _) = i + loop i (map unArg ps) where
loop 0 ps = 0
loop i [] = __IMPOSSIBLE__
loop i (DotP{} : ps) = 1 + loop i ps
loop i (_ : ps) = loop (i 1) ps
compile :: Cls -> CompiledClauses
compile cs = case nextSplit cs of
Just (isRecP, n)-> Case n $ fmap compile $ splitOn isRecP n cs
Nothing -> case map (getBody . snd) cs of
Just t : _ -> Done (map (fmap name) $ fst $ head cs) (shared t)
Nothing : _ -> Fail
[] -> __IMPOSSIBLE__
where
name (VarP x) = x
name (DotP _) = underscore
name ConP{} = __IMPOSSIBLE__
name LitP{} = __IMPOSSIBLE__
name ProjP{} = __IMPOSSIBLE__
nextSplit :: Cls -> Maybe (Bool, Int)
nextSplit [] = __IMPOSSIBLE__
nextSplit ((ps, _):_) = headMaybe $ catMaybes $
zipWith (\ p n -> (,n) <$> properSplit (unArg p)) ps [0..]
properSplit :: Pattern -> Maybe Bool
properSplit (ConP _ cpi _) = Just $ isJust $ conPRecord cpi
properSplit LitP{} = Just False
properSplit ProjP{} = Just False
properSplit VarP{} = Nothing
properSplit DotP{} = Nothing
isVar :: Pattern -> Bool
isVar VarP{} = True
isVar DotP{} = True
isVar ConP{} = False
isVar LitP{} = False
isVar ProjP{} = False
splitOn :: Bool -> Int -> Cls -> Case Cls
splitOn single n cs = mconcat $ map (fmap (:[]) . splitC n) $
expandCatchAlls single n cs
splitC :: Int -> Cl -> Case Cl
splitC n (ps, b) = case unArg p of
ProjP d -> projCase d (ps0 ++ ps1, b)
ConP c _ qs -> conCase (conName c) $ WithArity (length qs) (ps0 ++ map (fmap namedThing) qs ++ ps1, b)
LitP l -> litCase l (ps0 ++ ps1, b)
VarP{} -> catchAll (ps, b)
DotP{} -> catchAll (ps, b)
where
(ps0, p, ps1) = extractNthElement' n ps
expandCatchAlls :: Bool -> Int -> Cls -> Cls
expandCatchAlls single n cs =
if single then doExpand =<< cs else
case cs of
_ | all (isCatchAllNth . fst) cs -> cs
(ps, b) : cs | not (isCatchAllNth ps) -> (ps, b) : expandCatchAlls False n cs
| otherwise -> map (expand ps b) expansions ++ (ps, b) : expandCatchAlls False n cs
_ -> __IMPOSSIBLE__
where
doExpand c@(ps, b)
| isVar $ unArg $ nth ps = map (expand ps b) expansions ++ [c]
| otherwise = [c]
isCatchAllNth ps = all (isVar . unArg) $ take 1 $ drop n ps
nth qs = headWithDefault __IMPOSSIBLE__ $ drop n qs
classify (LitP l) = Left l
classify (ConP c _ _) = Right c
classify _ = __IMPOSSIBLE__
expansions = nubBy ((==) `on` (classify . unArg))
. filter (not . isVar . unArg)
. map (nth . fst)
$ cs
expand ps b q =
case unArg q of
ConP c mt qs' -> (ps0 ++ [q $> ConP c mt conPArgs] ++ ps1,
substBody n' m (Con c conArgs) b)
where
m = length qs'
conPArgs = map (fmap ($> VarP underscore)) qs'
conArgs = zipWith (\ q n -> q $> var n) qs' $ downFrom m
LitP l -> (ps0 ++ [q $> LitP l] ++ ps1, substBody n' 0 (Lit l) b)
_ -> __IMPOSSIBLE__
where
(ps0, rest) = splitAt n ps
ps1 = maybe __IMPOSSIBLE__ snd $ uncons rest
n' = countVars ps0
countVars = sum . map (count . unArg)
count VarP{} = 1
count (ConP _ _ ps) = countVars $ map (fmap namedThing) ps
count DotP{} = 1
count _ = 0
substBody :: Int -> Int -> Term -> ClauseBody -> ClauseBody
substBody _ _ _ NoBody = NoBody
substBody 0 m v b = case b of
Bind b -> foldr (.) id (replicate m (Bind . Abs underscore)) $ subst 0 v (absBody $ raise m b)
_ -> __IMPOSSIBLE__
substBody n m v b = case b of
Bind b -> Bind $ fmap (substBody (n 1) m v) b
_ -> __IMPOSSIBLE__