{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns        #-}

-- | Convert GHC AST definitions of records into our own representation, 'Record'.
module Data.Record.Internal.Plugin.Record (
    Record(..)
  , Field(..)
  , StockDeriving(..)
  , RecordDeriving(..)
  , viewRecord
  ) where

import Control.Monad.Except
import Data.Traversable (for)
import Data.List.NonEmpty (NonEmpty)

import qualified Data.List.NonEmpty as NE

import Data.Record.Internal.GHC.Shim
import Data.Record.Internal.GHC.TemplateHaskellStyle
import Data.Record.Internal.Plugin.Exception (Exception (..))
import Data.Record.Internal.Plugin.Options (LargeRecordOptions)

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

-- | A representation for records that can be processed by large-records.
data Record = Record {
      Record -> LRdrName
recordTyName    :: LRdrName
    , Record -> [LHsTyVarBndr GhcPs]
recordTyVars    :: [LHsTyVarBndr GhcPs]
    , Record -> LRdrName
recordConName   :: LRdrName
    , Record -> [Field]
recordFields    :: [Field]
    , Record -> [RecordDeriving]
recordDerivings :: [RecordDeriving]
    , Record -> LargeRecordOptions
recordOptions   :: LargeRecordOptions

      -- | The location of the @ANN@ pragma
      --
      -- We use this as the location of the new identifiers we generate.
    , Record -> SrcSpan
recordAnnLoc    :: SrcSpan
    }

data Field = Field {
      Field -> LRdrName
fieldName       :: LRdrName
    , Field -> LHsType GhcPs
fieldType       :: LHsType GhcPs
    , Field -> HsSrcBang
fieldStrictness :: HsSrcBang
    , Field -> Int
fieldIndex      :: Int
    }

-- | Derived classes that we can support.
data StockDeriving = Eq | Show | Ord | Generic

-- | A representation for @deriving@ clauses.
--
-- NOTE: We support @DeriveAnyClass@ style derivation, because this does not
-- depend on the internal representation we choose, but only on the default
-- implementation in the class, which typically depends on generics. For
-- example, it makes it possible to define things like
--
-- > data UserT (f :: Type -> Type) = User {
-- >       userEmail :: Columnar f Text
-- >       -- .. other fields ..
-- >     }
-- >   deriving stock (Show, Eq)
-- >   deriving anyclass (Beamable)
--
-- For now we do /not/ support newtype deriving or deriving-via, since this
-- /does/ depend on the internal record representation. See discussion at
-- <https://github.com/well-typed/large-records/pull/42>.
data RecordDeriving =
    DeriveStock StockDeriving
  | DeriveAnyClass (LHsType GhcPs)

{-------------------------------------------------------------------------------
  Views
-------------------------------------------------------------------------------}

viewRecord ::
     MonadError Exception m
  => SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord :: forall (m :: Type -> Type).
MonadError Exception m =>
SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord SrcSpan
annLoc LargeRecordOptions
options LHsDecl GhcPs
decl =
    case LHsDecl GhcPs
decl of
      DataD LRdrName
tyName [LHsTyVarBndr GhcPs]
tyVars [RecC LRdrName
conName [(LRdrName, LHsType GhcPs)]
fields] [LHsDerivingClause GhcPs]
derivs-> do
        [Int -> Field]
fields'   <- forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField [(LRdrName, LHsType GhcPs)]
fields
        [RecordDeriving]
derivings <- forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings [LHsDerivingClause GhcPs]
derivs
        forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Record {
            recordTyName :: LRdrName
recordTyName    = LRdrName
tyName
          , recordTyVars :: [LHsTyVarBndr GhcPs]
recordTyVars    = [LHsTyVarBndr GhcPs]
tyVars
          , recordConName :: LRdrName
recordConName   = LRdrName
conName
          , recordFields :: [Field]
recordFields    = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. (a -> b) -> a -> b
($) [Int -> Field]
fields' [Int
0..]
          , recordDerivings :: [RecordDeriving]
recordDerivings = [RecordDeriving]
derivings
          , recordOptions :: LargeRecordOptions
recordOptions   = LargeRecordOptions
options
          , recordAnnLoc :: SrcSpan
recordAnnLoc    = SrcSpan
annLoc
          }
      LHsDecl GhcPs
_otherwise -> forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ LHsDecl GhcPs -> Exception
InvalidDeclaration LHsDecl GhcPs
decl

viewField ::
     MonadError Exception m
  => (LRdrName, LHsType GhcPs) -> m (Int -> Field)  
viewField :: forall (m :: Type -> Type).
MonadError Exception m =>
(LRdrName, LHsType GhcPs) -> m (Int -> Field)
viewField (LRdrName
name, LHsType GhcPs
typ) =
  forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ LRdrName -> LHsType GhcPs -> HsSrcBang -> Int -> Field
Field LRdrName
name (LHsType GhcPs -> LHsType GhcPs
parensT (forall (p :: Pass). LHsType (GhcPass p) -> LHsType (GhcPass p)
getBangType LHsType GhcPs
typ)) (forall (p :: Pass). LHsType (GhcPass p) -> HsSrcBang
getBangStrictness LHsType GhcPs
typ)

viewRecordDerivings ::
     MonadError Exception m
  => [LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings :: forall (m :: Type -> Type).
MonadError Exception m =>
[LHsDerivingClause GhcPs] -> m [RecordDeriving]
viewRecordDerivings = forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving

viewRecordDeriving :: forall m.
     MonadError Exception m
  => LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving :: forall (m :: Type -> Type).
MonadError Exception m =>
LHsDerivingClause GhcPs -> m [RecordDeriving]
viewRecordDeriving = \case
    DerivClause Maybe (LDerivStrategy GhcPs)
Nothing NonEmpty (LHsType GhcPs)
tys ->
      NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
    DerivClause (Just (L SrcSpan
_ StockStrategy {})) NonEmpty (LHsType GhcPs)
tys ->
      NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys
    DerivClause (Just (L SrcSpan
_ AnyclassStrategy {})) NonEmpty (LHsType GhcPs)
tys ->
      forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap LHsType GhcPs -> RecordDeriving
DeriveAnyClass (forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
tys)
    DerivClause (Just LDerivStrategy GhcPs
strategy) NonEmpty (LHsType GhcPs)
_ ->
      forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LDerivStrategy GhcPs -> Exception
UnsupportedStrategy LDerivStrategy GhcPs
strategy)
    LHsDerivingClause GhcPs
_ ->
      forall (f :: Type -> Type) a. Applicative f => a -> f a
pure []
  where
    goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
    goStock :: NonEmpty (LHsType GhcPs) -> m [RecordDeriving]
goStock NonEmpty (LHsType GhcPs)
tys = forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (forall a. NonEmpty a -> [a]
NE.toList NonEmpty (LHsType GhcPs)
tys) forall a b. (a -> b) -> a -> b
$ \case
        ConT (LRdrName -> String
nameBase -> String
"Show")    -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Show
        ConT (LRdrName -> String
nameBase -> String
"Eq")      -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Eq
        ConT (LRdrName -> String
nameBase -> String
"Ord")     -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Ord
        ConT (LRdrName -> String
nameBase -> String
"Generic") -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StockDeriving -> RecordDeriving
DeriveStock StockDeriving
Generic
        GenLocated SrcSpanAnnA (HsType GhcPs)
ty -> forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (LHsType GhcPs -> Exception
UnsupportedStockDeriving GenLocated SrcSpanAnnA (HsType GhcPs)
ty)