{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Record.Internal.Plugin.Record (
Record(..)
, Field(..)
, StockDeriving(..)
, RecordDeriving(..)
, viewRecord
) where
import Control.Monad.Except
import Data.Traversable (for)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Data.Record.Internal.GHC.Shim
import Data.Record.Internal.GHC.TemplateHaskellStyle
import Data.Record.Internal.Plugin.Exception (Exception (..))
import Data.Record.Internal.Plugin.Options (LargeRecordOptions)
data Record = Record {
Record -> LRdrName
recordTyName :: LRdrName
, Record -> [LHsTyVarBndr GhcPs]
recordTyVars :: [LHsTyVarBndr GhcPs]
, Record -> LRdrName
recordConName :: LRdrName
, Record -> [Field]
recordFields :: [Field]
, Record -> [RecordDeriving]
recordDerivings :: [RecordDeriving]
, Record -> LargeRecordOptions
recordOptions :: LargeRecordOptions
, Record -> SrcSpan
recordAnnLoc :: SrcSpan
}
data Field = Field {
Field -> LRdrName
fieldName :: LRdrName
, Field -> LHsType GhcPs
fieldType :: LHsType GhcPs
, Field -> Int
fieldIndex :: Int
}
data StockDeriving = Eq | Show | Ord | Generic
data RecordDeriving =
DeriveStock StockDeriving
| DeriveAnyClass (LHsType GhcPs)
viewRecord ::
MonadError Exception m
=> SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord :: SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord SrcSpan
annLoc LargeRecordOptions
options LHsDecl GhcPs
decl =
case LHsDecl GhcPs
decl of
DataD LRdrName
tyName [LHsTyVarBndr GhcPs]
tyVars [RecC LRdrName
conName [(LRdrName, LHsType GhcPs)]
fields] [LHsDerivingClause GhcPs]
derivs-> do
[Int -> Field]
fields' <- ((LRdrName, LHsType GhcPs) -> m (Int -> Field))
-> [(LRdrName, LHsType GhcPs)] -> m [Int -> Field]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (LRdrName, LHsType GhcPs) -> m (Int -> Field)
forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField [(LRdrName, LHsType GhcPs)]
fields
[RecordDeriving]
derivings <- [LHsDerivingClause GhcPs] -> m [RecordDeriving]
forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings [LHsDerivingClause GhcPs]
derivs
Record -> m Record
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Record :: LRdrName
-> [LHsTyVarBndr GhcPs]
-> LRdrName
-> [Field]
-> [RecordDeriving]
-> LargeRecordOptions
-> SrcSpan
-> Record
Record {
recordTyName :: LRdrName
recordTyName = LRdrName
tyName
, recordTyVars :: [LHsTyVarBndr GhcPs]
recordTyVars = [LHsTyVarBndr GhcPs]
tyVars
, recordConName :: LRdrName
recordConName = LRdrName
conName
, recordFields :: [Field]
recordFields = ((Int -> Field) -> Int -> Field)
-> [Int -> Field] -> [Int] -> [Field]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Field) -> Int -> Field
forall a b. (a -> b) -> a -> b
($) [Int -> Field]
fields' [Int
0..]
, recordDerivings :: [RecordDeriving]
recordDerivings = [RecordDeriving]
derivings
, recordOptions :: LargeRecordOptions
recordOptions = LargeRecordOptions
options
, recordAnnLoc :: SrcSpan
recordAnnLoc = SrcSpan
annLoc
}
LHsDecl GhcPs
_otherwise -> Exception -> m Record
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (Exception -> m Record) -> Exception -> m Record
forall a b. (a -> b) -> a -> b
$ LHsDecl GhcPs -> Exception
InvalidDeclaration LHsDecl GhcPs
decl
viewField ::
MonadError Exception m
=> (LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField :: (LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField (LRdrName
name, LHsType GhcPs
typ) = (Int -> Field) -> m (Int -> Field)
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Int -> Field) -> m (Int -> Field))
-> (Int -> Field) -> m (Int -> Field)
forall a b. (a -> b) -> a -> b
$ LRdrName -> LHsType GhcPs -> Int -> Field
Field LRdrName
name LHsType GhcPs
typ
viewRecordDerivings ::
MonadError Exception m
=> [LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings :: [LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings = ([[RecordDeriving]] -> [RecordDeriving])
-> m [[RecordDeriving]] -> m [RecordDeriving]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap [[RecordDeriving]] -> [RecordDeriving]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat (m [[RecordDeriving]] -> m [RecordDeriving])
-> ([LHsDerivingClause GhcPs] -> m [[RecordDeriving]])
-> [LHsDerivingClause GhcPs]
-> m [RecordDeriving]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LHsDerivingClause GhcPs -> m [RecordDeriving])
-> [LHsDerivingClause GhcPs] -> m [[RecordDeriving]]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LHsDerivingClause GhcPs -> m [RecordDeriving]
forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving
viewRecordDeriving :: forall m.
MonadError Exception m
=> LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving :: LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving = \case
DerivClause Maybe (LDerivStrategy GhcPs)
Nothing NonEmpty (LHsType GhcPs)
tys ->
NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
DerivClause (Just (L SrcSpan
_ DerivStrategy GhcPs
StockStrategy)) NonEmpty (LHsType GhcPs)
tys ->
NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
DerivClause (Just (L SrcSpan
_ DerivStrategy GhcPs
AnyclassStrategy)) NonEmpty (LHsType GhcPs)
tys ->
[RecordDeriving] -> m [RecordDeriving]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([RecordDeriving] -> m [RecordDeriving])
-> [RecordDeriving] -> m [RecordDeriving]
forall a b. (a -> b) -> a -> b
$ (LHsType GhcPs -> RecordDeriving)
-> [LHsType GhcPs] -> [RecordDeriving]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap LHsType GhcPs -> RecordDeriving
DeriveAnyClass (NonEmpty (LHsType GhcPs) -> [LHsType GhcPs]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
tys)
DerivClause (Just LDerivStrategy GhcPs
strategy) NonEmpty (LHsType GhcPs)
_ ->
Exception -> m [RecordDeriving]
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LDerivStrategy GhcPs -> Exception
UnsupportedStrategy LDerivStrategy GhcPs
strategy)
LHsDerivingClause GhcPs
_ ->
[RecordDeriving] -> m [RecordDeriving]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure []
where
goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys = [LHsType GhcPs]
-> (LHsType GhcPs -> m RecordDeriving) -> m [RecordDeriving]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (NonEmpty (LHsType GhcPs) -> [LHsType GhcPs]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
tys) ((LHsType GhcPs -> m RecordDeriving) -> m [RecordDeriving])
-> (LHsType GhcPs -> m RecordDeriving) -> m [RecordDeriving]
forall a b. (a -> b) -> a -> b
$ \case
ConT (LRdrName -> String
nameBase -> String
"Show") -> RecordDeriving -> m RecordDeriving
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Show
ConT (LRdrName -> String
nameBase -> String
"Eq") -> RecordDeriving -> m RecordDeriving
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Eq
ConT (LRdrName -> String
nameBase -> String
"Ord") -> RecordDeriving -> m RecordDeriving
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Ord
ConT (LRdrName -> String
nameBase -> String
"Generic") -> RecordDeriving -> m RecordDeriving
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (RecordDeriving -> m RecordDeriving)
-> RecordDeriving -> m RecordDeriving
forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Generic
LHsType GhcPs
ty -> Exception -> m RecordDeriving
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LHsType GhcPs -> Exception
UnsupportedStockDeriving LHsType GhcPs
ty)