{-# LANGUAGE CPP          #-}
{-# LANGUAGE ViewPatterns #-}

module THSH.Internal.HsExprUtils
  ( RdrName
  , findFreeVariables
  ) where

import           GHC                         (GenLocated (..), Located, SrcSpan, locA, unLoc)
import qualified GHC.Hs.Expr                 as HsExpr (GRHS (..), GRHSs (..), HsExpr (..), Match (..), MatchGroup (..))
import qualified GHC.Hs.Extension
import           GHC.Types.Name.Reader       (RdrName (..))
import qualified Language.Haskell.Syntax.Pat as Pat
--
import           Data.Data                   (Data, gmapQ)
import           Data.Typeable               (Typeable, cast)


findFreeVariables :: Data a => a -> [(SrcSpan, RdrName)]
findFreeVariables :: forall a. Data a => a -> [(SrcSpan, RdrName)]
findFreeVariables a
item = [(SrcSpan, RdrName)]
allNames
  where
    -- Find all free Variables in an HsExpr
    f :: forall a. (Data a, Typeable a) => a -> [Located RdrName]
    f :: forall a. (Data a, Typeable a) => a -> [Located RdrName]
f a
expr = case forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast @_ @(HsExpr.HsExpr GHC.Hs.Extension.GhcPs) a
expr of
#if MIN_VERSION_ghc(9,2,0)
      Just (HsExpr.HsVar XVar GhcPs
_ l :: LIdP GhcPs
l@(L SrcSpanAnnN
a RdrName
_)) -> [SrcSpan -> RdrName -> Located RdrName
forall l e. l -> e -> GenLocated l e
L (SrcSpanAnnN -> SrcSpan
forall a. HasLoc a => a -> SrcSpan
locA SrcSpanAnnN
a) (GenLocated SrcSpanAnnN RdrName -> RdrName
forall l e. GenLocated l e -> e
unLoc LIdP GhcPs
GenLocated SrcSpanAnnN RdrName
l)]
#else
      Just (HsExpr.HsVar _ l) -> [l]
#endif

#if MIN_VERSION_ghc(9,10,0)
      Just (HsExpr.HsLam XLam GhcPs
_ HsLamVariant
_ (HsExpr.MG XMG GhcPs (LHsExpr GhcPs)
_ (XRec GhcPs [LMatch GhcPs (LHsExpr GhcPs)]
-> [GenLocated
      SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
GenLocated
  SrcSpanAnnL
  [GenLocated
     SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
-> [GenLocated
      SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
forall l e. GenLocated l e -> e
unLoc -> ((GenLocated
   SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
 -> Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> [GenLocated
      SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
-> [Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
forall a b. (a -> b) -> [a] -> [b]
map GenLocated
  SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
forall l e. GenLocated l e -> e
unLoc -> [HsExpr.Match XCMatch GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
_ HsMatchContext (LIdP (NoGhcTc GhcPs))
_ ((GenLocated SrcSpanAnnA (Pat GhcPs) -> Pat GhcPs)
-> [GenLocated SrcSpanAnnA (Pat GhcPs)] -> [Pat GhcPs]
forall a b. (a -> b) -> [a] -> [b]
map GenLocated SrcSpanAnnA (Pat GhcPs) -> Pat GhcPs
forall l e. GenLocated l e -> e
unLoc -> [Pat GhcPs]
ps) (HsExpr.GRHSs XCGRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
_ [LGRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
-> GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
GenLocated
  EpAnnCO (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
forall l e. GenLocated l e -> e
unLoc -> HsExpr.GRHS XCGRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
_ [GuardLStmt GhcPs]
_ (GenLocated SrcSpanAnnA (HsExpr GhcPs) -> HsExpr GhcPs
forall l e. GenLocated l e -> e
unLoc -> HsExpr GhcPs
e)] HsLocalBinds GhcPs
_)])))) -> (Located RdrName -> Bool) -> [Located RdrName] -> [Located RdrName]
forall a. (a -> Bool) -> [a] -> [a]
filter Located RdrName -> Bool
forall {l}. GenLocated l RdrName -> Bool
keepVar [Located RdrName]
subVars
#elif MIN_VERSION_ghc(9,6,0)
      Just (HsExpr.HsLam _ (HsExpr.MG _ (unLoc -> (map unLoc -> [HsExpr.Match _ _ (map unLoc -> ps) (HsExpr.GRHSs _ [unLoc -> HsExpr.GRHS _ _ (unLoc -> e)] _)])))) -> filter keepVar subVars
#else
      Just (HsExpr.HsLam _ (HsExpr.MG _ (unLoc -> (map unLoc -> [HsExpr.Match _ _ (map unLoc -> ps) (HsExpr.GRHSs _ [unLoc -> HsExpr.GRHS _ _ (unLoc -> e)] _)])) _)) -> filter keepVar subVars
#endif
        where
          keepVar :: GenLocated l RdrName -> Bool
keepVar (L l
_ RdrName
n) = RdrName
n RdrName -> [RdrName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [RdrName]
subPats
          subVars :: [Located RdrName]
subVars = [[Located RdrName]] -> [Located RdrName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Located RdrName]] -> [Located RdrName])
-> [[Located RdrName]] -> [Located RdrName]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [Located RdrName])
-> [HsExpr GhcPs] -> [[Located RdrName]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> [HsExpr GhcPs] -> [u]
gmapQ d -> [Located RdrName]
forall d. Data d => d -> [Located RdrName]
forall a. (Data a, Typeable a) => a -> [Located RdrName]
f [HsExpr GhcPs
e]
          subPats :: [RdrName]
subPats = [[RdrName]] -> [RdrName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[RdrName]] -> [RdrName]) -> [[RdrName]] -> [RdrName]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [RdrName]) -> [Pat GhcPs] -> [[RdrName]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> [Pat GhcPs] -> [u]
gmapQ d -> [RdrName]
forall d. Data d => d -> [RdrName]
forall a. (Data a, Typeable a) => a -> [RdrName]
findPats [Pat GhcPs]
ps
      Maybe (HsExpr GhcPs)
_ -> [[Located RdrName]] -> [Located RdrName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Located RdrName]] -> [Located RdrName])
-> [[Located RdrName]] -> [Located RdrName]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [Located RdrName])
-> a -> [[Located RdrName]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> a -> [u]
gmapQ d -> [Located RdrName]
forall d. Data d => d -> [Located RdrName]
forall a. (Data a, Typeable a) => a -> [Located RdrName]
f a
expr

    -- Find all Variables bindings (i.e. patterns) in an HsExpr
    findPats :: forall a. (Data a, Typeable a) => a -> [RdrName]
    findPats :: forall a. (Data a, Typeable a) => a -> [RdrName]
findPats a
p = case forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast @_ @(Pat.Pat GHC.Hs.Extension.GhcPs) a
p of
      Just (Pat.VarPat XVarPat GhcPs
_ (LIdP GhcPs -> RdrName
GenLocated SrcSpanAnnN RdrName -> RdrName
forall l e. GenLocated l e -> e
unLoc -> RdrName
name)) -> [RdrName
name]
      Maybe (Pat GhcPs)
_                                   -> [[RdrName]] -> [RdrName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[RdrName]] -> [RdrName]) -> [[RdrName]] -> [RdrName]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [RdrName]) -> a -> [[RdrName]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> a -> [u]
gmapQ d -> [RdrName]
forall d. Data d => d -> [RdrName]
forall a. (Data a, Typeable a) => a -> [RdrName]
findPats a
p
    -- Be careful, we wrap hsExpr in a list, so the toplevel hsExpr will be
    -- seen by gmapQ. Otherwise it will miss variables if they are the top
    -- level expression: gmapQ only checks sub constructors.
    allVars :: [Located RdrName]
allVars = [[Located RdrName]] -> [Located RdrName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Located RdrName]] -> [Located RdrName])
-> [[Located RdrName]] -> [Located RdrName]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [Located RdrName])
-> [a] -> [[Located RdrName]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> [a] -> [u]
gmapQ d -> [Located RdrName]
forall d. Data d => d -> [Located RdrName]
forall a. (Data a, Typeable a) => a -> [Located RdrName]
f [a
item]
    allNames :: [(SrcSpan, RdrName)]
allNames = (Located RdrName -> (SrcSpan, RdrName))
-> [Located RdrName] -> [(SrcSpan, RdrName)]
forall a b. (a -> b) -> [a] -> [b]
map (\(L SrcSpan
l RdrName
e) -> (SrcSpan
l, RdrName
e)) [Located RdrName]
allVars