module Database.Selda.PostgreSQL.Encoding
( toSqlValue, fromSqlValue, fromSqlType
, readInt
) where
#ifdef __HASTE__
toSqlValue, fromSqlValue, fromSqlType, readInt :: a
toSqlValue = undefined
fromSqlValue = undefined
fromSqlType = undefined
readInt = undefined
#else
import Control.Exception (throw)
import qualified Data.ByteString as BS
import Data.ByteString.Builder
import Data.ByteString.Char8 (unpack)
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Text as Text
import Data.Text.Encoding
import Database.PostgreSQL.LibPQ (Oid (..), Format (..))
import Database.Selda.Backend
import Unsafe.Coerce
blobType, boolType, intType, textType, doubleType, dateType, timeType, timestampType :: Oid
boolType = Oid 16
intType = Oid 20
textType = Oid 25
doubleType = Oid 701
dateType = Oid 1082
timeType = Oid 1083
timestampType = Oid 1114
blobType = Oid 17
fromSqlValue :: Lit a -> Maybe (Oid, BS.ByteString, Format)
fromSqlValue (LBool b) = Just (boolType, toBS $ if b then word8 1 else word8 0, Binary)
fromSqlValue (LInt n) = Just (intType, toBS $ int64BE (fromIntegral n), Binary)
fromSqlValue (LDouble f) = Just (doubleType, toBS $ int64BE (unsafeCoerce f), Binary)
fromSqlValue (LText s) = Just (textType, encodeUtf8 $ Text.filter (/= '\0') s, Binary)
fromSqlValue (LDateTime s) = Just (timestampType, encodeUtf8 s, Text)
fromSqlValue (LTime s) = Just (timeType, encodeUtf8 s, Text)
fromSqlValue (LDate s) = Just (dateType, encodeUtf8 s, Text)
fromSqlValue (LBlob b) = Just (blobType, b, Binary)
fromSqlValue (LNull) = Nothing
fromSqlValue (LJust x) = fromSqlValue x
fromSqlValue (LCustom l) = fromSqlValue l
fromSqlType :: SqlTypeRep -> Oid
fromSqlType TBool = boolType
fromSqlType TInt = intType
fromSqlType TFloat = doubleType
fromSqlType TText = textType
fromSqlType TDateTime = timestampType
fromSqlType TDate = dateType
fromSqlType TTime = timeType
fromSqlType TBlob = blobType
fromSqlType TRowID = intType
toSqlValue :: Oid -> BS.ByteString -> SqlValue
toSqlValue t val
| t == boolType = SqlBool $ readBool val
| t == intType = SqlInt $ readInt val
| t == doubleType = SqlFloat $ read (unpack val)
| t == blobType = SqlBlob $ pgDecode val
| t `elem` textish = SqlString (decodeUtf8 val)
| otherwise = error $ "BUG: result with unknown type oid: " ++ show t
where
pgDecode s
| BS.index s 0 == 92 && BS.index s 1 == 120 =
BS.pack $ go $ BS.drop 2 s
| otherwise =
error $ "bad blob string from postgres: " ++ show s
where
hex n s =
case BS.index s n of
c | c >= 97 -> c 87
| c >= 65 -> c 55
| otherwise -> c 48
go s
| BS.length s >= 2 = (16*hex 0 s + (hex 1 s)) : go (BS.drop 2 s)
| otherwise = []
textish = [textType, timestampType, timeType, dateType]
readBool "f" = False
readBool "0" = False
readBool "false" = False
readBool "n" = False
readBool "no" = False
readBool "off" = False
readBool _ = True
readInt :: BS.ByteString -> Int
readInt s
| BS.head s == asciiDash = negate $! go 1 0
| otherwise = go 0 0
where
!len = BS.length s
!asciiZero = 48
!asciiDash = 45
go !i !acc
| len > i = go (i+1) (acc * 10 + fromIntegral (BS.index s i asciiZero))
| otherwise = acc
toBS :: Builder -> BS.ByteString
toBS = unChunk . toLazyByteString
unChunk :: LBS.ByteString -> BS.ByteString
unChunk bs =
case LBS.toChunks bs of
[bs'] -> bs'
bss -> BS.concat bss
#endif