{-# LANGUAGE RecordWildCards #-}

-- |
-- Module      :   Grisette.Internal.TH.GADT.Common
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.GADT.Common
  ( CheckArgsResult (..),
    checkArgs,
  )
where

import Control.Monad (when)
import qualified Data.Map as M
import qualified Data.Set as S
import Grisette.Internal.TH.Util (occName)
import Language.Haskell.TH
  ( Name,
    Q,
    Type (VarT),
    newName,
  )
import Language.Haskell.TH.Datatype
  ( ConstructorInfo (constructorFields),
    DatatypeInfo (datatypeCons, datatypeVars),
    TypeSubstitution (applySubstitution, freeVariables),
    reifyDatatype,
    tvName,
  )
import Language.Haskell.TH.Datatype.TyVarBndr (TyVarBndr_, mapTVName)

-- | Result of 'checkArgs' for a GADT.
data CheckArgsResult = CheckArgsResult
  { CheckArgsResult -> [ConstructorInfo]
constructors :: [ConstructorInfo],
    CheckArgsResult -> [Name]
keptNewNames :: [Name],
    CheckArgsResult -> [TyVarBndr_ ()]
keptNewVars :: [TyVarBndr_ ()],
    CheckArgsResult -> [Name]
argNewNames :: [Name],
    CheckArgsResult -> [TyVarBndr_ ()]
argNewVars :: [TyVarBndr_ ()],
    CheckArgsResult -> Name -> Bool
isVarUsedInFields :: Name -> Bool
  }

-- | Check if the number of type parameters is valid for a GADT, and return
-- new names for the type variables, split into kept and arg parts.
checkArgs ::
  String ->
  Int ->
  Name ->
  Int ->
  Q CheckArgsResult
checkArgs :: String -> Int -> Name -> Int -> Q CheckArgsResult
checkArgs String
clsName Int
maxArgNum Name
typName Int
n = do
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unlines
        [ String
"Cannot derive "
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
clsName
            String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" instance with negative type parameters",
          String
"Requested: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n,
          String
"Hint: Use a non-negative number of type parameters"
        ]
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxArgNum) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
      String
"Requesting "
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
clsName
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" instance with more than "
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
maxArgNum
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" type parameters"
  DatatypeInfo
d <- Name -> Q DatatypeInfo
reifyDatatype Name
typName
  let dvars :: [TyVarBndr_ ()]
dvars = DatatypeInfo -> [TyVarBndr_ ()]
datatypeVars DatatypeInfo
d
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([TyVarBndr_ ()] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr_ ()]
dvars Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
      String
"Requesting Mergeable"
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" instance, while the type "
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Name -> String
forall a. Show a => a -> String
show Name
typName
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" has only "
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([TyVarBndr_ ()] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr_ ()]
dvars)
        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" type variables."
  let keptVars :: [TyVarBndr_ ()]
keptVars = Int -> [TyVarBndr_ ()] -> [TyVarBndr_ ()]
forall a. Int -> [a] -> [a]
take ([TyVarBndr_ ()] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr_ ()]
dvars Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) [TyVarBndr_ ()]
dvars
  [Name]
keptNewNames <- (TyVarBndr_ () -> Q Name) -> [TyVarBndr_ ()] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Q Name)
-> (TyVarBndr_ () -> String) -> TyVarBndr_ () -> Q Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
occName (Name -> String)
-> (TyVarBndr_ () -> Name) -> TyVarBndr_ () -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr_ () -> Name
forall flag. TyVarBndr_ flag -> Name
tvName) [TyVarBndr_ ()]
keptVars
  let keptNewVars :: [TyVarBndr_ ()]
keptNewVars =
        (Name -> TyVarBndr_ () -> TyVarBndr_ ())
