module Database.Persist.Join.Sql
( RunJoin (..)
) where
import Database.Persist.Join hiding (RunJoin (..))
import qualified Database.Persist.Join as J
import Database.Persist.Base
import Control.Monad (liftM)
import Data.Maybe (mapMaybe)
import Data.List (intercalate, groupBy)
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal hiding (withStmt)
import Database.Persist.GenericSql.Raw (withStmt)
import Control.Monad.Trans.Reader (ask)
#if MIN_VERSION_monad_control(0, 3, 0)
import Control.Monad.Trans.Control (MonadBaseControl)
#define MBCIO MonadBaseControl IO
#else
import Control.Monad.IO.Control (MonadControlIO)
#define MBCIO MonadControlIO
#endif
import Data.Function (on)
import Control.Arrow ((&&&))
import Data.Text (pack)
import Control.Monad.IO.Class (MonadIO)
fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key SqlPersist v, v)
fromPersistValuesId [] = Left "fromPersistValuesId: No values provided"
fromPersistValuesId (PersistInt64 i:rest) =
case fromPersistValues rest of
Left e -> Left e
Right x -> Right (Key $ PersistInt64 i, x)
fromPersistValuesId _ = Left "fromPersistValuesId: invalid ID"
class RunJoin a where
runJoin :: (MonadIO m, MBCIO m) => a -> SqlPersist m (J.Result a)
instance (PersistEntity one, PersistEntity many, Eq (Key SqlPersist one))
=> RunJoin (SelectOneMany SqlPersist one many) where
runJoin (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do
conn <- SqlPersist ask
liftM go $ withStmt (sql conn) (getFiltsValues conn oneF ++ getFiltsValues conn manyF) $ loop id
where
go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])]
go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst))
loop front popper = do
x <- popper
case x of
Nothing -> return $ front []
Just vals -> do
let (y, z) = splitAt oneCount vals
case (fromPersistValuesId y, fromPersistValuesId z) of
(Right y', Right z') -> loop (front . (:) (y', Just z')) popper
(Left e, _) -> error $ "selectOneMany: " ++ e
(Right y', Left e) ->
case z of
PersistNull:_ -> loop (front . (:) (y', Nothing)) popper
_ -> error $ "selectOneMany: " ++ e
oneCount = 1 + length (tableColumns $ entityDef one)
one = dummyFromFilts oneF
many = dummyFromFilts manyF
sql conn = pack $ concat
[ "SELECT "
, intercalate "," $ colsPlusId conn one ++ colsPlusId conn many
, " FROM "
, escapeName conn $ rawTableName $ entityDef one
, if isOuter then " LEFT JOIN " else " INNER JOIN "
, escapeName conn $ rawTableName $ entityDef many
, " ON "
, escapeName conn $ rawTableName $ entityDef one
, ".id = "
, escapeName conn $ rawTableName $ entityDef many
, "."
, escapeName conn $ RawName $ filterName $ eq undefined
, filts
, if null ords
then ""
else " ORDER BY " ++ intercalate ", " ords
]
where
filts1 = filterClauseNoWhere True conn oneF
filts2 = filterClauseNoWhere True conn manyF
orders :: PersistEntity val => [SelectOpt val] -> [SelectOpt val]
orders = third3 . limitOffsetOrder
filts
| null filts1 && null filts2 = ""
| null filts1 = " WHERE " ++ filts2
| null filts2 = " WHERE " ++ filts1
| otherwise = " WHERE " ++ filts1 ++ " AND " ++ filts2
ords = map (orderClause True conn) (orders oneO) ++ map (orderClause True conn) (orders manyO)
addTable :: PersistEntity val =>
Connection -> val -> [Char] -> [Char]
addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s]
colsPlusId :: PersistEntity e => Connection -> e -> [String]
colsPlusId conn e =
map (addTable conn e) $
id_ : (map (\(x, _, _) -> escapeName conn x) cols)
where
id_ = unRawName $ rawTableIdName $ entityDef e
cols = tableColumns $ entityDef e
filterName :: PersistEntity v => Filter v -> String
filterName (Filter f _ _) = columnName $ persistColumnDef f
filterName (FilterAnd _) = error "expected a raw filter, not an And"
filterName (FilterOr _) = error "expected a raw filter, not an Or"