{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.Persist.Query.Join.Sql
    ( RunJoin (..)
    ) where

import Database.Persist.Query.Join hiding (RunJoin (..))
import Database.Persist.EntityDef
import qualified Database.Persist.Query.Join as J
import Database.Persist.Store
import Database.Persist.Query.Internal
import Database.Persist.Query.GenericSql
import Control.Monad (liftM)
import Data.Maybe (mapMaybe)
import Data.List (groupBy)
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal hiding (withStmt)
import Database.Persist.GenericSql.Raw (withStmt)
import Control.Monad.Trans.Reader (ask)
import Data.Function (on)
import Control.Arrow ((&&&))
import Data.Text (Text, concat, null)
import Prelude hiding ((++), unlines, concat, show, null)
import Data.Monoid (Monoid, mappend)
import qualified Data.Text as T
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Trans.Control (MonadBaseControl)
import Control.Monad.Logger (MonadLogger)

fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either Text (Entity v)
fromPersistValuesId [] = Left "fromPersistValuesId: No values provided"
fromPersistValuesId (PersistInt64 i:rest) =
    case fromPersistValues rest of
        Left e -> Left e
        Right x -> Right (Entity (Key $ PersistInt64 i) x)
fromPersistValuesId _ = Left "fromPersistValuesId: invalid ID"

class RunJoin a where
    runJoin :: (C.MonadThrow m, C.MonadUnsafeIO m, MonadIO m, MonadBaseControl IO m, MonadLogger m) => a -> SqlPersist m (J.Result a)

instance (PersistEntity one, PersistEntity many, Eq (Key SqlPersist one))
    => RunJoin (SelectOneMany SqlPersist one many) where
    runJoin (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do
        conn <- SqlPersist ask
        C.runResourceT $ liftM go $ withStmt (sql conn) (getFiltsValues conn oneF ++ getFiltsValues conn manyF) C.$$ loop id
      where
        go :: [(Entity b, Maybe (Entity d))]
           -> [(Entity b, [Entity d])]
        go = map (fst . head &&& mapMaybe snd)
           . groupBy ((==) `on` (entityKey . fst))

        loop front = do
            x <- CL.head
            case x of
                Nothing -> return $ front []
                Just vals -> do
                    let (y, z) = splitAt oneCount vals
                    case (fromPersistValuesId y, fromPersistValuesId z) of
                        (Right y', Right z') -> loop (front . (:) (y', Just z'))
                        (Left e, _) -> error $ "selectOneMany: " ++ T.unpack e
                        (Right y', Left e) ->
                            case z of
                                PersistNull:_ -> loop (front . (:) (y', Nothing))
                                _ -> error $ "selectOneMany: " ++ T.unpack e
        oneCount = 1 + length (entityFields $ entityDef one)
        one = dummyFromFilts oneF
        many = dummyFromFilts manyF
        sql conn = concat
            [ "SELECT "
            , T.intercalate "," $ colsPlusId conn one ++ colsPlusId conn many
            , " FROM "
            , escapeName conn $ entityDB $ entityDef one
            , if isOuter then " LEFT JOIN " else " INNER JOIN "
            , escapeName conn $ entityDB $ entityDef many
            , " ON "
            , escapeName conn $ entityDB $ entityDef one
            , ".id = "
            , escapeName conn $ entityDB $ entityDef many
            , "."
            , escapeName conn $ filterName $ eq undefined
            , onFilts
            , whereFilts
            , case ords of
                [] -> ""
                _ -> " ORDER BY " ++ T.intercalate ", " ords
            ]
          where
            filts1 = filterClauseNoWhere True conn oneF
            filts2 = (if isOuter then filterClauseNoWhereOrNull else filterClauseNoWhere) True conn manyF

            whereFilts
                | isOuter =
                    if null filts1
                        then ""
                        else " WHERE " ++ filts1
                | null filts1 && null filts2 = ""
                | null filts1 = " WHERE " ++ filts2
                | null filts2 = " WHERE " ++ filts1
                | otherwise = " WHERE " ++ filts1 ++ " AND " ++ filts2

            onFilts
                | isOuter && not (null filts2) = " AND " ++ filts2
                | otherwise = ""

            orders :: PersistEntity val
                   => [SelectOpt val]
                   -> [SelectOpt val]
            orders x = let (_, _, y) = limitOffsetOrder x in y

            ords = map (orderClause True conn) (orders oneO) ++ map (orderClause True conn) (orders manyO)

addTable :: PersistEntity val
         => Connection
         -> val
         -> Text
         -> Text
addTable conn e s = concat
    [ escapeName conn $ entityDB $ entityDef e
    , "."
    , s
    ]

colsPlusId :: PersistEntity e => Connection -> e -> [Text]
colsPlusId conn e =
    map (addTable conn e) $
    id_ : (map (escapeName conn . fieldDB) cols)
  where
    id_ = escapeName conn $ entityID $ entityDef e
    cols = entityFields $ entityDef e

filterName :: PersistEntity v => Filter v -> DBName
filterName (Filter f _ _) = fieldDB $ persistFieldDef f
filterName (FilterAnd _) = error "expected a raw filter, not an And"
filterName (FilterOr _) = error "expected a raw filter, not an Or"

infixr 5 ++
(++) :: Monoid m => m -> m -> m
(++) = mappend