{-# LANGUAGE CPP #-}

{- |
Copyright: (c) 2020 Kowainik
SPDX-License-Identifier: MPL-2.0
Maintainer: Kowainik <xrom.xkov@gmail.com>

Patterns for AST and syntax tree nodes search.
-}

module Stan.Pattern.Ast
    ( -- * Type
      PatternAst (..)
    , Literal (..)

      -- * Helpers
    , namesToPatternAst
    , anyNamesToPatternAst

      -- * eDSL
    , app
    , opApp
    , constructor
    , constructorNameIdentifier
    , dataDecl
    , fixity
    , fun
    , guardBranch
    , lazyField
    , range
    , rhs
    , tuple
    , typeSig

      -- * Pattern matching
    , case'
    , lambdaCase
    , patternMatchBranch
    , patternMatchArrow
    , patternMatch_
    , literalPat
    , wildPat

      -- * More low-level interface
    , literalAnns
    ) where

import Stan.Hie.Compat (DeclType, NodeAnnotation, mkNodeAnnotation, conDec)
import Stan.NameMeta (NameMeta (..))
import Stan.Pattern.Edsl (PatternBool (..))
import Stan.Pattern.Type (PatternType)

import qualified Data.Set as Set


{- | Query pattern used to search AST nodes in HIE AST. This data type
tries to mirror HIE AST to each future matching, so it's quite
low-level, but helper functions are provided.
-}
data PatternAst
    -- | Integer constant in code.
    = PatternAstConstant !Literal
    -- | Name of a specific function, variable or data type.
    | PatternAstName !NameMeta !PatternType
    -- | Variable name.
    | PatternAstVarName !String
    -- | AST node with tags for current node and any children.
    | PatternAstNode
        !(Set NodeAnnotation)  -- ^ Set of context info (pairs of tags)
    -- | AST node with tags for current node and children
    -- patterns. This pattern should match the node exactly.
    | PatternAstNodeExact
        !(Set NodeAnnotation)  -- ^ Set of context info (pairs of tags)
        ![PatternAst]  -- ^ Node children
    -- | AST wildcard, matches anything.
    | PatternAstAnything
    -- | Choice between patterns. Should match either of them.
    | PatternAstOr !PatternAst !PatternAst
    -- | Union of patterns. Should match both of them.
    | PatternAstAnd !PatternAst !PatternAst
    -- | Negation of pattern. Should match everything except this pattern.
    | PatternAstNeg !PatternAst
    -- | AST node with the specified Identifier details (only 'DeclType')
    | PatternAstIdentifierDetailsDecl !DeclType
    deriving stock (Int -> PatternAst -> ShowS
[PatternAst] -> ShowS
PatternAst -> String
(Int -> PatternAst -> ShowS)
-> (PatternAst -> String)
-> ([PatternAst] -> ShowS)
-> Show PatternAst
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PatternAst -> ShowS
showsPrec :: Int -> PatternAst -> ShowS
$cshow :: PatternAst -> String
show :: PatternAst -> String
$cshowList :: [PatternAst] -> ShowS
showList :: [PatternAst] -> ShowS
Show, PatternAst -> PatternAst -> Bool
(PatternAst -> PatternAst -> Bool)
-> (PatternAst -> PatternAst -> Bool) -> Eq PatternAst
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PatternAst -> PatternAst -> Bool
== :: PatternAst -> PatternAst -> Bool
$c/= :: PatternAst -> PatternAst -> Bool
/= :: PatternAst -> PatternAst -> Bool
Eq)

instance PatternBool PatternAst where
    (?) :: PatternAst
    ? :: PatternAst
(?) = PatternAst
PatternAstAnything

    neg :: PatternAst -> PatternAst
    neg :: PatternAst -> PatternAst
neg = PatternAst -> PatternAst
PatternAstNeg

    (|||) :: PatternAst -> PatternAst -> PatternAst
    ||| :: PatternAst -> PatternAst -> PatternAst
(|||) = PatternAst -> PatternAst -> PatternAst
PatternAstOr

    (&&&) :: PatternAst -> PatternAst -> PatternAst
    &&& :: PatternAst -> PatternAst -> PatternAst
(&&&) = PatternAst -> PatternAst -> PatternAst
PatternAstAnd

data Literal
    = ExactNum !Int
    | ExactStr !ByteString
    | PrefixStr !ByteString
    | ContainStr !ByteString
    | AnyLiteral
    deriving stock (Int -> Literal -> ShowS
[Literal] -> ShowS
Literal -> String
(Int -> Literal -> ShowS)
-> (Literal -> String) -> ([Literal] -> ShowS) -> Show Literal
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Literal -> ShowS
showsPrec :: Int -> Literal -> ShowS
$cshow :: Literal -> String
show :: Literal -> String
$cshowList :: [Literal] -> ShowS
showList :: [Literal] -> ShowS
Show, Literal -> Literal -> Bool
(Literal -> Literal -> Bool)
-> (Literal -> Literal -> Bool) -> Eq Literal
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Literal -> Literal -> Bool
== :: Literal -> Literal -> Bool
$c/= :: Literal -> Literal -> Bool
/= :: Literal -> Literal -> Bool
Eq)

{- | Function that creates 'PatternAst' from the given non-empty list of pairs
'NameMeta' and 'PatternType'.

If the list contains only one 'PatternType' then it is simple 'PatternAstName'.
Else it is 'PatternAstOr' of all such 'PatternAstName's.
-}
namesToPatternAst :: NonEmpty (NameMeta, PatternType) -> PatternAst
namesToPatternAst :: NonEmpty (NameMeta, PatternType) -> PatternAst
namesToPatternAst ((NameMeta
nm, PatternType
pat) :| []) = NameMeta -> PatternType -> PatternAst
PatternAstName NameMeta
nm PatternType
pat
namesToPatternAst ((NameMeta
nm, PatternType
pat) :| (NameMeta, PatternType)
x:[(NameMeta, PatternType)]
rest) = PatternAst -> PatternAst -> PatternAst
PatternAstOr
    (NameMeta -> PatternType -> PatternAst
PatternAstName NameMeta
nm PatternType
pat)
    (NonEmpty (NameMeta, PatternType) -> PatternAst
namesToPatternAst (NonEmpty (NameMeta, PatternType) -> PatternAst)
-> NonEmpty (NameMeta, PatternType) -> PatternAst
forall a b. (a -> b) -> a -> b
$ (NameMeta, PatternType)
x (NameMeta, PatternType)
-> [(NameMeta, PatternType)] -> NonEmpty (NameMeta, PatternType)
forall a. a -> [a] -> NonEmpty a
:| [(NameMeta, PatternType)]
rest)

-- | Like 'namesToPatternAst' but doesn't care about types.
anyNamesToPatternAst :: NonEmpty NameMeta -> PatternAst
anyNamesToPatternAst :: NonEmpty NameMeta -> PatternAst
anyNamesToPatternAst = NonEmpty (NameMeta, PatternType) -> PatternAst
namesToPatternAst (NonEmpty (NameMeta, PatternType) -> PatternAst)
-> (NonEmpty NameMeta -> NonEmpty (NameMeta, PatternType))
-> NonEmpty NameMeta
-> PatternAst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NameMeta -> (NameMeta, PatternType))
-> NonEmpty NameMeta -> NonEmpty (NameMeta, PatternType)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (, PatternType
forall a. PatternBool a => a
(?))

