{-# LANGUAGE AllowAmbiguousTypes #-}
module Database.PostgreSQL.PQTypes.Deriving (
  -- * Helpers, to be used with @deriving via@ (@-XDerivingVia@).
    SQLEnum(..)
  , EnumEncoding(..)
  , SQLEnumAsText(..)
  , EnumAsTextEncoding(..)
    -- * For use in doctests.
  , isInjective
  ) where

import Control.Exception (SomeException(..), throwIO)
import Data.List.Extra (enumerate, nubSort)
import Data.Map.Strict (Map)
import Data.Text (Text)
import Data.Typeable
import Database.PostgreSQL.PQTypes
import qualified Data.Map.Strict as Map

-- | Helper newtype to be used with @deriving via@ to derive @(PQFormat, ToSQL,
-- FromSQL)@ instances for enums, given an instance of 'EnumEncoding'.
--
-- /Hint:/ non-trivial 'Enum' instances can be derived using the 'generic-data'
-- package!
--
-- >>> :{
-- data Colours = Blue | Black | Red | Mauve | Orange
--   deriving (Eq, Show, Enum, Bounded)
-- instance EnumEncoding Colours where
--   type EnumBase Colours = Int16
--   encodeEnum = \case
--     Blue   -> 1
--     Black  -> 7
--     Red    -> 2
--     Mauve  -> 6
--     Orange -> 3
-- :}
--
-- /Note:/ To get SQL-specific instances use @DerivingVia@:
--
-- @
-- data Colours = ...
--   ...
--   deriving (PQFormat, ToSQL, FromSQL) via SQLEnum Colours
-- @
--
-- >>> isInjective (encodeEnum @Colours)
-- True
--
-- >>> decodeEnum @Colours 7
-- Right Black
--
-- >>> decodeEnum @Colours 42
-- Left [(1,3),(6,7)]
newtype SQLEnum a = SQLEnum a

class
  ( -- The semantic type needs to be finitely enumerable.
    Enum a
  , Bounded a
    -- The base type needs to be enumerable and ordered.
  , Enum (EnumBase a)
  , Ord (EnumBase a)
  ) => EnumEncoding a where
  type EnumBase a
  -- | Encode @a@ as a base type.
  encodeEnum :: a -> EnumBase a

  -- | Decode base type to an @a@. If the conversion fails, a list of valid
  -- ranges is returned instead.
  --
  -- /Note:/ The default implementation looks up values in 'decodeEnumMap' and
  -- can be overwritten for performance if necessary.
  decodeEnum :: EnumBase a -> Either [(EnumBase a, EnumBase a)] a
  decodeEnum EnumBase a
b = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Enum a, Ord a) => [a] -> [(a, a)]
intervals forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
Map.keys (forall a. EnumEncoding a => Map (EnumBase a) a
decodeEnumMap @a)) forall a b. b -> Either a b
Right
               forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup EnumBase a
b (forall a. EnumEncoding a => Map (EnumBase a) a
decodeEnumMap @a)

  -- | Include the inverse map as a top-level part of the 'EnumEncoding'
  -- instance to ensure it is only computed once by GHC.
  decodeEnumMap :: Map (EnumBase a) a
  decodeEnumMap = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (forall a. EnumEncoding a => a -> EnumBase a
encodeEnum a
a, a
a) | a
a <- forall a. (Enum a, Bounded a) => [a]
enumerate ]

instance PQFormat (EnumBase a) => PQFormat (SQLEnum a) where
  pqFormat :: ByteString
pqFormat = forall t. PQFormat t => ByteString
pqFormat @(EnumBase a)

instance
  ( EnumEncoding a
  , PQFormat (EnumBase a)
  , ToSQL (EnumBase a)
  ) => ToSQL (SQLEnum a) where
  type PQDest (SQLEnum a) = PQDest (EnumBase a)
  toSQL :: forall r.
