{-# LANGUAGE CPP #-}
module Composite.Opaleye.TH where

import Control.Lens ((<&>))
import qualified Data.ByteString.Char8 as BSC8
import Data.List.Split (splitOn)
import Data.Maybe (fromMaybe)
import Data.Profunctor.Product.Default (Default, def)
import Data.Traversable (for)
import Database.PostgreSQL.Simple (ResultError(ConversionFailed, Incompatible, UnexpectedNull))
import Database.PostgreSQL.Simple.FromField (FromField, fromField, typename, returnError)
import Language.Haskell.TH
  ( Q, Name, mkName, nameBase, newName, pprint, reify
  , Info(TyConI), Dec(DataD), Con(NormalC)
  , conT
  , dataD, instanceD
  , lamE, varE, caseE, conE
  , conP, varP, wildP, litP, stringL
  , caseE, match
  , funD, clause
  , normalB, normalGE, guardedB
  , cxt
  )
import Language.Haskell.TH.Syntax (lift)
import Opaleye
  ( Column, DefaultFromField, ToFields, fromPGSFromField, defaultFromField
  )
import Opaleye.Internal.PGTypes (IsSqlType, showSqlType, literalColumn)
import Opaleye.Internal.HaskellDB.PrimQuery (Literal(StringLit))

getLastComponent :: String -> String
getLastComponent :: String -> String
getLastComponent String
str = case [String] -> [String]
forall a. [a] -> [a]
reverse (String -> String -> [String]
forall a. Eq a => [a] -> [a] -> [[a]]
splitOn String
"." String
str) of
  String
x:[String]
_ -> String
x
  [] -> String
str

-- |Derive the various instances required to make a Haskell enumeration map to a PostgreSQL @enum@ type.
--
-- In @deriveOpaleyeEnum ''HaskellType "schema.sqltype" hsConToSqlValue@, @''HaskellType@ is the sum type (data declaration) to make instances for, 
-- @"schema.sqltype"@ is the PostgreSQL type name, and @hsConToSqlValue@ is a function to map names of constructors to SQL values.
--
-- The function @hsConToSqlValue@ is of the type @String -> Maybe String@ in order to make using 'stripPrefix' convenient. The function is applied to each
-- constructor name and for @Just value@ that value is used, otherwise for @Nothing@ the constructor name is used.
--
-- For example, given the Haskell type:
--
-- @
--     data MyEnum = MyFoo | MyBar
-- @
--
-- And PostgreSQL type:
--
-- @
--     CREATE TYPE myenum AS ENUM('foo', 'bar');
-- @
--
-- The splice:
--
-- @
--     deriveOpaleyeEnum ''MyEnum "myschema.myenum" ('stripPrefix' "my" . 'map' 'toLower')
-- @
--
-- Will create @PGMyEnum@ and instances required to use @MyEnum@ / @Column MyEnum@ in Opaleye.
--
-- The Haskell generated by this splice for the example is something like:
--
-- @
--     data PGMyEnum
--
--     instance 'IsSqlType' PGMyEnum where
--       'showSqlType' _ = "myschema.myenum"
--
--     instance 'FromField' MyEnum where
--       'fromField' f mbs = do
--         tname <- 'typename' f
--         case mbs of
--           _ | 'getLastComponent' ('BSC8.unpack' tname) /= "myenum" -> 'returnError' 'Incompatible' f ""
--           Just "foo" -> pure MyFoo
--           Just "bar" -> pure MyBar
--           Just other -> 'returnError' 'ConversionFailed' f ("Unexpected myschema.myenum value: " <> 'BSC8.unpack' other)
--           Nothing    -> 'returnError' 'UnexpectedNull' f ""
--
--     instance 'DefaultFromField' PGMyEnum MyEnum where
--       defaultFromField = 'fromPGSFromField'
--
--     instance 'Default' 'ToFields' MyEnum ('Column' PGMyEnum) where
--       def = 'ToFields' $ \ a ->
--         'literalColumn' . 'stringLit' $ case a of
--           MyFoo -> "foo"
--           MyBar -> "bar"
-- @
deriveOpaleyeEnum :: Name -> String -> (String -> Maybe String) -> Q [Dec]
deriveOpaleyeEnum :: Name -> String -> (String -> Maybe String) -> Q [Dec]
deriveOpaleyeEnum Name
hsName String
sqlName String -> Maybe String
hsConToSqlValue = do
  let sqlTypeName :: Name
sqlTypeName = String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"PG" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
hsName
      sqlType :: TypeQ
sqlType = Name -> TypeQ
conT Name
sqlTypeName
      hsType :: TypeQ
hsType = Name -> TypeQ
conT Name
hsName
      unqualSqlName :: String
unqualSqlName = String -> String
getLastComponent String
sqlName

  [Con]
rawCons <- Name -> Q Info
reify Name
hsName Q Info -> (Info -> Q [Con]) -> Q [Con]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ case
    TyConI (DataD Cxt
_cxt Name
_name [TyVarBndr]
_tvVarBndrs Maybe Kind
_maybeKind [Con]
cons [DerivClause]
_derivingCxt) ->
      [Con] -> Q [Con]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Con]
