{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} module Database.Persist.Join.Sql ( RunJoin (..) ) where import Database.Persist.Join hiding (RunJoin (..)) import qualified Database.Persist.Join as J import Database.Persist.Base import Control.Monad (liftM) import Data.Maybe (mapMaybe) import Data.List (intercalate, groupBy) import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal hiding (withStmt) import Database.Persist.GenericSql.Raw (withStmt) import Control.Monad.Trans.Reader (ask) import Control.Monad.IO.Control (MonadControlIO) import Data.Function (on) import Control.Arrow ((&&&)) import Data.Text (pack) fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key SqlPersist v, v) fromPersistValuesId [] = Left "fromPersistValuesId: No values provided" fromPersistValuesId (PersistInt64 i:rest) = case fromPersistValues rest of Left e -> Left e Right x -> Right (Key $ PersistInt64 i, x) fromPersistValuesId _ = Left "fromPersistValuesId: invalid ID" class RunJoin a where runJoin :: MonadControlIO m => a -> SqlPersist m (J.Result a) instance (PersistEntity one, PersistEntity many, Eq (Key SqlPersist one)) => RunJoin (SelectOneMany SqlPersist one many) where runJoin (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do conn <- SqlPersist ask liftM go $ withStmt (sql conn) (getFiltsValues conn oneF ++ getFiltsValues conn manyF) $ loop id where go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])] go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst)) loop front popper = do x <- popper case x of Nothing -> return $ front [] Just vals -> do let (y, z) = splitAt oneCount vals case (fromPersistValuesId y, fromPersistValuesId z) of (Right y', Right z') -> loop (front . (:) (y', Just z')) popper (Left e, _) -> error $ "selectOneMany: " ++ e (Right y', Left e) -> case z of PersistNull:_ -> loop (front . (:) (y', Nothing)) popper _ -> error $ "selectOneMany: " ++ e oneCount = 1 + length (tableColumns $ entityDef one) one = dummyFromFilts oneF many = dummyFromFilts manyF sql conn = pack $ concat [ "SELECT " , intercalate "," $ colsPlusId conn one ++ colsPlusId conn many , " FROM " , escapeName conn $ rawTableName $ entityDef one , if isOuter then " LEFT JOIN " else " INNER JOIN " , escapeName conn $ rawTableName $ entityDef many , " ON " , escapeName conn $ rawTableName $ entityDef one , ".id = " , escapeName conn $ rawTableName $ entityDef many , "." , escapeName conn $ RawName $ filterName $ eq undefined , filts , if null ords then "" else " ORDER BY " ++ intercalate ", " ords ] where filts1 = filterClauseNoWhere True conn oneF filts2 = filterClauseNoWhere True conn manyF orders :: PersistEntity val => [SelectOpt val] -> [SelectOpt val] orders = third3 . limitOffsetOrder filts | null filts1 && null filts2 = "" | null filts1 = " WHERE " ++ filts2 | null filts2 = " WHERE " ++ filts1 | otherwise = " WHERE " ++ filts1 ++ " AND " ++ filts2 ords = map (orderClause True conn) (orders oneO) ++ map (orderClause True conn) (orders manyO) addTable :: PersistEntity val => Connection -> val -> [Char] -> [Char] addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s] colsPlusId :: PersistEntity e => Connection -> e -> [String] colsPlusId conn e = map (addTable conn e) $ id_ : (map (\(x, _, _) -> escapeName conn x) cols) where id_ = unRawName $ rawTableIdName $ entityDef e cols = tableColumns $ entityDef e filterName :: PersistEntity v => Filter v -> String filterName (Filter f _ _) = columnName $ persistColumnDef f filterName (FilterAnd _) = error "expected a raw filter, not an And" filterName (FilterOr _) = error "expected a raw filter, not an Or"