SQLEnum a
-> ParamAllocator -> (Ptr (PQDest (SQLEnum a)) -> IO r) -> IO r
toSQL (SQLEnum a
a) = forall t r.
ToSQL t =>
t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
toSQL forall a b. (a -> b) -> a -> b
$ forall a. EnumEncoding a => a -> EnumBase a
encodeEnum a
a

instance
  ( EnumEncoding a
  , PQFormat (EnumBase a)
  , FromSQL (EnumBase a)
  , Show (EnumBase a)
  , Typeable (EnumBase a)
  ) => FromSQL (SQLEnum a) where
  type PQBase (SQLEnum a) = PQBase (EnumBase a)
  fromSQL :: Maybe (PQBase (SQLEnum a)) -> IO (SQLEnum a)
fromSQL Maybe (PQBase (SQLEnum a))
base = do
    EnumBase a
b <- forall t. FromSQL t => Maybe (PQBase t) -> IO t
fromSQL Maybe (PQBase (SQLEnum a))
base
    case forall a.
EnumEncoding a =>
EnumBase a -> Either [(EnumBase a, EnumBase a)] a
decodeEnum EnumBase a
b of
      Left [(EnumBase a, EnumBase a)]
validRange -> forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ forall e. Exception e => e -> SomeException
SomeException RangeError
        { reRange :: [(EnumBase a, EnumBase a)]
reRange = [(EnumBase a, EnumBase a)]
validRange
        , reValue :: EnumBase a
reValue = EnumBase a
b
        }
      Right a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> SQLEnum a
SQLEnum a
a

-- | A special case of 'SQLEnum', where the enum is to be encoded as text
-- ('SQLEnum' can't be used because of the 'Enum' constraint on the domain of
-- 'encodeEnum').
--
-- >>> :{
-- data Person = Alfred | Bertrand | Charles
--   deriving (Eq, Show, Enum, Bounded)
-- instance EnumAsTextEncoding Person where
--   encodeEnumAsText = \case
--     Alfred   -> "alfred"
--     Bertrand -> "bertrand"
--     Charles  -> "charles"
-- :}
--
-- /Note:/ To get SQL-specific instances use @DerivingVia@:
--
-- @
-- data Person = ...
--   ...
--   deriving (PQFormat, ToSQL, FromSQL) via SQLEnumAsText Person
-- @
--
-- >>> isInjective (encodeEnumAsText @Person)
-- True
--
-- >>> decodeEnumAsText @Person "bertrand"
-- Right Bertrand
--
-- >>> decodeEnumAsText @Person "batman"
-- Left ["alfred","bertrand","charles"]
newtype SQLEnumAsText a = SQLEnumAsText a

class (Enum a, Bounded a) => EnumAsTextEncoding a where
  -- | Encode @a@ as 'Text'.
  encodeEnumAsText :: a -> Text

  -- | Decode 'Text' to an @a@. If the conversion fails, a list of valid values
  -- is returned instead.
  --
  -- /Note:/ The default implementation looks up values in 'decodeEnumAsTextMap'
  -- and can be overwritten for performance if necessary.
  decodeEnumAsText :: Text -> Either [Text] a
  decodeEnumAsText Text
text = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
Map.keys (forall a. EnumAsTextEncoding a => Map Text a
decodeEnumAsTextMap @a)) forall a b. b -> Either a b
Right
                        forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
text (forall a. EnumAsTextEncoding a => Map Text a
decodeEnumAsTextMap @a)

  -- | Include the inverse map as a top-level part of the 'SQLEnumTextEncoding'
  -- instance to ensure it is only computed once by GHC.
  decodeEnumAsTextMap :: Map Text a
  decodeEnumAsTextMap = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (forall a. EnumAsTextEncoding a => a -> Text
encodeEnumAsText a
a, a
a) | a
a <- forall a. (Enum a, Bounded a) => [a]
enumerate ]

instance EnumAsTextEncoding a => PQFormat (SQLEnumAsText a) where
  pqFormat :: ByteString
pqFormat = forall t. PQFormat t => ByteString
pqFormat @Text

