{-# LANGUAGE CPP #-}
{-# LANGUAGE Unsafe #-}

-- | A plugin that identifies and reports on uses of recursion. The name evokes
--   a language pragma – implying a @Recursion@ pragma that is enabled by
--   default.
module NoRecursion (plugin) where

import safe "base" Control.Applicative (Applicative (pure))
import safe "base" Control.Category (Category ((.)))
import safe "base" Control.Exception (ErrorCall (ErrorCall), throwIO)
import safe "base" Control.Monad ((=<<))
import safe "base" Data.Bool (not, (&&), (||))
import safe "base" Data.Either (Either (Left), either)
import safe "base" Data.Foldable (Foldable (foldMap, toList), all)
import safe "base" Data.Function (($))
import safe "base" Data.Functor (Functor (fmap), (<$>))
import safe "base" Data.List (filter, intercalate, isPrefixOf, null)
import safe "base" Data.List.NonEmpty (NonEmpty, nonEmpty)
import safe "base" Data.Maybe (maybe)
import safe "base" Data.Semigroup (Semigroup ((<>)))
import safe "base" Data.String (String)
import safe "base" Data.Tuple (fst)
#if MIN_VERSION_ghc(9, 0, 0)
import qualified "ghc" GHC.Plugins as Plugins
#else
import qualified "ghc" GhcPlugins as Plugins
#endif

defaultPurePlugin :: Plugins.Plugin
#if MIN_VERSION_ghc(8, 6, 1)
defaultPurePlugin :: Plugin
defaultPurePlugin =
  Plugin
Plugins.defaultPlugin {Plugins.pluginRecompile = Plugins.purePlugin}
#else
defaultPurePlugin = Plugins.defaultPlugin
#endif

plugin :: Plugins.Plugin
plugin :: Plugin
plugin = Plugin
defaultPurePlugin {Plugins.installCoreToDos = \[String]
_opts -> [CoreToDo] -> CoreM [CoreToDo]
forall a. a -> CoreM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([CoreToDo] -> CoreM [CoreToDo])
-> ([CoreToDo] -> [CoreToDo]) -> [CoreToDo] -> CoreM [CoreToDo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [CoreToDo] -> [CoreToDo]
install}

install :: [Plugins.CoreToDo] -> [Plugins.CoreToDo]
install :: [CoreToDo] -> [CoreToDo]
install = (String -> CorePluginPass -> CoreToDo
Plugins.CoreDoPluginPass String
"add NoRecursion rule" CorePluginPass
noRecursionPass :)

noRecursionPass :: Plugins.ModGuts -> Plugins.CoreM Plugins.ModGuts
noRecursionPass :: CorePluginPass
noRecursionPass ModGuts
guts = do
  DynFlags
dflags <- CoreM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
Plugins.getDynFlags
  (NonEmpty (RecursionRecord CoreBndr) -> CoreM ModGuts)
-> ([Bind CoreBndr] -> CoreM ModGuts)
-> Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr]
-> CoreM ModGuts
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
    ( \NonEmpty (RecursionRecord CoreBndr)
recs ->
        IO ModGuts -> CoreM ModGuts
forall a. IO a -> CoreM a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
Plugins.liftIO (IO ModGuts -> CoreM ModGuts)
-> (String -> IO ModGuts) -> String -> CoreM ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ErrorCall -> IO ModGuts
forall e a. Exception e => e -> IO a
throwIO (ErrorCall -> IO ModGuts)
-> (String -> ErrorCall) -> String -> IO ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCall
ErrorCall (String -> CoreM ModGuts) -> String -> CoreM ModGuts
forall a b. (a -> b) -> a -> b
$
          String
"something recursive:\n"
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" (NonEmpty String -> [String]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (NonEmpty String -> [String]) -> NonEmpty String -> [String]
forall a b. (a -> b) -> a -> b
$ DynFlags -> RecursionRecord CoreBndr -> String
forall b. Outputable b => DynFlags -> RecursionRecord b -> String
formatRecursionRecord DynFlags
dflags (RecursionRecord CoreBndr -> String)
-> NonEmpty (RecursionRecord CoreBndr) -> NonEmpty String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (RecursionRecord CoreBndr)
recs)
    )
    (\[Bind CoreBndr]
binds -> CorePluginPass
forall a. a -> CoreM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModGuts
guts {Plugins.mg_binds = binds})
    (Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr]
 -> CoreM ModGuts)
