module Agda.TypeChecking.CompiledClause.Compile where
import Prelude hiding (null)
import Control.Applicative
import Control.Monad
import Control.Monad.Trans.Identity
import Data.Maybe
import qualified Data.Map as Map
import Data.List (nubBy)
import Data.Function
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Pattern
import Agda.TypeChecking.CompiledClause
import Agda.TypeChecking.Coverage
import Agda.TypeChecking.Coverage.SplitTree
import Agda.TypeChecking.Forcing
import Agda.TypeChecking.Monad
import Agda.TypeChecking.RecordPatterns
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Free.Precompute
import Agda.TypeChecking.Reduce
import Agda.Utils.Functor
import Agda.Utils.Maybe
import Agda.Utils.List
import qualified Agda.Utils.Pretty as P
import Agda.Utils.Update
import Agda.Utils.Impossible
data RunRecordPatternTranslation = RunRecordPatternTranslation | DontRunRecordPatternTranslation
  deriving (Eq)
compileClauses' :: RunRecordPatternTranslation -> [Clause] -> Maybe SplitTree -> TCM CompiledClauses
compileClauses' recpat cs mSplitTree = do
  
  let notUnreachable = (Just True /=) . clauseUnreachable
  cs <- map unBruijn <$> normaliseProjP (filter notUnreachable cs)
  let translate | recpat == RunRecordPatternTranslation = runIdentityT . translateCompiledClauses
                | otherwise                             = return
  translate $ caseMaybe mSplitTree (compile cs) $ \splitTree ->
    compileWithSplitTree splitTree cs
compileClauses ::
  Maybe (QName, Type) 
  -> [Clause]
  -> TCM (Maybe SplitTree, Bool, CompiledClauses)
     
compileClauses mt cs = do
  
  
  case mt of
    Nothing -> (Nothing,False,) . compile . map unBruijn <$> normaliseProjP cs
    Just (q, t)  -> do
      splitTree <- coverageCheck q t cs
      reportSDoc "tc.cc.tree" 20 $ vcat
        [ "split tree of " <+> prettyTCM q <+> " from coverage check "
        , return $ P.pretty splitTree
        ]
      
      
      let notUnreachable = (Just True /=) . clauseUnreachable
      cs <- normaliseProjP =<< instantiateFull =<< filter notUnreachable . defClauses <$> getConstInfo q
      let cls = map unBruijn cs
      reportSDoc "tc.cc" 30 $ sep $ do
        ("clauses patterns of " <+> prettyTCM q <+> " before compilation") : do
          map (prettyTCM . map unArg . clPats) cls
      reportSDoc "tc.cc" 50 $
        "clauses of " <+> prettyTCM q <+> " before compilation" <?> pretty cs
      let cc = compileWithSplitTree splitTree cls
      reportSDoc "tc.cc" 20 $ sep
        [ "compiled clauses of " <+> prettyTCM q <+> " (still containing record splits)"
        , nest 2 $ return $ P.pretty cc
        ]
      (cc, becameCopatternLHS) <- runChangeT $ translateCompiledClauses cc
      reportSDoc "tc.cc" 12 $ sep
        [ "compiled clauses of " <+> prettyTCM q
        , nest 2 $ return $ P.pretty cc
        ]
      return (Just splitTree, becameCopatternLHS, fmap precomputeFreeVars_ cc)
data Cl = Cl
  { clPats :: [Arg Pattern]
      
  , clBody :: Maybe Term
  } deriving (Show)
instance P.Pretty Cl where
  pretty (Cl ps b) = P.prettyList ps P.<+> "->" P.<+> maybe "_|_" P.pretty b
type Cls = [Cl]
unBruijn :: Clause -> Cl
unBruijn c = Cl (applySubst sub $ (map . fmap) (fmap dbPatVarName . namedThing) $ namedClausePats c)
                (applySubst sub $ clauseBody c)
  where
    sub = renamingR $ fromMaybe __IMPOSSIBLE__ (clausePerm c)
