module Optics.TH.Internal.Utils where

import Control.Monad
import Data.Maybe
import Data.List
import Language.Haskell.TH
import Language.Haskell.TH.Datatype.TyVarBndr
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Language.Haskell.TH.Datatype as D

import Data.Set.Optics
import Language.Haskell.TH.Optics.Internal
import Optics.Core

-- | Apply arguments to a type constructor
appsT :: TypeQ -> [TypeQ] -> TypeQ
appsT :: TypeQ -> [TypeQ] -> TypeQ
appsT = (TypeQ -> TypeQ -> TypeQ) -> TypeQ -> [TypeQ] -> TypeQ
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TypeQ -> TypeQ -> TypeQ
appT

-- | Apply arguments to a function
appsE1 :: ExpQ -> [ExpQ] -> ExpQ
appsE1 :: ExpQ -> [ExpQ] -> ExpQ
appsE1 = (ExpQ -> ExpQ -> ExpQ) -> ExpQ -> [ExpQ] -> ExpQ
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ExpQ -> ExpQ -> ExpQ
appE

-- | Construct a tuple type given a list of types.
toTupleT :: [TypeQ] -> TypeQ
toTupleT :: [TypeQ] -> TypeQ
toTupleT [TypeQ
x] = TypeQ
x
toTupleT [TypeQ]
xs = TypeQ -> [TypeQ] -> TypeQ
appsT (Int -> TypeQ
tupleT ([TypeQ] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeQ]
xs)) [TypeQ]
xs

-- | Construct a tuple value given a list of expressions.
toTupleE :: [ExpQ] -> ExpQ
toTupleE :: [ExpQ] -> ExpQ
toTupleE [ExpQ
x] = ExpQ
x
toTupleE [ExpQ]
xs = [ExpQ] -> ExpQ
tupE [ExpQ]
xs

-- | Construct a tuple pattern given a list of patterns.
toTupleP :: [PatQ] -> PatQ
toTupleP :: [PatQ] -> PatQ
toTupleP [PatQ
x] = PatQ
x
toTupleP [PatQ]
xs = [PatQ] -> PatQ
tupP [PatQ]
xs

-- | Apply arguments to a type constructor.
conAppsT :: Name -> [Type] -> Type
conAppsT :: Name -> [Type] -> Type
conAppsT Name
conName = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
conName)

-- | Generate many new names from a given base name.
newNames :: String {- ^ base name -} -> Int {- ^ count -} -> Q [Name]
newNames :: String -> Int -> Q [Name]
newNames String
base Int
n = [Q Name] -> Q [Name]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ String -> Q Name
newName (String
baseString -> String -> String
forall a. [a] -> [a] -> [a]
++Int -> String
forall a. Show a => a -> String
show Int
i) | Int
i <- [Int
1..Int
n] ]

-- We substitute concrete types with type variables and match them with concrete
-- types in the instance context. This significantly improves type inference as
-- GHC can match the instance more easily, but costs dependence on TypeFamilies
-- and UndecidableInstances.
eqSubst :: Type -> String -> Q (Type, Pred)
eqSubst :: Type -> String -> Q (Type, Type)
eqSubst Type
ty String
n = do
  Type
placeholder <- Name -> Type
VarT (Name -> Type) -> Q Name -> TypeQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Q Name
newName String
n
  (Type, Type) -> Q (Type, Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
placeholder, Type -> Type -> Type
D.equalPred Type
placeholder Type
ty)

addKindInfo :: D.DatatypeInfo -> Type -> Type
addKindInfo :: DatatypeInfo -> Type -> Type
addKindInfo = [Type] -> DatatypeInfo -> Type -> Type
addKindInfo' []

-- | Fill in kind variables using info from datatype type parameters.
addKindInfo' :: [Type] -> D.DatatypeInfo -> Type -> Type
addKindInfo' :: [Type] -> DatatypeInfo -> Type -> Type
addKindInfo' [Type]
additionalInfo DatatypeInfo
di =
  Map Name Type -> Type -> Type