-> ([Bind CoreBndr]
    -> Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr])
-> [Bind CoreBndr]
-> CoreM ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DynFlags
-> [Bind CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr]
forall b.
Outputable b =>
DynFlags
-> [Bind b] -> Either (NonEmpty (RecursionRecord b)) [Bind b]
failOnRecursion DynFlags
dflags
    ([Bind CoreBndr] -> CoreM ModGuts)
-> [Bind CoreBndr] -> CoreM ModGuts
forall a b. (a -> b) -> a -> b
$ ModGuts -> [Bind CoreBndr]
Plugins.mg_binds ModGuts
guts

data RecursionRecord b = RecursionRecord [b] (NonEmpty b)

formatRecursionRecord ::
  (Plugins.Outputable b) => Plugins.DynFlags -> RecursionRecord b -> String
formatRecursionRecord :: forall b. Outputable b => DynFlags -> RecursionRecord b -> String
formatRecursionRecord DynFlags
dflags (RecursionRecord [b]
context NonEmpty b
recs) =
  String -> (NonEmpty b -> String) -> Maybe (NonEmpty b) -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    String
"at the top level"
    ( \NonEmpty b
v ->
        String
"in "
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate
            String
" >> "
            (DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> (b -> SDoc) -> b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr (b -> String) -> [b] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty b -> [b]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty b
v)
    )
    ([b] -> Maybe (NonEmpty b)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [b]
context)
    String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
", the following bindings were recursive: "
    String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " (DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> (b -> SDoc) -> b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr (b -> String) -> [b] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty b -> [b]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty b
recs)

failOnRecursion ::
  (Plugins.Outputable b) =>
  Plugins.DynFlags ->
  [Plugins.Bind b] ->
  Either (NonEmpty (RecursionRecord b)) [Plugins.Bind b]
failOnRecursion :: forall b.
Outputable b =>
DynFlags
-> [Bind b] -> Either (NonEmpty (RecursionRecord b)) [Bind b]
failOnRecursion DynFlags
dflags [Bind b]
original =
  Either (NonEmpty (RecursionRecord b)) [Bind b]
-> (NonEmpty (RecursionRecord b)
    -> Either (NonEmpty (RecursionRecord b)) [Bind b])
-> Maybe (NonEmpty (RecursionRecord b))
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Bind b] -> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall a. a -> Either (NonEmpty (RecursionRecord b)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Bind b]
original) NonEmpty (RecursionRecord b)
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall a b. a -> Either a b
Left
    (Maybe (NonEmpty (RecursionRecord b))
 -> Either (NonEmpty (RecursionRecord b)) [Bind b])
-> ([RecursionRecord b] -> Maybe (NonEmpty (RecursionRecord b)))
-> [RecursionRecord b]
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [RecursionRecord b] -> Maybe (NonEmpty (RecursionRecord b))
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty
    -- __TODO__: Default method implementations seem to cause mutual recursion
    --           with the instance, so here we filter them out, but this
    --           probably lets some real mutual recursion slip through.
    ([RecursionRecord b] -> Maybe (NonEmpty (RecursionRecord b)))
-> ([RecursionRecord b] -> [RecursionRecord b])
-> [RecursionRecord b]
-> Maybe (NonEmpty (RecursionRecord b))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (RecursionRecord b -> Bool)
-> [RecursionRecord b] -> [RecursionRecord b]
forall a. (a -> Bool) -> [a] -> [a]
filter
      ( \(RecursionRecord [b]
context NonEmpty b
recs) ->
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            [b] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [b]
context
              Bool -> Bool -> Bool
&& (b -> Bool) -> NonEmpty b -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                ( \b
var ->
                    let v :: String
v = DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr b
var
                     in String
"$c" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
v Bool -> Bool -> Bool
|| String
"$f" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
v
                )
                NonEmpty b
recs
      )
    ([RecursionRecord b]
 -> Either (NonEmpty (RecursionRecord b)) [Bind b])
