module Database.MySQL.Simple
    (
      execute
    , query
    , query_
    , formatQuery
    ) where

import Blaze.ByteString.Builder (fromByteString, toByteString)
import Control.Applicative ((<$>), pure)
import Control.DeepSeq (NFData(..))
import Control.Monad.Fix (fix)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Monoid (mappend, mempty)
import Database.MySQL.Base (Connection)
import Database.MySQL.Simple.Param (Action(..), inQuotes)
import Database.MySQL.Simple.QueryParams (QueryParams(..))
import Database.MySQL.Simple.QueryResults (QueryResults(..))
import Database.MySQL.Simple.Types (Query(..))
import qualified Data.ByteString.Char8 as B
import qualified Database.MySQL.Base as Base

formatQuery :: QueryParams q => Connection -> Query -> q -> IO ByteString
formatQuery conn (Query template) qs
    | '?' `B.notElem` template = return template
    | otherwise =
        toByteString . zipParams (split template) <$> mapM sub (renderParams qs)
  where sub (Plain b)  = pure b
        sub (Escape s) = (inQuotes . fromByteString) <$> Base.escape conn s
        split q = fromByteString h : if B.null t then [] else split (B.tail t)
            where (h,t) = B.break (=='?') q
        zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps
        zipParams [] []         = mempty
        zipParams [] _ = fmtError "more parameters than '?' characters"
        zipParams _ [] = fmtError "more '?' characters than parameters"

execute :: (QueryParams q) => Connection -> Query -> q -> IO Int64
execute conn template qs = do
  Base.query conn =<< formatQuery conn template qs
  ncols <- Base.fieldCount (Left conn)
  if ncols /= 0
    then error "execute: executed a select!"
    else Base.affectedRows conn
  
query :: (QueryParams q, QueryResults r) => Connection -> Query -> q -> IO [r]
query conn template qs = do
  Base.query conn =<< formatQuery conn template qs
  finishQuery conn
  
query_ :: (QueryResults r) => Connection -> Query -> IO [r]
query_ conn (Query q) = do
  Base.query conn q
  finishQuery conn

finishQuery :: (QueryResults r) => Connection -> IO [r]
finishQuery conn = do
  r <- Base.storeResult conn
  ncols <- Base.fieldCount (Right r)
  if ncols == 0
    then return []
    else do
      fs <- Base.fetchFields r
      flip fix [] $ \loop acc -> do
        row <- Base.fetchRow r
        case row of
          [] -> return (reverse acc)
          _  -> let c = convertResults fs row
                in rnf c `seq` loop (c:acc)

fmtError :: String -> a
fmtError msg = error $ "Database.MySQL.formatQuery: " ++ msg