{-# LANGUAGE GADTs, BangPatterns, OverloadedStrings, CPP #-}
module Database.Selda.PostgreSQL.Encoding
( toSqlValue, fromSqlValue, fromSqlType
, readInt, readBool
) where
#ifdef __HASTE__
toSqlValue, fromSqlValue, fromSqlType, readInt :: a
toSqlValue = undefined
fromSqlValue = undefined
fromSqlType = undefined
readInt = undefined
#else
import qualified Data.ByteString as BS
import Data.ByteString.Builder
import Data.ByteString.Char8 (unpack)
import qualified Data.ByteString.Char8 as BSC (map)
import qualified Data.ByteString.Lazy as LBS
import Data.Char (toLower)
import qualified Data.Text as Text
import Data.Text.Encoding
import Database.PostgreSQL.LibPQ (Oid (..), Format (..))
import Database.Selda.Backend
import Unsafe.Coerce
blobType, boolType, intType, int32Type, int16Type, textType, doubleType,
dateType, timeType, timestampType, nameType, varcharType :: Oid
boolType = Oid 16
intType = Oid 20
int32Type = Oid 23
int16Type = Oid 21
textType = Oid 25
nameType = Oid 19
doubleType = Oid 701
dateType = Oid 1082
timeType = Oid 1083
timestampType = Oid 1114
blobType = Oid 17
varcharType = Oid 1043
fromSqlValue :: Lit a -> Maybe (Oid, BS.ByteString, Format)
fromSqlValue (LBool b) = Just (boolType, toBS $ if b then word8 1 else word8 0, Binary)
fromSqlValue (LInt n) = Just (intType, toBS $ int64BE (fromIntegral n), Binary)
fromSqlValue (LDouble f) = Just (doubleType, toBS $ int64BE (unsafeCoerce f), Binary)
fromSqlValue (LText s) = Just (textType, encodeUtf8 $ Text.filter (/= '\0') s, Binary)
fromSqlValue (LDateTime s) = Just (timestampType, encodeUtf8 s, Text)
fromSqlValue (LTime s) = Just (timeType, encodeUtf8 s, Text)
fromSqlValue (LDate s) = Just (dateType, encodeUtf8 s, Text)
fromSqlValue (LBlob b) = Just (blobType, b, Binary)
fromSqlValue (LNull) = Nothing
fromSqlValue (LJust x) = fromSqlValue x
fromSqlValue (LCustom l) = fromSqlValue l
fromSqlType :: SqlTypeRep -> Oid
fromSqlType TBool = boolType
fromSqlType TInt = intType
fromSqlType TFloat = doubleType
fromSqlType TText = textType
fromSqlType TDateTime = timestampType
fromSqlType TDate = dateType
fromSqlType TTime = timeType
fromSqlType TBlob = blobType
fromSqlType TRowID = intType
toSqlValue :: Oid -> BS.ByteString -> SqlValue
toSqlValue t val
| t == boolType = SqlBool $ readBool (BSC.map toLower val)
| t == intType = SqlInt $ readInt val
| t == int32Type = SqlInt $ readInt val
| t == int16Type = SqlInt $ readInt val
| t == doubleType = SqlFloat $ read (unpack val)
| t == blobType = SqlBlob $ pgDecode val
| t `elem` textish = SqlString (decodeUtf8 val)
| otherwise = error $ "BUG: result with unknown type oid: " ++ show t
where
pgDecode s
| BS.index s 0 == 92 && BS.index s 1 == 120 =
BS.pack $ go $ BS.drop 2 s
| otherwise =
error $ "bad blob string from postgres: " ++ show s
where
hex n x =
case BS.index x n of
c | c >= 97 -> c - 87
| c >= 65 -> c - 55
| otherwise -> c - 48
go x
| BS.length x >= 2 = (16*hex 0 x + (hex 1 x)) : go (BS.drop 2 x)
| otherwise = []
textish = [textType, timestampType, timeType, dateType, nameType, varcharType]
readBool :: BS.ByteString -> Bool
readBool "f" = False
readBool "0" = False
readBool "false" = False
readBool "n" = False
readBool "no" = False
readBool "off" = False
readBool _ = True
readInt :: BS.ByteString -> Int
readInt s
| BS.head s == asciiDash = negate $! go 1 0
| otherwise = go 0 0
where
!len = BS.length s
!asciiZero = 48
!asciiDash = 45
go !i !acc
| len > i = go (i+1) (acc * 10 + fromIntegral (BS.index s i - asciiZero))
| otherwise = acc
toBS :: Builder -> BS.ByteString
toBS = unChunk . toLazyByteString
unChunk :: LBS.ByteString -> BS.ByteString
unChunk bs =
case LBS.toChunks bs of
[bs'] -> bs'
bss -> BS.concat bss
#endif