{-# LANGUAGE AllowAmbiguousTypes #-}
module Database.PostgreSQL.PQTypes.Deriving (
  -- * Helpers, to be used with @deriving via@ (@-XDerivingVia@).
    SQLEnum(..)
  , EnumEncoding(..)
  , SQLEnumAsText(..)
  , EnumAsTextEncoding(..)
    -- * For use in doctests.
  , isInjective
  ) where

import Control.Exception (SomeException(..), throwIO)
import Data.List.Extra (enumerate, nubSort)
import Data.Map.Strict (Map)
import Data.Text (Text)
import Data.Typeable
import Database.PostgreSQL.PQTypes
import qualified Data.Map.Strict as Map

-- | Helper newtype to be used with @deriving via@ to derive @(PQFormat, ToSQL,
-- FromSQL)@ instances for enums, given an instance of 'EnumEncoding'.
--
-- /Hint:/ non-trivial 'Enum' instances can be derived using the 'generic-data'
-- package!
--
-- >>> :{
-- data Colours = Blue | Black | Red | Mauve | Orange
--   deriving (Eq, Show, Enum, Bounded)
-- instance EnumEncoding Colours where
--   type EnumBase Colours = Int16
--   encodeEnum = \case
--     Blue   -> 1
--     Black  -> 7
--     Red    -> 2
--     Mauve  -> 6
--     Orange -> 3
-- :}
--
-- /Note:/ To get SQL-specific instances use @DerivingVia@:
--
-- @
-- data Colours = ...
--   ...
--   deriving (PQFormat, ToSQL, FromSQL) via SQLEnum Colours
-- @
--
-- >>> isInjective (encodeEnum @Colours)
-- True
--
-- >>> decodeEnum @Colours 7
-- Right Black
--
-- >>> decodeEnum @Colours 42
-- Left [(1,3),(6,7)]
newtype SQLEnum a = SQLEnum a

class
  ( -- The semantic type needs to be finitely enumerable.
    Enum a
  , Bounded a
    -- The base type needs to be enumerable and ordered.
  , Enum (EnumBase a)
  , Ord (EnumBase a)
  ) => EnumEncoding a where
  type EnumBase a
  -- | Encode @a@ as a base type.
  encodeEnum :: a -> EnumBase a

  -- | Decode base type to an @a@. If the conversion fails, a list of valid
  -- ranges is returned instead.
  --
  -- /Note:/ The default implementation looks up values in 'decodeEnumMap' and
  -- can be overwritten for performance if necessary.
  decodeEnum :: EnumBase a -> Either [(EnumBase a, EnumBase a)] a
  decodeEnum EnumBase a
b = Either [(EnumBase a, EnumBase a)] a
-> (a -> Either [(EnumBase a, EnumBase a)] a)
-> Maybe a
-> Either [(EnumBase a, EnumBase a)] a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([(EnumBase a, EnumBase a)] -> Either [(EnumBase a, EnumBase a)] a
forall a b. a -> Either a b
Left ([(EnumBase a, EnumBase a)] -> Either [(EnumBase a, EnumBase a)] a)
-> ([EnumBase a] -> [(EnumBase a, EnumBase a)])
-> [EnumBase a]
-> Either [(EnumBase a, EnumBase a)] a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [EnumBase a] -> [(EnumBase a, EnumBase a)]
forall a. (Enum a, Ord a) => [a] -> [(a, a)]
intervals ([EnumBase a] -> Either [(EnumBase a, EnumBase a)] a)
-> [EnumBase a] -> Either [(EnumBase a, EnumBase a)] a
forall a b. (a -> b) -> a -> b
$ Map (EnumBase a) a -> [EnumBase a]
forall k a. Map k a -> [k]
Map.keys (EnumEncoding a => Map (EnumBase a) a
forall a. EnumEncoding a => Map (EnumBase a) a
decodeEnumMap @a)) a -> Either [(EnumBase a, EnumBase a)] a
forall a b. b -> Either a b
Right
               (Maybe a -> Either [(EnumBase a, EnumBase a)] a)
-> Maybe a -> Either [(EnumBase a, EnumBase a)] a
forall a b. (a -> b) -> a -> b
$ EnumBase a -> Map (EnumBase a) a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup EnumBase a
b (EnumEncoding a => Map (EnumBase a) a
forall a. EnumEncoding a => Map (EnumBase a) a
decodeEnumMap @a)

  -- | Include the inverse map as a top-level part of the 'EnumEncoding'
  -- instance to ensure it is only computed once by GHC.
  decodeEnumMap :: Map (EnumBase a) a
  decodeEnumMap = [(EnumBase a, a)] -> Map (EnumBase a) a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (a -> EnumBase a
forall a. EnumEncoding a => a -> EnumBase a
encodeEnum a
a, a
a) | a
a <- [a]
forall a. (Enum a, Bounded a) => [a]
enumerate ]

