-- Andre Santos' Local Transformations (Ch 3 in his dissertation)
module Language.HERMIT.Primitive.Local.Case
       ( -- * Rewrites on Case Expressions
         externals
       , letFloatCase
       , caseFloatApp
       , caseFloatArg
       , caseFloatCase
       , caseFloatLet
       , caseFloat
       , caseReduce
       , caseSplit
       , caseSplitInline
       )
where


import GhcPlugins

import Data.List
import Control.Arrow
import Control.Applicative

import Language.HERMIT.GHC
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.Monad

import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC hiding (externals)
import Language.HERMIT.Primitive.Inline hiding (externals)
import Language.HERMIT.Primitive.AlphaConversion hiding (externals)

import qualified Language.Haskell.TH as TH

-- NOTE: these are hard to test in small examples, as GHC does them for us, so use with caution
------------------------------------------------------------------------------

-- | Externals relating to Case expressions.
externals :: [External]
externals =
         [ -- I'm not sure this is possible. In core, v2 can only be a Constructor, Lit, or DEFAULT
           -- In the last case, v1 is already inlined in e. So we can't construct v2 as a Var.
         --   external "case-elimination" (promoteR $ not_defined "case-elimination" :: RewriteH Core)
         --             [ "case v1 of v2 -> e ==> e[v1/v2]" ]                                         .+ Unimplemented .+ Eval
         --   -- Again, don't think the lhs of this rule is possible to construct in core.
         -- , external "default-binding-elim" (promoteR $ not_defined "default-binding-elim" :: RewriteH Core)
         --             [ "case v of ...;w -> e ==> case v of ...;w -> e[v/w]" ]                      .+ Unimplemented .+ Eval
         --   -- Again, don't think the lhs of this rule is possible to construct in core.
         -- , external "case-merging" (promoteR $ not_defined "case-merging" :: RewriteH Core)
         --             [ "case v of ...; d -> case v of alt -> e ==> case v of ...; alt -> e[v/d]" ] .+ Unimplemented .+ Eval
           external "let-float-case" (promoteExprR letFloatCase :: RewriteH Core)
                     [ "case (let v = ev in e) of ... ==> let v = ev in case e of ..." ]  .+ Commute .+ Shallow .+ Eval .+ Bash
         , external "case-float-app" (promoteExprR caseFloatApp :: RewriteH Core)
                     [ "(case ec of alt -> e) v ==> case ec of alt -> e v" ]              .+ Commute .+ Shallow .+ Bash
         , external "case-float-arg" (promoteExprR caseFloatArg :: RewriteH Core)
                     [ "f (case s of alt -> e) ==> case s of alt -> f e" ]                .+ Commute .+ Shallow .+ PreCondition
         , external "case-float-case" (promoteExprR caseFloatCase :: RewriteH Core)
                     [ "case (case ec of alt1 -> e1) of alta -> ea ==> case ec of alt1 -> case e1 of alta -> ea" ] .+ Commute .+ Eval .+ Bash
         , external "case-float-let" (promoteExprR caseFloatLet :: RewriteH Core)
                     [ "let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e" ] .+ Commute .+ Shallow .+ Bash
         , external "case-float" (promoteExprR caseFloat :: RewriteH Core)
                     [ "Float a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
         , external "case-reduce" (promoteExprR caseReduce :: RewriteH Core)
                     [ "case-of-known-constructor"
                     , "case C v1..vn of C w1..wn -> e ==> e[v1/w1..vn/wn]" ] .+ Shallow .+ Eval .+ Bash
         , external "case-split" (promoteExprR . caseSplit :: TH.Name -> RewriteH Core)
                [ "case-split 'x"
                , "e ==> case x of C1 vs -> e; C2 vs -> e, where x is free in e" ]
         , external "case-split-inline" (caseSplitInline :: TH.Name -> RewriteH Core)
                [ "Like case-split, but additionally inlines the matched constructor "
                , "applications for all occurances of the named variable." ]
         ]

-- not_defined :: String -> RewriteH CoreExpr
-- not_defined nm = fail $ nm ++ " not implemented!"

-- | case (let v = e1 in e2) of alts ==> let v = e1 in case e2 of alts
letFloatCase :: RewriteH CoreExpr
letFloatCase = prefixFailMsg "Let floating from Case failed: " $
  do
     captures <- caseT letVarsT (const (pure ())) $ \ vs _ _ _ -> vs
     cFrees   <- freeVarsT -- so we get type variables too
     caseT (if null (cFrees `intersect` captures) then idR else alphaLet)
           (const idR)
           (\ (Let bnds e) b ty alts -> Let bnds (Case e b ty alts))

-- | (case s of alt1 -> e1; alt2 -> e2) v ==> case s of alt1 -> e1 v; alt2 -> e2 v
caseFloatApp :: RewriteH CoreExpr
caseFloatApp = prefixFailMsg "Case floating from App function failed: " $
  do
    captures      <- appT caseAltVarsT freeVarsT (flip (map . intersect))
    binderCapture <- appT caseBinderVarT freeVarsT intersect
    appT ((if null binderCapture then idR else alphaCaseBinder Nothing)
          >>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
         )
          idR
          (\(Case s b _ty alts) v -> let newTy = exprType (App (case head alts of (_,_,f) -> f) v)
                                     in Case s b newTy [ (c, ids, App f v)
                                                       | (c,ids,f) <- alts ])

-- | @f (case s of alt1 -> e1; alt2 -> e2)@ ==> @case s of alt1 -> f e1; alt2 -> f e2@
--   Only safe if @f@ is strict.
caseFloatArg :: RewriteH CoreExpr
caseFloatArg = prefixFailMsg "Case floating from App argument failed: " $
  do
    captures      <- appT freeVarsT caseAltVarsT (map . intersect)
    binderCapture <- appT freeVarsT caseBinderVarT intersect
    appT idR
         ((if null binderCapture then idR else alphaCaseBinder Nothing)
          >>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
         )
         (\f (Case s b _ty alts) -> let newTy = exprType (App f (case head alts of (_,_,e) -> e))
                                    in Case s b newTy [ (c, ids, App f e)
                                                      | (c,ids,e) <- alts ])

-- | case (case s1 of alt11 -> e11; alt12 -> e12) of alt21 -> e21; alt22 -> e22
--   ==>
--   case s1 of
--     alt11 -> case e11 of alt21 -> e21; alt22 -> e22
--     alt12 -> case e12 of alt21 -> e21; alt22 -> e22
caseFloatCase :: RewriteH CoreExpr
caseFloatCase = prefixFailMsg "Case floating from Case failed: " $
  do
    captures <- caseT caseAltVarsT (const altFreeVarsT) $ \ vss bndr _ fs -> map (intersect (concatMap ($ bndr) fs)) vss
    -- does the binder of the inner case, shadow a free variable in any of the outer case alts?
    -- notice, caseBinderVarT returns a singleton list
    binderCapture <- caseT caseBinderVarT (const altFreeVarsT) $ \ innerBindr bndr _ fs -> intersect (concatMap ($ bndr) fs) innerBindr
    caseT ((if null binderCapture then idR else alphaCaseBinder Nothing)
           >>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
          )
          (const idR)
          (\ (Case s1 b1 ty1 alts1) b2 ty2 alts2 -> Case s1 b1 ty1 [ (c1, ids1, Case e1 b2 ty2 alts2) | (c1, ids1, e1) <- alts1 ])

-- | let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e
caseFloatLet :: RewriteH CoreExpr
caseFloatLet = prefixFailMsg "Case floating from Let failed: " $
  do vs <- letNonRecT caseAltVarsT idR (\ letVar caseVars _ -> elem letVar $ concat caseVars)
     let bdsAction = if not vs then idR else nonRecR alphaCase
     letT bdsAction idR $ \ (NonRec v (Case s b ty alts)) e -> Case s b ty [ (con, ids, Let (NonRec v ec) e) | (con, ids, ec) <- alts]


-- | Float a Case whatever the context.
caseFloat :: RewriteH CoreExpr
caseFloat = setFailMsg "Unsuitable expression for Case floating." $
            caseFloatApp <+ caseFloatArg <+ caseFloatCase <+ caseFloatLet

-- | Case-of-known-constructor rewrite.
caseReduce :: RewriteH CoreExpr
caseReduce = letTransform >>> tryR (repeatR letSubstR)
    where letTransform = prefixFailMsg "Case reduction failed: " $
                         withPatFailMsg (wrongExprForm "Case e v t alts") $
                         do Case s binder _ alts <- idR
                            case isDataCon s of
                              Nothing -> fail "head of scrutinee is not a data constructor."
                              Just (dc, args) -> case [ (bs, rhs) | (DataAlt dc', bs, rhs) <- alts, dc == dc' ] of
                                    [(bs,e')] -> let valArgs = filter isValArg args -- discard any type arguments
                                                  in return $ nestedLets e' $ (binder, s) : zip bs valArgs
                                    []   -> fail "no matching alternative."
                                    _    -> fail "more than one matching alternative."

-- | If expression is a constructor application, return the relevant bits.
isDataCon :: CoreExpr -> Maybe (DataCon, [CoreExpr])
isDataCon expr = case fn of
                    Var i -> do dc <- isDataConId_maybe i
                                return (dc, args)
                    _ -> fail "not a var"
    where (fn, args) = collectArgs expr

-- | We don't want to use the recursive let here, so nest a bunch of non-recursive lets
nestedLets :: CoreExpr -> [(Id, CoreExpr)] -> CoreExpr
nestedLets = foldr (\(b,rhs) -> Let $ NonRec b rhs)

-- | Case split a free variable in an expression:
--
-- Assume expression e which mentions x :: [a]
--
-- e ==> case x of x
--         [] -> e
--         (a:b) -> e
caseSplit :: TH.Name -> RewriteH CoreExpr
caseSplit nm = do
    frees <- freeIdsT
    contextfreeT $ \ e ->
        case [ i | i <- frees, cmpTHName2Id nm i ] of
            []    -> fail "caseSplit: provided name is not free"
            (i:_) -> do
                let (tycon, tys) = splitTyConApp (idType i)
                    dcs = tyConDataCons tycon
                    aNms = map (:[]) $ cycle ['a'..'z']
                dcsAndVars <- mapM (\dc -> do
                                        as <- sequence [ newVarH a ty | (a,ty) <- zip aNms $ dataConInstArgTys dc tys ]
                                        return (dc,as)) dcs
                return $ Case (Var i) i (exprType e) [ (DataAlt dc, as, e) | (dc,as) <- dcsAndVars ]

-- | Like caseSplit, but additionally inlines the constructor applications
-- for each occurance of the named variable.
--
-- > caseSplitInline nm = caseSplit nm >>> anybuR (inlineName nm)
caseSplitInline :: TH.Name -> RewriteH Core
caseSplitInline nm = promoteR (caseSplit nm) >>> anybuR (promoteExprR $ inlineName nm)