module Database.Persist.GenericSql
( SqlPersist (..)
, Connection
, ConnectionPool
, Statement
, runSqlConn
, runSqlPool
, Key (..)
, rawSql
, Entity(..)
, Single(..)
, Migration
, parseMigration
, parseMigration'
, printMigration
, getMigration
, runMigration
, runMigrationSilent
, runMigrationUnsafe
, migrate
, commit
, rollback
) where
import qualified Prelude as P
import Prelude hiding ((++), unlines, concat, show)
import Control.Applicative ((<$>), (<*>))
import Control.Arrow ((&&&))
import Database.Persist.Store
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader
import Data.Conduit.Pool
import Database.Persist.GenericSql.Internal
import Database.Persist.GenericSql.Migration
import qualified Database.Persist.GenericSql.Raw as R
import Database.Persist.GenericSql.Raw (SqlPersist (..))
import Control.Monad.Trans.Control (MonadBaseControl, control)
import qualified Control.Exception as E
import Control.Exception (throw)
import Data.Text (Text, pack, unpack, concat)
import qualified Data.Text as T
import Web.PathPieces (PathPiece (..))
import qualified Data.Text.Read
import Data.Monoid (Monoid, mappend)
import Database.Persist.EntityDef
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
type ConnectionPool = Pool Connection
instance PathPiece (Key SqlPersist entity) where
toPathPiece (Key (PersistInt64 i)) = toPathPiece i
toPathPiece k = throw $ PersistInvalidField $ "Invalid Key: " ++ show k
fromPathPiece t =
case Data.Text.Read.signed Data.Text.Read.decimal t of
Right (i, "") -> Just $ Key $ PersistInt64 i
_ -> Nothing
execute' :: MonadIO m => Text -> [PersistValue] -> SqlPersist m ()
execute' = R.execute
runSqlPool :: (MonadBaseControl IO m, MonadIO m) => SqlPersist m a -> Pool Connection -> m a
runSqlPool r pconn = withResource pconn $ runSqlConn r
runSqlConn :: (MonadBaseControl IO m, MonadIO m) => SqlPersist m a -> Connection -> m a
runSqlConn (SqlPersist r) conn = do
let getter = R.getStmt' conn
liftIO $ begin conn getter
x <- onException
(runReaderT r conn)
(liftIO $ rollbackC conn getter)
liftIO $ commitC conn getter
return x
instance (MonadBaseControl IO m, MonadIO m, C.MonadThrow m, C.MonadUnsafeIO m) => PersistStore SqlPersist m where
insert val = do
conn <- SqlPersist ask
let esql = insertSql conn (entityDB t) (map fieldDB $ entityFields t)
i <-
case esql of
Left sql -> C.runResourceT $ R.withStmt sql vals C.$$ do
x <- CL.head
case x of
Just [PersistInt64 i] -> return i
Nothing -> error $ "SQL insert did not return a result giving the generated ID"
Just vals' -> error $ "Invalid result from a SQL insert, got: " P.++ P.show vals'
Right (sql1, sql2) -> do
execute' sql1 vals
C.runResourceT $ R.withStmt sql2 [] C.$$ do
Just [PersistInt64 i] <- CL.head
return i
return $ Key $ PersistInt64 i
where
t = entityDef val
vals = map toPersistValue $ toPersistFields val
replace k val = do
conn <- SqlPersist ask
let t = entityDef val
let sql = concat
[ "UPDATE "
, escapeName conn (entityDB t)
, " SET "
, T.intercalate "," (map (go conn . fieldDB) $ entityFields t)
, " WHERE id=?"
]
execute' sql $ map toPersistValue (toPersistFields val)
`mappend` [unKey k]
where
go conn x = escapeName conn x ++ "=?"
insertKey = insrepHelper "INSERT"
repsert key value = do
delete key
insertKey key value
get k = do
conn <- SqlPersist ask
let t = entityDef $ dummyFromKey k
let cols = T.intercalate ","
$ map (escapeName conn . fieldDB) $ entityFields t
let sql = concat
[ "SELECT "
, cols
, " FROM "
, escapeName conn $ entityDB t
, " WHERE id=?"
]
C.runResourceT $ R.withStmt sql [unKey k] C.$$ do
res <- CL.head
case res of
Nothing -> return Nothing
Just vals ->
case fromPersistValues vals of
Left e -> error $ unpack $ "get " ++ show (unKey k) ++ ": " ++ e
Right v -> return $ Just v
delete k = do
conn <- SqlPersist ask
execute' (sql conn) [unKey k]
where
t = entityDef $ dummyFromKey k
sql conn = concat
[ "DELETE FROM "
, escapeName conn $ entityDB t
, " WHERE id=?"
]
insrepHelper :: (MonadIO m, PersistEntity val)
=> Text
-> Key SqlPersist val
-> val
-> SqlPersist m ()
insrepHelper command (Key k) val = do
conn <- SqlPersist ask
execute' (sql conn) vals
where
t = entityDef val
sql conn = concat
[ command
, " INTO "
, escapeName conn (entityDB t)
, "("
, T.intercalate ","
$ map (escapeName conn)
$ entityID t : map fieldDB (entityFields t)
, ") VALUES("
, T.intercalate "," ("?" : map (const "?") (entityFields t))
, ")"
]
vals = k : map toPersistValue (toPersistFields val)
instance (MonadBaseControl IO m, C.MonadUnsafeIO m, MonadIO m, C.MonadThrow m) => PersistUnique SqlPersist m where
deleteBy uniq = do
conn <- SqlPersist ask
execute' (sql conn) $ persistUniqueToValues uniq
where
t = entityDef $ dummyFromUnique uniq
go = map snd . persistUniqueToFieldNames
go' conn x = escapeName conn x ++ "=?"
sql conn = concat
[ "DELETE FROM "
, escapeName conn $ entityDB t
, " WHERE "
, T.intercalate " AND " $ map (go' conn) $ go uniq
]
getBy uniq = do
conn <- SqlPersist ask
let cols = T.intercalate "," $ (escapeName conn $ entityID t)
: map (escapeName conn . fieldDB) (entityFields t)
let sql = concat
[ "SELECT "
, cols
, " FROM "
, escapeName conn $ entityDB t
, " WHERE "
, sqlClause conn
]
C.runResourceT $ R.withStmt sql (persistUniqueToValues uniq) C.$$ do
row <- CL.head
case row of
Nothing -> return Nothing
Just (PersistInt64 k:vals) ->
case fromPersistValues vals of
Left s -> error $ unpack s
Right x -> return $ Just (Entity (Key $ PersistInt64 k) x)
Just _ -> error "Database.Persist.GenericSql: Bad list in getBy"
where
sqlClause conn =
T.intercalate " AND " $ map (go conn) $ toFieldNames' uniq
go conn x = escapeName conn x ++ "=?"
t = entityDef $ dummyFromUnique uniq
toFieldNames' = map snd . persistUniqueToFieldNames
dummyFromKey :: Key SqlPersist v -> v
dummyFromKey _ = error "dummyFromKey"
dummyFromUnique :: Unique v b -> v
dummyFromUnique _ = error "dummyFromUnique"
#if MIN_VERSION_monad_control(0, 3, 0)
onException :: MonadBaseControl IO m => m α -> m β -> m α
onException m what = control $ \runInIO ->
E.onException (runInIO m)
(runInIO what)
#endif
infixr 5 ++
(++) :: Text -> Text -> Text
(++) = mappend
show :: Show a => a -> Text
show = pack . P.show
newtype Single a = Single {unSingle :: a}
deriving (Eq, Ord, Show, Read)
rawSql :: (RawSql a, C.MonadUnsafeIO m, C.MonadThrow m, MonadIO m, MonadBaseControl IO m) =>
Text
-> [PersistValue]
-> SqlPersist m [a]
rawSql stmt = run
where
getType :: (x -> SqlPersist m [a]) -> a
getType = undefined
x = getType run
process = rawSqlProcessRow
withStmt' colSubsts = R.withStmt $ T.concat $
makeSubsts colSubsts $
T.splitOn placeholder stmt
where
placeholder = "??"
makeSubsts (s:ss) (t:ts) = t : s : makeSubsts ss ts
makeSubsts [] [] = []
makeSubsts [] ts = [T.intercalate placeholder ts]
makeSubsts ss [] = error (P.concat err)
where
err = [ "rawsql: there are still ", P.show (length ss)
, "'??' placeholder substitutions to be made "
, "but all '??' placeholders have already been "
, "consumed. Please read 'rawSql's documentation "
, "on how '??' placeholders work."
]
run params = do
conn <- SqlPersist ask
let (colCount, colSubsts) = rawSqlCols (escapeName conn) x
C.runResourceT $ withStmt' colSubsts params C.$$ firstRow colCount
firstRow colCount = do
mrow <- CL.head
case mrow of
Nothing -> return []
Just row
| colCount == length row -> getter mrow
| otherwise -> fail $ P.concat
[ "rawSql: wrong number of columns, got "
, P.show (length row), " but expected ", P.show colCount
, " (", rawSqlColCountReason x, ")." ]
getter = go id
where
go acc Nothing = return (acc [])
go acc (Just row) =
case process row of
Left err -> fail (T.unpack err)
Right r -> CL.head >>= go (acc . (r:))
class RawSql a where
rawSqlCols :: (DBName -> Text) -> a -> (Int, [Text])
rawSqlColCountReason :: a -> String
rawSqlProcessRow :: [PersistValue] -> Either Text a
instance PersistField a => RawSql (Single a) where
rawSqlCols _ _ = (1, [])
rawSqlColCountReason _ = "one column for a 'Single' data type"
rawSqlProcessRow [pv] = Single <$> fromPersistValue pv
rawSqlProcessRow _ = Left "RawSql (Single a): wrong number of columns."
instance PersistEntity a => RawSql (Entity a) where
rawSqlCols escape = ((+1).length.entityFields &&& process) . entityDef . entityVal
where
process ed = (:[]) $
T.intercalate ", " $
map ((name ed ++) . escape) $
(entityID ed:) $
map fieldDB $
entityFields ed
name ed = escape (entityDB ed) ++ "."
rawSqlColCountReason a =
case fst (rawSqlCols undefined a) of
1 -> "one column for an 'Entity' data type without fields"
n -> P.show n P.++ " columns for an 'Entity' data type"
rawSqlProcessRow (idCol:ent) = Entity <$> fromPersistValue idCol
<*> fromPersistValues ent
rawSqlProcessRow _ = Left "RawSql (Entity a): wrong number of columns."
instance (RawSql a, RawSql b) => RawSql (a, b) where
rawSqlCols e x = rawSqlCols e (fst x) # rawSqlCols e (snd x)
where (cnta, lsta) # (cntb, lstb) = (cnta + cntb, lsta P.++ lstb)
rawSqlColCountReason x = rawSqlColCountReason (fst x) P.++ ", " P.++
rawSqlColCountReason (snd x)
rawSqlProcessRow =
let x = getType processRow
getType :: (z -> Either y x) -> x
getType = undefined
colCountFst = fst $ rawSqlCols undefined (fst x)
processRow row =
let (rowFst, rowSnd) = splitAt colCountFst row
in (,) <$> rawSqlProcessRow rowFst
<*> rawSqlProcessRow rowSnd
in colCountFst `seq` processRow
instance (RawSql a, RawSql b, RawSql c) => RawSql (a, b, c) where
rawSqlCols e = rawSqlCols e . from3
rawSqlColCountReason = rawSqlColCountReason . from3
rawSqlProcessRow = fmap to3 . rawSqlProcessRow
from3 :: (a,b,c) -> ((a,b),c)
from3 (a,b,c) = ((a,b),c)
to3 :: ((a,b),c) -> (a,b,c)
to3 ((a,b),c) = (a,b,c)
instance (RawSql a, RawSql b, RawSql c, RawSql d) => RawSql (a, b, c, d) where
rawSqlCols e = rawSqlCols e . from4
rawSqlColCountReason = rawSqlColCountReason . from4
rawSqlProcessRow = fmap to4 . rawSqlProcessRow
from4 :: (a,b,c,d) -> ((a,b),(c,d))
from4 (a,b,c,d) = ((a,b),(c,d))
to4 :: ((a,b),(c,d)) -> (a,b,c,d)
to4 ((a,b),(c,d)) = (a,b,c,d)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e)
=> RawSql (a, b, c, d, e) where
rawSqlCols e = rawSqlCols e . from5
rawSqlColCountReason = rawSqlColCountReason . from5
rawSqlProcessRow = fmap to5 . rawSqlProcessRow
from5 :: (a,b,c,d,e) -> ((a,b),(c,d),e)
from5 (a,b,c,d,e) = ((a,b),(c,d),e)
to5 :: ((a,b),(c,d),e) -> (a,b,c,d,e)
to5 ((a,b),(c,d),e) = (a,b,c,d,e)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f)
=> RawSql (a, b, c, d, e, f) where
rawSqlCols e = rawSqlCols e . from6
rawSqlColCountReason = rawSqlColCountReason . from6
rawSqlProcessRow = fmap to6 . rawSqlProcessRow
from6 :: (a,b,c,d,e,f) -> ((a,b),(c,d),(e,f))
from6 (a,b,c,d,e,f) = ((a,b),(c,d),(e,f))
to6 :: ((a,b),(c,d),(e,f)) -> (a,b,c,d,e,f)
to6 ((a,b),(c,d),(e,f)) = (a,b,c,d,e,f)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f,
RawSql g)
=> RawSql (a, b, c, d, e, f, g) where
rawSqlCols e = rawSqlCols e . from7
rawSqlColCountReason = rawSqlColCountReason . from7
rawSqlProcessRow = fmap to7 . rawSqlProcessRow
from7 :: (a,b,c,d,e,f,g) -> ((a,b),(c,d),(e,f),g)
from7 (a,b,c,d,e,f,g) = ((a,b),(c,d),(e,f),g)
to7 :: ((a,b),(c,d),(e,f),g) -> (a,b,c,d,e,f,g)
to7 ((a,b),(c,d),(e,f),g) = (a,b,c,d,e,f,g)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f,
RawSql g, RawSql h)
=> RawSql (a, b, c, d, e, f, g, h) where
rawSqlCols e = rawSqlCols e . from8
rawSqlColCountReason = rawSqlColCountReason . from8
rawSqlProcessRow = fmap to8 . rawSqlProcessRow
from8 :: (a,b,c,d,e,f,g,h) -> ((a,b),(c,d),(e,f),(g,h))
from8 (a,b,c,d,e,f,g,h) = ((a,b),(c,d),(e,f),(g,h))
to8 :: ((a,b),(c,d),(e,f),(g,h)) -> (a,b,c,d,e,f,g,h)
to8 ((a,b),(c,d),(e,f),(g,h)) = (a,b,c,d,e,f,g,h)