{-# LANGUAGE OverloadedStrings #-}

module Wingman.Simplify
  ( simplify
  ) where

import Data.Generics (GenericT, everywhere, mkT)
import Data.List.Extra (unsnoc)
import Data.Monoid (Endo (..))
import Development.IDE.GHC.Compat
import GHC.SourceGen (var)
import GHC.SourceGen.Expr (lambda)
import Wingman.CodeGen.Utils
import Wingman.GHC (containsHsVar, fromPatCompat, pattern SingleLet)


------------------------------------------------------------------------------
-- | A pattern over the otherwise (extremely) messy AST for lambdas.
pattern Lambda :: [Pat GhcPs] -> HsExpr GhcPs -> HsExpr GhcPs
pattern $bLambda :: [Pat GhcPs] -> HsExpr GhcPs -> HsExpr GhcPs
$mLambda :: forall r.
HsExpr GhcPs
-> ([Pat GhcPs] -> HsExpr GhcPs -> r) -> (Void# -> r) -> r
Lambda pats body <-
  HsLam _
    MG {mg_alts = L _ [L _
      Match { m_pats = fmap fromPatCompat -> pats
            , m_grhss = GRHSs {grhssGRHSs = [L _ (
                 GRHS _ [] (L _ body))]}
            }]
        }
  where
    -- If there are no patterns to bind, just stick in the body
    Lambda [] HsExpr GhcPs
body   = HsExpr GhcPs
body
    Lambda [Pat GhcPs]
pats HsExpr GhcPs
body = [Pat GhcPs] -> HsExpr GhcPs -> HsExpr GhcPs
lambda [Pat GhcPs]
pats HsExpr GhcPs
body



------------------------------------------------------------------------------
-- | Simlify an expression.
simplify :: LHsExpr GhcPs -> LHsExpr GhcPs
simplify :: LHsExpr GhcPs -> LHsExpr GhcPs
simplify
  = ([LHsExpr GhcPs] -> Int -> LHsExpr GhcPs
forall a. [a] -> Int -> a
!!Int
3) -- Do three passes; this should be good enough for the limited
          -- amount of gas we give to auto
  ([LHsExpr GhcPs] -> LHsExpr GhcPs)
-> (LHsExpr GhcPs -> [LHsExpr GhcPs])
-> LHsExpr GhcPs
-> LHsExpr GhcPs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LHsExpr GhcPs -> LHsExpr GhcPs)
-> LHsExpr GhcPs -> [LHsExpr GhcPs]
forall a. (a -> a) -> a -> [a]
iterate ((forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$ [a -> a] -> a -> a
forall (t :: * -> *) a. Foldable t => t (a -> a) -> a -> a
foldEndo
    [ a -> a
forall a. Data a => a -> a
simplifyEtaReduce
    , a -> a
forall a. Data a => a -> a
simplifyRemoveParens
    , a -> a
forall a. Data a => a -> a
simplifyCompose
    , a -> a
forall a. Data a => a -> a
simplifySingleLet
    ])


------------------------------------------------------------------------------
-- | Like 'foldMap' but for endomorphisms.
foldEndo :: Foldable t => t (a -> a) -> a -> a
foldEndo :: t (a -> a) -> a -> a
foldEndo = Endo a -> a -> a
forall a. Endo a -> a -> a
appEndo (Endo a -> a -> a)
-> (t (a -> a) -> Endo a) -> t (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a -> a) -> Endo a) -> t (a -> a) -> Endo a
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (a -> a) -> Endo a
forall a. (a -> a) -> Endo a
Endo


------------------------------------------------------------------------------
-- | Perform an eta reduction. For example, transforms @\x -> (f g) x@ into
-- @f g@.
simplifyEtaReduce :: GenericT
simplifyEtaReduce :: a -> a
simplifyEtaReduce = (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT ((HsExpr GhcPs -> HsExpr GhcPs) -> a -> a)
-> (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (a -> b) -> a -> b
$ \case
  Lambda
      [VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
pat)]
      (HsVar XVar GhcPs
_ (L SrcSpan
_ IdP GhcPs
a)) | IdP GhcPs
RdrName
pat RdrName -> RdrName -> Bool
forall a. Eq a => a -> a -> Bool
== IdP GhcPs
RdrName
a ->
    RdrNameStr -> HsExpr GhcPs
forall a. Var a => RdrNameStr -> a
var RdrNameStr
"id"
  Lambda
      ([Pat GhcPs] -> Maybe ([Pat GhcPs], Pat GhcPs)
forall a. [a] -> Maybe ([a], a)
unsnoc -> Just ([Pat GhcPs]
pats, VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
pat)))
      (HsApp XApp GhcPs
_ (L SrcSpan
_ HsExpr GhcPs
f) (L SrcSpan
_ (HsVar XVar GhcPs
_ (L SrcSpan
_ IdP GhcPs
a))))
      | IdP GhcPs
RdrName
pat RdrName -> RdrName -> Bool
forall a. Eq a => a -> a -> Bool
== IdP GhcPs
RdrName
a
        -- We can only perform this simplifiation if @pat@ is otherwise unused.
      , Bool -> Bool
not (RdrName -> HsExpr GhcPs -> Bool
forall a. Data a => RdrName -> a -> Bool
containsHsVar IdP GhcPs
RdrName
pat HsExpr GhcPs
f) ->
    [Pat GhcPs] -> HsExpr GhcPs -> HsExpr GhcPs
Lambda [Pat GhcPs]
pats HsExpr GhcPs
f
  HsExpr GhcPs
x -> HsExpr GhcPs
x

------------------------------------------------------------------------------
-- | Eliminates the unnecessary binding in @let a = b in a@
simplifySingleLet :: GenericT
simplifySingleLet :: a -> a
simplifySingleLet = (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT ((HsExpr GhcPs -> HsExpr GhcPs) -> a -> a)
-> (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (a -> b) -> a -> b
$ \case
  SingleLet IdP GhcPs
bind [] HsExpr GhcPs
val (HsVar XVar GhcPs
_ (L SrcSpan
_ IdP GhcPs
a)) | IdP GhcPs
RdrName
a RdrName -> RdrName -> Bool
forall a. Eq a => a -> a -> Bool
== IdP GhcPs
RdrName
bind -> HsExpr GhcPs
val
  HsExpr GhcPs
x -> HsExpr GhcPs
x


------------------------------------------------------------------------------
-- | Perform an eta-reducing function composition. For example, transforms
-- @\x -> f (g (h x))@ into @f . g . h@.
simplifyCompose :: GenericT
simplifyCompose :: a -> a
simplifyCompose = (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT ((HsExpr GhcPs -> HsExpr GhcPs) -> a -> a)
-> (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (a -> b) -> a -> b
$ \case
  Lambda
      ([Pat GhcPs] -> Maybe ([Pat GhcPs], Pat GhcPs)
forall a. [a] -> Maybe ([a], a)
unsnoc -> Just ([Pat GhcPs]
pats, VarPat XVarPat GhcPs
_ (L SrcSpan
_ IdP GhcPs
pat)))
      (HsExpr GhcPs -> ([HsExpr GhcPs], HsExpr GhcPs)
unroll -> (fs :: [HsExpr GhcPs]
fs@(HsExpr GhcPs
_:[HsExpr GhcPs]
_), HsVar XVar GhcPs
_ (L SrcSpan
_ IdP GhcPs
a)))
      | IdP GhcPs
RdrName
pat RdrName -> RdrName -> Bool
forall a. Eq a => a -> a -> Bool
== IdP GhcPs
RdrName
a
        -- We can only perform this simplifiation if @pat@ is otherwise unused.
      , Bool -> Bool
not (RdrName -> [HsExpr GhcPs] -> Bool
forall a. Data a => RdrName -> a -> Bool
containsHsVar IdP GhcPs
RdrName
pat [HsExpr GhcPs]
fs) ->
    [Pat GhcPs] -> HsExpr GhcPs -> HsExpr GhcPs
Lambda [Pat GhcPs]
pats ((HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs)
-> [HsExpr GhcPs] -> HsExpr GhcPs
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
".") [HsExpr GhcPs]
fs)
  HsExpr GhcPs
x -> HsExpr GhcPs
x


------------------------------------------------------------------------------
-- | Removes unnecessary parentheses on any token that doesn't need them.
simplifyRemoveParens :: GenericT
simplifyRemoveParens :: a -> a
simplifyRemoveParens = (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT ((HsExpr GhcPs -> HsExpr GhcPs) -> a -> a)
-> (HsExpr GhcPs -> HsExpr GhcPs) -> a -> a
forall a b. (a -> b) -> a -> b
$ \case
  HsPar XPar GhcPs
_ (L SrcSpan
_ HsExpr GhcPs
x) | HsExpr GhcPs -> Bool
forall id. HsExpr id -> Bool
isAtomicHsExpr HsExpr GhcPs
x -> HsExpr GhcPs
x
  (HsExpr GhcPs
x :: HsExpr GhcPs)                -> HsExpr GhcPs
x


------------------------------------------------------------------------------
-- | Unrolls a right-associative function application of the form
-- @HsApp f (HsApp g (HsApp h x))@ into @([f, g, h], x)@.
unroll :: HsExpr GhcPs -> ([HsExpr GhcPs], HsExpr GhcPs)
unroll :: HsExpr GhcPs -> ([HsExpr GhcPs], HsExpr GhcPs)
unroll (HsPar XPar GhcPs
_ (L SrcSpan
_ HsExpr GhcPs
x)) = HsExpr GhcPs -> ([HsExpr GhcPs], HsExpr GhcPs)
unroll HsExpr GhcPs
x
unroll (HsApp XApp GhcPs
_ (L SrcSpan
_ HsExpr GhcPs
f) (L SrcSpan
_ HsExpr GhcPs
a)) =
  let ([HsExpr GhcPs]
fs, HsExpr GhcPs
r) = HsExpr GhcPs -> ([HsExpr GhcPs], HsExpr GhcPs)
unroll HsExpr GhcPs
a
   in (HsExpr GhcPs
f HsExpr GhcPs -> [HsExpr GhcPs] -> [HsExpr GhcPs]
forall a. a -> [a] -> [a]
: [HsExpr GhcPs]
fs, HsExpr GhcPs
r)
unroll HsExpr GhcPs
x = ([], HsExpr GhcPs
x)