module Database.Persist.GenericSql.Internal
( Connection (..)
, Statement (..)
, withSqlConn
, withSqlPool
, RowPopper
, mkColumns
, Column (..)
, UniqueDef
, refName
, tableColumns
, rawFieldName
, rawTableName
, RawName (..)
, filterClause
, getFieldName
, dummyFromFilts
, getFiltsValues
, orderClause
) where
import qualified Data.Map as Map
import Data.IORef
import Control.Monad.IO.Class
import Data.Pool
import Database.Persist.Base
import Data.Maybe (fromJust)
import Control.Arrow
import Control.Monad.IO.Control (MonadControlIO)
import Control.Exception.Control (bracket)
import Database.Persist.Util (nullable)
import Data.List (intercalate)
import Data.Text (Text)
type RowPopper m = m (Maybe [PersistValue])
data Connection = Connection
{ prepare :: Text -> IO Statement
, insertSql :: RawName -> [RawName] -> Either Text (Text, Text)
, stmtMap :: IORef (Map.Map Text Statement)
, close :: IO ()
, migrateSql :: forall v. PersistEntity v
=> (Text -> IO Statement) -> v
-> IO (Either [Text] [(Bool, Text)])
, begin :: (Text -> IO Statement) -> IO ()
, commit :: (Text -> IO Statement) -> IO ()
, rollback :: (Text -> IO Statement) -> IO ()
, escapeName :: RawName -> String
, noLimit :: String
}
data Statement = Statement
{ finalize :: IO ()
, reset :: IO ()
, execute :: [PersistValue] -> IO ()
, withStmt :: forall a m. MonadControlIO m
=> [PersistValue] -> (RowPopper m -> m a) -> m a
}
withSqlPool :: MonadControlIO m
=> IO Connection -> Int -> (Pool Connection -> m a) -> m a
withSqlPool mkConn = createPool mkConn close'
withSqlConn :: MonadControlIO m => IO Connection -> (Connection -> m a) -> m a
withSqlConn open = bracket (liftIO open) (liftIO . close')
close' :: Connection -> IO ()
close' conn = do
readIORef (stmtMap conn) >>= mapM_ finalize . Map.elems
close conn
mkColumns :: PersistEntity val => val -> ([Column], [UniqueDef])
mkColumns val =
(cols, uniqs)
where
colNameMap = map ((\(x, _, _) -> x) &&& rawFieldName) $ entityColumns t
uniqs = map (RawName *** map (fromJust . flip lookup colNameMap))
$ entityUniques t
cols = zipWith go (tableColumns t) $ toPersistFields $ halfDefined `asTypeOf` val
t = entityDef val
tn = rawTableName t
go (name, t', as) p =
Column name (nullable as) (sqlType p) (def as) (ref name t' as)
def [] = Nothing
def (('d':'e':'f':'a':'u':'l':'t':'=':d):_) = Just d
def (_:rest) = def rest
ref c t' [] =
let l = length t'
(f, b) = splitAt (l 2) t'
in if b == "Id"
then Just (RawName f, refName tn c)
else Nothing
ref _ _ ("noreference":_) = Nothing
ref c _ (('r':'e':'f':'e':'r':'e':'n':'c':'e':'=':x):_) =
Just (RawName x, refName tn c)
ref c x (_:y) = ref c x y
refName :: RawName -> RawName -> RawName
refName (RawName table) (RawName column) =
RawName $ table ++ '_' : column ++ "_fkey"
data Column = Column
{ cName :: RawName
, cNull :: Bool
, cType :: SqlType
, cDefault :: Maybe String
, cReference :: (Maybe (RawName, RawName))
}
getSqlValue :: [String] -> Maybe String
getSqlValue (('s':'q':'l':'=':x):_) = Just x
getSqlValue (_:x) = getSqlValue x
getSqlValue [] = Nothing
tableColumns :: EntityDef -> [(RawName, String, [String])]
tableColumns = map (\a@(_, y, z) -> (rawFieldName a, y, z)) . entityColumns
type UniqueDef = (RawName, [RawName])
rawFieldName :: (String, String, [String]) -> RawName
rawFieldName (n, _, as) = RawName $
case getSqlValue as of
Just x -> x
Nothing -> n
rawTableName :: EntityDef -> RawName
rawTableName t = RawName $
case getSqlValue $ entityAttribs t of
Nothing -> entityName t
Just x -> x
newtype RawName = RawName { unRawName :: String }
deriving (Eq, Ord)
filterClause :: PersistEntity val
=> Bool
-> Connection -> Filter val -> String
filterClause includeTable conn f =
case (isNull, persistFilterToFilter f, varCount) of
(True, Eq, _) -> name ++ " IS NULL"
(True, Ne, _) -> name ++ " IS NOT NULL"
(False, Ne, _) -> concat
[ "("
, name
, " IS NULL OR "
, name
, "<>?)"
]
(_, In, 0) -> "1=2"
(False, In, _) -> name ++ " IN " ++ qmarks
(True, In, _) -> concat
[ "("
, name
, " IS NULL OR "
, name
, " IN "
, qmarks
, ")"
]
(_, NotIn, 0) -> "1=1"
(False, NotIn, _) -> concat
[ "("
, name
, " IS NULL OR "
, name
, " NOT IN "
, qmarks
, ")"
]
(True, NotIn, _) -> concat
[ "("
, name
, " IS NOT NULL AND "
, name
, " NOT IN "
, qmarks
, ")"
]
_ -> name ++ showSqlFilter (persistFilterToFilter f) ++ "?"
where
isNull = any (== PersistNull)
$ either return id
$ persistFilterToValue f
t = entityDef $ dummyFromFilts [f]
name =
(if includeTable
then (++) (escapeName conn (rawTableName t) ++ ".")
else id)
$ escapeName conn $ getFieldName t $ persistFilterToFieldName f
qmarks = case persistFilterToValue f of
Left _ -> "?"
Right x ->
let x' = filter (/= PersistNull) x
in '(' : intercalate "," (map (const "?") x') ++ ")"
varCount = case persistFilterToValue f of
Left _ -> 1
Right x -> length x
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
dummyFromFilts :: [Filter v] -> v
dummyFromFilts _ = error "dummyFromFilts"
getFieldName :: EntityDef -> String -> RawName
getFieldName t s = rawFieldName $ tableColumn t s
tableColumn :: EntityDef -> String -> (String, String, [String])
tableColumn _ "id" = ("id", "Int64", [])
tableColumn t s = go $ entityColumns t
where
go [] = error $ "Unknown table column: " ++ s
go ((x, y, z):rest)
| x == s = (x, y, z)
| otherwise = go rest
getFiltsValues :: PersistEntity val => [Filter val] -> [PersistValue]
getFiltsValues =
concatMap $ go . persistFilterToValue
where
go (Left PersistNull) = []
go (Left x) = [x]
go (Right xs) = filter (/= PersistNull) xs
dummyFromOrder :: Order a -> a
dummyFromOrder _ = undefined
orderClause :: PersistEntity val => Bool -> Connection -> Order val -> String
orderClause includeTable conn o =
name ++ case persistOrderToOrder o of
Asc -> ""
Desc -> " DESC"
where
t = entityDef $ dummyFromOrder o
name =
(if includeTable
then (++) (escapeName conn (rawTableName t) ++ ".")
else id)
$ escapeName conn $ getFieldName t $ persistOrderToFieldName o