{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Database.Persist.Join.Sql
    ( RunJoin (..)
    ) where

import Database.Persist.Join hiding (RunJoin (..))
import qualified Database.Persist.Join as J
import Database.Persist.Base
import Control.Monad (liftM)
import Data.Maybe (mapMaybe)
import Data.List (intercalate, groupBy)
import Database.Persist.GenericSql (SqlPersist (SqlPersist))
import Database.Persist.GenericSql.Internal hiding (withStmt)
import Database.Persist.GenericSql.Raw (withStmt)
import Control.Monad.Trans.Reader (ask)
import Control.Monad.IO.Control (MonadControlIO)
import Data.Function (on)
import Control.Arrow ((&&&))
import Data.Text (pack)

fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key v, v)
fromPersistValuesId [] = Left "fromPersistValuesId: No values provided"
fromPersistValuesId (i:rest) =
    case fromPersistValues rest of
        Left e -> Left e
        Right x -> Right (toPersistKey i, x)

class RunJoin a where
    runJoin :: MonadControlIO m => a -> SqlPersist m (J.Result a)

instance (PersistEntity one, PersistEntity many, Eq (Key one))
    => RunJoin (SelectOneMany one many) where
    runJoin = selectOneMany'

selectOneMany' :: (MonadControlIO m,
                  PersistEntity d,
                  PersistEntity val1,
                  PersistEntity val,
                  PersistEntity b,
                  Eq (Key b)) =>
                 SelectOneMany val val1 -> SqlPersist m [((Key b, b), [(Key d, d)])]
selectOneMany' (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do
    conn <- SqlPersist ask
    liftM go $ withStmt (sql conn) (getFiltsValues oneF ++ getFiltsValues manyF) $ loop id
  where
    go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])]
    go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst))
    loop front popper = do
        x <- popper
        case x of
            Nothing -> return $ front []
            Just vals -> do
                let (y, z) = splitAt oneCount vals
                case (fromPersistValuesId y, fromPersistValuesId z) of
                    (Right y', Right z') -> loop (front . (:) (y', Just z')) popper
                    (Left e, _) -> error $ "selectOneMany: " ++ e
                    (Right y', Left e) ->
                        case z of
                            PersistNull:_ -> loop (front . (:) (y', Nothing)) popper
                            _ -> error $ "selectOneMany: " ++ e
    oneCount = 1 + length (tableColumns $ entityDef one)
    one = dummyFromFilts oneF
    many = dummyFromFilts manyF
    sql conn = pack $ concat
        [ "SELECT "
        , intercalate "," $ colsPlusId conn one ++ colsPlusId conn many
        , " FROM "
        , escapeName conn $ rawTableName $ entityDef one
        , if isOuter then " LEFT JOIN " else " INNER JOIN "
        , escapeName conn $ rawTableName $ entityDef many
        , " ON "
        , escapeName conn $ rawTableName $ entityDef one
        , ".id = "
        , escapeName conn $ rawTableName $ entityDef many
        , "."
        , escapeName conn $ RawName $ persistFilterToFieldName $ eq undefined
        , if null filts
            then ""
            else " WHERE " ++ intercalate " AND " filts
        , if null ords
            then ""
            else " ORDER BY " ++ intercalate ", " ords
        ]
      where
        filts = map (filterClause True conn) oneF ++ map (filterClause True conn) manyF
        ords = map (orderClause True conn) oneO ++ map (orderClause True conn) manyO

addTable :: PersistEntity val =>
           Connection -> val -> [Char] -> [Char]
addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s]

colsPlusId :: PersistEntity e => Connection -> e -> [String]
colsPlusId conn e =
    map (addTable conn e) $
    "id" : (map (\(x, _, _) -> escapeName conn x) cols)
  where
    cols = tableColumns $ entityDef e