-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Derive.Superclass
-- Copyright   :  (c) Song Zhang
-- License     :  BSD-style (see the LICENSE file)
-- 
-- Maintainer  :  haskell.zhang.song `at` hotmail.com
-- Stability   :  experimental
-- Portability :  non-portable
--
-----------------------------------------------------------------------------

module Data.Derive.Superclass
  ( deriving_superclasses
  , strategy_deriving_superclasses
  , newtype_deriving_superclasses
  , gnds
  ) where

import           Control.Monad
import           Control.Monad.Trans
import           Control.Monad.Trans.State
import           Data.Derive.TopDown.CxtGen
import           Data.Derive.TopDown.IsInstance
import           Data.Derive.TopDown.Lib
import           Data.List                      ( foldl1'
                                                , nub
                                                )
import           Language.Haskell.TH

-- Only support class that has paramter with kind * or * -> *
deriving_superclasses'
  :: Maybe DerivStrategy -> ClassName -> TypeName -> StateT [Type] Q [Dec]
deriving_superclasses' :: Maybe DerivStrategy
-> ClassName -> ClassName -> StateT [Type] Q [Dec]
deriving_superclasses' Maybe DerivStrategy
st ClassName
cn ClassName
tn = do
  [ClassName]
pnames             <- Q [ClassName] -> StateT [Type] Q [ClassName]
forall (m :: * -> *) a. Monad m => m a -> StateT [Type] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q [ClassName] -> StateT [Type] Q [ClassName])
-> Q [ClassName] -> StateT [Type] Q [ClassName]
forall a b. (a -> b) -> a -> b
$ ClassName -> Q [ClassName]
reifyTypeParameters ClassName
tn
  [Type]
types              <- StateT [Type] Q [Type]
forall (m :: * -> *) s. Monad m => StateT s m s
get
  Bool
isCnHighOrderClass <- Q Bool -> StateT [Type] Q Bool
forall (m :: * -> *) a. Monad m => m a -> StateT [Type] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Bool -> StateT [Type] Q Bool) -> Q Bool -> StateT [Type] Q Bool
forall a b. (a -> b) -> a -> b
$ ClassName -> Q Bool
isHigherOrderClass ClassName
cn
  let t :: Type
t = if Bool
isCnHighOrderClass
        then
          let pns :: [ClassName]
pns = [ClassName] -> [ClassName]
forall a. HasCallStack => [a] -> [a]
init [ClassName]
pnames
          in  if [ClassName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ClassName]
pns
                then ClassName -> Type
ConT ClassName
tn
                else (Type -> Type -> Type) -> [Type] -> Type
forall a. HasCallStack => (a -> a -> a) -> [a] -> a
foldl1' Type -> Type -> Type
AppT (ClassName -> Type
ConT ClassName
tn Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: ((ClassName -> Type) -> [ClassName] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ClassName -> Type
VarT [ClassName]
pns))
        else (Type -> Type -> Type) -> [Type] -> Type
forall a. HasCallStack => (a -> a -> a) -> [a] -> a
foldl1' Type -> Type -> Type
AppT (ClassName -> Type
ConT ClassName
tn Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: ((ClassName -> Type) -> [ClassName] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ClassName -> Type
VarT [ClassName]
pnames))
  let tp :: Type
tp = Type -> Type -> Type
AppT (ClassName -> Type
ConT ClassName
cn) Type
t

  Bool
isIns <- Q Bool -> StateT [Type] Q Bool
forall (m :: * -> *) a. Monad m => m a -> StateT [Type] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Bool -> StateT [Type] Q Bool) -> Q Bool -> StateT [Type] Q Bool
forall a b. (a -> b) -> a -> b
$ ClassName -> [Type] -> Q Bool
isInstance' ClassName
cn [Type
t]
  if (Bool
isIns Bool -> Bool -> Bool
|| Type -> [Type] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Type
tp [Type]
types)
    then [Dec] -> StateT [Type] Q [Dec]
forall a. a -> StateT [Type] Q a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    else do
      [Type]
classContext <- if Bool
isCnHighOrderClass
        then [Type] -> StateT [Type] Q [Type]
forall a. a -> StateT [Type] Q a
forall (m :: * -> *) a. Monad m => a -> m a
return []
        else Q [Type] -> StateT [Type] Q [Type]
forall (m :: * -> *) a. Monad m => m a -> StateT [Type] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q [Type] -> StateT [Type] Q [Type])
-> Q [Type] -> StateT [Type] Q [Type]
forall a b. (a -> b) -> a -> b
$ ClassName -> ClassName -> Q [Type]
genInferredContext ClassName
cn ClassName
tn
      let topClassInstance :: [Dec]
topClassInstance = [Maybe DerivStrategy -> [Type] -> Type -> Dec
StandaloneDerivD Maybe DerivStrategy
st [Type]
classContext Type
tp]
      ([Type] -> [Type]) -> StateT [Type] Q ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Type
tp Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:)
      Info
ci <- Q Info -> StateT [Type] Q Info
forall (m :: * -> *) a. Monad m => m a -> StateT [Type] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Info -> StateT [Type] Q Info) -> Q Info -> StateT [Type] Q Info
forall a b. (a -> b) -> a -> b
$ ClassName -> Q Info
reify ClassName
cn
      case Info
