-- | General purpose tree walking boilerplate.
module DDC.Core.Transform.TransformDownX
        ( TransformDownMX(..)
        , transformDownX
        , transformDownX')
where
import DDC.Core.Module
import DDC.Core.Exp
import DDC.Core.Compounds
import DDC.Type.Env             (KindEnv, TypeEnv)
import Data.Functor.Identity
import Control.Monad
import qualified DDC.Type.Env   as Env


-- | Top-down rewrite of all core expressions in a thing.
transformDownX
        :: forall (c :: * -> * -> *) a n
        .  (Ord n, TransformDownMX Identity c)
        => (KindEnv n -> TypeEnv n -> Exp a n -> Exp a n)       
                        -- ^ The worker function is given the current kind and type environments.
        -> KindEnv n    -- ^ Initial kind environment.
        -> TypeEnv n    -- ^ Initial type environment.
        -> c a n        -- ^ Transform this thing.
        -> c a n

transformDownX f kenv tenv xx
        = runIdentity 
        $ transformDownMX 
                (\kenv' tenv' x -> return (f kenv' tenv' x)) 
                kenv tenv xx


-- | Like transformDownX, but without using environments.
transformDownX'
        :: forall (c :: * -> * -> *) a n
        .  (Ord n, TransformDownMX Identity c)
        => (Exp a n -> Exp a n)       
                        -- ^ The worker function is given the current
                        --      kind and type environments.
        -> c a n        -- ^ Transform this thing.
        -> c a n

transformDownX' f xx
        = transformDownX (\_ _ -> f) Env.empty Env.empty xx


-------------------------------------------------------------------------------
class TransformDownMX m (c :: * -> * -> *) where
 -- | Top-down monadic rewrite of all core expressions in a thing.
 transformDownMX
        :: Ord n
        => (KindEnv n -> TypeEnv n -> Exp a n -> m (Exp a n))
                        -- ^ The worker function is given the current
                        --      kind and type environments.
        -> KindEnv n    -- ^ Initial kind environment.
        -> TypeEnv n    -- ^ Initial type environment.
        -> c a n        -- ^ Transform this thing.
        -> m (c a n)


instance Monad m => TransformDownMX m Module where
 transformDownMX f kenv tenv !mm
  = do  x'    <- transformDownMX f kenv tenv $ moduleBody mm
        return  $ mm { moduleBody = x' }


instance Monad m => TransformDownMX m Exp where
 transformDownMX f kenv tenv !xx
  = {-# SCC transformDownMX #-} 
    f kenv tenv xx >>= \xx' -> 
     case xx' of
        XVar{}          -> return xx'
        XCon{}          -> return xx'

        XLAM a b x1
         -> liftM3 XLAM (return a) (return b)
                        (transformDownMX f (Env.extend b kenv) tenv x1)

        XLam a b  x1    
         -> liftM3 XLam (return a) (return b) 
                        (transformDownMX f kenv (Env.extend b tenv) x1)

        XApp a x1 x2    
         -> liftM3 XApp (return a) 
                        (transformDownMX f kenv tenv x1) 
                        (transformDownMX f kenv tenv x2)

        XLet a lts x
         -> do  lts'      <- transformDownMX f kenv tenv lts
                let kenv' = Env.extends (specBindsOfLets lts')   kenv
                let tenv' = Env.extends (valwitBindsOfLets lts') tenv
                x'        <- transformDownMX f kenv' tenv' x
                return  $ XLet a lts' x'
                
        XCase a x alts
         -> liftM3 XCase (return a)
                         (transformDownMX f kenv tenv x)
                         (mapM (transformDownMX f kenv tenv) alts)

        XCast a c x       
         -> liftM3 XCast
                        (return a) (return c)
                        (transformDownMX f kenv tenv x)

        XType{}         -> return xx'
        XWitness{}      -> return xx'


instance Monad m => TransformDownMX m Lets where
 transformDownMX f kenv tenv xx
  = case xx of
        LLet b x
         -> liftM2 LLet (return b)
                        (transformDownMX f kenv tenv x)
        
        LRec bxs
         -> do  let (bs, xs) = unzip bxs
                let tenv'    = Env.extends bs tenv
                xs'          <- mapM (transformDownMX f kenv tenv') xs
                return       $ LRec $ zip bs xs'

        LPrivate{}      -> return xx
        LWithRegion{}   -> return xx


instance Monad m => TransformDownMX m Alt where
 transformDownMX f kenv tenv alt
  = case alt of
        AAlt p@(PData _ bsArg) x
         -> let tenv'   = Env.extends bsArg tenv
            in  liftM2  AAlt (return p) 
                        (transformDownMX f kenv tenv' x)

        AAlt PDefault x
         ->     liftM2  AAlt (return PDefault)
                        (transformDownMX f kenv tenv x)