compileWithSplitTree :: SplitTree -> Cls -> CompiledClauses
compileWithSplitTree t cs = case t of
  SplitAt i lz ts -> Case i $ compiles lz ts $ splitOn (length ts == 1) (unArg i) cs
        
        
        
  SplittingDone n -> compile cs
    
  where
    compiles :: LazySplit -> SplitTrees -> Case Cls -> Case CompiledClauses
    compiles lz ts br@Branches{ projPatterns = cop
                              , conBranches = cons
                              , etaBranch   = Nothing
                              , litBranches = lits
                              , fallThrough = fT
                              , catchAllBranch = catchAll
                              , lazyMatch = lazy }
      = br{ conBranches    = updCons cons
          , etaBranch      = Nothing
          , litBranches    = updLits lits
          , fallThrough    = fT
          , catchAllBranch = updCatchall catchAll
          , lazyMatch      = lazy || lz == LazySplit
          }
      where
        updCons = Map.mapWithKey $ \ c cl ->
         caseMaybe (lookup (SplitCon c) ts) compile compileWithSplitTree <$> cl
         
        updLits = Map.mapWithKey $ \ l cl ->
          caseMaybe (lookup (SplitLit l) ts) compile compileWithSplitTree cl
        updCatchall = fmap $ caseMaybe (lookup SplitCatchall ts) compile compileWithSplitTree
    compiles _ _ Branches{etaBranch = Just{}} = __IMPOSSIBLE__  
compile :: Cls -> CompiledClauses
compile [] = Fail
compile cs = case nextSplit cs of
  Just (isRecP, n) -> Case n $ fmap compile $ splitOn isRecP (unArg n) cs
  Nothing -> case clBody c of
    
    
    Just t  -> Done (map (fmap name) $ clPats c) t
    Nothing -> Fail
  where
    
    c = headWithDefault __IMPOSSIBLE__ cs
    name (VarP _ x) = x
    name (DotP _ _) = underscore
    name ConP{}  = __IMPOSSIBLE__
    name DefP{}  = __IMPOSSIBLE__
    name LitP{}  = __IMPOSSIBLE__
    name ProjP{} = __IMPOSSIBLE__
    name (IApplyP _ _ _ x) = x
nextSplit :: Cls -> Maybe (Bool, Arg Int)
nextSplit []             = __IMPOSSIBLE__
nextSplit (Cl ps _ : cs) = findSplit nonLazy ps <|> findSplit allAgree ps
  where
    nonLazy _ (ConP _ cpi _) = not $ conPLazy cpi
    nonLazy _ _              = True
    findSplit okPat ps = listToMaybe (catMaybes $
      zipWith (\ (Arg ai p) n -> (, Arg ai n) <$> properSplit p <* guard (okPat n p)) ps [0..])
    allAgree i (ConP c _ _) = all ((== Just (conName c)) . getCon . map unArg . drop i . clPats) cs
    allAgree _ _            = False
    getCon (ConP c _ _ : _) = Just $ conName c
    getCon _                = Nothing
properSplit :: Pattern' a -> Maybe Bool
properSplit (ConP _ cpi _) = Just ((conPRecord cpi && patOrigin (conPInfo cpi) == PatORec) || conPFallThrough cpi)
properSplit DefP{}    = Just False
properSplit LitP{}    = Just False
properSplit ProjP{}   = Just False
properSplit IApplyP{} = Nothing
properSplit VarP{}    = Nothing
properSplit DotP{}    = Nothing
isVar :: Pattern' a -> Bool
isVar IApplyP{} = True
isVar VarP{}    = True
isVar DotP{}    = True
isVar ConP{}    = False
isVar DefP{}    = False
isVar LitP{}    = False
isVar ProjP{}   = False
splitOn :: Bool -> Int -> Cls -> Case Cls
splitOn single n cs = mconcat $ map (fmap (:[]) . splitC n) $
  
    expandCatchAlls single n cs
