module Database.Persist.GenericSql.Internal
( Connection (..)
, Statement (..)
, withSqlConn
, withSqlPool
, RowPopper
, mkColumns
, Column (..)
, UniqueDef'
, refName
, tableColumns
, rawFieldName
, rawTableName
, rawTableIdName
, RawName (..)
, filterClause
, filterClauseNoWhere
, getFieldName
, dummyFromFilts
, orderClause
, getFiltsValues
) where
import qualified Data.Map as Map
import Data.IORef
import Control.Monad.IO.Class
import Data.Pool
import Database.Persist.Base
import Data.Maybe (fromMaybe)
import Control.Arrow
import Control.Monad.IO.Control (MonadControlIO)
import Control.Exception.Control (bracket)
import Database.Persist.Util (nullable)
import Data.List (intercalate)
import Data.Text (Text)
type RowPopper m = m (Maybe [PersistValue])
data Connection = Connection
{ prepare :: Text -> IO Statement
, insertSql :: RawName -> [RawName] -> Either Text (Text, Text)
, stmtMap :: IORef (Map.Map Text Statement)
, close :: IO ()
, migrateSql :: forall v. PersistEntity v
=> (Text -> IO Statement) -> v
-> IO (Either [Text] [(Bool, Text)])
, begin :: (Text -> IO Statement) -> IO ()
, commitC :: (Text -> IO Statement) -> IO ()
, rollbackC :: (Text -> IO Statement) -> IO ()
, escapeName :: RawName -> String
, noLimit :: String
}
data Statement = Statement
{ finalize :: IO ()
, reset :: IO ()
, execute :: [PersistValue] -> IO ()
, withStmt :: forall a m. MonadControlIO m
=> [PersistValue] -> (RowPopper m -> m a) -> m a
}
withSqlPool :: MonadControlIO m
=> IO Connection -> Int -> (Pool Connection -> m a) -> m a
withSqlPool mkConn = createPool mkConn close'
withSqlConn :: MonadControlIO m => IO Connection -> (Connection -> m a) -> m a
withSqlConn open = bracket (liftIO open) (liftIO . close')
close' :: Connection -> IO ()
close' conn = do
readIORef (stmtMap conn) >>= mapM_ finalize . Map.elems
close conn
mkColumns :: PersistEntity val => val -> ([Column], [UniqueDef'])
mkColumns val =
(cols, uniqs)
where
colNameMap = map (columnName &&& rawFieldName) $ entityColumns t
uniqs = map (RawName *** map (unjustLookup colNameMap))
$ map (uniqueName &&& uniqueColumns)
$ entityUniques t
cols = zipWith go (tableColumns t) $ toPersistFields $ halfDefined `asTypeOf` val
unjustLookup m a = fromMaybe (error $ "Column not found: " ++ a)
$ lookup a m
t = entityDef val
tn = rawTableName t
go (name, t', as) p =
Column name (nullable as) (sqlType p) (def as) (ref name t' as)
def [] = Nothing
def (('d':'e':'f':'a':'u':'l':'t':'=':d):_) = Just d
def (_:rest) = def rest
ref c t' [] =
let l = length t'
(f, b) = splitAt (l 2) t'
in if b == "Id"
then Just (RawName f, refName tn c)
else Nothing
ref _ _ ("noreference":_) = Nothing
ref c _ (('r':'e':'f':'e':'r':'e':'n':'c':'e':'=':x):_) =
Just (RawName x, refName tn c)
ref c x (_:y) = ref c x y
refName :: RawName -> RawName -> RawName
refName (RawName table) (RawName column) =
RawName $ table ++ '_' : column ++ "_fkey"
data Column = Column
{ cName :: RawName
, cNull :: Bool
, cType :: SqlType
, cDefault :: Maybe String
, cReference :: (Maybe (RawName, RawName))
}
getSqlValue :: [String] -> Maybe String
getSqlValue (('s':'q':'l':'=':x):_) = Just x
getSqlValue (_:x) = getSqlValue x
getSqlValue [] = Nothing
getIdNameValue :: [String] -> Maybe String
getIdNameValue (('i':'d':'=':x):_) = Just x
getIdNameValue (_:x) = getIdNameValue x
getIdNameValue [] = Nothing
tableColumns :: EntityDef -> [(RawName, String, [String])]
tableColumns = map (\a@(ColumnDef _ y z) -> (rawFieldName a, y, z)) . entityColumns
type UniqueDef' = (RawName, [RawName])
rawFieldName :: ColumnDef -> RawName
rawFieldName (ColumnDef n _ as) = RawName $
case getSqlValue as of
Just x -> x
Nothing -> n
rawTableName :: EntityDef -> RawName
rawTableName t = RawName $
case getSqlValue $ entityAttribs t of
Nothing -> entityName t
Just x -> x
rawTableIdName :: EntityDef -> RawName
rawTableIdName t = RawName $
case getIdNameValue $ entityAttribs t of
Nothing -> "id"
Just x -> x
newtype RawName = RawName { unRawName :: String }
deriving (Eq, Ord)
getFiltsValues :: forall val. PersistEntity val => Connection -> [Filter val] -> [PersistValue]
getFiltsValues conn = snd . filterClauseHelper False False conn
filterClause :: PersistEntity val
=> Bool
-> Connection -> [Filter val] -> String
filterClause b c = fst . filterClauseHelper b True c
filterClauseNoWhere :: PersistEntity val
=> Bool
-> Connection -> [Filter val] -> String
filterClauseNoWhere b c = fst . filterClauseHelper b False c
filterClauseHelper :: PersistEntity val
=> Bool
-> Bool
-> Connection -> [Filter val] -> (String, [PersistValue])
filterClauseHelper includeTable includeWhere conn filters =
(if not (null sql) && includeWhere
then " WHERE " ++ sql
else sql, vals)
where
(sql, vals) = combineAND filters
combineAND = combine " AND "
combine s fs =
(intercalate s $ map wrapP a, concat b)
where
(a, b) = unzip $ map go fs
wrapP x = concat ["(", x, ")"]
go (FilterAnd fs) = combineAND fs
go (FilterOr fs) = combine " OR " fs
go (Filter field value pfilter) =
case (isNull, pfilter, varCount) of
(True, Eq, _) -> (name ++ " IS NULL", [])
(True, Ne, _) -> (name ++ " IS NOT NULL", [])
(False, Ne, _) -> (name ++ " IS NULL", [])
(_, In, 0) -> ("1=2", [])
(False, In, _) -> (name ++ " IN " ++ qmarks, allVals)
(True, In, _) -> (concat
[ "("
, name
, " IS NULL OR "
, name
, " IN "
, qmarks
, ")"
], notNullVals)
(_, NotIn, 0) -> ("1=1", [])
(False, NotIn, _) -> (concat
[ "("
, name
, " IS NULL OR "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
(True, NotIn, _) -> (concat
[ "("
, name
, " IS NOT NULL AND "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
_ -> (name ++ showSqlFilter pfilter ++ "?", allVals)
where
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
filterValueToPersistValues v = map toPersistValue $ either return id v
isNull = any (== PersistNull) allVals
notNullVals = filter (/= PersistNull) allVals
allVals = filterValueToPersistValues value
t = entityDef $ dummyFromFilts [Filter field value pfilter]
name =
(if includeTable
then (++) (escapeName conn (rawTableName t) ++ ".")
else id)
$ escapeName conn $ getFieldName t $ columnName $ persistColumnDef field
qmarks = case value of
Left _ -> "?"
Right x ->
let x' = filter (/= PersistNull) $ map toPersistValue x
in '(' : intercalate "," (map (const "?") x') ++ ")"
varCount = case value of
Left _ -> 1
Right x -> length x
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
showSqlFilter (BackendSpecificFilter s) = s
dummyFromFilts :: [Filter v] -> v
dummyFromFilts _ = error "dummyFromFilts"
getFieldName :: EntityDef -> String -> RawName
getFieldName t s = rawFieldName $ tableColumn t s
tableColumn :: EntityDef -> String -> ColumnDef
tableColumn t s | s == id_ = ColumnDef id_ "Int64" []
where id_ = unRawName $ rawTableIdName t
tableColumn t s = go $ entityColumns t
where
go [] = error $ "Unknown table column: " ++ s
go (ColumnDef x y z:rest)
| x == s = ColumnDef x y z
| otherwise = go rest
dummyFromOrder :: SelectOpt a -> a
dummyFromOrder _ = undefined
orderClause :: PersistEntity val => Bool -> Connection -> SelectOpt val -> String
orderClause includeTable conn o =
case o of
Asc x -> name x
Desc x -> name x ++ " DESC"
_ -> error $ "expected Asc or Desc, not limit or offset"
where
cd x = persistColumnDef x
t = entityDef $ dummyFromOrder o
name x =
(if includeTable
then (++) (escapeName conn (rawTableName t) ++ ".")
else id)
$ escapeName conn $ getFieldName t $ columnName $ cd x