instance EnumAsTextEncoding a => ToSQL (SQLEnumAsText a) where
  type PQDest (SQLEnumAsText a) = PQDest Text
  toSQL :: forall r.
SQLEnumAsText a
-> ParamAllocator
-> (Ptr (PQDest (SQLEnumAsText a)) -> IO r)
-> IO r
toSQL (SQLEnumAsText a
a) = forall t r.
ToSQL t =>
t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
toSQL forall a b. (a -> b) -> a -> b
$ forall a. EnumAsTextEncoding a => a -> Text
encodeEnumAsText a
a

instance EnumAsTextEncoding a => FromSQL (SQLEnumAsText a) where
  type PQBase (SQLEnumAsText a) = PQBase Text
  fromSQL :: Maybe (PQBase (SQLEnumAsText a)) -> IO (SQLEnumAsText a)
fromSQL Maybe (PQBase (SQLEnumAsText a))
base = do
    Text
text <- forall t. FromSQL t => Maybe (PQBase t) -> IO t
fromSQL Maybe (PQBase (SQLEnumAsText a))
base
    case forall a. EnumAsTextEncoding a => Text -> Either [Text] a
decodeEnumAsText Text
text of
      Left [Text]
validValues -> forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ forall e. Exception e => e -> SomeException
SomeException InvalidValue
        { ivValue :: Text
ivValue       = Text
text
        , ivValidValues :: Maybe [Text]
ivValidValues = forall a. a -> Maybe a
Just [Text]
validValues
        }
      Right a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> SQLEnumAsText a
SQLEnumAsText a
a

-- | To be used in doctests to prove injectivity of encoding functions.
--
-- >>> isInjective (id :: Bool -> Bool)
-- True
--
-- >>> isInjective (\(_ :: Bool) -> False)
-- False
isInjective :: (Enum a, Bounded a, Eq a, Eq b) => (a -> b) -> Bool
isInjective :: forall a b. (Enum a, Bounded a, Eq a, Eq b) => (a -> b) -> Bool
isInjective a -> b
f = forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ (a
a, a
b) | a
a <- forall a. (Enum a, Bounded a) => [a]
enumerate, a
b <- forall a. (Enum a, Bounded a) => [a]
enumerate, a
a forall a. Eq a => a -> a -> Bool
/= a
b, a -> b
f a
a forall a. Eq a => a -> a -> Bool
== a -> b
f a
b ]

-- | Internal helper: given a list of values, decompose it into a list of
-- intervals.
--
-- >>> intervals [42,2,1,0,3,88,-1,43,42]
-- [(-1,3),(42,43),(88,88)]
--
-- prop> nubSort xs == concatMap (\(l,r) -> [l .. r]) (intervals xs)
intervals :: forall  a . (Enum a, Ord a) => [a] -> [(a, a)]
intervals :: forall a. (Enum a, Ord a) => [a] -> [(a, a)]
intervals [a]
as = case forall a. Ord a => [a] -> [a]
nubSort [a]
as of
  [] -> []
  (a
first : [a]
ascendingRest) -> (a, a) -> [a] -> [(a, a)]
accumIntervals (a
first, a
first) [a]
ascendingRest
  where
    accumIntervals :: (a, a) -> [a] -> [(a, a)]
    accumIntervals :: (a, a) -> [a] -> [(a, a)]
accumIntervals (a
lower, a
upper) [] = [(a
lower, a
upper)]
    accumIntervals (a
lower, a
upper) (a
first' : [a]
ascendingRest') = if forall a. Enum a => a -> a
succ a
upper forall a. Eq a => a -> a -> Bool
== a
first'
      then (a, a) -> [a] -> [(a, a)]
accumIntervals (a
lower, a
first') [a]
ascendingRest'
      else (a
lower, a
upper) forall a. a -> [a] -> [a]
: (a, a) -> [a] -> [(a, a)]
accumIntervals (a
first', a
first') [a]
ascendingRest'

-- $setup
-- >>> import Data.Int