{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE CPP #-} module Database.Persist.Join.Sql ( RunJoin (..) ) where import Database.Persist.Join hiding (RunJoin (..)) import qualified Database.Persist.Join as J import Database.Persist.Base import Control.Monad (liftM) import Data.Maybe (mapMaybe) import Data.List (intercalate, groupBy) import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal hiding (withStmt) import Database.Persist.GenericSql.Raw (withStmt) import Control.Monad.Trans.Reader (ask) #if MIN_VERSION_monad_control(0, 3, 0) import Control.Monad.Trans.Control (MonadBaseControl) #define MBCIO MonadBaseControl IO #else import Control.Monad.IO.Control (MonadControlIO) #define MBCIO MonadControlIO #endif import Data.Function (on) import Control.Arrow ((&&&)) import Data.Text (pack) import Control.Monad.IO.Class (MonadIO) fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key SqlPersist v, v) fromPersistValuesId [] = Left "fromPersistValuesId: No values provided" fromPersistValuesId (PersistInt64 i:rest) = case fromPersistValues rest of Left e -> Left e Right x -> Right (Key $ PersistInt64 i, x) fromPersistValuesId _ = Left "fromPersistValuesId: invalid ID" class RunJoin a where runJoin :: (MonadIO m, MBCIO m) => a -> SqlPersist m (J.Result a) instance (PersistEntity one, PersistEntity many, Eq (Key SqlPersist one)) => RunJoin (SelectOneMany SqlPersist one many) where runJoin (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do conn <- SqlPersist ask liftM go $ withStmt (sql conn) (getFiltsValues conn oneF ++ getFiltsValues conn manyF) $ loop id where go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])] go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst)) loop front popper = do x <- popper case x of Nothing -> return $ front [] Just vals -> do let (y, z) = splitAt oneCount vals case (fromPersistValuesId y, fromPersistValuesId z) of (Right y', Right z') -> loop (front . (:) (y', Just z')) popper (Left e, _) -> error $ "selectOneMany: " ++ e (Right y', Left e) -> case z of PersistNull:_ -> loop (front . (:) (y', Nothing)) popper _ -> error $ "selectOneMany: " ++ e oneCount = 1 + length (tableColumns $ entityDef one) one = dummyFromFilts oneF many = dummyFromFilts manyF sql conn = pack $ concat [ "SELECT " , intercalate "," $ colsPlusId conn one ++ colsPlusId conn many , " FROM " , escapeName conn $ rawTableName $ entityDef one , if isOuter then " LEFT JOIN " else " INNER JOIN " , escapeName conn $ rawTableName $ entityDef many , " ON " , escapeName conn $ rawTableName $ entityDef one , ".id = " , escapeName conn $ rawTableName $ entityDef many , "." , escapeName conn $ RawName $ filterName $ eq undefined , filts , if null ords then "" else " ORDER BY " ++ intercalate ", " ords ] where filts1 = filterClauseNoWhere True conn oneF filts2 = filterClauseNoWhere True conn manyF orders :: PersistEntity val => [SelectOpt val] -> [SelectOpt val] orders = third3 . limitOffsetOrder filts | null filts1 && null filts2 = "" | null filts1 = " WHERE " ++ filts2 | null filts2 = " WHERE " ++ filts1 | otherwise = " WHERE " ++ filts1 ++ " AND " ++ filts2 ords = map (orderClause True conn) (orders oneO) ++ map (orderClause True conn) (orders manyO) addTable :: PersistEntity val => Connection -> val -> [Char] -> [Char] addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s] colsPlusId :: PersistEntity e => Connection -> e -> [String] colsPlusId conn e = map (addTable conn e) $ id_ : (map (\(x, _, _) -> escapeName conn x) cols) where id_ = unRawName $ rawTableIdName $ entityDef e cols = tableColumns $ entityDef e filterName :: PersistEntity v => Filter v -> String filterName (Filter f _ _) = columnName $ persistColumnDef f filterName (FilterAnd _) = error "expected a raw filter, not an And" filterName (FilterOr _) = error "expected a raw filter, not an Or"