{-# LANGUAGE TupleSections #-}

-- |
-- Module      :   Grisette.Internal.TH.DeriveWithHandlers
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.DeriveWithHandlers
  ( deriveWithHandlers,
  )
where

import Control.Monad (foldM, unless, when)
import Data.List (transpose)
import qualified Data.Map as M
import Data.Maybe (fromMaybe, mapMaybe)
import Grisette.Internal.TH.DeriveInstanceProvider
  ( DeriveInstanceProvider (instanceDeclaration),
  )
import Grisette.Internal.TH.DeriveTypeParamHandler
  ( DeriveTypeParamHandler (handleBody, handleTypeParams),
    SomeDeriveTypeParamHandler,
  )
import Grisette.Internal.TH.Util
  ( allSameKind,
    dropNTypeParam,
    reifyDatatypeWithFreshNames,
    substDataType,
  )
import Language.Haskell.TH (Dec, Name, Q)
import Language.Haskell.TH.Datatype
  ( ConstructorInfo (constructorFields),
    DatatypeInfo (datatypeCons, datatypeVars),
    datatypeType,
    reifyDatatype,
    tvName,
  )

transposeMatrix :: Int -> [[a]] -> [[a]]
transposeMatrix :: forall a. Int -> [[a]] -> [[a]]
transposeMatrix Int
n [] = Int -> [a] -> [[a]]
forall a. Int -> a -> [a]
replicate Int
n []
transposeMatrix Int
_ [[a]]
x = [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose [[a]]
x

-- | Derive instances for a list of types with a list of handlers and a
-- provider.
deriveWithHandlers ::
  (DeriveInstanceProvider provider) =>
  [SomeDeriveTypeParamHandler] ->
  provider ->
  Bool ->
  Int ->
  [Name] ->
  Q [Dec]
deriveWithHandlers :: forall provider.
DeriveInstanceProvider provider =>
[SomeDeriveTypeParamHandler]
-> provider -> Bool -> Int -> [Name] -> Q [Dec]
deriveWithHandlers
  [SomeDeriveTypeParamHandler]
handlers
  provider
provider
  Bool
ignoreBodyConstraints
  Int
numDroppedTailTypes
  [Name]
names = do
    Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
numDroppedTailTypes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveWithHandlers: numDroppedTailTypes must be non-negative"
    Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
numDroppedTailTypes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
ignoreBodyConstraints) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
        String
"deriveWithHandlers: ignoreBodyConstraints must be True if "
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"numDroppedTailTypes > 0"
    Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Name] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
names) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveWithHandlers: no types provided"
    [DatatypeInfo]
