```-- | Common sub-expression elimination and variable hoisting for 'Lam'
-- expressions.

module Feldspar.DSL.Sharing where

import Data.Typeable

import Feldspar.DSL.Expression
import Feldspar.DSL.Lambda

-- | Substituting a sub-expression
substitute :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
=> Lam expr ra a  -- ^ Sub-expression to be replaced
-> Lam expr ra a  -- ^ Replacing sub-expression
-> Lam expr rb b  -- ^ Whole expression
-> Lam expr rb b
substitute x y a | Just y' <- exprCast y, exprEq x a = y'
substitute x y (Lambda f) = Lambda \$ \v -> substitute x y (f v)
substitute x y (f :\$: a)  = substitute x y f :\$: substitute x y a
substitute x y a          = a

-- | Count the number of occurrences of a sub-expression
count :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
=> Lam expr ra a  -- ^ Sub-expression
-> Lam expr rb b  -- ^ Whole expression
-> Integer
count a b = evalState (countM a b) 0

countM :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
=> Lam expr ra a
-> Lam expr rb b
-> State Integer Integer
countM a b = case exprCast b of
Just b' -> do
eq <- exprEqLam a (b' `asTypeOf` a)
if eq then return 1 else countNonEq a b
_ -> countNonEq a b

countNonEq :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b)
=> Lam expr ra a
-> Lam expr rb b
-> State Integer Integer
countNonEq a (Lambda f) = do
v <- freshVar ""
countM a (f v)
countNonEq a (f :\$: b) = liftM2 (+) (countM a f) (countM a b)
countNonEq _ _ = return 0

data SomeLam expr = forall ra a .
(Typeable ra, Typeable a) => SomeLam (Lam expr ra a)

-- | Custom parameters to sharing transformation. The 'necessary' predicate
-- gives a necessary condition for lifting an expression, and 'sufficient' gives
-- a sufficient condition. Note that 'necessary' takes precedence over
-- 'sufficient'.
--
-- The 'sharingPoint' field determines whether the expression is a valid point
-- for introducing a 'Let'.
data Params expr = Params
{ necessary    :: SomeLam expr -> Bool
, sufficient   :: SomeLam expr -> Bool
, sharingPoint :: SomeLam expr -> Bool
}

data Env expr = Env
{ inLambda :: Bool  -- ^ Whether the current expression is inside a lambda
, subExpr  :: Bool  -- ^ Whether the current expression is a sub-expression
, counter  :: SomeLam expr -> Integer
-- ^ Counting the number of occurrences of an expression in the
-- environment
}

simpleParams :: Params expr
simpleParams = Params
{ necessary    = const True
, sufficient   = const False
, sharingPoint = const True
}

initEnv :: (ExprEq expr, Typeable ra, Typeable a)
=> Lam expr ra a
-> Env expr
initEnv a = Env
{ inLambda = False
, subExpr  = False
, counter  = \(SomeLam b) -> count b a
}

dummy = Variable ""
ph    = Variable "PLACEHOLDER"

-- | Checks whether an expression is compound
compound :: Lam expr ra a -> Bool
compound (Lambda _) = True
compound (_ :\$: _)  = True
compound _          = False

-- | Checks if the expression does not contain any @"PLACEHOLDER"@ variables
independent :: Lam expr ra a -> Bool
independent (Variable ident) = ident /= "PLACEHOLDER"
independent (Lambda f)       = independent (f dummy)
independent (f :\$: a)        = independent f && independent a
independent _                = True

-- | Checks whether a sub-expression in a given environment can be lifted out
liftable :: (Typeable ra, Typeable a)
=> Params expr
-> Env expr
-> Lam expr ra a -> Bool
liftable params env a
=  independent a  -- Lifting dependent expressions is semantically incorrect
&& subExpr env    -- Otherwise infinite loop
&& necessary params (SomeLam a)
&& (heuristic || sufficient params (SomeLam a))
where
heuristic = compound a && (inLambda env || (counter env (SomeLam a) > 1))

-- | Chooses a sub-expression to lift out
choose :: (Typeable ra, Typeable a) =>
Params expr -> Env expr -> Lam expr ra a -> Maybe (SomeLam expr)
choose par env a | liftable par env a = Just (SomeLam a)
choose par env (Lambda f) = choose par env' (f ph)
where env' = env {inLambda = True, subExpr = True}
choose par env (f :\$: a) = choose par env' f `mplus` choose par env' a
where env' = env {subExpr = True}
choose _ _ _ = Nothing

-- | Perform common sub-expression elimination and variable hoisting
sharing :: forall expr ra a . (ExprEq expr, Typeable ra, Typeable a)
=> Params expr
-> Lam expr ra a
-> Lam expr ra a
sharing par a = case choose par (initEnv a) a of
Just b | sharingPoint par (SomeLam a) -> share b
_ -> descend par a
where
share :: SomeLam expr -> Lam expr ra a
share (SomeLam b) = let_ "v" (sharing par b) f
where f x = sharing par (substitute b x a)

descend :: (ExprEq expr, Typeable ra, Typeable a)
=> Params expr
-> Lam expr ra a
-> Lam expr ra a
descend params (Lambda f) = Lambda \$ \v -> sharing params (f v)
descend params (f :\$: a)  = sharing params f :\$: sharing params a
descend _ a = a

simpleSharing :: (ExprEq expr, Typeable ra, Typeable a) =>
Lam expr ra a -> Lam expr ra a
simpleSharing = sharing simpleParams

-- | Checks if the expression computes a function. Can be used in the parameters
-- passed to 'sharing'.
isFunction :: forall expr ra a . Typeable a => Lam expr ra a -> Bool
isFunction _ = show (typeRepTyCon \$ typeOf (undefined :: a)) == "->"

```