{-# LANGUAGE CPP #-}
{-# LANGUAGE StaticPointers #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}

-- | Utility Template Haskell macros.

module Control.Distributed.Closure.TH
  ( cstatic
  , cstaticDict
  , cdict
  , cdictFrom
  , withStatic
  ) where

import           Control.Monad (replicateM, unless)
import           Control.Distributed.Closure
import           Data.Generics (everything, mkQ)
import           Data.List (nub)
import           Data.Typeable (Typeable)
import           GHC.StaticPtr
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH
import           Numeric.Natural

-- | @$(cstatic 'foo)@ is an abbreviation for @closure (static foo)@.
cstatic :: TH.Name -> TH.ExpQ
cstatic :: Name -> ExpQ
cstatic Name
f = [| closure (static $(TH.varE f)) |]

-- | @$(cstaticDict 'foo)@ is an abbreviation for @closure (static foo) `cap`
-- $cdict@, a common pattern for implicitly feeding the static dictionary when
-- which dictionary to choose is clear from context.
cstaticDict :: TH.Name -> TH.ExpQ
cstaticDict :: Name -> ExpQ
cstaticDict Name
f = [| closure (static $(TH.varE f)) `cap` $cdict |]

-- | Abbreviation for @closure (static Dict)@. Example usage:
--
-- @
-- foo :: Closure (Dict (Num a)) -> ...
--
-- foo $cdict ...
-- @
cdict :: TH.ExpQ
cdict :: ExpQ
cdict = Natural -> ExpQ
cdictFrom Natural
0

-- | Create a static dictionary from the given dictionaries. Example usage:
--
-- @
-- $cdictFrom 2 $cdict $cdict :: Closure (Static (Dict (Eq a, Show a)))
-- @
cdictFrom :: Natural -> TH.ExpQ
cdictFrom :: Natural -> ExpQ
cdictFrom Natural
n0 = forall {a} {m :: * -> *} {m :: * -> *} {b}.
(Integral a, Quote m, Quote m) =>
([Name] -> m Exp -> m b) -> m Exp -> a -> m b
apply forall {m :: * -> *}. Quote m => [Name] -> m Exp -> m Exp
abstract [| closure (static $(staticFun n0)) |] Natural
n0
  where
    staticFun :: t -> m Exp
staticFun t
0 = [| Dict |]
    staticFun t
n = [| \Dict -> $(staticFun (n - 1)) |]
    apply :: ([Name] -> m Exp -> m b) -> m Exp -> a -> m b
apply [Name] -> m Exp -> m b
k m Exp
f a
n = do
        [Name]
names <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n) (forall (m :: * -> *). Quote m => String -> m Name
TH.newName String
"x")
        [Name] -> m Exp -> m b
k [Name]
names (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\m Exp
acc Name
x -> [| $acc `cap` $(TH.varE x) |]) m Exp
f [Name]
names)
    abstract :: [Name] -> m Exp -> m Exp
abstract [] m Exp
expr = m Exp
expr
    abstract (Name
nm:[Name]
names) m Exp
expr = [| \ $(TH.varP nm) -> $(abstract names expr) |]

-- | Compute free variables of a type.
fvT :: TH.Type -> [TH.Name]
fvT :: Type -> [Name]
fvT = forall a. Eq a => [a] -> [a]
nub forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. [a] -> [a] -> [a]
(++) ([] forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
`mkQ` (\Type
ty -> [Name
nm | TH.VarT Name
nm <- [Type
ty]]))

caps :: [TH.ExpQ] -> TH.ExpQ
caps :: [ExpQ] -> ExpQ
caps = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\ExpQ
f ExpQ
x -> [| $f `cap` $x|])

-- XXX It turns out that GHC's newName doesn't produce really fresh names. Call
-- newName twice to define two new globals and you'll find they share the same
-- name. A workaround mentioned in https://ghc.haskell.org/trac/ghc/ticket/5398
-- is this snippet of code...
mangleName :: TH.Name -> TH.Name
mangleName :: Name -> Name
mangleName name :: Name
name@(TH.Name OccName
occ NameFlavour
fl) = case NameFlavour
fl of
    TH.NameU Uniq
u -> OccName -> NameFlavour -> Name
TH.Name (Uniq -> OccName
mangle_occ Uniq
u) NameFlavour
fl
    NameFlavour
_ -> Name
name
  where
    mangle_occ :: TH.Uniq -> TH.OccName
    mangle_occ :: Uniq -> OccName
mangle_occ Uniq
uniq = String -> OccName
TH.mkOccName (OccName -> String
TH.occString OccName
occ forall a. [a] -> [a] -> [a]
++ String
"_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Uniq
uniq)

-- | Auto-generates the 'Static' instances corresponding to the given class
-- instances. Example:
--
-- @
-- data T a = T a
--
-- withStatic [d| instance Show a => Show (T a) where ... |]
-- ======>
-- instance Show a => Show (T a) where ...
-- instance (Static (Show a), Typeable a) => Static (Show (T a)) where
--   closureDict = closure (static (Dict -> Dict)) `cap` closureDict
-- @
--
-- You will probably want to enable @FlexibleContexts@ and @ScopedTypeVariables@
-- in modules that use 'withStatic'. 'withStatic' can also handle non-user
-- generated instances like 'Typeable' instances: just write @instance Typeable
-- T@.
withStatic :: TH.DecsQ -> TH.DecsQ
withStatic :: DecsQ -> DecsQ
withStatic = (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Dec] -> DecsQ
go)
  where
    checkExtension :: TH.Extension -> TH.Q ()
    checkExtension :: Extension -> Q ()
