-- |
-- Module      :   Grisette.Internal.TH.DeriveBuiltin
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.DeriveBuiltin
  ( deriveBuiltinExtra,
    deriveBuiltin,
    deriveBuiltins,
  )
where

import Control.Monad (when)
import Grisette.Internal.TH.DeriveInstanceProvider
  ( Strategy (strategyClassName),
  )
import Grisette.Internal.TH.DeriveTypeParamHandler
  ( NatShouldBePositive (NatShouldBePositive),
    PrimaryConstraint (PrimaryConstraint),
    SomeDeriveTypeParamHandler (SomeDeriveTypeParamHandler),
  )
import Grisette.Internal.TH.DeriveWithHandlers
  ( deriveWithHandlers,
  )
import Grisette.Internal.TH.Util
  ( classNumParam,
    classParamKinds,
    kindNumParam,
  )
import Language.Haskell.TH (Dec, Name, Q)

-- | Derive a builtin class for a type, with extra handlers.
deriveBuiltinExtra ::
  [SomeDeriveTypeParamHandler] -> Bool -> Strategy -> [Name] -> Name -> Q [Dec]
deriveBuiltinExtra :: [SomeDeriveTypeParamHandler]
-> Bool -> Strategy -> [Name] -> Name -> Q [Dec]
deriveBuiltinExtra
  [SomeDeriveTypeParamHandler]
extraHandlers
  Bool
ignoreBodyConstraints
  Strategy
strategy
  [Name]
constraints
  Name
name = do
    let finalCtxName :: Name
finalCtxName = Strategy -> Name
strategyClassName Strategy
strategy
    Int
numParam <- Name -> Q Int
classNumParam Name
finalCtxName
    Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
numParam Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveBuiltin: the class must have at least one parameter"
    [Kind]
kinds <- Name -> Q [Kind]
classParamKinds Name
finalCtxName
    case [Kind]
kinds of
      [] ->
        String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$
          String
"deriveBuiltin: the class must have at least one parameter, bug, "
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"should not happen"
      (Kind
k : [Kind]
ks) -> do
        Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Kind -> Bool) -> [Kind] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
/= Kind
k) [Kind]
ks) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
          String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveBuiltin: all parameters must have the same kind"
        [Int]
constraintNumParams <- (Name -> Q Int) -> [Name] -> Q [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Name -> Q Int
classNumParam [Name]
constraints
        Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
numParam) [Int]
constraintNumParams) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
          String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
            String
"deriveBuiltin: all constraints must have the same number of "
              String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"parameters as the results"
        Int
numDrop <- Kind -> Q Int
kindNumParam Kind
k
        [SomeDeriveTypeParamHandler]
-> Strategy -> Bool -> Int -> [Name] -> Q [Dec]
forall provider.
DeriveInstanceProvider provider =>
[SomeDeriveTypeParamHandler]
-> provider -> Bool -> Int -> [Name] -> Q [Dec]
deriveWithHandlers
          ( NatShouldBePositive -> SomeDeriveTypeParamHandler
forall handler.
DeriveTypeParamHandler handler =>
handler -> SomeDeriveTypeParamHandler
SomeDeriveTypeParamHandler NatShouldBePositive
NatShouldBePositive
              SomeDeriveTypeParamHandler
-> [SomeDeriveTypeParamHandler] -> [SomeDeriveTypeParamHandler]
forall a. a -> [a] -> [a]
: [SomeDeriveTypeParamHandler]
extraHandlers
                [SomeDeriveTypeParamHandler]
-> [SomeDeriveTypeParamHandler] -> [SomeDeriveTypeParamHandler]
forall a. Semigroup a => a -> a -> a
<> ( (PrimaryConstraint -> SomeDeriveTypeParamHandler
forall handler.
DeriveTypeParamHandler handler =>
handler -> SomeDeriveTypeParamHandler
SomeDeriveTypeParamHandler (PrimaryConstraint -> SomeDeriveTypeParamHandler)
-> (Name -> PrimaryConstraint)
-> Name
-> SomeDeriveTypeParamHandler
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name -> Bool -> PrimaryConstraint)
-> Bool -> Name -> PrimaryConstraint
forall a b c. (a -> b -> c) -> b -> a -> c
flip Name -> Bool -> PrimaryConstraint
PrimaryConstraint Bool
False)
                       (Name -> SomeDeriveTypeParamHandler)
-> [Name] -> [SomeDeriveTypeParamHandler]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
constraints
                   )
          )
          Strategy
strategy
          Bool
ignoreBodyConstraints
          Int
numDrop
          (Int -> Name -> [Name]
forall a. Int -> a -> [a]
replicate Int
numParam Name
name)

-- | Derive a builtin class for a type.
deriveBuiltin :: Strategy -> [Name] -> Name -> Q [Dec]
deriveBuiltin :: Strategy -> [Name] -> Name -> Q [Dec]
deriveBuiltin = [SomeDeriveTypeParamHandler]
-> Bool -> Strategy -> [Name] -> Name -> Q [Dec]
deriveBuiltinExtra [] Bool
True

-- | Derive builtin classes for a list of types.
deriveBuiltins :: Strategy -> [Name] -> [Name] -> Q [Dec]
deriveBuiltins :: Strategy -> [Name] -> [Name] -> Q [Dec]
deriveBuiltins Strategy
strategy [Name]
constraints =
  ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [[Dec]] -> Q [Dec])
-> ([Name] -> Q [[Dec]]) -> [Name] -> Q [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name -> Q [Dec]) -> [Name] -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Strategy -> [Name] -> Name -> Q [Dec]
deriveBuiltin Strategy
strategy [Name]
constraints)