module Database.Persist.Join.Sql
( RunJoin (..)
) where
import Database.Persist.Join hiding (RunJoin (..))
import qualified Database.Persist.Join as J
import Database.Persist.Base
import Control.Monad (liftM)
import Data.Maybe (mapMaybe)
import Data.List (intercalate, groupBy)
import Database.Persist.GenericSql (SqlPersist (SqlPersist))
import Database.Persist.GenericSql.Internal hiding (withStmt)
import Database.Persist.GenericSql.Raw (withStmt)
import Control.Monad.Trans.Reader (ask)
import Control.Monad.IO.Control (MonadControlIO)
import Data.Function (on)
import Control.Arrow ((&&&))
import Data.Text (pack)
fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key v, v)
fromPersistValuesId [] = Left "fromPersistValuesId: No values provided"
fromPersistValuesId (i:rest) =
case fromPersistValues rest of
Left e -> Left e
Right x -> Right (toPersistKey i, x)
class RunJoin a where
runJoin :: MonadControlIO m => a -> SqlPersist m (J.Result a)
instance (PersistEntity one, PersistEntity many, Eq (Key one))
=> RunJoin (SelectOneMany one many) where
runJoin = selectOneMany'
selectOneMany' :: (MonadControlIO m,
PersistEntity d,
PersistEntity val1,
PersistEntity val,
PersistEntity b,
Eq (Key b)) =>
SelectOneMany val val1 -> SqlPersist m [((Key b, b), [(Key d, d)])]
selectOneMany' (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do
conn <- SqlPersist ask
liftM go $ withStmt (sql conn) (getFiltsValues oneF ++ getFiltsValues manyF) $ loop id
where
go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])]
go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst))
loop front popper = do
x <- popper
case x of
Nothing -> return $ front []
Just vals -> do
let (y, z) = splitAt oneCount vals
case (fromPersistValuesId y, fromPersistValuesId z) of
(Right y', Right z') -> loop (front . (:) (y', Just z')) popper
(Left e, _) -> error $ "selectOneMany: " ++ e
(Right y', Left e) ->
case z of
PersistNull:_ -> loop (front . (:) (y', Nothing)) popper
_ -> error $ "selectOneMany: " ++ e
oneCount = 1 + length (tableColumns $ entityDef one)
one = dummyFromFilts oneF
many = dummyFromFilts manyF
sql conn = pack $ concat
[ "SELECT "
, intercalate "," $ colsPlusId conn one ++ colsPlusId conn many
, " FROM "
, escapeName conn $ rawTableName $ entityDef one
, if isOuter then " LEFT JOIN " else " INNER JOIN "
, escapeName conn $ rawTableName $ entityDef many
, " ON "
, escapeName conn $ rawTableName $ entityDef one
, ".id = "
, escapeName conn $ rawTableName $ entityDef many
, "."
, escapeName conn $ RawName $ persistFilterToFieldName $ eq undefined
, if null filts
then ""
else " WHERE " ++ intercalate " AND " filts
, if null ords
then ""
else " ORDER BY " ++ intercalate ", " ords
]
where
filts = map (filterClause True conn) oneF ++ map (filterClause True conn) manyF
ords = map (orderClause True conn) oneO ++ map (orderClause True conn) manyO
addTable :: PersistEntity val =>
Connection -> val -> [Char] -> [Char]
addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s]
colsPlusId :: PersistEntity e => Connection -> e -> [String]
colsPlusId conn e =
map (addTable conn e) $
"id" : (map (\(x, _, _) -> escapeName conn x) cols)
where
cols = tableColumns $ entityDef e