checkExtension Extension
ext = do
      Bool
enabled <- Extension -> Q Bool
TH.isExtEnabled Extension
TH.ScopedTypeVariables
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
enabled forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"withStatic requires the language extension " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Extension
ext

    go :: [TH.Dec] -> TH.DecsQ
    go :: [Dec] -> DecsQ
go [] = forall (m :: * -> *) a. Monad m => a -> m a
return []
#if MIN_VERSION_template_haskell(2,11,0)
    go (ins :: Dec
ins@(TH.InstanceD Maybe Overlap
overlap [Type]
cxt Type
hd [Dec]
_):[Dec]
decls) = do
#else
    go (ins@(TH.InstanceD cxt hd _):decls) = do
#endif
        let n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
cxt
        [Type]
dictsigs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
c -> [t| Dict $(return c) |]) [Type]
cxt
        Type
retsig <- [t| Dict $(return hd) |]
        Name
f <- Name -> Name
mangleName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). Quote m => String -> m Name
TH.newName String
"static_helper"
        Exp
fbody <- forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Type
_ ExpQ
body -> [| \Dict -> $body |]) [| Dict |] [Type]
cxt
        let tyf :: Type
tyf = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Type
a Type
b -> Type
TH.ArrowT Type -> Type -> Type
`TH.AppT` Type
a Type -> Type -> Type
`TH.AppT` Type
b) Type
retsig [Type]
dictsigs
#if MIN_VERSION_template_haskell(2,16,0)
            specifiedPlainTV :: TH.Name -> TH.TyVarBndr TH.Specificity
            specifiedPlainTV :: Name -> TyVarBndr Specificity
specifiedPlainTV Name
n = forall flag. Name -> flag -> TyVarBndr flag
TH.PlainTV Name
n Specificity
TH.SpecifiedSpec
#else
            specifiedPlainTV :: TH.Name -> TH.TyVarBndr
            specifiedPlainTV = TH.PlainTV
#endif
            sigf :: Dec
sigf = Name -> Type -> Dec
TH.SigD Name
f ([TyVarBndr Specificity] -> [Type] -> Type -> Type
TH.ForallT (forall a b. (a -> b) -> [a] -> [b]
map Name -> TyVarBndr Specificity
specifiedPlainTV (Type -> [Name]
fvT Type
tyf)) [] Type
tyf)
            declf :: Dec
declf = Pat -> Body -> [Dec] -> Dec
TH.ValD (Name -> Pat
TH.VarP Name
f) (Exp -> Body
TH.NormalB Exp
fbody) []
        [Dec]
methods <- (forall a. a -> [a] -> [a]
:[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Dec
TH.valD
            (forall (m :: * -> *). Quote m => Name -> m Pat
TH.varP 'closureDict)
            (forall (m :: * -> *). Quote m => m Exp -> m Body
TH.normalB ([ExpQ] -> ExpQ
caps ( [| closure (static $(TH.varE f) :: StaticPtr $(return tyf)) |]
                              forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
n [| closureDict |]
                              )))
            []
        [Type]
typeableConstraints <-
          forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ [t| Typeable $(return d) |]
                   | Type
d <- Type
retsig forall a. a -> [a] -> [a]
: [Type]
dictsigs
                   , Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Type -> [Name]
fvT Type
d))
                   ]
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
typeableConstraints) forall a b. (a -> b) -> a -> b
$
          Extension -> Q ()
checkExtension Extension
TH.ScopedTypeVariables
        [Type]
staticcxt <- ([Type]
typeableConstraints forall a. [a] -> [a] -> [a]
++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
c -> [t| Static   $(return c) |]) [Type]
cxt
        Type
statichd <- [t| Static $(return hd) |]
#if MIN_VERSION_template_haskell(2,11,0)
        let staticins :: Dec
staticins = Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
TH.InstanceD Maybe Overlap
overlap [Type]
staticcxt Type
statichd [Dec]
methods
#else
        let staticins = TH.InstanceD staticcxt statichd methods
#endif
        [Dec]
decls' <- [Dec] -> DecsQ
go [Dec]
decls
        case Type
hd of
          TH.AppT (TH.ConT Name
nm) Type
_ | Name
nm forall a. Eq a => a -> a -> Bool
== ''Typeable ->
            forall (m :: * -> *) a. Monad m => a -> m a
return (Dec
sigf forall a. a -> [a] -> [a]
: Dec
declf forall a. a -> [a] -> [a]
: Dec
staticins forall a. a -> [a] -> [a]
: [Dec]
decls')
          Type
_ ->
            forall (m :: * -> *) a. Monad m => a -> m a
return (Dec
ins forall a. a -> [a] -> [a]
: Dec
sigf forall a. a -> [a] -> [a]
: Dec
declf forall a. a -> [a] -> [a]
: Dec
staticins forall a. a -> [a] -> [a]
: [Dec]
decls')
    go (Dec
decl:[Dec]
decls) = (Dec
declforall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dec] -> DecsQ
go [Dec]
decls