-- | Common sub-expression elimination and variable hoisting for 'Lam'
-- expressions.

module Feldspar.DSL.Sharing where



import Control.Monad.State
import Data.Typeable

import Feldspar.DSL.Expression
import Feldspar.DSL.Lambda



-- | Substituting a sub-expression
substitute :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
    => Lam expr ra a  -- ^ Sub-expression to be replaced
    -> Lam expr ra a  -- ^ Replacing sub-expression
    -> Lam expr rb b  -- ^ Whole expression
    -> Lam expr rb b
substitute x y a | Just y' <- exprCast y, exprEq x a = y'
substitute x y (Lambda f) = Lambda $ \v -> substitute x y (f v)
substitute x y (f :$: a)  = substitute x y f :$: substitute x y a
substitute x y a          = a

-- | Count the number of occurrences of a sub-expression
count :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
    => Lam expr ra a  -- ^ Sub-expression
    -> Lam expr rb b  -- ^ Whole expression
    -> Integer
count a b = evalState (countM a b) 0

countM :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
    => Lam expr ra a
    -> Lam expr rb b
    -> State Integer Integer
countM a b = case exprCast b of
    Just b' -> do
        eq <- exprEqLam a (b' `asTypeOf` a)
        if eq then return 1 else countNonEq a b
    _ -> countNonEq a b

countNonEq :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
    => Lam expr ra a
    -> Lam expr rb b
    -> State Integer Integer
countNonEq a (Lambda f) = do
    v <- freshVar ""
    countM a (f v)
countNonEq a (f :$: b) = liftM2 (+) (countM a f) (countM a b)
countNonEq _ _ = return 0



data SomeLam expr = forall ra a .
    (Typeable ra, Typeable a) => SomeLam (Lam expr ra a)

-- | Custom parameters to sharing transformation. The 'necessary' predicate
-- gives a necessary condition for lifting an expression, and 'sufficient' gives
-- a sufficient condition. Note that 'necessary' takes precedence over
-- 'sufficient'.
--
-- The 'sharingPoint' field determines whether the expression is a valid point
-- for introducing a 'Let'.
data Params expr = Params
    { necessary    :: SomeLam expr -> Bool
    , sufficient   :: SomeLam expr -> Bool
    , sharingPoint :: SomeLam expr -> Bool
    }

data Env expr = Env
    { inLambda :: Bool  -- ^ Whether the current expression is inside a lambda
    , subExpr  :: Bool  -- ^ Whether the current expression is a sub-expression
    , counter  :: SomeLam expr -> Integer
        -- ^ Counting the number of occurrences of an expression in the
        -- environment
    }

simpleParams :: Params expr
simpleParams = Params
    { necessary    = const True
    , sufficient   = const False
    , sharingPoint = const True
    }

initEnv :: (ExprEq expr, Typeable ra, Typeable a)
    => Lam expr ra a
    -> Env expr
initEnv a = Env
    { inLambda = False
    , subExpr  = False
    , counter  = \(SomeLam b) -> count b a
    }

dummy = Variable ""
ph    = Variable "PLACEHOLDER"

-- | Checks whether an expression is compound
compound :: Lam expr ra a -> Bool
compound (Lambda _) = True
compound (_ :$: _)  = True
compound _          = False

-- | Checks if the expression does not contain any @"PLACEHOLDER"@ variables
independent :: Lam expr ra a -> Bool
independent (Variable ident) = ident /= "PLACEHOLDER"
independent (Lambda f)       = independent (f dummy)
independent (f :$: a)        = independent f && independent a
independent _                = True

-- | Checks whether a sub-expression in a given environment can be lifted out
liftable :: (Typeable ra, Typeable a)
    => Params expr
    -> Env expr
    -> Lam expr ra a -> Bool
liftable params env a
    =  independent a  -- Lifting dependent expressions is semantically incorrect
    && subExpr env    -- Otherwise infinite loop
    && necessary params (SomeLam a)
    && (heuristic || sufficient params (SomeLam a))
  where
    heuristic = compound a && (inLambda env || (counter env (SomeLam a) > 1))

-- | Chooses a sub-expression to lift out
choose :: (Typeable ra, Typeable a) =>
    Params expr -> Env expr -> Lam expr ra a -> Maybe (SomeLam expr)
choose par env a | liftable par env a = Just (SomeLam a)
choose par env (Lambda f) = choose par env' (f ph)
  where env' = env {inLambda = True, subExpr = True}
choose par env (f :$: a) = choose par env' f `mplus` choose par env' a
  where env' = env {subExpr = True}
choose _ _ _ = Nothing

-- | Perform common sub-expression elimination and variable hoisting
sharing :: forall expr ra a . (ExprEq expr, Typeable ra, Typeable a)
    => Params expr
    -> Lam expr ra a
    -> Lam expr ra a
sharing par a = case choose par (initEnv a) a of
    Just b | sharingPoint par (SomeLam a) -> share b
    _ -> descend par a
  where
    share :: SomeLam expr -> Lam expr ra a
    share (SomeLam b) = let_ "v" (sharing par b) f
      where f x = sharing par (substitute b x a)

descend :: (ExprEq expr, Typeable ra, Typeable a)
    => Params expr
    -> Lam expr ra a
    -> Lam expr ra a
descend params (Lambda f) = Lambda $ \v -> sharing params (f v)
descend params (f :$: a)  = sharing params f :$: sharing params a
descend _ a = a

simpleSharing :: (ExprEq expr, Typeable ra, Typeable a) =>
    Lam expr ra a -> Lam expr ra a
simpleSharing = sharing simpleParams

-- | Checks if the expression computes a function. Can be used in the parameters
-- passed to 'sharing'.
isFunction :: forall expr ra a . Typeable a => Lam expr ra a -> Bool
isFunction _ = show (typeRepTyCon $ typeOf (undefined :: a)) == "->"