-- | @app f x@ is a pattern for function application @f x@.
app :: PatternAst -> PatternAst -> PatternAst
app :: PatternAst -> PatternAst -> PatternAst
app PatternAst
f PatternAst
x = Set NodeAnnotation -> [PatternAst] -> PatternAst
PatternAstNodeExact (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsApp" FastString
"HsExpr")) [PatternAst
f, PatternAst
x]

-- | @opApp x op y@ is a pattern for operator application @x `op` y@.
opApp :: PatternAst -> PatternAst -> PatternAst -> PatternAst
opApp :: PatternAst -> PatternAst -> PatternAst -> PatternAst
opApp PatternAst
x PatternAst
op PatternAst
y = Set NodeAnnotation -> [PatternAst] -> PatternAst
PatternAstNodeExact (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"OpApp" FastString
"HsExpr")) [PatternAst
x, PatternAst
op, PatternAst
y]

-- | @range a b@ is a pattern for @[a .. b]@
range :: PatternAst -> PatternAst -> PatternAst
range :: PatternAst -> PatternAst -> PatternAst
range PatternAst
from PatternAst
to = Set NodeAnnotation -> [PatternAst] -> PatternAst
PatternAstNodeExact (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"ArithSeq" FastString
"HsExpr")) [PatternAst
from, PatternAst
to]

