{-# LANGUAGE QuasiQuotes, NamedFieldPuns #-}

module CallCount.TcPlugin (callCount, optCallCount) where

import Language.Haskell.Printf (s)
import Data.Maybe (fromMaybe)
import Data.IORef (IORef)
import GHC.Corroborate
import NoOp.Plugin

newtype State = State{State -> IORef Int
callref :: IORef Int}

-- | A plugin that counts the number of times its 'tcPluginSolve' function is
-- called when GHC is type checking.
callCount :: TcPlugin
callCount :: TcPlugin
callCount = TcPlugin -> Maybe TcPlugin -> TcPlugin
forall a. a -> Maybe a -> a
fromMaybe TcPlugin
noOp (Maybe TcPlugin -> TcPlugin) -> Maybe TcPlugin -> TcPlugin
forall a b. (a -> b) -> a -> b
$ [CommandLineOption] -> Maybe TcPlugin
optCallCount []

-- | This plugin does no type checking.
--
-- Any options passed are echoed as a prefix before the call count when
-- 'tcPluginSolve' is called. If no options are passed then __GHC-TcPlugin__ is
-- used instead as the prefix before the count like this test suite shows.
--
-- @
-- Building test suite 'test-counter-foobar-main'
-- [1 of 2] Compiling FooBar
-- >>> GHC-TcPlugin #1
-- >>> GHC-TcPlugin #2
-- [2 of 2] Compiling Main
-- @
--
-- The options passed to 'optCallCount' are echoed as a prefix in the test
-- suites called by @./build-opts.sh@.
--
-- @
-- > cat ./build-opts.sh
-- # The steps in .github\/workflows\/cabal.yml related to passing options to plugins.
-- # You might like to run cabal update and cabal clean before running this script.
-- cabal build test-in-turn
-- cabal build test-in-line
-- cabal build test-in-turn-each
-- cabal build test-in-line-each
-- @
--
-- @
-- > ./build-opts.sh
-- ...
-- Building test suite 'test-in-turn'
-- [1 of 1] Compiling Main
-- >>> AB #1
-- >>> AB #1
-- >>> AB #2
-- >>> AB #2
-- ...
-- Building test suite 'test-in-line'
-- [1 of 1] Compiling Main
-- >>> AB #1
-- >>> AB #1
-- >>> AB #2
-- >>> AB #2
-- ...
-- Building test suite 'test-in-turn-each'
-- [1 of 1] Compiling Main
-- >>> B #1
-- >>> A #1
-- >>> B #2
-- >>> A #2
-- ...
-- Building test suite 'test-in-line-each'
-- [1 of 1] Compiling Main
-- >>> B #1
-- >>> A #1
-- >>> B #2
-- >>> A #2
-- @
optCallCount :: [CommandLineOption] -> Maybe TcPlugin
optCallCount :: [CommandLineOption] -> Maybe TcPlugin
optCallCount [CommandLineOption]
opts = TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$
    TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin
        { tcPluginInit :: TcPluginM State
tcPluginInit = IORef Int -> State
State (IORef Int -> State) -> TcPluginM (IORef Int) -> TcPluginM State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcM (IORef Int) -> TcPluginM (IORef Int)
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM (Int -> TcM (IORef Int)
forall a env. a -> IOEnv env (IORef a)
newMutVar Int
1)

        , tcPluginSolve :: State -> TcPluginSolver
tcPluginSolve = \State{callref :: State -> IORef Int
callref = IORef Int
c} [Ct]
_ [Ct]
_ [Ct]
_ -> do
            Int
n <- TcM Int -> TcPluginM Int
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM (TcM Int -> TcPluginM Int) -> TcM Int -> TcPluginM Int
forall a b. (a -> b) -> a -> b
$ IORef Int -> TcM Int
forall a env. IORef a -> IOEnv env a
readMutVar IORef Int
c
            let msg :: CommandLineOption
msg = if [CommandLineOption] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CommandLineOption]
opts then CommandLineOption
"GHC-TcPlugin" else [CommandLineOption] -> CommandLineOption
forall a. Monoid a => [a] -> a
mconcat [CommandLineOption]
opts
            IO () -> TcPluginM ()
forall a. IO a -> TcPluginM a
tcPluginIO (IO () -> TcPluginM ())
-> (CommandLineOption -> IO ())
-> CommandLineOption
-> TcPluginM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CommandLineOption -> IO ()
putStrLn (CommandLineOption -> TcPluginM ())
-> CommandLineOption -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ [s|>>> %s #%d|] CommandLineOption
msg Int
n
            TcM () -> TcPluginM ()
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM (TcM () -> TcPluginM ()) -> TcM () -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ IORef Int -> Int -> TcM ()
forall a env. IORef a -> a -> IOEnv env ()
writeMutVar IORef Int
c (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] []

        , tcPluginStop :: State -> TcPluginM ()
tcPluginStop = TcPluginM () -> State -> TcPluginM ()
forall a b. a -> b -> a
const (TcPluginM () -> State -> TcPluginM ())
-> TcPluginM () -> State -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ () -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        }