instance PQFormat (EnumBase a) => PQFormat (SQLEnum a) where
  pqFormat :: ByteString
pqFormat = PQFormat (EnumBase a) => ByteString
forall t. PQFormat t => ByteString
pqFormat @(EnumBase a)

instance
  ( EnumEncoding a
  , PQFormat (EnumBase a)
  , ToSQL (EnumBase a)
  ) => ToSQL (SQLEnum a) where
  type PQDest (SQLEnum a) = PQDest (EnumBase a)
  toSQL :: SQLEnum a
-> ParamAllocator -> (Ptr (PQDest (SQLEnum a)) -> IO r) -> IO r
toSQL (SQLEnum a
a) = EnumBase a
-> ParamAllocator -> (Ptr (PQDest (EnumBase a)) -> IO r) -> IO r
forall t r.
ToSQL t =>
t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
toSQL (EnumBase a
 -> ParamAllocator -> (Ptr (PQDest (EnumBase a)) -> IO r) -> IO r)
-> EnumBase a
-> ParamAllocator
-> (Ptr (PQDest (EnumBase a)) -> IO r)
-> IO r
forall a b. (a -> b) -> a -> b
$ a -> EnumBase a
forall a. EnumEncoding a => a -> EnumBase a
encodeEnum a
a

instance
  ( EnumEncoding a
  , PQFormat (EnumBase a)
  , FromSQL (EnumBase a)
  , Show (EnumBase a)
  , Typeable (EnumBase a)
  ) => FromSQL (SQLEnum a) where
  type PQBase (SQLEnum a) = PQBase (EnumBase a)
  fromSQL :: Maybe (PQBase (SQLEnum a)) -> IO (SQLEnum a)
fromSQL Maybe (PQBase (SQLEnum a))
base = do
    EnumBase a
b <- Maybe (PQBase (EnumBase a)) -> IO (EnumBase a)
forall t. FromSQL t => Maybe (PQBase t) -> IO t
fromSQL Maybe (PQBase (EnumBase a))
Maybe (PQBase (SQLEnum a))
base
    case EnumBase a -> Either [(EnumBase a, EnumBase a)] a
forall a.
EnumEncoding a =>
EnumBase a -> Either [(EnumBase a, EnumBase a)] a
decodeEnum EnumBase a
b of
      Left [(EnumBase a, EnumBase a)]
validRange -> SomeException -> IO (SQLEnum a)
forall e a. Exception e => e -> IO a
throwIO (SomeException -> IO (SQLEnum a))
-> SomeException -> IO (SQLEnum a)
forall a b. (a -> b) -> a -> b
$ RangeError (EnumBase a) -> SomeException
forall e. Exception e => e -> SomeException
SomeException RangeError :: forall t. [(t, t)] -> t -> RangeError t
RangeError
        { reRange :: [(EnumBase a, EnumBase a)]
reRange = [(EnumBase a, EnumBase a)]
validRange
        , reValue :: EnumBase a
reValue = EnumBase a
b
        }
      Right a
a -> SQLEnum a -> IO (SQLEnum a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SQLEnum a -> IO (SQLEnum a)) -> SQLEnum a -> IO (SQLEnum a)
forall a b. (a -> b) -> a -> b
$ a -> SQLEnum a
forall a. a -> SQLEnum a
SQLEnum a
a