-- | 'lambdaCase' is a pattern for @\case@ expression (not considering branches).
lambdaCase :: PatternAst
lambdaCase :: PatternAst
lambdaCase = Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation
#if __GLASGOW_HASKELL__ < 910
    FastString
"HsLamCase"
#else
    "HsLam"
#endif
    FastString
"HsExpr"))

-- | 'case'' is a pattern for @case EXP of@ expression (not considering branches).
case' :: PatternAst
case' :: PatternAst
case' = Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsCase" FastString
"HsExpr"))

-- | Pattern to represent one pattern matching branch.
patternMatchBranch :: PatternAst
patternMatchBranch :: PatternAst
patternMatchBranch = Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"Match" FastString
"Match"))

{- | Pattern for @_@ in pattern matching.

__Note:__ presents on GHC >=8.10 only.
-}
wildPat :: PatternAst
wildPat :: PatternAst
wildPat = Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"WildPat" FastString
"Pat"))

{- | Pattern for literals in pattern matching.

__Note:__ presents on GHC >=8.10 only.
-}
literalPat :: PatternAst
literalPat :: PatternAst
literalPat = Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"NPat" FastString
"Pat"))
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
||| Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"LitPat" FastString
"Pat"))

-- | Pattern to represent one pattern matching branch on @_@.
patternMatch_ :: PatternAst -> PatternAst
patternMatch_ :: PatternAst -> PatternAst
patternMatch_ PatternAst
val = Set NodeAnnotation -> [PatternAst] -> PatternAst
PatternAstNodeExact (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"Match" FastString
"Match"))
#if __GLASGOW_HASKELL__ >= 810
    ([PatternAst] -> PatternAst) -> [PatternAst] -> PatternAst
forall a b. (a -> b) -> a -> b
$ PatternAst
wildPat PatternAst -> [PatternAst] -> [PatternAst]
forall a. a -> [a] -> [a]
:
#endif
    [PatternAst -> PatternAst
patternMatchArrow PatternAst
val]

-- | Pattern to represent right side of the pattern matching, e.g. @-> "foo"@.
patternMatchArrow :: PatternAst -> PatternAst
patternMatchArrow :: PatternAst -> PatternAst
patternMatchArrow PatternAst
x = Set NodeAnnotation -> [PatternAst] -> PatternAst
PatternAstNodeExact (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"GRHS" FastString
"GRHS")) [PatternAst
x]

{- | Pattern for the top-level fixity declaration:

@
infixr 7 ***, +++, ???
@
-}
fixity :: PatternAst
fixity :: PatternAst
fixity = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"FixitySig" FastString
"FixitySig")

{- | Pattern for the function type signature declaration:

@
foo :: Some -> Type
@
-}
typeSig :: PatternAst
typeSig :: PatternAst
typeSig = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"TypeSig" FastString
"Sig")

absBinds :: NodeAnnotation
absBinds :: NodeAnnotation
absBinds =
#if __GLASGOW_HASKELL__ < 904
  mkNodeAnnotation "AbsBinds" "HsBindLR"
#else
  FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"XHsBindsLR" FastString
"HsBindLR"
#endif

{- | Pattern for the function definition:

@
foo x y = ...
@
-}
fun :: PatternAst
fun :: PatternAst
fun = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ [NodeAnnotation] -> Set NodeAnnotation
forall a. Ord a => [a] -> Set a
Set.fromList
    [ NodeAnnotation
absBinds
    , FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"FunBind"  FastString
"HsBindLR"
    , FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"Match"    FastString
"Match"
    ]

