module Database.Persist.Join.Sql
( RunJoin (..)
) where
import Database.Persist.Join hiding (RunJoin (..))
import qualified Database.Persist.Join as J
import Database.Persist.Base
import Control.Monad (liftM)
import Data.Maybe (mapMaybe)
import Data.List (intercalate, groupBy)
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal hiding (withStmt)
import Database.Persist.GenericSql.Raw (withStmt)
import Control.Monad.Trans.Reader (ask)
import Control.Monad.IO.Control (MonadControlIO)
import Data.Function (on)
import Control.Arrow ((&&&))
import Data.Text (pack)
fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key SqlPersist v, v)
fromPersistValuesId [] = Left "fromPersistValuesId: No values provided"
fromPersistValuesId (PersistInt64 i:rest) =
case fromPersistValues rest of
Left e -> Left e
Right x -> Right (Key $ PersistInt64 i, x)
fromPersistValuesId _ = Left "fromPersistValuesId: invalid ID"
class RunJoin a where
runJoin :: MonadControlIO m => a -> SqlPersist m (J.Result a)
instance (PersistEntity one, PersistEntity many, Eq (Key SqlPersist one))
=> RunJoin (SelectOneMany SqlPersist one many) where
runJoin (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do
conn <- SqlPersist ask
liftM go $ withStmt (sql conn) (getFiltsValues conn oneF ++ getFiltsValues conn manyF) $ loop id
where
go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])]
go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst))
loop front popper = do
x <- popper
case x of
Nothing -> return $ front []
Just vals -> do
let (y, z) = splitAt oneCount vals
case (fromPersistValuesId y, fromPersistValuesId z) of
(Right y', Right z') -> loop (front . (:) (y', Just z')) popper
(Left e, _) -> error $ "selectOneMany: " ++ e
(Right y', Left e) ->
case z of
PersistNull:_ -> loop (front . (:) (y', Nothing)) popper
_ -> error $ "selectOneMany: " ++ e
oneCount = 1 + length (tableColumns $ entityDef one)
one = dummyFromFilts oneF
many = dummyFromFilts manyF
sql conn = pack $ concat
[ "SELECT "
, intercalate "," $ colsPlusId conn one ++ colsPlusId conn many
, " FROM "
, escapeName conn $ rawTableName $ entityDef one
, if isOuter then " LEFT JOIN " else " INNER JOIN "
, escapeName conn $ rawTableName $ entityDef many
, " ON "
, escapeName conn $ rawTableName $ entityDef one
, ".id = "
, escapeName conn $ rawTableName $ entityDef many
, "."
, escapeName conn $ RawName $ filterName $ eq undefined
, filts
, if null ords
then ""
else " ORDER BY " ++ intercalate ", " ords
]
where
filts1 = filterClauseNoWhere True conn oneF
filts2 = filterClauseNoWhere True conn manyF
orders :: PersistEntity val => [SelectOpt val] -> [SelectOpt val]
orders = third3 . limitOffsetOrder
filts
| null filts1 && null filts2 = ""
| null filts1 = " WHERE " ++ filts2
| null filts2 = " WHERE " ++ filts1
| otherwise = " WHERE " ++ filts1 ++ " AND " ++ filts2
ords = map (orderClause True conn) (orders oneO) ++ map (orderClause True conn) (orders manyO)
addTable :: PersistEntity val =>
Connection -> val -> [Char] -> [Char]
addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s]
colsPlusId :: PersistEntity e => Connection -> e -> [String]
colsPlusId conn e =
map (addTable conn e) $
id_ : (map (\(x, _, _) -> escapeName conn x) cols)
where
id_ = unRawName $ rawTableIdName $ entityDef e
cols = tableColumns $ entityDef e
filterName :: PersistEntity v => Filter v -> String
filterName (Filter f _ _) = columnName $ persistColumnDef f
filterName (FilterAnd _) = error "expected a raw filter, not an And"
filterName (FilterOr _) = error "expected a raw filter, not an Or"