{-# LANGUAGE GADTs, OverloadedStrings, ScopedTypeVariables, FlexibleInstances #-} {-# LANGUAGE UndecidableInstances, DefaultSignatures, DeriveGeneric #-} -- | Types representable as columns in Selda's subset of SQL. module Database.Selda.SqlType ( SqlType (..), SqlEnum (..) , Lit (..), UUID, RowID, ID, SqlValue (..), SqlTypeRep (..) , invalidRowId, isInvalidRowId, toRowId, fromRowId , fromId, toId, invalidId, isInvalidId, untyped , compLit, litType , sqlDateTimeFormat, sqlDateFormat, sqlTimeFormat ) where import Control.Applicative ((<|>)) import Data.ByteString (ByteString, empty) import qualified Data.ByteString.Lazy as BSL import Data.Maybe (fromJust) import Data.Proxy import Data.Text (Text, pack, unpack) import Data.Time import Data.Typeable import Data.UUID.Types (UUID, toString, fromByteString, nil) import GHC.Generics (Generic) -- | Format string used to represent date and time when -- representing timestamps as text. -- If at all possible, use 'SqlUTCTime' instead. sqlDateTimeFormat :: String sqlDateTimeFormat = "%F %H:%M:%S%Q%z" -- | Format string used to represent date when -- representing dates as text. -- If at all possible, use 'SqlDate' instead. sqlDateFormat :: String sqlDateFormat = "%F" -- | Format string used to represent time of day when -- representing time as text. -- If at all possible, use 'SqlTime' instead. sqlTimeFormat :: String sqlTimeFormat = "%H:%M:%S%Q%z" -- | Representation of an SQL type. data SqlTypeRep = TText | TRowID | TInt | TFloat | TBool | TDateTime | TDate | TTime | TBlob | TUUID | TJSON deriving (Show, Eq, Ord) -- | Any datatype representable in (Selda's subset of) SQL. class Typeable a => SqlType a where -- | Create a literal of this type. mkLit :: a -> Lit a default mkLit :: (Typeable a, SqlEnum a) => a -> Lit a mkLit = LCustom TText . LText . toText -- | The SQL representation for this type. sqlType :: Proxy a -> SqlTypeRep sqlType _ = litType (defaultValue :: Lit a) -- | Convert an SqlValue into this type. fromSql :: SqlValue -> a default fromSql :: (Typeable a, SqlEnum a) => SqlValue -> a fromSql = fromText . fromSql -- | Default value when using 'def' at this type. defaultValue :: Lit a default defaultValue :: (Typeable a, SqlEnum a) => Lit a defaultValue = mkLit (minBound :: a) -- | Any type that's bounded, enumerable and has a text representation, and -- thus representable as a Selda enumerable. -- -- While it would be more efficient to store enumerables as integers, this -- makes hand-rolled SQL touching the values inscrutable, and will break if -- the user a) derives Enum and b) changes the order of their constructors. -- Long-term, this should be implemented in PostgreSQL as a proper enum -- anyway, which mostly renders the performance argument moot. class (Typeable a, Bounded a, Enum a) => SqlEnum a where toText :: a -> Text fromText :: Text -> a instance {-# OVERLAPPABLE #-} (Typeable a, Bounded a, Enum a, Show a, Read a) => SqlEnum a where toText = pack . show fromText = read . unpack -- | An SQL literal. data Lit a where LText :: !Text -> Lit Text LInt :: !Int -> Lit Int LDouble :: !Double -> Lit Double LBool :: !Bool -> Lit Bool LDateTime :: !UTCTime -> Lit UTCTime LDate :: !Day -> Lit Day LTime :: !TimeOfDay -> Lit TimeOfDay LJust :: SqlType a => !(Lit a) -> Lit (Maybe a) LBlob :: !ByteString -> Lit ByteString LNull :: SqlType a => Lit (Maybe a) LCustom :: SqlTypeRep -> Lit a -> Lit b LUUID :: !UUID -> Lit UUID -- | The SQL type representation for the given literal. litType :: Lit a -> SqlTypeRep litType (LText{}) = TText litType (LInt{}) = TInt litType (LDouble{}) = TFloat litType (LBool{}) = TBool litType (LDateTime{}) = TDateTime litType (LDate{}) = TDate litType (LTime{}) = TTime litType (LJust x) = litType x litType (LBlob{}) = TBlob litType (x@LNull) = sqlType (proxyFor x) where proxyFor :: Lit (Maybe a) -> Proxy a proxyFor _ = Proxy litType (LCustom t _) = t litType (LUUID{}) = TUUID instance Eq (Lit a) where a == b = compLit a b == EQ instance Ord (Lit a) where compare = compLit -- | Constructor tag for all literals. Used for Ord instance. litConTag :: Lit a -> Int litConTag (LText{}) = 0 litConTag (LInt{}) = 1 litConTag (LDouble{}) = 2 litConTag (LBool{}) = 3 litConTag (LDateTime{}) = 4 litConTag (LDate{}) = 5 litConTag (LTime{}) = 6 litConTag (LJust{}) = 7 litConTag (LBlob{}) = 8 litConTag (LNull) = 9 litConTag (LCustom{}) = 10 litConTag (LUUID{}) = 11 -- | Compare two literals of different type for equality. compLit :: Lit a -> Lit b -> Ordering compLit (LText x) (LText x') = x `compare` x' compLit (LInt x) (LInt x') = x `compare` x' compLit (LDouble x) (LDouble x') = x `compare` x' compLit (LBool x) (LBool x') = x `compare` x' compLit (LDateTime x) (LDateTime x') = x `compare` x' compLit (LDate x) (LDate x') = x `compare` x' compLit (LTime x) (LTime x') = x `compare` x' compLit (LBlob x) (LBlob x') = x `compare` x' compLit (LJust x) (LJust x') = x `compLit` x' compLit (LCustom _ x) (LCustom _ x') = x `compLit` x' compLit (LUUID x) (LUUID x') = x `compare` x' compLit a b = litConTag a `compare` litConTag b -- | Some value that is representable in SQL. data SqlValue where SqlInt :: !Int -> SqlValue SqlFloat :: !Double -> SqlValue SqlString :: !Text -> SqlValue SqlBool :: !Bool -> SqlValue SqlBlob :: !ByteString -> SqlValue SqlUTCTime :: !UTCTime -> SqlValue SqlTime :: !TimeOfDay -> SqlValue SqlDate :: !Day -> SqlValue SqlNull :: SqlValue instance Show SqlValue where show (SqlInt n) = "SqlInt " ++ show n show (SqlFloat f) = "SqlFloat " ++ show f show (SqlString s) = "SqlString " ++ show s show (SqlBool b) = "SqlBool " ++ show b show (SqlBlob b) = "SqlBlob " ++ show b show (SqlUTCTime t) = "SqlUTCTime " ++ show t show (SqlTime t) = "SqlTime " ++ show t show (SqlDate d) = "SqlDate " ++ show d show (SqlNull) = "SqlNull" instance Show (Lit a) where show (LText s) = show s show (LInt i) = show i show (LDouble d) = show d show (LBool b) = show b show (LDateTime s) = show s show (LDate s) = show s show (LTime s) = show s show (LBlob b) = show b show (LJust x) = "Just " ++ show x show (LNull) = "Nothing" show (LCustom _ l) = show l show (LUUID u) = toString u -- | A row identifier for some table. -- This is the type of auto-incrementing primary keys. newtype RowID = RowID Int deriving (Eq, Ord, Typeable, Generic) instance Show RowID where show (RowID n) = show n -- | A row identifier which is guaranteed to not match any row in any table. invalidRowId :: RowID invalidRowId = RowID (-1) -- | Is the given row identifier invalid? I.e. is it guaranteed to not match any -- row in any table? isInvalidRowId :: RowID -> Bool isInvalidRowId (RowID n) = n < 0 -- | Create a row identifier from an integer. -- Use with caution, preferably only when reading user input. toRowId :: Int -> RowID toRowId = RowID -- | Inspect a row identifier. fromRowId :: RowID -> Int fromRowId (RowID n) = n -- | A typed row identifier. -- Generic tables should use this instead of 'RowID'. -- Use 'untyped' to erase the type of a row identifier, and @cast@ from the -- "Database.Selda.Unsafe" module if you for some reason need to add a type -- to a row identifier. newtype ID a = ID {untyped :: RowID} deriving (Eq, Ord, Typeable, Generic) instance Show (ID a) where show = show . untyped -- | Create a typed row identifier from an integer. -- Use with caution, preferably only when reading user input. toId :: Int -> ID a toId = ID . toRowId -- | Create a typed row identifier from an integer. -- Use with caution, preferably only when reading user input. fromId :: ID a -> Int fromId (ID i) = fromRowId i -- | A typed row identifier which is guaranteed to not match any row in any -- table. invalidId :: ID a invalidId = ID invalidRowId -- | Is the given typed row identifier invalid? I.e. is it guaranteed to not -- match any row in any table? isInvalidId :: ID a -> Bool isInvalidId = isInvalidRowId . untyped instance SqlType RowID where mkLit (RowID n) = LCustom TRowID (LInt n) sqlType _ = TRowID fromSql (SqlInt x) = RowID x fromSql v = error $ "fromSql: RowID column with non-int value: " ++ show v defaultValue = mkLit invalidRowId instance Typeable a => SqlType (ID a) where mkLit (ID n) = LCustom TRowID (mkLit n) sqlType _ = TRowID fromSql = ID . fromSql defaultValue = mkLit (ID invalidRowId) instance SqlType Int where mkLit = LInt sqlType _ = TInt fromSql (SqlInt x) = x fromSql v = error $ "fromSql: int column with non-int value: " ++ show v defaultValue = LInt 0 instance SqlType Double where mkLit = LDouble sqlType _ = TFloat fromSql (SqlFloat x) = x fromSql v = error $ "fromSql: float column with non-float value: " ++ show v defaultValue = LDouble 0 instance SqlType Text where mkLit = LText sqlType _ = TText fromSql (SqlString x) = x fromSql v = error $ "fromSql: text column with non-text value: " ++ show v defaultValue = LText "" instance SqlType Bool where mkLit = LBool sqlType _ = TBool fromSql (SqlBool x) = x fromSql (SqlInt 0) = False fromSql (SqlInt _) = True fromSql v = error $ "fromSql: bool column with non-bool value: " ++ show v defaultValue = LBool False instance SqlType UTCTime where mkLit = LDateTime sqlType _ = TDateTime fromSql (SqlUTCTime t) = t fromSql (SqlString s) = case withWeirdTimeZone sqlDateTimeFormat (unpack s) of Just t -> t _ -> error $ "fromSql: bad datetime string: " ++ unpack s fromSql v = error $ "fromSql: datetime column with non-datetime value: " ++ show v defaultValue = LDateTime $ UTCTime (ModifiedJulianDay 40587) 0 instance SqlType Day where mkLit = LDate sqlType _ = TDate fromSql (SqlDate d) = d fromSql (SqlString s) = case parseTimeM True defaultTimeLocale sqlDateFormat (unpack s) of Just t -> t _ -> error $ "fromSql: bad date string: " ++ unpack s fromSql v = error $ "fromSql: date column with non-date value: " ++ show v defaultValue = LDate $ ModifiedJulianDay 40587 instance SqlType TimeOfDay where mkLit = LTime sqlType _ = TTime fromSql (SqlTime s) = s fromSql (SqlString s) = case withWeirdTimeZone sqlTimeFormat (unpack s) of Just t -> t _ -> error $ "fromSql: bad time string: " ++ unpack s fromSql v = error $ "fromSql: time column with non-time value: " ++ show v defaultValue = LTime $ TimeOfDay 0 0 0 -- | Both PostgreSQL and SQLite to weird things with time zones. -- Long term solution is to use proper binary types internally for -- time values, so this is really just an interim solution. withWeirdTimeZone :: ParseTime t => String -> String -> Maybe t withWeirdTimeZone fmt s = parseTimeM True defaultTimeLocale fmt (s++"00") <|> parseTimeM True defaultTimeLocale fmt s <|> parseTimeM True defaultTimeLocale fmt (s++"+0000") instance SqlType ByteString where mkLit = LBlob sqlType _ = TBlob fromSql (SqlBlob x) = x fromSql v = error $ "fromSql: blob column with non-blob value: " ++ show v defaultValue = LBlob empty instance SqlType BSL.ByteString where mkLit = LCustom TBlob . LBlob . BSL.toStrict sqlType _ = TBlob fromSql (SqlBlob x) = BSL.fromStrict x fromSql v = error $ "fromSql: blob column with non-blob value: " ++ show v defaultValue = LCustom TBlob (LBlob empty) -- | @defaultValue@ for UUIDs is the all-zero RFC4122 nil UUID. instance SqlType UUID where mkLit = LUUID sqlType _ = TUUID fromSql (SqlBlob x) = fromJust . fromByteString $ BSL.fromStrict x fromSql v = error $ "fromSql: UUID column with non-blob value: " ++ show v defaultValue = LUUID nil instance SqlType a => SqlType (Maybe a) where mkLit (Just x) = LJust $ mkLit x mkLit Nothing = LNull sqlType _ = sqlType (Proxy :: Proxy a) fromSql (SqlNull) = Nothing fromSql x = Just $ fromSql x defaultValue = LNull instance SqlType Ordering