{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Database.Quibble ( Query , RowsOf , DBWhere(..), DBRender(..) , Expr(..) , SortExpr(..) , HasTable, HasColumn , Inline(..) , query, rowsOf , where_, orderBy, limit, offset , asc, desc , (.&&), (.||), (.==), (./=), true, false , isNull, isNotNull , module Reexport ) where import Data.ByteString ( ByteString ) import Data.ByteString.Builder ( Builder ) import Data.Foldable ( fold ) import Data.Int ( Int16, Int32, Int64 ) import Data.Kind ( Type ) import Data.MonoTraversable ( MonoFoldable, Element, ofoldMap ) import Data.Proxy ( Proxy(..) ) import Data.Sequence ( Seq ) import Data.String ( IsString(..) ) import Data.Time ( UTCTime ) import Data.Time.Format.ISO8601 ( iso8601Show ) import Data.Text ( Text ) import Data.Text.Conversions ( convertText, UTF8(..) ) import Data.UUID ( UUID ) import Data.Word ( Word16, Word32, Word64 ) import Optics.Lens ( Lens', lens ) import Optics.Setter ( over ) import GHC.OverloadedLabels ( IsLabel(..) ) import GHC.TypeLits ( KnownSymbol, Symbol, symbolVal ) import qualified Data.ByteString.Builder as Builder import qualified Data.ByteString.Lazy as BL import qualified Data.Sequence as Seq import qualified Data.Text.Lazy as TL import qualified Data.UUID as UUID import Data.Function as Reexport ( (&) ) -- | The constructor is exposed in case you need to unsafely construct -- | expressions. But you shouldn't rely on it too much. newtype Expr (ctx :: Type) (ty :: Type) = Expr { unExpr :: Builder } -- | You can think of this as a wrapper around `Expr' that adds sorting direction. -- | Unfortunately, we can't allow @ASC NULLS LAST@ or friends here because that's -- | not portable across database. data SortExpr (ctx :: Type) = Asc Builder | Desc Builder | CustomSort ByteString Builder -- | Allows us to lookup table name by type. class KnownSymbol tbl => HasTable (ctx :: Type) (tbl :: Symbol) | ctx -> tbl -- | Allows us to infer the column type from the table type and column name. class KnownSymbol col => HasColumn (ctx :: Type) (col :: Symbol) (ty :: Type) | ctx col -> ty data Query (ctx :: Type) = Query { qryWhereCond :: Maybe (Expr ctx Bool) , qryLimit :: Maybe Word64 , qryOffset :: Maybe Word64 , qrySort :: Seq (SortExpr ctx) } data RowsOf (ctx :: Type) = RowsOf { rowsWhereCond :: Maybe (Expr ctx Bool) } sqlSortExpr :: SortExpr ctx -> Builder sqlSortExpr sort = case sort of Asc expr -> expr <> " ASC" Desc expr -> expr <> " DESC" CustomSort sortDir expr -> expr <> " " <> Builder.byteString sortDir sqlQuery :: Query ctx -> ByteString sqlQuery qry = BL.toStrict $ Builder.toLazyByteString $ mconcat [ case qryWhereCond qry of Nothing -> mempty Just cond -> "\nWHERE " <> unExpr cond , if Seq.null (qrySort qry) then mempty else "\nORDER BY " <> fold (Seq.intersperse ", " $ fmap sqlSortExpr $ qrySort qry) , case qryLimit qry of Nothing -> mempty Just limit -> "\nLIMIT " <> Builder.word64Dec limit , case qryOffset qry of Nothing -> mempty Just offset -> "\nOFFSET " <> Builder.word64Dec offset ] sqlRowsOf :: RowsOf ctx -> ByteString sqlRowsOf rows = case rowsWhereCond rows of Nothing -> "" Just cond -> BL.toStrict $ Builder.toLazyByteString $ "WHERE " <> unExpr cond asc :: Expr ctx ty -> SortExpr ctx asc (Expr bldr) = Asc bldr desc :: Expr ctx ty -> SortExpr ctx desc (Expr bldr) = Desc bldr -- | A `Query' is meant to represent everything in a @SELECT@ statement, other -- | than the columns selection and the joined tables. -- | This is meant to be used with @-XTypeApplications@, like @query \@Foo@. query :: forall ctx. Query ctx query = Query { qryWhereCond = Nothing , qryLimit = Nothing , qryOffset = Nothing , qrySort = Seq.empty } -- | A `RowsOf' is meant to represent the conditions of an @UPDATE@ or @DELETE@. -- | This is meant to be used with @-XTypeApplications@, like @rowsOf \@Foo@. rowsOf :: forall ctx. RowsOf ctx rowsOf = RowsOf { rowsWhereCond = Nothing } class DBWhere qry ctx where whereCond :: Lens' qry (Maybe (Expr ctx Bool)) -- | A convenience function for when the user doesn't want to specify any conditions. allRows :: qry instance (ctx ~ ctx') => DBWhere (Query ctx) ctx' where whereCond = lens qryWhereCond (\qry cond -> qry { qryWhereCond = cond }) allRows = query @ctx instance (ctx ~ ctx') => DBWhere (RowsOf ctx) ctx' where whereCond = lens rowsWhereCond (\rows cond -> rows { rowsWhereCond = cond }) allRows = rowsOf @ctx class DBRender a where renderSQL :: a -> ByteString instance DBRender (Query ctx) where renderSQL = sqlQuery instance DBRender (RowsOf ctx) where renderSQL = sqlRowsOf -- | Add a condition to the `Query' or `RowsOf'. If @where_@ is used multiple times, -- | the condition will be @AND@ed together. -- | -- | > rowsOf @Foo -- | > & where_ (...) where_ :: DBWhere qry ctx => Expr ctx Bool -> qry -> qry where_ cond qry = over whereCond (\existing -> case existing of Nothing -> Just cond Just cond' -> Just (cond' .&& cond)) qry -- | Set the number of rows specified by a @LIMIT@. If @limit@ is used multiple -- | times, only the result of the last call matters. -- | -- | > query @Foo -- | > & limit 50 limit :: Word64 -> Query ctx -> Query ctx limit rows qry = qry { qryLimit = Just rows } -- | Set the number of rows specified by an @OFFSET@. If @offset@ is used -- | multiple times, only the result of the last call matters. -- | -- | > query @Foo -- | > & offset 100 -- | -- | Note that in the general case, your query will still pay the cost of -- | looking up all the rows that were skipped! Be careful when using @OFFSET@. offset :: Word64 -> Query ctx -> Query ctx offset rows qry = qry { qryOffset = Just rows } -- | Specify how the output should be sorted. If @orderBy@ is used multiple times, -- | all sort expressions are concatenated together, with later calls having -- | lower sort precedence. -- | -- | @-XOverloadedLists@ can help make specifying the list of sort expressions -- | simpler. -- | -- | > query @Foo -- | > & orderBy [asc #col1, desc #col2] orderBy :: Seq (SortExpr ctx) -> Query ctx -> Query ctx orderBy sorts qry = qry { qrySort = qrySort qry <> sorts } escapeCharSeq :: (MonoFoldable t, Element t ~ Char) => t -> Builder escapeCharSeq t = mconcat ["E'", ofoldMap escapeChar t, "'"] where escapeChar :: Char -> Builder escapeChar '\\' = "\\\\" escapeChar '\'' = "''" escapeChar '?' = "\\x3F" escapeChar c | fromEnum c > 127 = Builder.int32HexFixed (fromIntegral $ fromEnum c) escapeChar other = Builder.char7 other escapeByteString :: ByteString -> Builder escapeByteString bs = mconcat ["'\\x", Builder.byteStringHex bs, "'"] instance IsString (Expr ctx ByteString) where fromString s = let UTF8 bs = convertText s in Expr (escapeByteString bs) instance IsString (Expr ctx BL.ByteString) where fromString s = let UTF8 bs = convertText s in Expr (escapeByteString bs) instance IsString (Expr ctx Text) where fromString s = Expr (escapeCharSeq s) instance IsString (Expr ctx TL.Text) where fromString s = Expr (escapeCharSeq s) instance IsString (Expr ctx String) where fromString s = Expr (escapeCharSeq s) instance (Show a, Num a) => Num (Expr ctx a) where (+) (Expr l) (Expr r) = Expr (mconcat ["(", l, " + ", r, ")"]) (-) (Expr l) (Expr r) = Expr (mconcat ["(", l, " - ", r, ")"]) (*) (Expr l) (Expr r) = Expr (mconcat ["(", l, " * ", r, ")"]) negate (Expr l) = Expr (mconcat ["-", l]) abs (Expr x) = Expr (mconcat ["ABS(", x, ")"]) signum (Expr x) = Expr (mconcat ["SIGN(", x, ")"]) fromInteger i = let UTF8 bs = convertText $ show (fromInteger i :: a) in Expr (Builder.byteString bs) instance (Show a, Fractional a) => Fractional (Expr ctx a) where (/) (Expr l) (Expr r) = Expr (mconcat ["(", l, " / ", r, ")"]) fromRational r = let UTF8 bs = convertText $ show (fromRational r :: a) in Expr (Builder.byteString bs) -- This instance is not currently Unicode-aware, because we don't use -- the quoted syntax for identifiers. instance ( KnownSymbol tbl , HasTable ctx tbl , HasColumn ctx col ty ) => IsLabel col (Expr ctx ty) where fromLabel = Expr $ mconcat [ Builder.string7 $ symbolVal $ Proxy @tbl , "." , Builder.string7 $ symbolVal $ Proxy @col ] true :: Expr ctx Bool true = Expr "TRUE" false :: Expr ctx Bool false = Expr "FALSE" isNull :: Expr ctx a -> Expr ctx Bool isNull (Expr x) = Expr (mconcat [x, " IS NULL"]) isNotNull :: Expr ctx a -> Expr ctx Bool isNotNull (Expr x) = Expr (mconcat [x, " IS NOT NULL"]) infixl 0 .&& (.&&) :: Expr ctx Bool -> Expr ctx Bool -> Expr ctx Bool (.&&) (Expr l) (Expr r) = Expr (mconcat ["(", l, " AND ", r, ")"]) infixl 1 .|| (.||) :: Expr ctx Bool -> Expr ctx Bool -> Expr ctx Bool (.||) (Expr l) (Expr r) = Expr (mconcat ["(", l, " OR ", r, ")"]) (./=) :: Expr ctx a -> Expr ctx a -> Expr ctx Bool (./=) (Expr l) (Expr r) = Expr (mconcat [l, " <> ", r]) (.==) :: Expr ctx a -> Expr ctx a -> Expr ctx Bool (.==) (Expr l) (Expr r) = Expr (mconcat [l, " = ", r]) class Inline ty where inline :: ty -> Expr ctx ty instance Inline ByteString where inline s = Expr (escapeByteString s) instance Inline BL.ByteString where inline s = Expr (escapeByteString $ BL.toStrict s) instance Inline Text where inline s = Expr (escapeCharSeq s) instance Inline TL.Text where inline s = Expr (escapeCharSeq s) instance Inline String where inline s = Expr (escapeCharSeq s) instance Inline Bool where inline b = if b then true else false instance Inline Int16 where inline n = Expr (Builder.int16Dec n) instance Inline Int32 where inline n = Expr (Builder.int32Dec n) instance Inline Int64 where inline n = Expr (Builder.int64Dec n) instance Inline Word16 where inline n = Expr (Builder.word16Dec n) instance Inline Word32 where inline n = Expr (Builder.word32Dec n) instance Inline Word64 where inline n = Expr (Builder.word64Dec n) instance Inline UUID where inline uuid = Expr $ mconcat ["'", Builder.byteString $ UUID.toASCIIBytes uuid, "'"] instance Inline UTCTime where inline t = Expr $ mconcat ["'", Builder.string7 $ iso8601Show t, "'"]