{-# LANGUAGE OverloadedStrings #-}

module Rattus.Plugin.Strictify where
import Prelude hiding ((<>))
import Rattus.Plugin.Utils
import GhcPlugins


data SCxt = SCxt {srcSpan :: SrcSpan, checkStrictData :: Bool}

-- | Transforms all functions into strict functions. If the
-- 'checkStrictData' field of the 'SCxt' argument is set to @True@,
-- then this function also checks for use of non-strict data types and
-- produces warnings if it finds any.
strictifyExpr :: SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr ss (Let (NonRec b e1) e2) = do
  e1' <- strictifyExpr ss e1
  e2' <- strictifyExpr ss e2
  return (Case e1' b (exprType e2) [(DEFAULT, [], e2')])
strictifyExpr ss (Case e b t alts) = do
  e' <- strictifyExpr ss e
  alts' <- mapM (\(c,args,e) -> fmap (\e' -> (c,args,e')) (strictifyExpr ss e)) alts
  return (Case e' b t alts')
strictifyExpr ss (Let (Rec es) e) = do
  es' <- mapM (\ (b,e) -> strictifyExpr ss e >>= \e'-> return (b,e')) es
  e' <- strictifyExpr ss e
  return (Let (Rec es') e')
strictifyExpr ss (Lam b e)
   | not (isCoVar b) && not (isTyVar b) && tcIsLiftedTypeKind(typeKind (varType b))
    = do
       e' <- strictifyExpr ss e
       b' <- mkSysLocalM (fsLit "strict") (varType b)
       return (Lam b' (Case (varToCoreExpr b') b (exprType e) [(DEFAULT,[],e')]))
   | otherwise = do
       e' <- strictifyExpr ss e
       return (Lam b e')
strictifyExpr ss (Cast e c) = do
  e' <- strictifyExpr ss e
  return (Cast e' c)
strictifyExpr ss (Tick t@(SourceNote span _) e) = do
  e' <- strictifyExpr (ss{srcSpan = RealSrcSpan span}) e
  return (Tick t e')
strictifyExpr ss (App e1 e2)
  | (checkStrictData ss && not (isType e2) && tcIsLiftedTypeKind(typeKind (exprType e2))
        && not (isStrict (exprType e2))) = do
      (printMessage SevWarning (srcSpan ss)
         (text "The use of lazy type " <> ppr (exprType e2) <> " may lead to memory leaks"))
      e1' <- strictifyExpr ss{checkStrictData = False} e1
      e2' <- strictifyExpr ss{checkStrictData = False} e2
      return (App e1' e2')
  | otherwise = do
      e1' <- strictifyExpr ss e1
      e2' <- strictifyExpr ss e2
      return (App e1' e2')
strictifyExpr _ss e = return e


isDelayApp (App e _) = isDelayApp e
isDelayApp (Cast e _) = isDelayApp e
isDelayApp (Tick _ e) = isDelayApp e
isDelayApp (Var v) = isDelayVar v
isDelayApp _ = False




isDelayVar :: Var -> Bool
isDelayVar v = maybe False id $ do
  let name = varName v
  mod <- nameModule_maybe name
  let occ = getOccString name
  return ((occ == "Delay" || occ == "delay") || (occ == "Box" || occ == "delay")
          && ((moduleNameString (moduleName mod) == "Rattus.Internal") ||
          moduleNameString (moduleName mod) == "Rattus.Primitives"))

isCase Case{} = True
isCase (Tick _ e) = isCase e
isCase (Cast e _) = isCase e
isCase Lam {} = True
isCase e = isType e

isTophandler (App e1 e2) = isTophandler e1 || isTophandler e2
isTophandler (Cast e _) = isTophandler e
isTophandler (Tick _ e) = isTophandler e
isTophandler e = showSDocUnsafe (ppr e) == "GHC.TopHandler.runMainIO"