splitC :: Int -> Cl -> Case Cl
splitC n (Cl ps b) = caseMaybe mp fallback $ \case
  ProjP _ d   -> projCase d $ Cl (ps0 ++ ps1) b
  IApplyP{}   -> fallback
  ConP c i qs -> (conCase (conName c) (conPFallThrough i) $ WithArity (length qs) $
                   Cl (ps0 ++ map (fmap namedThing) qs ++ ps1) b) { lazyMatch = conPLazy i }
  DefP o q qs -> (conCase q False $ WithArity (length qs) $
                   Cl (ps0 ++ map (fmap namedThing) qs ++ ps1) b) { lazyMatch = False }
  LitP _ l    -> litCase l $ Cl (ps0 ++ ps1) b
  VarP{}      -> fallback
  DotP{}      -> fallback
  where
    (ps0, rest) = splitAt n ps
    mp          = unArg <$> listToMaybe rest
    ps1         = drop 1 rest
    fallback    = catchAll $ Cl ps b
expandCatchAlls :: Bool -> Int -> Cls -> Cls
expandCatchAlls single n cs =
  
  
  
  if single then doExpand =<< cs else
  case cs of
  _                | all (isCatchAllNth . clPats) cs -> cs
  c@(Cl ps b) : cs | not (isCatchAllNth ps) -> c : expandCatchAlls False n cs
                   | otherwise -> map (expand c) expansions ++ c : expandCatchAlls False n cs
  _ -> __IMPOSSIBLE__
  where
    
    
    
    
    doExpand c@(Cl ps _)
      | exCatchAllNth ps = map (expand c) expansions ++ [c]
      | otherwise = [c]
    
    isCatchAllNth ps = all (isVar . unArg) $ take 1 $ drop n ps
    
    exCatchAllNth ps = any (isVar . unArg) $ take 1 $ drop n ps
    classify (LitP _ l)   = Left l
    classify (ConP c _ _) = Right (Left c)
    classify (DefP _ q _) = Right (Right q)
    classify _            = __IMPOSSIBLE__
    
    
    expansions = nubOn (classify . unArg . snd)
               . mapMaybe (notVarNth . clPats)
               $ cs
    notVarNth
      :: [Arg Pattern]
      -> Maybe ([Arg Pattern]  
               , Arg Pattern)  
    notVarNth ps = do
      let (ps1, ps2) = splitAt n ps
      p <- listToMaybe ps2
      guard $ not $ isVar $ unArg p
      return (ps1, p)
    expand cl (qs, q) =
      case unArg q of
        ConP c mt qs' -> Cl (ps0 ++ [q $> ConP c mt conPArgs] ++ ps1)
                            (substBody n' m (Con c ci (map Apply conArgs)) b)
          where
            ci       = fromConPatternInfo mt
            m        = length qs'
            
            
            conPArgs = map (fmap ($> varP "_")) qs'
            conArgs  = zipWith (\ q' i -> q' $> var i) qs' $ downFrom m
        LitP i l -> Cl (ps0 ++ [q $> LitP i l] ++ ps1) (substBody n' 0 (Lit l) b)
        DefP o d qs' -> Cl (ps0 ++ [q $> DefP o d conPArgs] ++ ps1)
                            (substBody n' m (Def d (map Apply conArgs)) b)
          where
            m        = length qs'
            
            conPArgs = map (fmap ($> varP "_")) qs'
            conArgs  = zipWith (\ q' i -> q' $> var i) qs' $ downFrom m
        _ -> __IMPOSSIBLE__
      where
        
        
        
        Cl ps b = ensureNPatterns (n + 1) (map getArgInfo $ qs ++ [q]) cl
        
        (ps0, _:ps1) = splitAt n ps
        n' = countPatternVars ps1
ensureNPatterns :: Int -> [ArgInfo] -> Cl -> Cl
ensureNPatterns n ais0 cl@(Cl ps b)
  | m <= 0    = cl
  | otherwise = Cl (ps ++ ps') (raise m b `apply` args)
  where
  k    = length ps
  ais  = drop k ais0
  
  m    = n - k
  ps'  = for ais $ \ ai -> Arg ai $ varP "_"
  args = zipWith (\ i ai -> Arg ai $ var i) (downFrom m) ais
substBody :: (Subst t a) => Int -> Int -> t -> a -> a
substBody n m v = applySubst $ liftS n $ v :# raiseS m
instance PrecomputeFreeVars a => PrecomputeFreeVars (CompiledClauses' a) where