{-# HLINT ignore "Unused LANGUAGE pragma" #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

-- |
-- Module      :   Grisette.Internal.TH.GADT.DeriveGADT
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.GADT.DeriveGADT
  ( deriveGADT,
    deriveGADTAll,
    deriveGADTAllExcept,
  )
where

import qualified Data.Map as M
import qualified Data.Set as S
import Grisette.Internal.Core.Data.Class.EvalSym
  ( EvalSym,
    EvalSym1,
    EvalSym2,
  )
import Grisette.Internal.Core.Data.Class.ExtractSym
  ( ExtractSym,
    ExtractSym1,
    ExtractSym2,
  )
import Grisette.Internal.Core.Data.Class.Mergeable
  ( Mergeable,
    Mergeable1,
    Mergeable2,
    Mergeable3,
  )
import Grisette.Internal.TH.GADT.DeriveEvalSym
  ( deriveGADTEvalSym,
    deriveGADTEvalSym1,
    deriveGADTEvalSym2,
  )
import Grisette.Internal.TH.GADT.DeriveExtractSym
  ( deriveGADTExtractSym,
    deriveGADTExtractSym1,
    deriveGADTExtractSym2,
  )
import Grisette.Internal.TH.GADT.DeriveMergeable (genMergeable, genMergeable', genMergeableAndGetMergingInfoResult)
import Language.Haskell.TH (Dec, Name, Q)

deriveProcedureMap :: M.Map Name (Name -> Q [Dec])
deriveProcedureMap :: Map Name (Name -> Q [Dec])
deriveProcedureMap =
  [(Name, Name -> Q [Dec])] -> Map Name (Name -> Q [Dec])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
    [ -- (''Mergeable, deriveGADTMergeable),
      -- (''Mergeable1, deriveGADTMergeable1),
      -- (''Mergeable2, deriveGADTMergeable2),
      -- (''Mergeable3, deriveGADTMergeable3),
      (''EvalSym, Name -> Q [Dec]
deriveGADTEvalSym),
      (''EvalSym1, Name -> Q [Dec]
deriveGADTEvalSym1),
      (''EvalSym2, Name -> Q [Dec]
deriveGADTEvalSym2),
      (''ExtractSym, Name -> Q [Dec]
deriveGADTExtractSym),
      (''ExtractSym1, Name -> Q [Dec]
deriveGADTExtractSym1),
      (''ExtractSym2, Name -> Q [Dec]
deriveGADTExtractSym2)
    ]

deriveSingleGADT :: Name -> Name -> Q [Dec]
deriveSingleGADT :: Name -> Name -> Q [Dec]
deriveSingleGADT Name
typName Name
className = do
  case Name -> Map Name (Name -> Q [Dec]) -> Maybe (Name -> Q [Dec])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
className Map Name (Name -> Q [Dec])
deriveProcedureMap of
    Just Name -> Q [Dec]
procedure -> Name -> Q [Dec]
procedure Name
typName
    Maybe (Name -> Q [Dec])
Nothing ->
      String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ String
"No derivation available for class " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
className

-- | Derive the specified classes for a GADT with the given name.
--
-- Support the following classes.
--
-- * 'Mergeable'
-- * 'Mergeable1'
-- * 'Mergeable2'
-- * 'Mergeable3'
-- * 'EvalSym'
-- * 'EvalSym1'
-- * 'EvalSym2'
-- * 'ExtractSym'
-- * 'ExtractSym1'
-- * 'ExtractSym2'
deriveGADT :: Name -> [Name] -> Q [Dec]
deriveGADT :: Name -> [Name] -> Q [Dec]
deriveGADT Name
typName [Name]
classNames = do
  let allClassNames :: [Name]
allClassNames = Set Name -> [Name]
forall a. Set a -> [a]
S.toList (Set Name -> [Name]) -> Set Name -> [Name]
forall a b. (a -> b) -> a -> b
$ [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList [Name]
classNames
  let ([Name]
ns, [Int]
ms) = [Name] -> ([Name], [Int])
splitMergeable [Name]
allClassNames
  [[Dec]]
decs <- (Name -> Q [Dec]) -> [Name] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name -> Name -> Q [Dec]
deriveSingleGADT Name
typName) [Name]
ns
  [Dec]
decMergeables <- [Int] -> Q [Dec]
deriveMergeables [Int]
ms
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Dec]]
decs [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec]
decMergeables
  where
    deriveMergeables :: [Int] -> Q [Dec]
    deriveMergeables :: [Int] -> Q [Dec]
deriveMergeables [] = [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    deriveMergeables [Int
n] = Name -> Int -> Q [Dec]
genMergeable Name
typName Int
n
    deriveMergeables (Int
n : [Int]
ns) = do
      (MergingInfoResult
info, [Dec]
dn) <- Name -> Int -> Q (MergingInfoResult, [Dec])
genMergeableAndGetMergingInfoResult Name
typName Int
n
      [(Name, [Dec])]
dns <- (Int -> Q (Name, [Dec])) -> [Int] -> Q [(Name, [Dec])]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (MergingInfoResult -> Name -> Int -> Q (Name, [Dec])
genMergeable' MergingInfoResult
info Name
typName) [Int]
ns
      [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [Dec]
dn [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ ((Name, [Dec]) -> [Dec]) -> [(Name, [Dec])] -> [Dec]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Name, [Dec]) -> [Dec]
forall a b. (a, b) -> b
snd [(Name, [Dec])]
dns
    splitMergeable :: [Name] -> ([Name], [Int])
    splitMergeable :: [Name] -> ([Name], [Int])
splitMergeable [] = ([], [])
    splitMergeable (Name
x : [Name]
xs) =
      let ([Name]
ns, [Int]
is) = [Name] -> ([Name], [Int])
splitMergeable [Name]
xs
       in if
            | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Mergeable -> ([Name]
ns, Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
is)
            | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Mergeable1 -> ([Name]
ns, Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
is)
            | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Mergeable2 -> ([Name]
ns, Int
2 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
is)
            | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Mergeable3 -> ([Name]
ns, Int
3 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
is)
            | Bool
otherwise -> (Name
x Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: [Name]
ns, [Int]
is)

-- | Derive all (non-functor) classes related to Grisette for a GADT with the
-- given name.
--
-- Classes that are derived by this procedure are:
--
-- * 'Mergeable'
-- * 'EvalSym'
-- * 'ExtractSym'
--
-- Note that it is okay to derive for non-GADT types using this procedure, and
-- it will be slightly more efficient.
deriveGADTAll :: Name -> Q [Dec]
deriveGADTAll :: Name -> Q [Dec]
deriveGADTAll Name
typName =
  Name -> [Name] -> Q [Dec]
deriveGADT Name
typName [''Mergeable, ''EvalSym, ''ExtractSym]

-- | Derive all (non-functor) classes related to Grisette for a GADT with the
-- given name except the specified classes.
deriveGADTAllExcept :: Name -> [Name] -> Q [Dec]
deriveGADTAllExcept :: Name -> [Name] -> Q [Dec]
deriveGADTAllExcept Name
typName [Name]
classNames = do
  Name -> [Name] -> Q [Dec]
deriveGADT Name
typName ([Name] -> Q [Dec]) -> [Name] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$
    Set Name -> [Name]
forall a. Set a -> [a]
S.toList (Set Name -> [Name]) -> Set Name -> [Name]
forall a b. (a -> b) -> a -> b
$
      [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList [''Mergeable, ''EvalSym, ''ExtractSym]
        Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
S.\\ [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList [Name]
classNames