datatypes <-
      if [Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
        then (Name -> Q DatatypeInfo) -> [Name] -> Q [DatatypeInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Name -> Q DatatypeInfo
reifyDatatype [Name]
names
        else (Name -> Q DatatypeInfo) -> [Name] -> Q [DatatypeInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Name -> Q DatatypeInfo
reifyDatatypeWithFreshNames [Name]
names

    let tyVars :: [[TyVarBndrUnit]]
tyVars =
          Int -> [[TyVarBndrUnit]] -> [[TyVarBndrUnit]]
forall a. Int -> [[a]] -> [[a]]
transposeMatrix Int
0 ([[TyVarBndrUnit]] -> [[TyVarBndrUnit]])
-> [[TyVarBndrUnit]] -> [[TyVarBndrUnit]]
forall a b. (a -> b) -> a -> b
$
            (DatatypeInfo -> [TyVarBndrUnit])
-> [DatatypeInfo] -> [[TyVarBndrUnit]]
forall a b. (a -> b) -> [a] -> [b]
map
              ([TyVarBndrUnit] -> [TyVarBndrUnit]
forall a. [a] -> [a]
reverse ([TyVarBndrUnit] -> [TyVarBndrUnit])
-> (DatatypeInfo -> [TyVarBndrUnit])
-> DatatypeInfo
-> [TyVarBndrUnit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [TyVarBndrUnit] -> [TyVarBndrUnit]
forall a. Int -> [a] -> [a]
drop Int
numDroppedTailTypes ([TyVarBndrUnit] -> [TyVarBndrUnit])
-> (DatatypeInfo -> [TyVarBndrUnit])
-> DatatypeInfo
-> [TyVarBndrUnit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TyVarBndrUnit] -> [TyVarBndrUnit]
forall a. [a] -> [a]
reverse ([TyVarBndrUnit] -> [TyVarBndrUnit])
-> (DatatypeInfo -> [TyVarBndrUnit])
-> DatatypeInfo
-> [TyVarBndrUnit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DatatypeInfo -> [TyVarBndrUnit]
datatypeVars)
              [DatatypeInfo]
datatypes
    Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (([TyVarBndrUnit] -> Bool) -> [[TyVarBndrUnit]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all [TyVarBndrUnit] -> Bool
allSameKind [[TyVarBndrUnit]]
tyVars) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveWithHandlers: all type variables must be aligned"
    [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints <-
      ([([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
 -> SomeDeriveTypeParamHandler
 -> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> [SomeDeriveTypeParamHandler]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
        ((SomeDeriveTypeParamHandler
 -> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
 -> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> SomeDeriveTypeParamHandler
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SomeDeriveTypeParamHandler
  -> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
  -> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
 -> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
 -> SomeDeriveTypeParamHandler
 -> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> (SomeDeriveTypeParamHandler
    -> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
    -> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> SomeDeriveTypeParamHandler
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall a b. (a -> b) -> a -> b
$ Int
-> SomeDeriveTypeParamHandler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall handler.
DeriveTypeParamHandler handler =>
Int
-> handler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
handleTypeParams ([DatatypeInfo] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DatatypeInfo]
datatypes))
        ( ([TyVarBndrUnit] -> ([(TyVarBndrUnit, Maybe Type)], Maybe [Type]))
-> [[TyVarBndrUnit]]
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall a b. (a -> b) -> [a] -> [b]
map
            (\[TyVarBndrUnit]
tyVarList -> ((,Maybe Type
forall a. Maybe a
Nothing) (TyVarBndrUnit -> (TyVarBndrUnit, Maybe Type))
-> [TyVarBndrUnit] -> [(TyVarBndrUnit, Maybe Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndrUnit]
tyVarList, Maybe [Type]
forall a. Maybe a
Nothing))
            [[TyVarBndrUnit]]
tyVars
        )
        [SomeDeriveTypeParamHandler]
handlers

    let allTyVarsConstraints :: [Type]
allTyVarsConstraints =
          (([(TyVarBndrUnit, Maybe Type)], Maybe [Type]) -> [Type])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([Type] -> Maybe [Type] -> [Type]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Type] -> [Type])
-> (([(TyVarBndrUnit, Maybe Type)], Maybe [Type]) -> Maybe [Type])
-> ([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(TyVarBndrUnit, Maybe Type)], Maybe [Type]) -> Maybe [Type]
forall a b. (a, b) -> b
snd) [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints
    let tvWithSubst :: [[(TyVarBndrUnit, Maybe Type)]]
tvWithSubst =
          Int
-> [[(TyVarBndrUnit, Maybe Type)]]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall a. Int -> [[a]] -> [[a]]
transposeMatrix ([DatatypeInfo] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DatatypeInfo]
datatypes) ([[(TyVarBndrUnit, Maybe Type)]]
 -> [[(TyVarBndrUnit, Maybe Type)]])
-> [[(TyVarBndrUnit, Maybe Type)]]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall a b. (a -> b) -> a -> b
$
            ([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [(TyVarBndrUnit, Maybe Type)]
forall a b. (a, b) -> a
fst (([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
 -> [(TyVarBndrUnit, Maybe Type)])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints

    let substMaps :: [Map Name Type]
substMaps =
          ([(TyVarBndrUnit, Maybe Type)] -> Map Name Type)
-> [[(TyVarBndrUnit, Maybe Type)]] -> [Map Name Type]
forall a b. (a -> b) -> [a] -> [b]
map
            ( [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                ([(Name, Type)] -> Map Name Type)
-> ([(TyVarBndrUnit, Maybe Type)] -> [(Name, Type)])
-> [(TyVarBndrUnit, Maybe Type)]
-> Map Name Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TyVarBndrUnit, Maybe Type) -> Maybe (Name, Type))
-> [(TyVarBndrUnit, Maybe Type)] -> [(Name, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
                  ( \(TyVarBndrUnit
tv, Maybe Type
t) -> do
                      Type
substTy <- Maybe Type
t
                      (Name, Type) -> Maybe (Name, Type)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (TyVarBndrUnit -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndrUnit
tv, Type
substTy)
                  )
            )
            [[(TyVarBndrUnit, Maybe Type)]]
tvWithSubst
    let substedTypes :: [DatatypeInfo]
substedTypes = (DatatypeInfo -> Map Name Type -> DatatypeInfo)
-> [DatatypeInfo] -> [Map Name Type] -> [DatatypeInfo]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DatatypeInfo -> Map Name Type -> DatatypeInfo
substDataType [DatatypeInfo]
datatypes [Map Name Type]
substMaps
    [Type]
tys <-
      (DatatypeInfo -> Q Type) -> [DatatypeInfo] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> Type -> Q Type
dropNTypeParam Int
numDroppedTailTypes (Type -> Q Type)
-> (DatatypeInfo -> Type) -> DatatypeInfo -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DatatypeInfo -> Type
datatypeType) [DatatypeInfo]
substedTypes
    [Type]
allConstraints <-
      ( if Bool
ignoreBodyConstraints
          then [Type] -> Q [Type]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
allTyVarsConstraints
          else do
            [Type]
bodyConstraints <- [DatatypeInfo] -> [SomeDeriveTypeParamHandler] -> Q [Type]
forall {t :: * -> *} {a}.
(Traversable t, DeriveTypeParamHandler a) =>
[DatatypeInfo] -> t a -> Q [Type]
handleBodyWithHandlers [DatatypeInfo]
substedTypes [SomeDeriveTypeParamHandler]
handlers
            [Type] -> Q [Type]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$ [Type]
allTyVarsConstraints [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
bodyConstraints
        )
    provider
-> [[(TyVarBndrUnit, Maybe Type)]] -> [Type] -> [Type] -> Q [Dec]
forall provider.
DeriveInstanceProvider provider =>
provider
-> [[(TyVarBndrUnit, Maybe Type)]] -> [Type] -> [Type] -> Q [Dec]
instanceDeclaration
      provider
provider
      (([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [(TyVarBndrUnit, Maybe Type)]
forall a b. (a, b) -> a
fst (([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
 -> [(TyVarBndrUnit, Maybe Type)])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints)
      [Type]
allConstraints
      [Type]
tys
    where
      handleBodyWithHandlers :: [DatatypeInfo] -> t a -> Q [Type]
handleBodyWithHandlers [DatatypeInfo]
datatypes t a
handlers = do
        let cons :: [[ConstructorInfo]]
cons = DatatypeInfo -> [ConstructorInfo]
datatypeCons (DatatypeInfo -> [ConstructorInfo])
-> [DatatypeInfo] -> [[ConstructorInfo]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DatatypeInfo]
datatypes
        let zippedFields :: [[Type]]
zippedFields = [[ConstructorInfo]] -> [[Type]]
forall {t :: * -> *}. Foldable t => [t ConstructorInfo] -> [[Type]]
zipFields [[ConstructorInfo]]
cons
        t [Type] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (t [Type] -> [Type]) -> Q (t [Type]) -> Q [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> Q [Type]) -> t a -> Q (t [Type])
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse (a -> [[Type]] -> Q [Type]
forall handler.
DeriveTypeParamHandler handler =>
handler -> [[Type]] -> Q [Type]
`handleBody` [[Type]]
zippedFields) t a
handlers
      zipFields :: [t ConstructorInfo] -> [[Type]]
zipFields [t ConstructorInfo]
cons = [[Type]] -> [[Type]]
forall a. [[a]] -> [[a]]
transpose ([[Type]] -> [[Type]]) -> [[Type]] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ (ConstructorInfo -> [Type]) -> t ConstructorInfo -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ConstructorInfo -> [Type]
constructorFields (t ConstructorInfo -> [Type]) -> [t ConstructorInfo] -> [[Type]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [t ConstructorInfo]
cons