cons
    Info
other ->
      String -> Q [Con]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Con]) -> String -> Q [Con]
forall a b. (a -> b) -> a -> b
$ String
"expected " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Name -> String
forall a. Show a => a -> String
show Name
hsName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" to name a data declaration, not:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Info -> String
forall a. Ppr a => a -> String
pprint Info
other

  [Name]
nullaryCons <- [Con] -> (Con -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Con]
rawCons ((Con -> Q Name) -> Q [Name]) -> (Con -> Q Name) -> Q [Name]
forall a b. (a -> b) -> a -> b
$ \ case
    NormalC Name
conName [] ->
      Name -> Q Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
conName
    Con
other ->
      String -> Q Name
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ String
"expected every constructor of " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Name -> String
forall a. Show a => a -> String
show Name
hsName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" to be a regular nullary constructor, not:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Con -> String
forall a. Ppr a => a -> String
pprint Con
other

  let conPairs :: [(Name, String)]
conPairs = [Name]
nullaryCons [Name] -> (Name -> (Name, String)) -> [(Name, String)]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ Name
conName ->
        (Name
conName, String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe (Name -> String
nameBase Name
conName) (String -> Maybe String
hsConToSqlValue (Name -> String
nameBase Name
conName)))

  Dec
sqlTypeDecl <-
    CxtQ
-> Name
-> [TyVarBndr]
-> Maybe Kind
-> [ConQ]
-> [DerivClauseQ]
-> DecQ
dataD
      ([TypeQ] -> CxtQ
cxt [])
      Name
sqlTypeName
      []
      Maybe Kind
forall a. Maybe a
Nothing
      []
#if MIN_VERSION_template_haskell(2,12,0)
      []
#else
      (cxt [])
#endif

  Dec
isSqlTypeInst <- CxtQ -> TypeQ -> [DecQ] -> DecQ
instanceD ([TypeQ] -> CxtQ
cxt []) [t| IsSqlType $sqlType |] ([DecQ] -> DecQ) -> (DecQ -> [DecQ]) -> DecQ -> DecQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DecQ -> [DecQ] -> [DecQ]
forall a. a -> [a] -> [a]
:[]) (DecQ -> DecQ) -> DecQ -> DecQ
forall a b. (a -> b) -> a -> b
$ do
    Name -> [ClauseQ] -> DecQ
funD 'showSqlType
      [ [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause
          [PatQ
wildP]
          (ExpQ -> BodyQ
normalB (String -> ExpQ
forall t. Lift t => t -> ExpQ
lift String
sqlName))
          []
      ]

  Dec
fromFieldInst <- CxtQ -> TypeQ -> [DecQ] -> DecQ
instanceD ([TypeQ] -> CxtQ
cxt []) [t| FromField $hsType |] ([DecQ] -> DecQ) -> (DecQ -> [DecQ]) -> DecQ -> DecQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DecQ -> [DecQ] -> [DecQ]
forall a. a -> [a] -> [a]
:[]) (DecQ -> DecQ) -> DecQ -> DecQ
forall a b. (a -> b) -> a -> b
$ do
    Name
field <- String -> Q Name
newName String
"field"
    Name
mbs   <- String -> Q Name
newName String
"mbs"
    Name
tname <- String -> Q Name
newName String
"tname"
    Name
other <- String -> Q Name
newName String
"other"

    let bodyCase :: ExpQ
bodyCase = ExpQ -> [MatchQ] -> ExpQ
caseE (Name -> ExpQ
varE Name
mbs) ([MatchQ] -> ExpQ) -> [MatchQ] -> ExpQ
forall a b. (a -> b) -> a -> b
$
          [ PatQ -> BodyQ -> [DecQ] -> MatchQ
match
              PatQ
wildP
              ([Q (Guard, Exp)] -> BodyQ
guardedB [ ExpQ -> ExpQ -> Q (Guard, Exp)
normalGE [| getLastComponent (BSC8.unpack $(varE tname)) /= $(lift unqualSqlName) |]
                                   [| returnError Incompatible $(varE field) "" |] ])
              []
          ] [MatchQ] -> [MatchQ] -> [MatchQ]
