{-# LANGUAGE GADTs, BangPatterns, OverloadedStrings, CPP #-}
-- | Encoding/decoding for PostgreSQL.
module Database.Selda.PostgreSQL.Encoding
  ( toSqlValue, fromSqlValue, fromSqlType, readInt64, readBool
  ) where
#ifdef __HASTE__

toSqlValue, fromSqlValue, fromSqlType, readInt64, readBool :: a
toSqlValue = undefined
fromSqlValue = undefined
fromSqlType = undefined
readInt64 = undefined
readBool = undefined

#else

import Control.Applicative ((<|>))
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Char (toLower)
import qualified Data.Text as T
import Data.Time (utc, localToUTCTimeOfDay)
import Database.PostgreSQL.LibPQ (Oid (..), Format (Binary))
import Database.Selda.Backend
import PostgreSQL.Binary.Encoding as Enc
import PostgreSQL.Binary.Decoding as Dec
import qualified Data.UUID.Types as UUID (toByteString)
import Data.Int (Int16, Int32, Int64)

-- | OIDs for all types used by Selda.
blobType, boolType, intType, int32Type, int16Type, textType, doubleType,
  dateType, timeType, timestampType, nameType, varcharType, uuidType,
  jsonbType :: Oid
boolType :: Oid
boolType      = CUInt -> Oid
Oid CUInt
16
intType :: Oid
intType       = CUInt -> Oid
Oid CUInt
20
int32Type :: Oid
int32Type     = CUInt -> Oid
Oid CUInt
23
int16Type :: Oid
int16Type     = CUInt -> Oid
Oid CUInt
21
textType :: Oid
textType      = CUInt -> Oid
Oid CUInt
25
nameType :: Oid
nameType      = CUInt -> Oid
Oid CUInt
19
doubleType :: Oid
doubleType    = CUInt -> Oid
Oid CUInt
701
dateType :: Oid
dateType      = CUInt -> Oid
Oid CUInt
1082
timeType :: Oid
timeType      = CUInt -> Oid
Oid CUInt
1266
timestampType :: Oid
timestampType = CUInt -> Oid
Oid CUInt
1184
blobType :: Oid
blobType      = CUInt -> Oid
Oid CUInt
17
varcharType :: Oid
varcharType   = CUInt -> Oid
Oid CUInt
1043
uuidType :: Oid
uuidType      = CUInt -> Oid
Oid CUInt
2950
jsonbType :: Oid
jsonbType     = CUInt -> Oid
Oid CUInt
3802

bytes :: Enc.Encoding -> BS.ByteString
bytes :: Encoding -> ByteString
bytes = Encoding -> ByteString
Enc.encodingBytes

-- | Convert a parameter into an postgres parameter triple.
fromSqlValue :: Lit a -> Maybe (Oid, BS.ByteString, Format)
fromSqlValue :: forall a. Lit a -> Maybe (Oid, ByteString, Format)
fromSqlValue (LBool Bool
b)     = forall a. a -> Maybe a
Just (Oid
boolType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ Bool -> Encoding
Enc.bool Bool
b, Format
Binary)
fromSqlValue (LInt64 Int64
n)    = forall a. a -> Maybe a
Just ( Oid
intType
                                  , Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ Int64 -> Encoding
Enc.int8_int64 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n
                                  , Format
Binary)
fromSqlValue (LInt32 Int32
n)    = forall a. a -> Maybe a
Just ( Oid
int32Type
                                  , Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ Int32 -> Encoding
Enc.int4_int32 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
n
                                  , Format
Binary)
fromSqlValue (LDouble Double
f)   = forall a. a -> Maybe a
Just (Oid
doubleType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ Double -> Encoding
Enc.float8 Double
f, Format
Binary)
fromSqlValue (LText Text
s)     = forall a. a -> Maybe a
Just (Oid
textType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ Text -> Encoding
Enc.text_strict Text
s, Format
Binary)
fromSqlValue (LDateTime UTCTime
t) = forall a. a -> Maybe a
Just ( Oid
timestampType
                                  , Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ UTCTime -> Encoding
Enc.timestamptz_int UTCTime
t
                                  , Format
Binary)
fromSqlValue (LTime TimeOfDay
t)     = forall a. a -> Maybe a
Just (Oid
timeType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ (TimeOfDay, TimeZone) -> Encoding
Enc.timetz_int (TimeOfDay
t, TimeZone
utc), Format
Binary)
fromSqlValue (LDate Day
d)     = forall a. a -> Maybe a
Just (Oid
dateType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ Day -> Encoding
Enc.date Day
d, Format
Binary)
fromSqlValue (LUUID UUID
x)     = forall a. a -> Maybe a
Just (Oid
uuidType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ UUID -> Encoding
Enc.uuid UUID
x, Format
Binary)
fromSqlValue (LBlob ByteString
b)     = forall a. a -> Maybe a
Just (Oid
blobType, Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ ByteString -> Encoding
Enc.bytea_strict ByteString
b, Format
Binary)
fromSqlValue (Lit a
LNull)       = forall a. Maybe a
Nothing
fromSqlValue (LJust Lit a1
x)     = forall a. Lit a -> Maybe (Oid, ByteString, Format)
fromSqlValue Lit a1
x
fromSqlValue (LCustom SqlTypeRep
TJSON (LBlob ByteString
b)) = forall a. a -> Maybe a
Just ( Oid
jsonbType
                                              , Encoding -> ByteString
bytes forall a b. (a -> b) -> a -> b
$ ByteString -> Encoding
Enc.jsonb_bytes ByteString
b
                                              , Format
Binary)
fromSqlValue (LCustom SqlTypeRep
_ Lit a1
l) = forall a. Lit a -> Maybe (Oid, ByteString, Format)
fromSqlValue Lit a1
l

-- | Get the corresponding OID for an SQL type representation.
fromSqlType :: SqlTypeRep -> Oid
fromSqlType :: SqlTypeRep -> Oid
fromSqlType SqlTypeRep
TBool     = Oid
boolType
fromSqlType SqlTypeRep
TInt64    = Oid
intType
fromSqlType SqlTypeRep
TInt32    = Oid
int32Type
fromSqlType SqlTypeRep
TFloat    = Oid
doubleType
fromSqlType SqlTypeRep
TText     = Oid
textType
fromSqlType SqlTypeRep
TDateTime = Oid
timestampType
fromSqlType SqlTypeRep
TDate     = Oid
dateType
fromSqlType SqlTypeRep
TTime     = Oid
timeType
fromSqlType SqlTypeRep
TBlob     = Oid
blobType
fromSqlType SqlTypeRep
TRowID    = Oid
intType
fromSqlType SqlTypeRep
TUUID     = Oid
uuidType
fromSqlType SqlTypeRep
TJSON     = Oid
jsonbType

-- | Convert the given postgres return value and type to an @SqlValue@.
toSqlValue :: Oid -> BS.ByteString -> SqlValue
toSqlValue :: Oid -> ByteString -> SqlValue
toSqlValue Oid
t ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
boolType      = Bool -> SqlValue
SqlBool    forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse Value Bool
Dec.bool ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
intType       = Int64 -> SqlValue
SqlInt64   forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse (forall a. (Integral a, Bits a) => Value a
Dec.int :: Value Int64) ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
int32Type     = Int32 -> SqlValue
SqlInt32   forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse (forall a. (Integral a, Bits a) => Value a
Dec.int :: Value Int32) ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
int16Type     = Int32 -> SqlValue
SqlInt32   forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse (forall a. (Integral a, Bits a) => Value a
Dec.int :: Value Int16) ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
doubleType    = Double -> SqlValue
SqlFloat   forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse Value Double
Dec.float8 ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
blobType      = ByteString -> SqlValue
SqlBlob    forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse Value ByteString
Dec.bytea_strict ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
uuidType      = ByteString -> SqlValue
SqlBlob    forall a b. (a -> b) -> a -> b
$ UUID -> ByteString
uuid2bs forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse Value UUID
Dec.uuid ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
timestampType = UTCTime -> SqlValue
SqlUTCTime forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse BinaryParser UTCTime
parseTimestamp ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
timeType      = TimeOfDay -> SqlValue
SqlTime    forall a b. (a -> b) -> a -> b
$ (TimeOfDay, TimeZone) -> TimeOfDay
toTime forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse BinaryParser (TimeOfDay, TimeZone)
parseTime ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
dateType      = Day -> SqlValue
SqlDate    forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse Value Day
Dec.date ByteString
val
  | Oid
t forall a. Eq a => a -> a -> Bool
== Oid
jsonbType     = ByteString -> SqlValue
SqlBlob    forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse (forall a. (ByteString -> Either Text a) -> Value a
Dec.jsonb_bytes forall (f :: * -> *) a. Applicative f => a -> f a
pure) ByteString
val
  | Oid
t forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Oid]
textish   = Text -> SqlValue
SqlString  forall a b. (a -> b) -> a -> b
$ forall a. Value a -> ByteString -> a
parse Value Text
Dec.text_strict ByteString
val
  | Bool