-- | A special case of 'SQLEnum', where the enum is to be encoded as text
-- ('SQLEnum' can't be used because of the 'Enum' constraint on the domain of
-- 'encodeEnum').
--
-- >>> :{
-- data Person = Alfred | Bertrand | Charles
--   deriving (Eq, Show, Enum, Bounded)
-- instance EnumAsTextEncoding Person where
--   encodeEnumAsText = \case
--     Alfred   -> "alfred"
--     Bertrand -> "bertrand"
--     Charles  -> "charles"
-- :}
--
-- /Note:/ To get SQL-specific instances use @DerivingVia@:
--
-- @
-- data Person = ...
--   ...
--   deriving (PQFormat, ToSQL, FromSQL) via SQLEnumAsText Person
-- @
--
-- >>> isInjective (encodeEnumAsText @Person)
-- True
--
-- >>> decodeEnumAsText @Person "bertrand"
-- Right Bertrand
--
-- >>> decodeEnumAsText @Person "batman"
-- Left ["alfred","bertrand","charles"]
newtype SQLEnumAsText a = SQLEnumAsText a

class (Enum a, Bounded a) => EnumAsTextEncoding a where
  -- | Encode @a@ as 'Text'.
  encodeEnumAsText :: a -> Text

  -- | Decode 'Text' to an @a@. If the conversion fails, a list of valid values
  -- is returned instead.
  --
  -- /Note:/ The default implementation looks up values in 'decodeEnumAsTextMap'
  -- and can be overwritten for performance if necessary.
  decodeEnumAsText :: Text -> Either [Text] a
  decodeEnumAsText Text
text = Either [Text] a
-> (a -> Either [Text] a) -> Maybe a -> Either [Text] a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Text] -> Either [Text] a
forall a b. a -> Either a b
Left ([Text] -> Either [Text] a) -> [Text] -> Either [Text] a
forall a b. (a -> b) -> a -> b
$ Map Text a -> [Text]
forall k a. Map k a -> [k]
Map.keys (EnumAsTextEncoding a => Map Text a
forall a. EnumAsTextEncoding a => Map Text a
decodeEnumAsTextMap @a)) a -> Either [Text] a
forall a b. b -> Either a b
Right
                        (Maybe a -> Either [Text] a) -> Maybe a -> Either [Text] a
forall a b. (a -> b) -> a -> b
$ Text -> Map Text a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
text (EnumAsTextEncoding a => Map Text a
forall a. EnumAsTextEncoding a => Map Text a
decodeEnumAsTextMap @a)

  -- | Include the inverse map as a top-level part of the 'SQLEnumTextEncoding'
  -- instance to ensure it is only computed once by GHC.
  decodeEnumAsTextMap :: Map Text a
  decodeEnumAsTextMap = [(Text, a)] -> Map Text a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (a -> Text
forall a. EnumAsTextEncoding a => a -> Text
encodeEnumAsText a
a, a
a) | a
a <- [a]
forall a. (Enum a, Bounded a) => [a]
enumerate ]

instance EnumAsTextEncoding a => PQFormat (SQLEnumAsText a) where
  pqFormat :: ByteString
pqFormat = PQFormat Text => ByteString
forall t. PQFormat t => ByteString
pqFormat @Text

instance EnumAsTextEncoding a => ToSQL (SQLEnumAsText a) where
  type PQDest (SQLEnumAsText a) = PQDest Text
  toSQL :: SQLEnumAsText a