forall t. SubstType t => Map Name Type -> t -> t
substType (Map Name Type -> Type -> Type)
-> ([Type] -> Map Name Type) -> [Type] -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Type)] -> Map Name Type)
-> ([Type] -> [(Name, Type)]) -> [Type] -> Map Name Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Maybe (Name, Type)) -> [Type] -> [(Name, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Type -> Maybe (Name, Type)
var ([Type] -> Type -> Type) -> [Type] -> Type -> Type
forall a b. (a -> b) -> a -> b
$ [Type]
additionalInfo [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ DatatypeInfo -> [Type]
D.datatypeInstTypes DatatypeInfo
di
  where
    -- If the type is a data/newtype family instance, we need to fill in all of
    -- the kinds for weird cases such as:
    --
    -- data family KDF (a :: k)
    -- data instance KDF (a :: Type) = Kinded3 { _kdf :: Proxy a }
    --
    -- Otherwise we only need info about kind variables.
    --
    -- More info at https://github.com/ekmett/lens/pull/945.
    isDataFamily :: Bool
isDataFamily = DatatypeInfo -> DatatypeVariant
D.datatypeVariant DatatypeInfo
di DatatypeVariant -> DatatypeVariant -> Bool
forall a. Eq a => a -> a -> Bool
== DatatypeVariant
D.DataInstance
                Bool -> Bool -> Bool
|| DatatypeInfo -> DatatypeVariant
D.datatypeVariant DatatypeInfo
di DatatypeVariant -> DatatypeVariant -> Bool
forall a. Eq a => a -> a -> Bool
== DatatypeVariant
D.NewtypeInstance

    var :: Type -> Maybe (Name, Type)
var t :: Type
t@(SigT (VarT Name
n) Type
k)
      | Bool
isDataFamily   = (Name, Type) -> Maybe (Name, Type)
forall a. a -> Maybe a
Just (Name
n, Type
t)
      | Optic' A_Traversal NoIx Type Name -> Type -> Bool
forall k (is :: IxList) s a.
Is k A_Fold =>
Optic' k is s a -> s -> Bool
has Optic' A_Traversal NoIx Type Name
forall t. HasTypeVars t => Traversal' t Name
typeVars Type
k = (Name, Type) -> Maybe (Name, Type)
forall a. a -> Maybe a
Just (Name
n, Type
t)
      | Bool
otherwise      = Maybe (Name, Type)
forall a. Maybe a
Nothing
    var Type
_              = Maybe (Name, Type)
forall a. Maybe a
Nothing

-- | Template Haskell wants type variables declared in a forall, so
-- we find all free type variables in a given type and declare them.
quantifyType :: [TyVarBndrSpec] -> Cxt -> Type -> Type
quantifyType :: [TyVarBndrSpec] -> [Type] -> Type -> Type
quantifyType = Set Name -> [TyVarBndrSpec] -> [Type] -> Type -> Type
quantifyType' Set Name
forall a. Set a
S.empty

-- | This function works like 'quantifyType' except that it takes
-- a list of variables to exclude from quantification.
quantifyType' :: S.Set Name -> [TyVarBndrSpec] -> Cxt -> Type -> Type
quantifyType' :: Set Name -> [TyVarBndrSpec] -> [Type] -> Type -> Type
quantifyType' Set Name
exclude [TyVarBndrSpec]
vars [Type]
cx Type
t = [TyVarBndrSpec] -> [Type] -> Type -> Type
ForallT [TyVarBndrSpec]
forall flag. [TyVarBndrSpec]
vs [Type]
cx Type
t
  where
    vs :: [TyVarBndrSpec]
vs = (TyVarBndrSpec -> Bool) -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall a. (a -> Bool) -> [a] -> [a]
filter (\TyVarBndrSpec
v -> TyVarBndrSpec -> Name
forall flag. TyVarBndrSpec -> Name
D.tvName TyVarBndrSpec
v Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set Name
exclude)
       ([TyVarBndrSpec] -> [TyVarBndrSpec])
-> ([Type] -> [TyVarBndrSpec]) -> [Type] -> [TyVarBndrSpec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Specificity -> [TyVarBndrSpec] -> [TyVarBndrSpec]
forall newFlag oldFlag.
newFlag -> [TyVarBndrSpec] -> [TyVarBndrSpec]
changeTVFlags Specificity
SpecifiedSpec
       ([TyVarBndrSpec] -> [TyVarBndrSpec])
-> ([Type] -> [TyVarBndrSpec]) -> [Type] -> [TyVarBndrSpec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [TyVarBndrSpec]
D.freeVariablesWellScoped
       ([Type] -> [TyVarBndrSpec]) -> [Type] -> [TyVarBndrSpec]
forall a b. (a -> b) -> a -> b
$ (TyVarBndrSpec -> Type) -> [TyVarBndrSpec] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndrSpec -> Type
forall flag. TyVarBndrSpec -> Type
tyVarBndrToType [TyVarBndrSpec]
vars [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Set Type -> [Type]
forall a. Set a -> [a]
S.toList (Optic' A_Fold NoIx Type Type -> Type -> Set Type
forall k a (is :: IxList) s.
(Is k A_Fold, Ord a) =>
Optic' k is s a -> s -> Set a
setOf Optic' A_Fold NoIx Type Type
typeVarsKinded Type
t)

-- | Transform 'TyVarBndr' into a 'Type' so it's suitable e.g. for
-- freeVariablesWellScoped or type substitution.
tyVarBndrToType :: TyVarBndr_ flag -> Type
tyVarBndrToType :: TyVarBndrSpec -> Type
tyVarBndrToType = (Name -> Type) -> (Name -> Type -> Type) -> TyVarBndrSpec -> Type
forall r flag.
(Name -> r) -> (Name -> Type -> r) -> TyVarBndrSpec -> r
elimTV Name -> Type
VarT (\Name
n Type
k -> Type -> Type -> Type
SigT (Name -> Type
VarT Name
n) Type
k)

-- | Pass in a list of lists of extensions, where any of the given extensions
-- will satisfy it. For example, you might need either GADTs or
-- ExistentialQuantification, so you'd write:
--
-- > requireExtensions [[GADTs, ExistentialQuantification]]
--
-- But if you need TypeFamilies and MultiParamTypeClasses, then you'd write:
--
-- > requireExtensions [[TypeFamilies], [MultiParamTypeClasses]]
--
requireExtensions :: String -> [[Extension]] -> Q ()
requireExtensions :: String -> [[Extension]] -> Q ()
requireExtensions String
what [[Extension]]
extLists = do
  -- Taken from the persistent library
  [[Extension]]
required <- ([Extension] -> Q Bool) -> [[Extension]] -> Q [[Extension]]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (([Bool] -> Bool) -> Q [Bool] -> Q Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool -> Bool
not (Bool -> Bool) -> ([Bool] -> Bool) -> [Bool] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or) (Q [Bool] -> Q Bool)
-> ([Extension] -> Q [Bool]) -> [Extension] -> Q Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Extension -> Q Bool) -> [Extension] -> Q [Bool]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Extension -> Q Bool
isExtEnabled) [[Extension]]
extLists
  case ([Extension] -> Maybe Extension) -> [[Extension]] -> [Extension]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [Extension] -> Maybe Extension
forall a. [a] -> Maybe a
listToMaybe [[Extension]]
required of
    [] -> () -> Q ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    [Extension
extension] -> String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall a. Monoid a => [a] -> a
mconcat
      [ String
"Generating " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
what String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" requires the "
      , Extension -> String
forall a. Show a => a -> String
show Extension
extension
      , String
" language extension. Please enable it by copy/pasting this line to the top of your file:\n\n"
      , Extension -> String
forall a. Show a => a -> String
extensionToPragma Extension
extension
      , String
"\n\nTo enable it in a GHCi session, use the following command:\n\n"
      , String
":seti -X" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Extension -> String
forall a. Show a => a -> String
show Extension
extension
      ]
    [Extension]
extensions -> String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
forall a. Monoid a => [a] -> a
mconcat
      [ String
"Generating " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
what String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" requires the following language extensions:\n\n"
      , String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" ((Extension -> String) -> [Extension] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ((String
"- " String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> (Extension -> String) -> Extension -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extension -> String
forall a. Show a => a -> String
show) [Extension]
extensions)
      , String
"\n\nPlease enable the extensions by copy/pasting these lines into the top of your file:\n\n"
      , String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" ((Extension -> String) -> [Extension] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Extension -> String
forall a. Show a => a -> String
extensionToPragma [Extension]
extensions)
      , String
"\n\nTo enable them in a GHCi session, use the following command:\n\n"
      , String
":seti " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords ((Extension -> String) -> [Extension] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ((String
"-X" String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> (Extension -> String) -> Extension -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extension -> String
forall a. Show a => a -> String
show) [Extension]
extensions)
      ]
  where
    extensionToPragma :: a -> String
extensionToPragma a
ext = String
"{-# LANGUAGE " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
ext String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" #-}"

requireExtensionsForLabels :: Q ()
requireExtensionsForLabels :: Q ()
requireExtensionsForLabels = String -> [[Extension]] -> Q ()
requireExtensions String
"LabelOptic instances"
  [ [Extension
DataKinds]
  , [Extension
FlexibleInstances]
  , [Extension
MultiParamTypeClasses]
  , [Extension
TypeFamilies, Extension
GADTs]
  , [Extension
UndecidableInstances]
  ]

requireExtensionsForFields :: Q ()
requireExtensionsForFields :: Q ()
requireExtensionsForFields = String -> [[Extension]] -> Q ()
requireExtensions String
"field optics"
  [ [Extension
FlexibleInstances]
  , [Extension
FunctionalDependencies]
  ]

------------------------------------------------------------------------
-- Support for generating inline pragmas
------------------------------------------------------------------------

inlinePragma :: Name -> [DecQ]
inlinePragma :: Name -> [DecQ]
inlinePragma Name
methodName = [Name -> Inline -> RuleMatch -> Phases -> DecQ
pragInlD Name
methodName Inline
Inline RuleMatch
FunLike Phases
AllPhases]