{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
module Database.Persist.Join.Sql
    ( RunJoin (..)
    ) where

import Database.Persist.Join hiding (RunJoin (..))
import qualified Database.Persist.Join as J
import Database.Persist.Base
import Control.Monad (liftM)
import Data.Maybe (mapMaybe)
import Data.List (intercalate, groupBy)
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal hiding (withStmt)
import Database.Persist.GenericSql.Raw (withStmt)
import Control.Monad.Trans.Reader (ask)
import Control.Monad.IO.Control (MonadControlIO)
import Data.Function (on)
import Control.Arrow ((&&&))
import Data.Text (pack)

fromPersistValuesId :: PersistEntity v => [PersistValue] -> Either String (Key SqlPersist v, v)
fromPersistValuesId [] = Left "fromPersistValuesId: No values provided"
fromPersistValuesId (PersistInt64 i:rest) =
    case fromPersistValues rest of
        Left e -> Left e
        Right x -> Right (Key $ PersistInt64 i, x)
fromPersistValuesId _ = Left "fromPersistValuesId: invalid ID"

class RunJoin a where
    runJoin :: MonadControlIO m => a -> SqlPersist m (J.Result a)

instance (PersistEntity one, PersistEntity many, Eq (Key SqlPersist one))
    => RunJoin (SelectOneMany SqlPersist one many) where
    runJoin (SelectOneMany oneF oneO manyF manyO eq _getKey isOuter) = do
        conn <- SqlPersist ask
        liftM go $ withStmt (sql conn) (getFiltsValues conn oneF ++ getFiltsValues conn manyF) $ loop id
      where
        go :: Eq a => [((a, b), Maybe (c, d))] -> [((a, b), [(c, d)])]
        go = map (fst . head &&& mapMaybe snd) . groupBy ((==) `on` (fst . fst))
        loop front popper = do
            x <- popper
            case x of
                Nothing -> return $ front []
                Just vals -> do
                    let (y, z) = splitAt oneCount vals
                    case (fromPersistValuesId y, fromPersistValuesId z) of
                        (Right y', Right z') -> loop (front . (:) (y', Just z')) popper
                        (Left e, _) -> error $ "selectOneMany: " ++ e
                        (Right y', Left e) ->
                            case z of
                                PersistNull:_ -> loop (front . (:) (y', Nothing)) popper
                                _ -> error $ "selectOneMany: " ++ e
        oneCount = 1 + length (tableColumns $ entityDef one)
        one = dummyFromFilts oneF
        many = dummyFromFilts manyF
        sql conn = pack $ concat
            [ "SELECT "
            , intercalate "," $ colsPlusId conn one ++ colsPlusId conn many
            , " FROM "
            , escapeName conn $ rawTableName $ entityDef one
            , if isOuter then " LEFT JOIN " else " INNER JOIN "
            , escapeName conn $ rawTableName $ entityDef many
            , " ON "
            , escapeName conn $ rawTableName $ entityDef one
            , ".id = "
            , escapeName conn $ rawTableName $ entityDef many
            , "."
            , escapeName conn $ RawName $ filterName $ eq undefined
            , filts
            , if null ords
                then ""
                else " ORDER BY " ++ intercalate ", " ords
            ]
          where
            filts1 = filterClauseNoWhere True conn oneF
            filts2 = filterClauseNoWhere True conn manyF

            orders :: PersistEntity val => [SelectOpt val] -> [SelectOpt val]
            orders = third3 . limitOffsetOrder

            filts
                | null filts1 && null filts2 = ""
                | null filts1 = " WHERE " ++ filts2
                | null filts2 = " WHERE " ++ filts1
                | otherwise = " WHERE " ++ filts1 ++ " AND " ++ filts2
            ords = map (orderClause True conn) (orders oneO) ++ map (orderClause True conn) (orders manyO)

addTable :: PersistEntity val =>
           Connection -> val -> [Char] -> [Char]
addTable conn e s = concat [escapeName conn $ rawTableName $ entityDef e, ".", s]

colsPlusId :: PersistEntity e => Connection -> e -> [String]
colsPlusId conn e =
    map (addTable conn e) $
    id_ : (map (\(x, _, _) -> escapeName conn x) cols)
  where
    id_ = unRawName $ rawTableIdName $ entityDef e
    cols = tableColumns $ entityDef e

filterName :: PersistEntity v => Filter v -> String
filterName (Filter f _ _) = columnName $ persistColumnDef f
filterName (FilterAnd _) = error "expected a raw filter, not an And"
filterName (FilterOr _) = error "expected a raw filter, not an Or"