{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Record.Internal.Plugin.Record (
Record(..)
, Field(..)
, StockDeriving(..)
, RecordDeriving(..)
, viewRecord
) where
import Control.Monad.Except
import Data.Traversable (for)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Data.Record.Internal.GHC.Shim
import Data.Record.Internal.GHC.TemplateHaskellStyle
import Data.Record.Internal.Plugin.Exception (Exception (..))
import Data.Record.Internal.Plugin.Options (LargeRecordOptions)
data Record = Record {
Record -> LRdrName
recordTyName :: LRdrName
, Record -> [LHsTyVarBndr GhcPs]
recordTyVars :: [LHsTyVarBndr GhcPs]
, Record -> LRdrName
recordConName :: LRdrName
, Record -> [Field]
recordFields :: [Field]
, Record -> [RecordDeriving]
recordDerivings :: [RecordDeriving]
, Record -> LargeRecordOptions
recordOptions :: LargeRecordOptions
, Record -> SrcSpan
recordAnnLoc :: SrcSpan
}
data Field = Field {
Field -> LRdrName
fieldName :: LRdrName
, Field -> LHsType GhcPs
fieldType :: LHsType GhcPs
, Field -> HsSrcBang
fieldStrictness :: HsSrcBang
, Field -> Int
fieldIndex :: Int
}
data StockDeriving = Eq | Show | Ord | Generic
data RecordDeriving =
DeriveStock StockDeriving
| DeriveAnyClass (LHsType GhcPs)
viewRecord ::
MonadError Exception m
=> SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord :: forall (m :: Type -> Type).
MonadError Exception m =>
SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord SrcSpan
annLoc LargeRecordOptions
options LHsDecl GhcPs
decl =
case LHsDecl GhcPs
decl of
DataD LRdrName
tyName [LHsTyVarBndr GhcPs]
tyVars [RecC LRdrName
conName [(LRdrName, LHsType GhcPs)]
fields] [LHsDerivingClause GhcPs]
derivs-> do
[Int -> Field]
fields' <- ((LRdrName, GenLocated SrcSpanAnnA (HsType GhcPs))
-> m (Int -> Field))
-> [(LRdrName, GenLocated SrcSpanAnnA (HsType GhcPs))]
-> m [Int -> Field]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (LRdrName, LHsType GhcPs) -> m (Int -> Field)
(LRdrName, GenLocated SrcSpanAnnA (HsType GhcPs))
-> m (Int -> Field)
forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField [(LRdrName, LHsType GhcPs)]
[(LRdrName, GenLocated SrcSpanAnnA (HsType GhcPs))]
fields
[RecordDeriving]
derivings <- [LHsDerivingClause GhcPs] -> m [RecordDeriving]
forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings [LHsDerivingClause GhcPs]
derivs
Record -> m Record
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Record {
recordTyName :: LRdrName
recordTyName = LRdrName
tyName
, recordTyVars :: [LHsTyVarBndr GhcPs]
recordTyVars = [LHsTyVarBndr GhcPs]
tyVars
, recordConName :: LRdrName
recordConName = LRdrName
conName
, recordFields :: [Field]
recordFields = ((Int -> Field) -> Int -> Field)
-> [Int -> Field] -> [Int] -> [Field]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Field) -> Int -> Field
forall a b. (a -> b) -> a -> b
($) [Int -> Field]
fields' [Int
0..]
, recordDerivings :: [RecordDeriving]
recordDerivings = [RecordDeriving]
derivings
, recordOptions :: LargeRecordOptions
recordOptions = LargeRecordOptions
options
, recordAnnLoc :: SrcSpan
recordAnnLoc = SrcSpan
annLoc
}
LHsDecl GhcPs
_otherwise -> Exception -> m Record
forall a. Exception -> m a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (Exception -> m Record) -> Exception -> m Record
forall a b. (a -> b) -> a -> b
$ LHsDecl GhcPs -> Exception
InvalidDeclaration LHsDecl GhcPs
decl
viewField ::
MonadError Exception m
=> (LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField :: forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField (LRdrName
name, LHsType GhcPs
typ) =
(Int -> Field) -> m (Int -> Field)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Int -> Field) -> m (Int -> Field))
-> (Int -> Field) -> m (Int -> Field)
forall a b. (a -> b) -> a -> b
$ LRdrName -> LHsType GhcPs -> HsSrcBang -> Int -> Field
Field LRdrName
name (LHsType GhcPs -> LHsType GhcPs
parensT (LHsType GhcPs -> LHsType GhcPs
forall (p :: Pass). LHsType (GhcPass p) -> LHsType (GhcPass p)
getBangType LHsType GhcPs
typ)) (LHsType GhcPs -> HsSrcBang
forall (p :: Pass). LHsType (GhcPass p) -> HsSrcBang
getBangStrictness LHsType GhcPs
typ)
viewRecordDerivings ::
MonadError Exception m
=> [LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings :: forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings = ([[RecordDeriving]] -> [RecordDeriving])
-> m [[RecordDeriving]] -> m [RecordDeriving]
forall a b. (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap [[RecordDeriving]] -> [RecordDeriving]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat (m [[RecordDeriving]] -> m [RecordDeriving])
-> ([GenLocated (SrcAnn NoEpAnns) (HsDerivingClause GhcPs)]
-> m [[RecordDeriving]])
-> [GenLocated (SrcAnn NoEpAnns) (HsDerivingClause GhcPs)]
-> m [RecordDeriving]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GenLocated (SrcAnn NoEpAnns) (HsDerivingClause GhcPs)
-> m [RecordDeriving])
-> [GenLocated (SrcAnn NoEpAnns) (HsDerivingClause GhcPs)]
-> m [[RecordDeriving]]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse LHsDerivingClause GhcPs -> m [RecordDeriving]
GenLocated (SrcAnn NoEpAnns) (HsDerivingClause GhcPs)
-> m [RecordDeriving]
forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving
viewRecordDeriving :: forall m.
MonadError Exception m
=> LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving :: forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving = \case
DerivClause Maybe (LDerivStrategy GhcPs)
Nothing NonEmpty (LHsType GhcPs)
tys ->
NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
DerivClause (Just (L SrcAnn NoEpAnns
_ StockStrategy {})) NonEmpty (LHsType GhcPs)
tys ->
NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
DerivClause (Just (L SrcAnn NoEpAnns
_ AnyclassStrategy {})) NonEmpty (LHsType GhcPs)
tys ->
[RecordDeriving] -> m [RecordDeriving]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([RecordDeriving] -> m [RecordDeriving])
-> [RecordDeriving] -> m [RecordDeriving]
forall a b. (a -> b) -> a -> b
$ (GenLocated SrcSpanAnnA (HsType GhcPs) -> RecordDeriving)
-> [GenLocated SrcSpanAnnA (HsType GhcPs)] -> [RecordDeriving]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap LHsType GhcPs -> RecordDeriving
GenLocated SrcSpanAnnA (HsType GhcPs) -> RecordDeriving
DeriveAnyClass (NonEmpty (GenLocated SrcSpanAnnA (HsType GhcPs))
-> [GenLocated SrcSpanAnnA (HsType GhcPs)]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
NonEmpty (GenLocated SrcSpanAnnA (HsType GhcPs))
tys)
DerivClause (Just LDerivStrategy GhcPs
strategy) NonEmpty (LHsType GhcPs)
_ ->
Exception -> m [RecordDeriving]
forall a. Exception -> m a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LDerivStrategy GhcPs -> Exception
UnsupportedStrategy LDerivStrategy GhcPs
strategy)
LHsDerivingClause GhcPs
_ ->
[RecordDeriving] -> m [RecordDeriving]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure []
where
goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys = [GenLocated SrcSpanAnnA (HsType GhcPs)]
-> (GenLocated SrcSpanAnnA (HsType GhcPs) -> m RecordDeriving)
-> m [RecordDeriving]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (NonEmpty (GenLocated SrcSpanAnnA (HsType GhcPs))
-> [GenLocated SrcSpanAnnA (HsType GhcPs)]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
NonEmpty (GenLocated SrcSpanAnnA (HsType GhcPs))
tys) ((GenLocated SrcSpanAnnA (HsType GhcPs) -> m RecordDeriving)
-> m [RecordDeriving])
-> (GenLocated SrcSpanAnnA (HsType GhcPs) -> m RecordDeriving)
-> m [RecordDeriving]
forall a b. (a -> b) -> a -> b
$ \case
ConT (LRdrName -> String
nameBase -> String
"Show") -> RecordDeriving -> m RecordDeriving
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Show
ConT (LRdrName -> String
nameBase -> String
"Eq") -> RecordDeriving -> m RecordDeriving
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Eq
ConT (LRdrName -> String
nameBase -> String
"Ord") -> RecordDeriving -> m RecordDeriving
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Ord
ConT (LRdrName -> String
nameBase -> String
"Generic") -> RecordDeriving -> m RecordDeriving
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Generic
GenLocated SrcSpanAnnA (HsType GhcPs)
ty -> Exception -> m RecordDeriving
forall a. Exception -> m a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LHsType GhcPs -> Exception
UnsupportedStockDeriving LHsType GhcPs
GenLocated SrcSpanAnnA (HsType GhcPs)
ty)