{-# LANGUAGE CPP #-}
{-# LANGUAGE Unsafe #-}
module NoRecursion (plugin) where
import safe "base" Control.Applicative (Applicative (pure))
import safe "base" Control.Category (Category ((.)))
import safe "base" Control.Exception (ErrorCall (ErrorCall), throwIO)
import safe "base" Control.Monad ((=<<))
import safe "base" Data.Bool (not, (&&), (||))
import safe "base" Data.Either (Either (Left), either)
import safe "base" Data.Foldable (Foldable (foldMap, toList), all)
import safe "base" Data.Function (($))
import safe "base" Data.Functor (Functor (fmap), (<$>))
import safe "base" Data.List (filter, intercalate, isPrefixOf, null)
import safe "base" Data.List.NonEmpty (NonEmpty, nonEmpty)
import safe "base" Data.Maybe (maybe)
import safe "base" Data.Semigroup (Semigroup ((<>)))
import safe "base" Data.String (String)
import safe "base" Data.Tuple (fst)
#if MIN_VERSION_ghc(9, 0, 0)
import qualified "ghc" GHC.Plugins as Plugins
#else
import qualified "ghc" GhcPlugins as Plugins
#endif
defaultPurePlugin :: Plugins.Plugin
#if MIN_VERSION_ghc(8, 6, 1)
defaultPurePlugin :: Plugin
defaultPurePlugin =
Plugin
Plugins.defaultPlugin {Plugins.pluginRecompile = Plugins.purePlugin}
#else
defaultPurePlugin = Plugins.defaultPlugin
#endif
plugin :: Plugins.Plugin
plugin :: Plugin
plugin = Plugin
defaultPurePlugin {Plugins.installCoreToDos = \[String]
_opts -> [CoreToDo] -> CoreM [CoreToDo]
forall a. a -> CoreM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([CoreToDo] -> CoreM [CoreToDo])
-> ([CoreToDo] -> [CoreToDo]) -> [CoreToDo] -> CoreM [CoreToDo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [CoreToDo] -> [CoreToDo]
install}
install :: [Plugins.CoreToDo] -> [Plugins.CoreToDo]
install :: [CoreToDo] -> [CoreToDo]
install = (String -> CorePluginPass -> CoreToDo
Plugins.CoreDoPluginPass String
"add NoRecursion rule" CorePluginPass
noRecursionPass :)
noRecursionPass :: Plugins.ModGuts -> Plugins.CoreM Plugins.ModGuts
noRecursionPass :: CorePluginPass
noRecursionPass ModGuts
guts = do
DynFlags
dflags <- CoreM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
Plugins.getDynFlags
(NonEmpty (RecursionRecord CoreBndr) -> CoreM ModGuts)
-> ([Bind CoreBndr] -> CoreM ModGuts)
-> Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr]
-> CoreM ModGuts
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
( \NonEmpty (RecursionRecord CoreBndr)
recs ->
IO ModGuts -> CoreM ModGuts
forall a. IO a -> CoreM a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
Plugins.liftIO (IO ModGuts -> CoreM ModGuts)
-> (String -> IO ModGuts) -> String -> CoreM ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ErrorCall -> IO ModGuts
forall e a. Exception e => e -> IO a
throwIO (ErrorCall -> IO ModGuts)
-> (String -> ErrorCall) -> String -> IO ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCall
ErrorCall (String -> CoreM ModGuts) -> String -> CoreM ModGuts
forall a b. (a -> b) -> a -> b
$
String
"something recursive:\n"
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" (NonEmpty String -> [String]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (NonEmpty String -> [String]) -> NonEmpty String -> [String]
forall a b. (a -> b) -> a -> b
$ DynFlags -> RecursionRecord CoreBndr -> String
forall b. Outputable b => DynFlags -> RecursionRecord b -> String
formatRecursionRecord DynFlags
dflags (RecursionRecord CoreBndr -> String)
-> NonEmpty (RecursionRecord CoreBndr) -> NonEmpty String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (RecursionRecord CoreBndr)
recs)
)
(\[Bind CoreBndr]
binds -> CorePluginPass
forall a. a -> CoreM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ModGuts
guts {Plugins.mg_binds = binds})
(Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr]
-> CoreM ModGuts)
-> ([Bind CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr])
-> [Bind CoreBndr]
-> CoreM ModGuts
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DynFlags
-> [Bind CoreBndr]
-> Either (NonEmpty (RecursionRecord CoreBndr)) [Bind CoreBndr]
forall b.
Outputable b =>
DynFlags
-> [Bind b] -> Either (NonEmpty (RecursionRecord b)) [Bind b]
failOnRecursion DynFlags
dflags
([Bind CoreBndr] -> CoreM ModGuts)
-> [Bind CoreBndr] -> CoreM ModGuts
forall a b. (a -> b) -> a -> b
$ ModGuts -> [Bind CoreBndr]
Plugins.mg_binds ModGuts
guts
data RecursionRecord b = RecursionRecord [b] (NonEmpty b)
formatRecursionRecord ::
(Plugins.Outputable b) => Plugins.DynFlags -> RecursionRecord b -> String
formatRecursionRecord :: forall b. Outputable b => DynFlags -> RecursionRecord b -> String
formatRecursionRecord DynFlags
dflags (RecursionRecord [b]
context NonEmpty b
recs) =
String -> (NonEmpty b -> String) -> Maybe (NonEmpty b) -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
String
"at the top level"
( \NonEmpty b
v ->
String
"in "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate
String
" >> "
(DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> (b -> SDoc) -> b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr (b -> String) -> [b] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty b -> [b]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty b
v)
)
([b] -> Maybe (NonEmpty b)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [b]
context)
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
", the following bindings were recursive: "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " (DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> (b -> SDoc) -> b -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr (b -> String) -> [b] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty b -> [b]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty b
recs)
failOnRecursion ::
(Plugins.Outputable b) =>
Plugins.DynFlags ->
[Plugins.Bind b] ->
Either (NonEmpty (RecursionRecord b)) [Plugins.Bind b]
failOnRecursion :: forall b.
Outputable b =>
DynFlags
-> [Bind b] -> Either (NonEmpty (RecursionRecord b)) [Bind b]
failOnRecursion DynFlags
dflags [Bind b]
original =
Either (NonEmpty (RecursionRecord b)) [Bind b]
-> (NonEmpty (RecursionRecord b)
-> Either (NonEmpty (RecursionRecord b)) [Bind b])
-> Maybe (NonEmpty (RecursionRecord b))
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Bind b] -> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall a. a -> Either (NonEmpty (RecursionRecord b)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Bind b]
original) NonEmpty (RecursionRecord b)
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall a b. a -> Either a b
Left
(Maybe (NonEmpty (RecursionRecord b))
-> Either (NonEmpty (RecursionRecord b)) [Bind b])
-> ([RecursionRecord b] -> Maybe (NonEmpty (RecursionRecord b)))
-> [RecursionRecord b]
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [RecursionRecord b] -> Maybe (NonEmpty (RecursionRecord b))
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty
([RecursionRecord b] -> Maybe (NonEmpty (RecursionRecord b)))
-> ([RecursionRecord b] -> [RecursionRecord b])
-> [RecursionRecord b]
-> Maybe (NonEmpty (RecursionRecord b))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (RecursionRecord b -> Bool)
-> [RecursionRecord b] -> [RecursionRecord b]
forall a. (a -> Bool) -> [a] -> [a]
filter
( \(RecursionRecord [b]
context NonEmpty b
recs) ->
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
[b] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [b]
context
Bool -> Bool -> Bool
&& (b -> Bool) -> NonEmpty b -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
( \b
var ->
let v :: String
v = DynFlags -> SDoc -> String
Plugins.showSDoc DynFlags
dflags (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ b -> SDoc
forall a. Outputable a => a -> SDoc
Plugins.ppr b
var
in String
"$c" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
v Bool -> Bool -> Bool
|| String
"$f" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
v
)
NonEmpty b
recs
)
([RecursionRecord b]
-> Either (NonEmpty (RecursionRecord b)) [Bind b])
-> [RecursionRecord b]
-> Either (NonEmpty (RecursionRecord b)) [Bind b]
forall a b. (a -> b) -> a -> b
$ Bind b -> [RecursionRecord b]
forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind (Bind b -> [RecursionRecord b]) -> [Bind b] -> [RecursionRecord b]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Bind b]
original
addBindingReference :: b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference :: forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
var =
(RecursionRecord b -> RecursionRecord b)
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(RecursionRecord [b]
context NonEmpty b
recs) -> [b] -> NonEmpty b -> RecursionRecord b
forall b. [b] -> NonEmpty b -> RecursionRecord b
RecursionRecord (b
var b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
context) NonEmpty b
recs)
recursiveCallsForBind :: Plugins.Bind b -> [RecursionRecord b]
recursiveCallsForBind :: forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind = \case
Plugins.NonRec b
v Expr b
rhs -> b -> [RecursionRecord b] -> [RecursionRecord b]
forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
v ([RecursionRecord b] -> [RecursionRecord b])
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs
Plugins.Rec [(b, Expr b)]
binds ->
let nestedRecursion :: [RecursionRecord b]
nestedRecursion =
((b, Expr b) -> [RecursionRecord b])
-> [(b, Expr b)] -> [RecursionRecord b]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
(\(b
v, Expr b
rhs) -> b -> [RecursionRecord b] -> [RecursionRecord b]
forall b. b -> [RecursionRecord b] -> [RecursionRecord b]
addBindingReference b
v ([RecursionRecord b] -> [RecursionRecord b])
-> [RecursionRecord b] -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs)
[(b, Expr b)]
binds
in [RecursionRecord b]
-> (NonEmpty (b, Expr b) -> [RecursionRecord b])
-> Maybe (NonEmpty (b, Expr b))
-> [RecursionRecord b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
[RecursionRecord b]
nestedRecursion
(\NonEmpty (b, Expr b)
bnds -> [b] -> NonEmpty b -> RecursionRecord b
forall b. [b] -> NonEmpty b -> RecursionRecord b
RecursionRecord [] ((b, Expr b) -> b
forall a b. (a, b) -> a
fst ((b, Expr b) -> b) -> NonEmpty (b, Expr b) -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (b, Expr b)
bnds) RecursionRecord b -> [RecursionRecord b] -> [RecursionRecord b]
forall a. a -> [a] -> [a]
: [RecursionRecord b]
nestedRecursion)
(Maybe (NonEmpty (b, Expr b)) -> [RecursionRecord b])
-> Maybe (NonEmpty (b, Expr b)) -> [RecursionRecord b]
forall a b. (a -> b) -> a -> b
$ [(b, Expr b)] -> Maybe (NonEmpty (b, Expr b))
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty [(b, Expr b)]
binds
collectRecursiveCalls :: Plugins.Expr b -> [RecursionRecord b]
collectRecursiveCalls :: forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls = \case
Plugins.App Expr b
f Expr b
a -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
f [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
a
Plugins.Case Expr b
scrut b
_ Type
_ [Alt b]
alts ->
Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
scrut [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> (Alt b -> [RecursionRecord b]) -> [Alt b] -> [RecursionRecord b]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt b -> [RecursionRecord b]
forall b. Alt b -> [RecursionRecord b]
recursiveCallsForAlt [Alt b]
alts
Plugins.Cast Expr b
e CoercionR
_ -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
e
Plugins.Coercion CoercionR
_ -> []
Plugins.Lam b
_ Expr b
body -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
body
Plugins.Let Bind b
bind Expr b
e -> Bind b -> [RecursionRecord b]
forall b. Bind b -> [RecursionRecord b]
recursiveCallsForBind Bind b
bind [RecursionRecord b] -> [RecursionRecord b] -> [RecursionRecord b]
forall a. Semigroup a => a -> a -> a
<> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
e
Plugins.Lit Literal
_ -> []
Plugins.Tick CoreTickish
_ Expr b
body -> Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
body
Plugins.Type Type
_ -> []
Plugins.Var CoreBndr
_ -> []
recursiveCallsForAlt :: Plugins.Alt b -> [RecursionRecord b]
#if MIN_VERSION_ghc(9, 2, 0)
recursiveCallsForAlt :: forall b. Alt b -> [RecursionRecord b]
recursiveCallsForAlt (Plugins.Alt AltCon
_ [b]
_ Expr b
rhs) = Expr b -> [RecursionRecord b]
forall b. Expr b -> [RecursionRecord b]
collectRecursiveCalls Expr b
rhs
#else
recursiveCallsForAlt (_, _, rhs) = collectRecursiveCalls rhs
#endif