otherwise          = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"BUG: result with unknown type oid: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Oid
t
  where
    parseTimestamp :: BinaryParser UTCTime
parseTimestamp = BinaryParser UTCTime
Dec.timestamptz_int forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> BinaryParser UTCTime
Dec.timestamptz_float
    parseTime :: BinaryParser (TimeOfDay, TimeZone)
parseTime = BinaryParser (TimeOfDay, TimeZone)
Dec.timetz_int forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> BinaryParser (TimeOfDay, TimeZone)
Dec.timetz_float
    toTime :: (TimeOfDay, TimeZone) -> TimeOfDay
toTime (TimeOfDay
tod, TimeZone
tz) = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ TimeZone -> TimeOfDay -> (Integer, TimeOfDay)
localToUTCTimeOfDay TimeZone
tz TimeOfDay
tod
    uuid2bs :: UUID -> ByteString
uuid2bs = ByteString -> ByteString
LBS.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. UUID -> ByteString
UUID.toByteString
    textish :: [Oid]
textish = [Oid
textType, Oid
nameType, Oid
varcharType]

parse :: Value a -> BS.ByteString -> a
parse :: forall a. Value a -> ByteString -> a
parse Value a
p ByteString
x =
  case forall a. Value a -> ByteString -> Either Text a
valueParser Value a
p ByteString
x of
    Right a
x' -> a
x'
    Left Text
_   -> forall a. HasCallStack => [Char] -> a
error [Char]
"unable to decode value"

-- | Read an Int from a binary encoded pgint8.
readInt64 :: BS.ByteString -> Int64
readInt64 :: ByteString -> Int64
readInt64 = forall a. Value a -> ByteString -> a
parse (forall a. (Integral a, Bits a) => Value a
Dec.int :: Value Int64)

readBool :: T.Text -> Bool
readBool :: Text -> Bool
readBool = forall {a}. (Eq a, IsString a) => a -> Bool
go forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Char) -> Text -> Text
T.map Char -> Char
toLower
  where
    go :: a -> Bool
go a
"f"     = Bool
False
    go a
"0"     = Bool
False
    go a
"false" = Bool
False
    go a
"n"     = Bool
False
    go a
"no"    = Bool
False
    go a
"off"   = Bool
False
    go a
_       = Bool
True

#endif