{-# LANGUAGE DataKinds       #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE PatternSynonyms #-}

-- | Utility functions for working with TH
module Data.Record.Internal.TH.Util (
    -- * Folding
    appsT
  , arrT
    -- * Constructing lists (variations on 'listE')
  , vectorE
  , plistT
  , ptupleT
    -- * Simplified construction
  , simpleFn
  , simplePatSynType
    -- * Dealing with type variables
  , tyVarName
  , tyVarType
    -- * Bang
  , pattern DefaultBang
    -- * Extensions
  , requiresExtensions
  ) where

import Control.Monad
import Data.List (intercalate)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax

import qualified Data.Vector as V

import qualified Data.Record.Internal.TH.Name as N

{-------------------------------------------------------------------------------
  Folding
-------------------------------------------------------------------------------}

-- | Repeated application
--
-- @appsT f [x1, .., xN]@ constructs something like
--
-- > f x1 .. xN
appsT :: Q Type -> [Q Type] -> Q Type
appsT :: Q Type -> [Q Type] -> Q Type
appsT Q Type
t [Q Type]
ts = (Q Type -> Q Type -> Q Type) -> Q Type -> [Q Type] -> Q Type
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Q Type -> Q Type -> Q Type
appT Q Type
t [Q Type]
ts

-- | Repeated application of @(->)@
--
-- @arrT [x1, .., xN] y@ constructs something like
--
-- > x1 -> .. -> xN -> y
arrT :: [Q Type] -> Q Type -> Q Type
arrT :: [Q Type] -> Q Type -> Q Type
arrT [Q Type]
ts Q Type
t = (Q Type -> Q Type -> Q Type) -> Q Type -> [Q Type] -> Q Type
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Q Type
a Q Type
b -> Q Type
arrowT Q Type -> Q Type -> Q Type
`appT` Q Type
a Q Type -> Q Type -> Q Type
`appT` Q Type
b) Q Type
t [Q Type]
ts

{-------------------------------------------------------------------------------
  Constructing lists (variations on 'listE')
-------------------------------------------------------------------------------}

vectorE :: (a -> Q Exp) -> [a] -> Q Exp
vectorE :: (a -> Q Exp) -> [a] -> Q Exp
vectorE a -> Q Exp
f [a]
elems = [| V.fromList $(listE (map f elems)) |]

plistT :: [Q Type] -> Q Type
plistT :: [Q Type] -> Q Type
plistT = (Q Type -> Q Type -> Q Type) -> Q Type -> [Q Type] -> Q Type
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Q Type -> Q Type -> Q Type
cons Q Type
nil
  where
    nil :: Q Type
nil       = Q Type
promotedNilT
    cons :: Q Type -> Q Type -> Q Type
cons Q Type
t Q Type
ts = Q Type
promotedConsT Q Type -> Q Type -> Q Type
`appT` Q Type
t Q Type -> Q Type -> Q Type
`appT` Q Type
ts

ptupleT :: [Q Type] -> Q Type
ptupleT :: [Q Type] -> Q Type
ptupleT [Q Type]
ts = Q Type -> [Q Type] -> Q Type
appsT (Int -> Q Type
promotedTupleT ([Q Type] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Q Type]
ts)) [Q Type]
ts

{-------------------------------------------------------------------------------
  Simplified construction
-------------------------------------------------------------------------------}

-- | Construct simple function
--
-- @simpleFn n typ body@ constructs something like
--
-- > f :: typ
-- > f = body
simpleFn :: N.Name 'VarName flavour -> Q Type -> Q Exp -> Q [Dec]
simpleFn :: Name 'VarName flavour -> Q Type -> Q Exp -> Q [Dec]
simpleFn Name 'VarName flavour
fnName Q Type
qTyp Q Exp
qBody = do
    Type
typ  <- Q Type
qTyp
    Exp
body <- Q Exp
qBody
    [Dec] -> Q [Dec]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [
          Name -> Type -> Dec
SigD Name
fnName' Type
typ
        , Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
fnName') (Exp -> Body
NormalB Exp
body) []
        ]
  where
    fnName' :: Name
    fnName' :: Name
fnName' = Name 'VarName flavour -> Name
forall (ns :: NameSpace) (flavour :: Flavour).
Name ns flavour -> Name
N.toTH Name 'VarName flavour
fnName

-- | Construct simple pattern synonym type
--
-- @simplePatSynType xs [t1, .., tn] s@ constructs something like
--
-- > pattern foo :: forall xs. t1 -> .. -> tn -> s
simplePatSynType :: [TyVarBndr] -> [Q Type] -> Q Type -> Q PatSynType
simplePatSynType :: [TyVarBndr] -> [Q Type] -> Q Type -> Q Type
simplePatSynType [TyVarBndr]
tvars [Q Type]
fieldTypes Q Type
resultType =
      [TyVarBndr] -> CxtQ -> Q Type -> Q Type
forallT [TyVarBndr]
tvars ([Q Type] -> CxtQ
cxt [])
    (Q Type -> Q Type) -> Q Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr] -> CxtQ -> Q Type -> Q Type
forallT []    ([Q Type] -> CxtQ
cxt [])
    (Q Type -> Q Type) -> Q Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [Q Type] -> Q Type -> Q Type
arrT [Q Type]
fieldTypes Q Type
resultType

{-------------------------------------------------------------------------------
  Dealing with type variables
-------------------------------------------------------------------------------}

tyVarName :: TyVarBndr -> Name
tyVarName :: TyVarBndr -> Name
tyVarName (PlainTV  Name
n)   = Name
n
tyVarName (KindedTV Name
n Type
_) = Name
n

tyVarType :: TyVarBndr -> Q Type
tyVarType :: TyVarBndr -> Q Type
tyVarType = Name -> Q Type
varT (Name -> Q Type) -> (TyVarBndr -> Name) -> TyVarBndr -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
tyVarName

{-------------------------------------------------------------------------------
  Bang
-------------------------------------------------------------------------------}

pattern DefaultBang :: Bang
pattern $bDefaultBang :: Bang
$mDefaultBang :: forall r. Bang -> (Void# -> r) -> (Void# -> r) -> r
DefaultBang = Bang NoSourceUnpackedness NoSourceStrictness

{-------------------------------------------------------------------------------
  Extensions
-------------------------------------------------------------------------------}

-- | Check that the specified extensions are enabled
--
-- To improve user experience, we report all missing extensions at once (rather
-- than giving an error for the first missing one).
requiresExtensions :: Quasi m => [Extension] -> m ()
requiresExtensions :: [Extension] -> m ()
requiresExtensions [Extension]
exts = Q () -> m ()
forall (m :: Type -> Type) a. Quasi m => Q a -> m a
runQ (Q () -> m ()) -> Q () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    [Extension]
disabled <- (Extension -> Q Bool) -> [Extension] -> Q [Extension]
forall (m :: Type -> Type) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ((Bool -> Bool) -> Q Bool -> Q Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not (Q Bool -> Q Bool) -> (Extension -> Q Bool) -> Extension -> Q Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extension -> Q Bool
isExtEnabled) [Extension]
exts
    Bool -> Q () -> Q ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless ([Extension] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Extension]
disabled) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$ do
      String -> Q ()
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"Please enable " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ((Extension -> String) -> [Extension] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Extension -> String
forall a. Show a => a -> String
show [Extension]
disabled)