{-# LANGUAGE DeriveDataTypeable #-}
module Agda.TypeChecking.CompiledClause where
import Prelude hiding (null)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Semigroup hiding (Arg(..))
import Data.Data (Data)
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Generic
import Agda.Syntax.Literal
import Agda.Syntax.Position
import Agda.Utils.Null
import Agda.Utils.Pretty
import Agda.Utils.Impossible
data WithArity c = WithArity { arity :: Int, content :: c }
  deriving (Data, Functor, Foldable, Traversable, Show)
data Case c = Branches
  { projPatterns   :: Bool
    
    
  , conBranches    :: Map QName (WithArity c)
    
    
  , etaBranch      :: Maybe (ConHead, WithArity c)
    
    
  , litBranches    :: Map Literal c
    
  , catchAllBranch :: Maybe c
    
  , fallThrough :: Maybe Bool
    
  , lazyMatch :: Bool
    
    
  }
  deriving (Data, Functor, Foldable, Traversable, Show)
data CompiledClauses' a
  = Case (Arg Int) (Case (CompiledClauses' a))
    
    
    
    
  | Done [Arg ArgName] a
    
    
    
    
  | Fail
    
  deriving (Data, Functor, Traversable, Foldable, Show)
type CompiledClauses = CompiledClauses' Term
litCase :: Literal -> c -> Case c
litCase l x = Branches False Map.empty Nothing (Map.singleton l x) Nothing (Just False) False
conCase :: QName -> Bool -> WithArity c -> Case c
conCase c b x = Branches False (Map.singleton c x) Nothing Map.empty Nothing (Just b) False
etaCase :: ConHead -> WithArity c -> Case c
etaCase c x = Branches False Map.empty (Just (c, x)) Map.empty Nothing (Just False) True
projCase :: QName -> c -> Case c
projCase c x = Branches True (Map.singleton c $ WithArity 0 x) Nothing Map.empty Nothing (Just False) False
catchAll :: c -> Case c
catchAll x = Branches False Map.empty Nothing Map.empty (Just x) (Just True) False
checkLazyMatch :: Case c -> Case c
checkLazyMatch b = b { lazyMatch = lazyMatch b && requirements }
  where
    requirements = and
      [ null (catchAllBranch b)
      , Map.size (conBranches b) <= 1
      , null (litBranches b)
      , not $ projPatterns b ]
hasCatchAll :: CompiledClauses -> Bool
hasCatchAll = getAny . loop
  where
  loop cc = case cc of
    Fail{}    -> mempty
    Done{}    -> mempty
    Case _ br -> maybe (foldMap loop br) (const $ Any True) $ catchAllBranch br
hasProjectionPatterns :: CompiledClauses -> Bool
hasProjectionPatterns = getAny . loop
  where
  loop cc = case cc of
    Fail{}    -> mempty
    Done{}    -> mempty
    Case _ br -> Any (projPatterns br) <> foldMap loop br
instance Semigroup c => Semigroup (WithArity c) where
  WithArity n1 c1 <> WithArity n2 c2
    | n1 == n2  = WithArity n1 (c1 <> c2)
    | otherwise = __IMPOSSIBLE__   
instance (Semigroup c, Monoid c) => Monoid (WithArity c) where
  mempty  = WithArity __IMPOSSIBLE__ mempty
  mappend = (<>)
instance Semigroup m => Semigroup (Case m) where
  Branches cop cs eta ls m b lazy <> Branches cop' cs' eta' ls' m' b' lazy' = checkLazyMatch $
    Branches (cop || cop') 
             (Map.unionWith (<>) cs cs')
             (unionEta eta eta')
             (Map.unionWith (<>) ls ls')
             (m <> m')
             (combine b b')
             (lazy && lazy')
   where
     combine Nothing  b'        = b
     combine b        Nothing   = b
     combine (Just b) (Just b') = Just $ b && b'
     unionEta Nothing b = b
     unionEta b Nothing = b
     unionEta Just{} Just{} = __IMPOSSIBLE__
instance (Semigroup m, Monoid m) => Monoid (Case m) where
  mempty  = empty
  mappend = (<>)
instance Null (Case m) where
  empty = Branches False Map.empty Nothing Map.empty Nothing Nothing True
  null (Branches _cop cs eta ls mcatch _b _lazy) = null cs && null eta && null ls && null mcatch
instance Pretty a => Pretty (WithArity a) where
  pretty = pretty . content
instance Pretty a => Pretty (Case a) where
  prettyPrec p (Branches _cop cs eta ls m b lazy) =
    mparens (p > 0) $ prLazy lazy <+> vcat (prettyMap cs ++ prEta eta ++ prettyMap ls ++ prC m)
    where
      prLazy True  = "~"
      prLazy False = empty
      prC Nothing = []
      prC (Just x) = ["_ ->" <+> pretty x]
      prEta Nothing = []
      prEta (Just (c, cc)) = [("eta" <+> pretty c <+> "->") <?> pretty cc]
prettyMap :: (Pretty k, Pretty v) => Map k v -> [Doc]
prettyMap m = [ sep [ pretty k <+> "->"
                    , nest 2 $ pretty v ]
              | (k, v) <- Map.toList m ]
instance Pretty CompiledClauses where
  pretty (Done hs t) = ("done" <> pretty hs) <?> pretty t
  pretty Fail        = "fail"
  pretty (Case n bs) | projPatterns bs =
    sep [ "record"
        , nest 2 $ pretty bs
        ]
  pretty (Case n bs) =
    text ("case " ++ prettyShow n ++ " of") <?> pretty bs
instance KillRange c => KillRange (WithArity c) where
  killRange = fmap killRange
instance KillRange c => KillRange (Case c) where
  killRange (Branches cop con eta lit all b lazy) = Branches cop
    (killRangeMap con)
    (killRange eta)
    (killRangeMap lit)
    (killRange all)
    b lazy
instance KillRange CompiledClauses where
  killRange (Case i br) = killRange2 Case i br
  killRange (Done xs v) = killRange2 Done xs v
  killRange Fail        = Fail
instance TermLike a => TermLike (WithArity a) where
  traverseTermM = traverse . traverseTermM
  foldTerm      = foldMap . foldTerm
instance TermLike a => TermLike (Case a) where
  traverseTermM = traverse . traverseTermM
  foldTerm      = foldMap . foldTerm
instance TermLike a => TermLike (CompiledClauses' a) where
  traverseTermM = traverse . traverseTermM
  foldTerm      = foldMap . foldTerm