{-# LANGUAGE CPP #-}
module Composite.Opaleye.TH where

import Control.Lens ((<&>))
import qualified Data.ByteString.Char8 as BSC8
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Data.Profunctor.Product.Default (Default, def)
import Data.Traversable (for)
import Database.PostgreSQL.Simple (ResultError(ConversionFailed, Incompatible, UnexpectedNull))
import Database.PostgreSQL.Simple.FromField (FromField, fromField, typename, returnError)
import Language.Haskell.TH
  ( Q, Name, mkName, nameBase, newName, pprint, reify
  , Info(TyConI), Dec(DataD), Con(NormalC)
  , conT
  , dataD, instanceD
  , lamE, varE, caseE, conE
  , conP, varP, wildP, litP, stringL
  , caseE, match
  , funD, clause
  , normalB, normalGE, guardedB
  , cxt
  )
import Language.Haskell.TH.Syntax (lift)
import Opaleye
  ( Column, Constant(..), QueryRunnerColumnDefault, ToFields, fieldQueryRunnerColumn, queryRunnerColumnDefault
  )
import Opaleye.Internal.PGTypes (IsSqlType, showSqlType, literalColumn)
import Opaleye.Internal.HaskellDB.PrimQuery (Literal(StringLit))

-- |Derive the various instances required to make a Haskell enumeration map to a PostgreSQL @enum@ type.
--
-- In @deriveOpaleyeEnum ''HaskellType "schema.sqltype" hsConToSqlValue@, @''HaskellType@ is the sum type (data declaration) to make instances for, 
-- @"schema.sqltype"@ is the PostgreSQL type name, and @hsConToSqlValue@ is a function to map names of constructors to SQL values.
--
-- The function @hsConToSqlValue@ is of the type @String -> Maybe String@ in order to make using 'stripPrefix' convenient. The function is applied to each
-- constructor name and for @Just value@ that value is used, otherwise for @Nothing@ the constructor name is used.
--
-- For example, given the Haskell type:
--
-- @
--     data MyEnum = MyFoo | MyBar
-- @
--
-- And PostgreSQL type:
--
-- @
--     CREATE TYPE myenum AS ENUM('foo', 'bar');
-- @
--
-- The splice:
--
-- @
--     deriveOpaleyeEnum ''MyEnum "myschema.myenum" ('stripPrefix' "my" . 'map' 'toLower')
-- @
--
-- Will create @PGMyEnum@ and instances required to use @MyEnum@ / @Column MyEnum@ in Opaleye.
--
-- The Haskell generated by this splice for the example is something like:
--
-- @
--     data PGMyEnum
--
--     instance 'IsSqlType' PGMyEnum where
--       'showSqlType' _ = "myschema.myenum"
--
--     instance 'FromField' MyEnum where
--       'fromField' f mbs = do
--         tname <- 'typename' f
--         case mbs of
--           _ | tname /= "myenum" -> 'returnError' 'Incompatible' f ""
--           Just "foo" -> pure MyFoo
--           Just "bar" -> pure MyBar
--           Just other -> 'returnError' 'ConversionFailed' f ("Unexpected myschema.myenum value: " <> 'BSC8.unpack' other)
--           Nothing    -> 'returnError' 'UnexpectedNull' f ""
--
--     instance 'QueryRunnerColumnDefault' PGMyEnum MyEnum where
--       queryRunnerColumnDefault = 'fieldQueryRunnerColumn'
--
--     instance 'Default' 'ToFields' MyEnum ('Column' PGMyEnum) where
--       def = 'Constant' $ \ a ->
--         'literalColumn' . 'stringLit' $ case a of
--           MyFoo -> "foo"
--           MyBar -> "bar"
-- @
deriveOpaleyeEnum :: Name -> String -> (String -> Maybe String) -> Q [Dec]
deriveOpaleyeEnum hsName sqlName hsConToSqlValue = do
  let sqlTypeName = mkName $ "PG" ++ nameBase hsName
      sqlType = conT sqlTypeName
      hsType = conT hsName

  rawCons <- reify hsName >>= \ case
    TyConI (DataD _cxt _name _tvVarBndrs _maybeKind cons _derivingCxt) ->
      pure cons
    other ->
      fail $ "expected " <> show hsName <> " to name a data declaration, not:\n" <> pprint other

  nullaryCons <- for rawCons $ \ case
    NormalC conName [] ->
      pure conName
    other ->
      fail $ "expected every constructor of " <> show hsName <> " to be a regular nullary constructor, not:\n" <> pprint other

  let conPairs = nullaryCons <&> \ conName ->
        (conName, fromMaybe (nameBase conName) (hsConToSqlValue (nameBase conName)))

  sqlTypeDecl <-
    dataD
      (cxt [])
      sqlTypeName
      []
      Nothing
      []
#if MIN_VERSION_template_haskell(2,12,0)
      []
#else
      (cxt [])
#endif

  isSqlTypeInst <- instanceD (cxt []) [t| IsSqlType $sqlType |] . (:[]) $ do
    funD 'showSqlType
      [ clause
          [wildP]
          (normalB (lift sqlName))
          []
      ]

  fromFieldInst <- instanceD (cxt []) [t| FromField $hsType |] . (:[]) $ do
    field <- newName "field"
    mbs   <- newName "mbs"
    tname <- newName "tname"
    other <- newName "other"

    let bodyCase = caseE (varE mbs) $
          [ match
              wildP
              (guardedB [ normalGE [| $(varE tname) /= $(lift sqlName) |]
                                   [| returnError Incompatible $(varE field) "" |] ])
              []
          ] ++
          (
            conPairs <&> \ (conName, value) ->
              match
                [p| Just $(litP $ stringL value) |]
                (normalB [| pure $(conE conName) |])
                []
          ) ++
          [ match
              [p| Just $(varP other) |]
              (normalB [| returnError ConversionFailed $(varE field) ("Unexpected " <> $(lift sqlName) <> " value: " <> BSC8.unpack $(varE other)) |])
              []
          , match
              [p| Nothing |]
              (normalB [| returnError UnexpectedNull $(varE field) "" |])
              []
          ]

    funD 'fromField
      [ clause
          [varP field, varP mbs]
          (normalB [|
            do
              $(varP tname) <- typename $(varE field)
              $bodyCase
            |])
          []
      ]

  queryRunnerColumnDefaultInst <- instanceD (cxt []) [t| QueryRunnerColumnDefault $sqlType $hsType |] . (:[]) $
    funD 'queryRunnerColumnDefault
      [ clause
          []
          (normalB [| fieldQueryRunnerColumn |])
          []
      ]

  defaultInst <- instanceD (cxt []) [t| Default ToFields $hsType (Column $sqlType) |] . (:[]) $ do
    s <- newName "s"
    let body = lamE [varP s] $
          caseE (varE s) $
            conPairs <&> \ (conName, value) ->
              match
                (conP conName [])
                (normalB $ lift value)
                []

    funD 'def
      [ clause
          []
          (normalB [| Constant (literalColumn . StringLit . $body) |])
          []
      ]

  pure [sqlTypeDecl, isSqlTypeInst, fromFieldInst, queryRunnerColumnDefaultInst, defaultInst]