{- | @data@ or @newtype@ declaration.
-}
dataDecl :: PatternAst
dataDecl :: PatternAst
dataDecl = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"DataDecl" FastString
"TyClDecl")

{- | Constructor of a plain data type or newtype. Children of node
that matches this pattern are constructor fields.
-}
constructor :: PatternAst
constructor :: PatternAst
constructor = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"ConDeclH98" FastString
"ConDecl")

{- | Constructor name Identifier info
-}
constructorNameIdentifier :: PatternAst
constructorNameIdentifier :: PatternAst
constructorNameIdentifier = DeclType -> PatternAst
PatternAstIdentifierDetailsDecl DeclType
conDec

{- | Lazy data type field. Comes in two shapes:

1. Record field, like: @foo :: Text@
2. Simple type: @Int@
-}
lazyField :: PatternAst
lazyField :: PatternAst
lazyField = PatternAst
lazyRecordField PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
||| PatternAst
type_

{- | Pattern for any occurrence of a plain type. Covers the following
cases:

* Simple type: Int, Bool, a
* Higher-kinded type: Maybe Int, Either String a
* Type in parenthesis: (Int)
* Tuples: (Int, Bool)
* List type: [Int]
* Function type: Int -> Bool
-}
type_ :: PatternAst
type_ :: PatternAst
type_ =
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsTyVar" FastString
"HsType"))  -- simple type: Int, Bool
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
|||
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsAppTy" FastString
"HsType"))  -- composite: Maybe Int
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
|||
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsParTy" FastString
"HsType"))  -- type in ()
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
|||
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsTupleTy" FastString
"HsType"))  -- tuple types: (Int, Bool)
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
|||
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsListTy" FastString
"HsType"))  -- list types: [Int]
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
|||
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsFunTy" FastString
"HsType"))  -- function types: Int -> Bool

{- | Pattern for the field without the explicit bang pattern:

@
someField :: Int
@
-}
lazyRecordField :: PatternAst
lazyRecordField :: PatternAst
lazyRecordField = Set NodeAnnotation -> [PatternAst] -> PatternAst
PatternAstNodeExact
    (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"ConDeclField" FastString
"ConDeclField"))
    [ Set NodeAnnotation -> PatternAst
PatternAstNode
        ([Item (Set NodeAnnotation)] -> Set NodeAnnotation
forall l. IsList l => [Item l] -> l
fromList
            [ Item (Set NodeAnnotation)
NodeAnnotation
absBinds
            , FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"FunBind" FastString
"HsBindLR"
            ]
        )
    , PatternAst
type_
    ]

{- | Pattern for tuples:

* Type signatures: foo :: (Int, Int, Int, Int)
* Literals: (True, 0, [], Nothing)
-}
tuple :: PatternAst
tuple :: PatternAst
tuple =
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsTupleTy" FastString
"HsType"))  -- tuple type
    PatternAst -> PatternAst -> PatternAst
forall a. PatternBool a => a -> a -> a
|||
    Set NodeAnnotation -> PatternAst
PatternAstNode (OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"ExplicitTuple" FastString
"HsExpr"))  -- tuple literal

{- | Pattern for a single @guard@ branch:

@
    | x < y = ...
@
-}
guardBranch :: PatternAst
guardBranch :: PatternAst
guardBranch = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"BodyStmt" FastString
"StmtLR")

{- | Pattern for the right-hand-side. Usually an equality sign.

@
   foo = baz
@
-}
rhs :: PatternAst
rhs :: PatternAst
rhs = Set NodeAnnotation -> PatternAst
PatternAstNode (Set NodeAnnotation -> PatternAst)
-> Set NodeAnnotation -> PatternAst
forall a b. (a -> b) -> a -> b
$ OneItem (Set NodeAnnotation) -> Set NodeAnnotation
forall x. One x => OneItem x -> x
one (FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"GRHS" FastString
"GRHS")

-- | Annotations for constants: 0, "foo", etc.
literalAnns :: NodeAnnotation
literalAnns :: NodeAnnotation
literalAnns = FastString -> FastString -> NodeAnnotation
mkNodeAnnotation FastString
"HsOverLit" FastString
"HsExpr"