forall a. [a] -> [a] -> [a]
++
          (
            [(Name, String)]
conPairs [(Name, String)] -> ((Name, String) -> MatchQ) -> [MatchQ]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ (Name
conName, String
value) ->
              PatQ -> BodyQ -> [DecQ] -> MatchQ
match
                [p| Just $(litP $ stringL value) |]
                (ExpQ -> BodyQ
normalB [| pure $(conE conName) |])
                []
          ) [MatchQ] -> [MatchQ] -> [MatchQ]
forall a. [a] -> [a] -> [a]
++
          [ PatQ -> BodyQ -> [DecQ] -> MatchQ
match 
              [p| Just $(varP other) |]
              (ExpQ -> BodyQ
normalB [| returnError ConversionFailed $(varE field) ("Unexpected " <> $(lift sqlName) <> " value: " <> BSC8.unpack $(varE other)) |])
              []
          , PatQ -> BodyQ -> [DecQ] -> MatchQ
match
              [p| Nothing |]
              (ExpQ -> BodyQ
normalB [| returnError UnexpectedNull $(varE field) "" |])
              []
          ]

    Name -> [ClauseQ] -> DecQ
funD 'fromField
      [ [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause
          [Name -> PatQ
varP Name
field, Name -> PatQ
varP Name
mbs]
          (ExpQ -> BodyQ
normalB [|
            do
              $(varP tname) <- typename $(varE field)
              $bodyCase
            |])
          []
      ]

  Dec
defaultFromFieldInst <- CxtQ -> TypeQ -> [DecQ] -> DecQ
instanceD ([TypeQ] -> CxtQ
cxt []) [t| DefaultFromField $sqlType $hsType |] ([DecQ] -> DecQ) -> (DecQ -> [DecQ]) -> DecQ -> DecQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DecQ -> [DecQ] -> [DecQ]
forall a. a -> [a] -> [a]
:[]) (DecQ -> DecQ) -> DecQ -> DecQ
forall a b. (a -> b) -> a -> b
$
    Name -> [ClauseQ] -> DecQ
funD 'defaultFromField
      [ [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause
          []
          (ExpQ -> BodyQ
normalB [| fromPGSFromField |])
          []
      ]

  Dec
defaultInst <- CxtQ -> TypeQ -> [DecQ] -> DecQ
instanceD ([TypeQ] -> CxtQ
cxt []) [t| Default ToFields $hsType (Column $sqlType) |] ([DecQ] -> DecQ) -> (DecQ -> [DecQ]) -> DecQ -> DecQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DecQ -> [DecQ] -> [DecQ]
forall a. a -> [a] -> [a]
:[]) (DecQ -> DecQ) -> DecQ -> DecQ
forall a b. (a -> b) -> a -> b
$ do
    Name
s <- String -> Q Name
newName String
"s"
    let body :: ExpQ
body = [PatQ] -> ExpQ -> ExpQ
lamE [Name -> PatQ
varP Name
s] (ExpQ -> ExpQ) -> ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$
          ExpQ -> [MatchQ] -> ExpQ
caseE (Name -> ExpQ
varE Name
s) ([MatchQ] -> ExpQ) -> [MatchQ] -> ExpQ
forall a b. (a -> b) -> a -> b
$
            [(Name, String)]
conPairs [(Name, String)] -> ((Name, String) -> MatchQ) -> [MatchQ]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ (Name
conName, String
value) ->
              PatQ -> BodyQ -> [DecQ] -> MatchQ
match
                (Name -> [PatQ] -> PatQ
conP Name
conName [])
                (ExpQ -> BodyQ
normalB (ExpQ -> BodyQ) -> ExpQ -> BodyQ
forall a b. (a -> b) -> a -> b
$ String -> ExpQ
forall t. Lift t => t -> ExpQ
lift String
value)
                []

    Name -> [ClauseQ] -> DecQ
funD 'def
      [ [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause
          []
          (ExpQ -> BodyQ
normalB [| ToFields (literalColumn . StringLit . $body) |])
          []
      ]

  [Dec] -> Q [Dec]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Dec
sqlTypeDecl, Dec
isSqlTypeInst, Dec
fromFieldInst, Dec
defaultFromFieldInst, Dec
defaultInst]