-> [Name] -> [TyVarBndr_ ()] -> [TyVarBndr_ ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Name -> Name) -> TyVarBndr_ () -> TyVarBndr_ ()
forall flag. (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVName ((Name -> Name) -> TyVarBndr_ () -> TyVarBndr_ ())
-> (Name -> Name -> Name) -> Name -> TyVarBndr_ () -> TyVarBndr_ ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Name -> Name
forall a b. a -> b -> a
const) [Name]
keptNewNames [TyVarBndr_ ()]
keptVars
  let argVars :: [TyVarBndr_ ()]
argVars = Int -> [TyVarBndr_ ()] -> [TyVarBndr_ ()]
forall a. Int -> [a] -> [a]
drop ([TyVarBndr_ ()] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr_ ()]
dvars Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) [TyVarBndr_ ()]
dvars
  [Name]
argNewNames <- (TyVarBndr_ () -> Q Name) -> [TyVarBndr_ ()] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Q Name)
-> (TyVarBndr_ () -> String) -> TyVarBndr_ () -> Q Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
occName (Name -> String)
-> (TyVarBndr_ () -> Name) -> TyVarBndr_ () -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr_ () -> Name
forall flag. TyVarBndr_ flag -> Name
tvName) [TyVarBndr_ ()]
argVars
  let argNewVars :: [TyVarBndr_ ()]
argNewVars =
        (Name -> TyVarBndr_ () -> TyVarBndr_ ())
-> [Name] -> [TyVarBndr_ ()] -> [TyVarBndr_ ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Name -> Name) -> TyVarBndr_ () -> TyVarBndr_ ()
forall flag. (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVName ((Name -> Name) -> TyVarBndr_ () -> TyVarBndr_ ())
-> (Name -> Name -> Name) -> Name -> TyVarBndr_ () -> TyVarBndr_ ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Name -> Name
forall a b. a -> b -> a
const) [Name]
argNewNames [TyVarBndr_ ()]
argVars
  let substMap :: Map Name Type
substMap =
        [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Type)] -> Map Name Type)
-> [(Name, Type)] -> Map Name Type
forall a b. (a -> b) -> a -> b
$
          [Name] -> [Type] -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip
            (TyVarBndr_ () -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndr_ () -> Name) -> [TyVarBndr_ ()] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ ()]
dvars)
            (Name -> Type
VarT (Name -> Type) -> [Name] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
keptNewNames [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name]
argNewNames)
  let constructors :: [ConstructorInfo]
constructors = Map Name Type -> [ConstructorInfo] -> [ConstructorInfo]
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
substMap ([ConstructorInfo] -> [ConstructorInfo])
-> [ConstructorInfo] -> [ConstructorInfo]
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
d
  let allFields :: [Type]
allFields = (ConstructorInfo -> [Type]) -> [ConstructorInfo] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ConstructorInfo -> [Type]
constructorFields [ConstructorInfo]
constructors
  let allFieldsFreeVars :: Set Name
allFieldsFreeVars = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ [Type] -> [Name]
forall a. TypeSubstitution a => a -> [Name]
freeVariables [Type]
allFields
  let isVarUsedInFields :: Name -> Bool
isVarUsedInFields Name
var = Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
var Set Name
allFieldsFreeVars
  CheckArgsResult -> Q CheckArgsResult
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (CheckArgsResult -> Q CheckArgsResult)
-> CheckArgsResult -> Q CheckArgsResult
forall a b. (a -> b) -> a -> b
$ CheckArgsResult {[Name]
[ConstructorInfo]
[TyVarBndr_ ()]
Name -> Bool
constructors :: [ConstructorInfo]
keptNewNames :: [Name]
keptNewVars :: [TyVarBndr_ ()]
argNewNames :: [Name]
argNewVars :: [TyVarBndr_ ()]
isVarUsedInFields :: Name -> Bool
keptNewNames :: [Name]
keptNewVars :: [TyVarBndr_ ()]
argNewNames :: [Name]
argNewVars :: [TyVarBndr_ ()]
constructors :: [ConstructorInfo]
isVarUsedInFields :: Name -> Bool
..}