{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Record.Internal.Plugin.Record (
Record(..)
, Field(..)
, StockDeriving(..)
, RecordDeriving(..)
, viewRecord
) where
import Control.Monad.Except
import Data.Traversable (for)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Data.Record.Internal.GHC.Shim
import Data.Record.Internal.GHC.TemplateHaskellStyle
import Data.Record.Internal.Plugin.Exception (Exception (..))
import Data.Record.Internal.Plugin.Options (LargeRecordOptions)
data Record = Record {
Record -> LRdrName
recordTyName :: LRdrName
, Record -> [LHsTyVarBndr GhcPs]
recordTyVars :: [LHsTyVarBndr GhcPs]
, Record -> LRdrName
recordConName :: LRdrName
, Record -> [Field]
recordFields :: [Field]
, Record -> [RecordDeriving]
recordDerivings :: [RecordDeriving]
, Record -> LargeRecordOptions
recordOptions :: LargeRecordOptions
, Record -> SrcSpan
recordAnnLoc :: SrcSpan
}
data Field = Field {
Field -> LRdrName
fieldName :: LRdrName
, Field -> LHsType GhcPs
fieldType :: LHsType GhcPs
, Field -> HsSrcBang
fieldStrictness :: HsSrcBang
, Field -> Int
fieldIndex :: Int
}
data StockDeriving = Eq | Show | Ord | Generic
data RecordDeriving =
DeriveStock StockDeriving
| DeriveAnyClass (LHsType GhcPs)
viewRecord ::
MonadError Exception m
=> SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord :: forall (m :: Type -> Type).
MonadError Exception m =>
SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord SrcSpan
annLoc LargeRecordOptions
options LHsDecl GhcPs
decl =
case LHsDecl GhcPs
decl of
DataD LRdrName
tyName [LHsTyVarBndr GhcPs]
tyVars [RecC LRdrName
conName [(LRdrName, LHsType GhcPs)]
fields] [LHsDerivingClause GhcPs]
derivs-> do
[Int -> Field]
fields' <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField [(LRdrName, LHsType GhcPs)]
fields
[RecordDeriving]
derivings <- forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings [LHsDerivingClause GhcPs]
derivs
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Record {
recordTyName :: LRdrName
recordTyName = LRdrName
tyName
, recordTyVars :: [LHsTyVarBndr GhcPs]
recordTyVars = [LHsTyVarBndr GhcPs]
tyVars
, recordConName :: LRdrName
recordConName = LRdrName
conName
, recordFields :: [Field]
recordFields = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. (a -> b) -> a -> b
($) [Int -> Field]
fields' [Int
0..]
, recordDerivings :: [RecordDeriving]
recordDerivings = [RecordDeriving]
derivings
, recordOptions :: LargeRecordOptions
recordOptions = LargeRecordOptions
options
, recordAnnLoc :: SrcSpan
recordAnnLoc = SrcSpan
annLoc
}
LHsDecl GhcPs
_otherwise -> forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ LHsDecl GhcPs -> Exception
InvalidDeclaration LHsDecl GhcPs
decl
viewField ::
MonadError Exception m
=> (LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField :: forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField (LRdrName
name, LHsType GhcPs
typ) =
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ LRdrName -> LHsType GhcPs -> HsSrcBang -> Int -> Field
Field LRdrName
name (LHsType GhcPs -> LHsType GhcPs
parensT (forall (p :: Pass). LHsType (GhcPass p) -> LHsType (GhcPass p)
getBangType LHsType GhcPs
typ)) (forall (p :: Pass). LHsType (GhcPass p) -> HsSrcBang
getBangStrictness LHsType GhcPs
typ)
viewRecordDerivings ::
MonadError Exception m
=> [LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings :: forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings = forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving
viewRecordDeriving :: forall m.
MonadError Exception m
=> LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving :: forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving = \case
DerivClause Maybe (LDerivStrategy GhcPs)
Nothing NonEmpty (LHsType GhcPs)
tys ->
NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
DerivClause (Just (L SrcSpan
_ StockStrategy {})) NonEmpty (LHsType GhcPs)
tys ->
NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
DerivClause (Just (L SrcSpan
_ AnyclassStrategy {})) NonEmpty (LHsType GhcPs)
tys ->
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap LHsType GhcPs -> RecordDeriving
DeriveAnyClass (forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
tys)
DerivClause (Just LDerivStrategy GhcPs
strategy) NonEmpty (LHsType GhcPs)
_ ->
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LDerivStrategy GhcPs -> Exception
UnsupportedStrategy LDerivStrategy GhcPs
strategy)
LHsDerivingClause GhcPs
_ ->
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure []
where
goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys = forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
tys) forall a b. (a -> b) -> a -> b
$ \case
ConT (LRdrName -> String
nameBase -> String
"Show") -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Show
ConT (LRdrName -> String
nameBase -> String
"Eq") -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Eq
ConT (LRdrName -> String
nameBase -> String
"Ord") -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Ord
ConT (LRdrName -> String
nameBase -> String
"Generic") -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Generic
GenLocated SrcSpanAnnA (HsType GhcPs)
ty -> forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LHsType GhcPs -> Exception
UnsupportedStockDeriving GenLocated SrcSpanAnnA (HsType GhcPs)
ty)