{-# LANGUAGE OverloadedStrings #-}

module Rattus.Plugin.Strictify where
import Prelude hiding ((<>))
import Rattus.Plugin.Utils
import GhcPlugins


data SCxt = SCxt {SCxt -> SrcSpan
srcSpan :: SrcSpan, SCxt -> Bool
checkStrictData :: Bool}

-- | Transforms all functions into strict functions. If the
-- 'checkStrictData' field of the 'SCxt' argument is set to @True@,
-- then this function also checks for use of non-strict data types and
-- produces warnings if it finds any.
strictifyExpr :: SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr :: SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss (Let (NonRec CoreBndr
b CoreExpr
e1) CoreExpr
e2) = do
  CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e1
  CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e2
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e1' CoreBndr
b (CoreExpr -> Type
exprType CoreExpr
e2) [(AltCon
DEFAULT, [], CoreExpr
e2')])
strictifyExpr SCxt
ss (Case CoreExpr
e CoreBndr
b Type
t [Alt CoreBndr]
alts) = do
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
  [Alt CoreBndr]
alts' <- (Alt CoreBndr -> CoreM (Alt CoreBndr))
-> [Alt CoreBndr] -> CoreM [Alt CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(AltCon
c,[CoreBndr]
args,CoreExpr
e) -> (CoreExpr -> Alt CoreBndr)
-> CoreM CoreExpr -> CoreM (Alt CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\CoreExpr
e' -> (AltCon
c,[CoreBndr]
args,CoreExpr
e')) (SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e)) [Alt CoreBndr]
alts
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e' CoreBndr
b Type
t [Alt CoreBndr]
alts')
strictifyExpr SCxt
ss (Let (Rec [(CoreBndr, CoreExpr)]
es) CoreExpr
e) = do
  [(CoreBndr, CoreExpr)]
es' <- ((CoreBndr, CoreExpr) -> CoreM (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> CoreM [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (CoreBndr
b,CoreExpr
e) -> SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e CoreM CoreExpr
-> (CoreExpr -> CoreM (CoreBndr, CoreExpr))
-> CoreM (CoreBndr, CoreExpr)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \CoreExpr
e'-> (CoreBndr, CoreExpr) -> CoreM (CoreBndr, CoreExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBndr
b,CoreExpr
e')) [(CoreBndr, CoreExpr)]
es
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, CoreExpr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, CoreExpr)]
es') CoreExpr
e')
strictifyExpr SCxt
ss (Lam CoreBndr
b CoreExpr
e)
   | Bool -> Bool
not (CoreBndr -> Bool
isCoVar CoreBndr
b) Bool -> Bool -> Bool
&& Bool -> Bool
not (CoreBndr -> Bool
isTyVar CoreBndr
b) Bool -> Bool -> Bool
&& Type -> Bool
tcIsLiftedTypeKind(HasDebugCallStack => Type -> Type
Type -> Type
typeKind (CoreBndr -> Type
varType CoreBndr
b))
    = do
       CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
       CoreBndr
b' <- FastString -> Type -> CoreM CoreBndr
forall (m :: * -> *).
MonadUnique m =>
FastString -> Type -> m CoreBndr
mkSysLocalM (String -> FastString
fsLit String
"strict") (CoreBndr -> Type
varType CoreBndr
b)
       CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
b' (CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
varToCoreExpr CoreBndr
b') CoreBndr
b (CoreExpr -> Type
exprType CoreExpr
e) [(AltCon
DEFAULT,[],CoreExpr
e')]))
   | Bool
otherwise = do
       CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
       CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
b CoreExpr
e')
strictifyExpr SCxt
ss (Cast CoreExpr
e Coercion
c) = do
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast CoreExpr
e' Coercion
c)
strictifyExpr SCxt
ss (Tick t :: Tickish CoreBndr
t@(SourceNote RealSrcSpan
span String
_) CoreExpr
e) = do
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr (SCxt
ss{srcSpan :: SrcSpan
srcSpan = RealSrcSpan -> SrcSpan
RealSrcSpan RealSrcSpan
span}) CoreExpr
e
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Tickish CoreBndr -> CoreExpr -> CoreExpr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t CoreExpr
e')
strictifyExpr SCxt
ss (App CoreExpr
e1 CoreExpr
e2)
  | (SCxt -> Bool
checkStrictData SCxt
ss Bool -> Bool -> Bool
&& Bool -> Bool
not (CoreExpr -> Bool
forall b. Expr b -> Bool
isType CoreExpr
e2) Bool -> Bool -> Bool
&& Type -> Bool
tcIsLiftedTypeKind(HasDebugCallStack => Type -> Type
Type -> Type
typeKind (CoreExpr -> Type
exprType CoreExpr
e2))
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Type -> Bool
isStrict (CoreExpr -> Type
exprType CoreExpr
e2))) = do
      (Severity -> SrcSpan -> MsgDoc -> CoreM ()
forall (m :: * -> *).
(HasDynFlags m, MonadIO m) =>
Severity -> SrcSpan -> MsgDoc -> m ()
printMessage Severity
SevWarning (SCxt -> SrcSpan
srcSpan SCxt
ss)
         (String -> MsgDoc
text String
"The use of lazy type " MsgDoc -> MsgDoc -> MsgDoc
<> Type -> MsgDoc
forall a. Outputable a => a -> MsgDoc
ppr (CoreExpr -> Type
exprType CoreExpr
e2) MsgDoc -> MsgDoc -> MsgDoc
<> MsgDoc
" may lead to memory leaks"))
      CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss{checkStrictData :: Bool
checkStrictData = Bool
False} CoreExpr
e1
      CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss{checkStrictData :: Bool
checkStrictData = Bool
False} CoreExpr
e2
      CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
  | Bool
otherwise = do
      CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e1
      CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e2
      CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
strictifyExpr SCxt
_ss CoreExpr
e = CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e


isDelayApp :: Expr b -> Bool
isDelayApp (App Expr b
e Expr b
_) = Expr b -> Bool
isDelayApp Expr b
e
isDelayApp (Cast Expr b
e Coercion
_) = Expr b -> Bool
isDelayApp Expr b
e
isDelayApp (Tick Tickish CoreBndr
_ Expr b
e) = Expr b -> Bool
isDelayApp Expr b
e
isDelayApp (Var CoreBndr
v) = CoreBndr -> Bool
isDelayVar CoreBndr
v
isDelayApp Expr b
_ = Bool
False




isDelayVar :: Var -> Bool
isDelayVar :: CoreBndr -> Bool
isDelayVar CoreBndr
v = Bool -> (Bool -> Bool) -> Maybe Bool -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
forall a. a -> a
id (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
  let name :: Name
name = CoreBndr -> Name
varName CoreBndr
v
  Module
mod <- Name -> Maybe Module
nameModule_maybe Name
name
  let occ :: String
occ = Name -> String
forall a. NamedThing a => a -> String
getOccString Name
name
  Bool -> Maybe Bool
forall (m :: * -> *) a. Monad m => a -> m a
return ((String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Delay" Bool -> Bool -> Bool
|| String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"delay") Bool -> Bool -> Bool
|| (String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Box" Bool -> Bool -> Bool
|| String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"delay")
          Bool -> Bool -> Bool
&& ((ModuleName -> String
moduleNameString (Module -> ModuleName
moduleName Module
mod) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Rattus.Internal") Bool -> Bool -> Bool
||
          ModuleName -> String
moduleNameString (Module -> ModuleName
moduleName Module
mod) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Rattus.Primitives"))

isCase :: Expr b -> Bool
isCase Case{} = Bool
True
isCase (Tick Tickish CoreBndr
_ Expr b
e) = Expr b -> Bool
isCase Expr b
e
isCase (Cast Expr b
e Coercion
_) = Expr b -> Bool
isCase Expr b
e
isCase Lam {} = Bool
True
isCase Expr b
e = Expr b -> Bool
forall b. Expr b -> Bool
isType Expr b
e

isTophandler :: Expr b -> Bool
isTophandler (App Expr b
e1 Expr b
e2) = Expr b -> Bool
isTophandler Expr b
e1 Bool -> Bool -> Bool
|| Expr b -> Bool
isTophandler Expr b
e2
isTophandler (Cast Expr b
e Coercion
_) = Expr b -> Bool
isTophandler Expr b
e
isTophandler (Tick Tickish CoreBndr
_ Expr b
e) = Expr b -> Bool
isTophandler Expr b
e
isTophandler Expr b
e = MsgDoc -> String
showSDocUnsafe (Expr b -> MsgDoc
forall a. Outputable a => a -> MsgDoc
ppr Expr b
e) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"GHC.TopHandler.runMainIO"