{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TemplateHaskell            #-}
module Clash.Rewrite.Types where
import Control.Concurrent.Supply             (Supply, freshId)
import Control.Lens                          (use, (.=), (<<%=))
import Control.Monad
import Control.Monad.Fix                     (MonadFix (..), fix)
import Control.Monad.Reader                  (MonadReader (..))
import Control.Monad.State                   (MonadState (..))
import Control.Monad.Writer                  (MonadWriter (..))
import Data.HashMap.Strict                   (HashMap)
import Data.HashSet                          (HashSet)
import Data.IntMap.Strict                    (IntMap)
import Data.Monoid                           (Any)
import Unbound.Generics.LocallyNameless      (Fresh (..))
import Unbound.Generics.LocallyNameless.Name (Name (..))
import SrcLoc (SrcSpan)
import Clash.Core.Evaluator      (PrimEvaluator)
import Clash.Core.Term           (Term, TmName, TmOccName)
import Clash.Core.Type           (Type)
import Clash.Core.TyCon          (TyCon, TyConName, TyConOccName)
import Clash.Core.Var            (Id, TyVar)
import Clash.Driver.Types        (BindingMap, DebugLevel)
import Clash.Netlist.Types       (HWType)
import Clash.Util
data CoreContext
  = AppFun           
  | AppArg           
  | TyAppC           
  | LetBinding Id [Id] 
  | LetBody    [Id]  
  | LamBody    Id    
  | TyLamBody  TyVar 
                     
  | CaseAlt    [Id]  
                     
  | CaseScrut        
  | CastBody         
  deriving (Eq,Show)
data RewriteState extra
  = RewriteState
  { _transformCounter :: {-# UNPACK #-} !Int
  
  , _bindings         :: !BindingMap
  
  , _uniqSupply       :: !Supply
  
  , _curFun           :: (TmName,SrcSpan) 
  
  , _nameCounter      :: {-# UNPACK #-} !Int
  
  , _extra            :: !extra
  
  }
makeLenses ''RewriteState
data RewriteEnv
  = RewriteEnv
  { _dbgLevel       :: DebugLevel
  
  , _typeTranslator :: HashMap TyConOccName TyCon -> Bool -> Type
                    -> Maybe (Either String HWType)
  
  , _tcCache        :: HashMap TyConOccName TyCon
  
  , _tupleTcCache   :: IntMap TyConName
  
  , _evaluator      :: PrimEvaluator
  
  , _allowZero      :: Bool
  
  , _topEntities    :: HashSet TmOccName
  
  }
makeLenses ''RewriteEnv
newtype RewriteMonad extra a = R
  { runR :: RewriteEnv -> RewriteState extra -> (a,RewriteState extra,Any) }
instance Functor (RewriteMonad extra) where
  fmap f m = R (\r s -> case runR m r s of (a,s',w) -> (f a,s',w))
instance Applicative (RewriteMonad extra) where
  pure  = return
  (<*>) = ap
instance Monad (RewriteMonad extra) where
  return a = R (\_ s -> (a, s, mempty))
  m >>= k  = R (\r s -> case runR m r s of
                          (a,s',w) -> case runR (k a) r s' of
                                        (b,s'',w') -> let w'' = mappend w w'
                                                      in seq w'' (b,s'',w''))
instance MonadState (RewriteState extra) (RewriteMonad extra) where
  get     = R (\_ s -> (s,s,mempty))
  put s   = R (\_ _ -> ((),s,mempty))
  state f = R (\_ s -> case f s of (a,s') -> (a,s',mempty))
instance Fresh (RewriteMonad extra) where
  fresh (Fn s _) = do
    n <- nameCounter <<%= (+1)
    let n' = toInteger n
    n' `seq` return (Fn s n')
  fresh nm@(Bn {}) = return nm
instance MonadUnique (RewriteMonad extra) where
  getUniqueM = do
    sup <- use uniqSupply
    let (a,sup') = freshId sup
    uniqSupply .= sup'
    a `seq` return a
instance MonadWriter Any (RewriteMonad extra) where
  writer (a,w) = R (\_ s -> (a,s,w))
  tell   w     = R (\_ s -> ((),s,w))
  listen m     = R (\r s -> case runR m r s of (a,s',w) -> ((a,w),s',w))
  pass   m     = R (\r s -> case runR m r s of ((a,f),s',w) -> (a, s', f w))
instance MonadReader RewriteEnv (RewriteMonad extra) where
   ask       = R (\r s -> (r,s,mempty))
   local f m = R (\r s -> runR m (f r) s)
   reader f  = R (\r s -> (f r,s,mempty))
instance MonadFix (RewriteMonad extra) where
  mfix f = R (\r s -> fix $ \ ~(a,_,_) -> runR (f a) r s)
type Transform m = [CoreContext] -> Term -> m Term
type Rewrite extra = Transform (RewriteMonad extra)