ci of
        ClassI (ClassD [Type]
ctx ClassName
_ [TyVarBndr ()]
_ [FunDep]
_ [Dec]
_) [Dec]
_ -> do
          let classConTs :: [Type]
classConTs = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
getTypeConstructor [Type]
ctx
          [Dec]
ss <- ([[Dec]] -> [Dec])
-> StateT [Type] Q [[Dec]] -> StateT [Type] Q [Dec]
forall a b. (a -> b) -> StateT [Type] Q a -> StateT [Type] Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Dec] -> [Dec]
forall a. Eq a => [a] -> [a]
nub ([Dec] -> [Dec]) -> ([[Dec]] -> [Dec]) -> [[Dec]] -> [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat) (StateT [Type] Q [[Dec]] -> StateT [Type] Q [Dec])
-> StateT [Type] Q [[Dec]] -> StateT [Type] Q [Dec]
forall a b. (a -> b) -> a -> b
$ [Type]
-> (Type -> StateT [Type] Q [Dec]) -> StateT [Type] Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
classConTs ((Type -> StateT [Type] Q [Dec]) -> StateT [Type] Q [[Dec]])
-> (Type -> StateT [Type] Q [Dec]) -> StateT [Type] Q [[Dec]]
forall a b. (a -> b) -> a -> b
$ \Type
superCln ->
            case Type
superCln of
              ConT ClassName
className -> do
                [Dec]
superclass_decls <- Maybe DerivStrategy
-> ClassName -> ClassName -> StateT [Type] Q [Dec]
deriving_superclasses' Maybe DerivStrategy
st ClassName
className ClassName
tn
                [Dec] -> StateT [Type] Q [Dec]
forall a. a -> StateT [Type] Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Dec]
superclass_decls
              Type
x -> [Char] -> StateT [Type] Q [Dec]
forall a. HasCallStack => [Char] -> a
error ([Char] -> StateT [Type] Q [Dec])
-> [Char] -> StateT [Type] Q [Dec]
forall a b. (a -> b) -> a -> b
$ [Char]
"cannot generate class for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Show a => a -> [Char]
show Type
x
          [Dec] -> StateT [Type] Q [Dec]
forall a. a -> StateT [Type] Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> StateT [Type] Q [Dec]) -> [Dec] -> StateT [Type] Q [Dec]
forall a b. (a -> b) -> a -> b
$ [Dec]
topClassInstance [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec]
ss
        Info
_ -> [Char] -> StateT [Type] Q [Dec]
forall a. HasCallStack => [Char] -> a
error ([Char] -> StateT [Type] Q [Dec])
-> [Char] -> StateT [Type] Q [Dec]
forall a b. (a -> b) -> a -> b
$ ClassName -> [Char]
forall a. Show a => a -> [Char]
show ClassName
cn [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"is not type class"

{- | Note: It cannot be used with mutual recursive types.

For mutual recursive types, you need to put them together. For mutual recursive types @T1@ and @T2@:

@
fmap concat (sequence [(deriving_superclasses ''Ord ''T1), (deriving_superclasses ''Ord ''T2)])
@
-}
deriving_superclasses :: ClassName -> TypeName -> Q [Dec]
deriving_superclasses :: ClassName -> ClassName -> Q [Dec]
deriving_superclasses ClassName
cn ClassName
tn =
  StateT [Type] Q [Dec] -> [Type] -> Q [Dec]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Maybe DerivStrategy
-> ClassName -> ClassName -> StateT [Type] Q [Dec]
deriving_superclasses' Maybe DerivStrategy
forall a. Maybe a
Nothing ClassName
cn ClassName
tn) []

strategy_deriving_superclasses
  :: DerivStrategy  -- ^ deriving strategy
  -> ClassName      -- ^ class name
  -> TypeName       -- ^ type name
  -> Q [Dec]
strategy_deriving_superclasses :: DerivStrategy -> ClassName -> ClassName -> Q [Dec]
strategy_deriving_superclasses DerivStrategy
st ClassName
cn ClassName
tn =
  StateT [Type] Q [Dec] -> [Type] -> Q [Dec]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Maybe DerivStrategy
-> ClassName -> ClassName -> StateT [Type] Q [Dec]
deriving_superclasses' (DerivStrategy -> Maybe DerivStrategy
forall a. a -> Maybe a
Just DerivStrategy
st) ClassName
cn ClassName
tn) []

-- | Use newtype strategy to derive all the superclass instances.
newtype_deriving_superclasses :: ClassName -> TypeName -> Q [Dec]
newtype_deriving_superclasses :: ClassName -> ClassName -> Q [Dec]
newtype_deriving_superclasses = DerivStrategy -> ClassName -> ClassName -> Q [Dec]
strategy_deriving_superclasses DerivStrategy
NewtypeStrategy

-- | Abbreviation for @newtype_deriving_superclasses@. for generalized newtype deriving
gnds :: ClassName -> TypeName -> Q [Dec]
gnds :: ClassName -> ClassName -> Q [Dec]
gnds = ClassName -> ClassName -> Q [Dec]
newtype_deriving_superclasses