module Polysemy.Plugin.InlineRecursiveCalls
  ( inlineRecursiveCalls
  ) where

import BasicTypes
import Control.Monad
import Control.Monad.Trans.State
import CoreMonad
import CoreSyn
import Data.Monoid
import Data.Traversable
import GHC
import Generics.SYB
import HscTypes
import IdInfo
import Name
import UniqSupply
import Unique
import Var


inlineRecursiveCalls :: ModGuts -> CoreM ModGuts
inlineRecursiveCalls mg = do
  uniqSupply <- liftIO $ mkSplitUniqSupply '\x264a'
  flip evalStateT uniqSupply $ do
    bs <- traverse loopbreakBinds $ mg_binds mg
    pure $ mg { mg_binds = bs }


type CoreSupplyM = StateT UniqSupply CoreM


getUniq :: CoreSupplyM Unique
getUniq = do
  (u, s) <- gets takeUniqFromSupply
  put s
  pure u


containsName :: CoreBndr -> CoreExpr -> Bool
containsName n e =
  getAny $
    everything
      (<>)
      (mkQ (Any False) $ matches n)
      e


matches :: CoreBndr -> CoreExpr -> Any
matches n (Var n') | n == n' = Any True
matches _ _ = Any False


replace :: Id -> Id -> Expr CoreBndr -> Expr CoreBndr
replace n n' = everywhere $ mkT go
  where
    go :: Expr CoreBndr -> Expr CoreBndr
    go v@(Var nn)
      | nn == n   = Var n'
      | otherwise = v
    go x = x


loopbreaker :: CoreBndr -> CoreExpr -> CoreSupplyM [(Var, CoreExpr)]
loopbreaker n b = do
  u <- getUniq
  let Just info = zapUsageInfo $ idInfo n
      info' = setInlinePragInfo info alwaysInlinePragma
      n' = mkLocalVar
             (idDetails n)
             (mkInternalName u (occName n) noSrcSpan)
             (idType n)
         $ setInlinePragInfo vanillaIdInfo neverInlinePragma
  pure [ (lazySetIdInfo n info', replace n n' b)
       , (n', Var n)
       ]


-- TODO(sandy): Make this only break loops in functions whose type ends in `Sem
-- * * -> Sem * *` for wildcards `*`
loopbreakBinds
    :: Bind CoreBndr
    -> CoreSupplyM (Bind CoreBndr)
loopbreakBinds nr@(NonRec n b)
  | containsName n b = Rec <$> loopbreaker n b
  | otherwise        = pure nr
loopbreakBinds (Rec bs) = fmap (Rec . join) . for bs $ \(n, b) ->
  case containsName n b of
    False -> pure [(n, b)]
    True  -> loopbreaker n b