-> [RecursionRecord b]
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall a b. (a -> b) -> a -> b
$ Bind b -> [RecursionRecord b]
forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind (Bind b -> [RecursionRecord b]) -> [Bind b] -> [RecursionRecord b]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Bind b]
original

addBindingReference :: b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference :: forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
var =
  (RecursionRecord b -> RecursionRecord b)
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(RecursionRecord [b]
context NonEmpty b
recs) -> [b] -> NonEmpty b -> RecursionRecord b
forall b. [b] -> NonEmpty b -> RecursionRecord b
RecursionRecord (b
var b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
context) NonEmpty b
recs)

recursiveCallsForBind :: Plugins.Bind b -> [RecursionRecord b]
recursiveCallsForBind :: forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind = \case
  Plugins.NonRec b
v Expr b
rhs -> b -> [RecursionRecord b] -> [RecursionRecord b]
forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
v ([RecursionRecord b] -> [RecursionRecord b])
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs
  Plugins.Rec [(b, Expr b)]
binds ->
    let nestedRecursion :: [RecursionRecord b]
nestedRecursion =
          ((b, Expr b) -> [RecursionRecord b])
-> [(b, Expr b)] -> [RecursionRecord b]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
            (\(b
v, Expr b
rhs) -> b -> [RecursionRecord b] -> [RecursionRecord b]
forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
v ([RecursionRecord b] -> [RecursionRecord b])
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs)
            [(b, Expr b)]
binds
     in [RecursionRecord b]
-> (NonEmpty (b, Expr b) -> [RecursionRecord b])
-> Maybe (NonEmpty (b, Expr b))
-> [RecursionRecord b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
          [RecursionRecord b]
nestedRecursion
          (\NonEmpty (b, Expr b)
bnds -> [b] -> NonEmpty b -> RecursionRecord b
forall b. [b] -> NonEmpty b -> RecursionRecord b
RecursionRecord [] ((b, Expr b) -> b
forall a b. (a, b) -> a
fst ((b, Expr b) -> b) -> NonEmpty (b, Expr b) -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (b, Expr b)
bnds) RecursionRecord b -> [RecursionRecord b] -> [RecursionRecord b]
forall a. a -> [a] -> [a]
: [RecursionRecord b]
nestedRecursion)
          (Maybe (NonEmpty (b, Expr b)) -> [RecursionRecord b])
-> Maybe (NonEmpty (b, Expr b)) -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ [(b, Expr b)] -> Maybe (NonEmpty (b, Expr b))
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [(b, Expr b)]
binds

-- | This collects all identifiable recursion points in an expression.
collectRecursiveCalls :: Plugins.Expr b -> [RecursionRecord b]
collectRecursiveCalls :: forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls = \case
  Plugins.App Expr b
f Expr b
a -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
f [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
a
  Plugins.Case Expr b
scrut b
_ Type
_ [Alt b]
alts ->
    Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
scrut [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> (Alt b -> [RecursionRecord b]) -> [Alt b] -> [RecursionRecord b]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt b -> [RecursionRecord b]
forall b. Alt b -> [RecursionRecord b]
recursiveCallsForAlt [Alt b]
alts
  Plugins.Cast Expr b
e CoercionR
_ -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
e
  Plugins.Coercion CoercionR
_ -> []
  Plugins.Lam b
_ Expr b
body -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
body
  Plugins.Let Bind b
bind Expr b
e -> Bind b -> [RecursionRecord b]
forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind Bind b
bind [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
e
  Plugins.Lit Literal
_ -> []
  Plugins.Tick CoreTickish
_ Expr b
body -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
body
  Plugins.Type Type
_ -> []
  Plugins.Var CoreBndr
_ -> []

recursiveCallsForAlt :: Plugins.Alt b -> [RecursionRecord b]
#if MIN_VERSION_ghc(9, 2, 0)
recursiveCallsForAlt :: forall b. Alt b -> [RecursionRecord b]
recursiveCallsForAlt (Plugins.Alt AltCon
_ [b]
_ Expr b
rhs) = Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs
#else
recursiveCallsForAlt (_, _, rhs) = collectRecursiveCalls rhs
#endif