-> ParamAllocator
-> (Ptr (PQDest (SQLEnumAsText a)) -> IO r)
-> IO r
toSQL (SQLEnumAsText a
a) = Text -> ParamAllocator -> (Ptr (PQDest Text) -> IO r) -> IO r
forall t r.
ToSQL t =>
t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
toSQL (Text -> ParamAllocator -> (Ptr (PQDest Text) -> IO r) -> IO r)
-> Text -> ParamAllocator -> (Ptr (PQDest Text) -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ a -> Text
forall a. EnumAsTextEncoding a => a -> Text
encodeEnumAsText a
a

instance EnumAsTextEncoding a => FromSQL (SQLEnumAsText a) where
  type PQBase (SQLEnumAsText a) = PQBase Text
  fromSQL :: Maybe (PQBase (SQLEnumAsText a)) -> IO (SQLEnumAsText a)
fromSQL Maybe (PQBase (SQLEnumAsText a))
base = do
    Text
text <- Maybe (PQBase Text) -> IO Text
forall t. FromSQL t => Maybe (PQBase t) -> IO t
fromSQL Maybe (PQBase Text)
Maybe (PQBase (SQLEnumAsText a))
base
    case Text -> Either [Text] a
forall a. EnumAsTextEncoding a => Text -> Either [Text] a
decodeEnumAsText Text
text of
      Left [Text]
validValues -> SomeException -> IO (SQLEnumAsText a)
forall e a. Exception e => e -> IO a
throwIO (SomeException -> IO (SQLEnumAsText a))
-> SomeException -> IO (SQLEnumAsText a)
forall a b. (a -> b) -> a -> b
$ InvalidValue Text -> SomeException
forall e. Exception e => e -> SomeException
SomeException InvalidValue :: forall t. t -> Maybe [t] -> InvalidValue t
InvalidValue
        { ivValue :: Text
ivValue       = Text
text
        , ivValidValues :: Maybe [Text]
ivValidValues = [Text] -> Maybe [Text]
forall a. a -> Maybe a
Just [Text]
validValues
        }
      Right a
a -> SQLEnumAsText a -> IO (SQLEnumAsText a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SQLEnumAsText a -> IO (SQLEnumAsText a))
-> SQLEnumAsText a -> IO (SQLEnumAsText a)
forall a b. (a -> b) -> a -> b
$ a -> SQLEnumAsText a
forall a. a -> SQLEnumAsText a
SQLEnumAsText a
a

-- | To be used in doctests to prove injectivity of encoding functions.
--
-- >>> isInjective (id :: Bool -> Bool)
-- True
--
-- >>> isInjective (\(_ :: Bool) -> False)
-- False
isInjective :: (Enum a, Bounded a, Eq a, Eq b) => (a -> b) -> Bool
isInjective :: (a -> b) -> Bool
isInjective a -> b
f = [(a, a)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ (a
a, a
b) | a
a <- [a]
forall a. (Enum a, Bounded a) => [a]
enumerate, a
b <- [a]
forall a. (Enum a, Bounded a) => [a]
enumerate, a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
b, a -> b
f a
a b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== a -> b
f a
b ]

-- | Internal helper: given a list of values, decompose it into a list of
-- intervals.
--
-- >>> intervals [42,2,1,0,3,88,-1,43,42]
-- [(-1,3),(42,43),(88,88)]
--
-- prop> nubSort xs == concatMap (\(l,r) -> [l .. r]) (intervals xs)
intervals :: forall  a . (Enum a, Ord a) => [a] -> [(a, a)]
intervals :: [a] -> [(a, a)]
intervals [a]
as = case [a] -> [a]
forall a. Ord a => [a] -> [a]
nubSort [a]
as of
  [] -> []
  (a
first : [a]
ascendingRest) -> (a, a) -> [a] -> [(a, a)]
accumIntervals (a
first, a
first) [a]
ascendingRest
  where
    accumIntervals :: (a, a) -> [a] -> [(a, a)]
    accumIntervals :: (a, a) -> [a] -> [(a, a)]
accumIntervals (a
lower, a
upper) [] = [(a
lower, a
upper)]
    accumIntervals (a
lower, a
upper) (a
first' : [a]
ascendingRest') = if a -> a
forall a. Enum a => a -> a
succ a
upper a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
first'
      then (a, a) -> [a] -> [(a, a)]
accumIntervals (a
lower, a
first') [a]
ascendingRest'
      else (a
lower, a
upper) (a, a) -> [(a, a)] -> [(a, a)]
forall a. a -> [a] -> [a]
: (a, a) -> [a] -> [(a, a)]
accumIntervals (a
first', a
first') [a]
ascendingRest'

-- $setup
-- >>> import Data.Int