module Database.PostgreSQL.Simple.Internal where
import           Control.Applicative
import           Control.Exception
import           Control.Concurrent.MVar
import           Control.Monad(MonadPlus(..))
import           Data.ByteString(ByteString)
import qualified Data.ByteString.Char8 as B8
import           Data.Char (ord)
import           Data.Int (Int64)
import qualified Data.IntMap as IntMap
import           Data.IORef
import           Data.Maybe(fromMaybe)
import           Data.String
import           Data.Typeable
import           Data.Word
import           Database.PostgreSQL.LibPQ(Oid(..))
import qualified Database.PostgreSQL.LibPQ as PQ
import           Database.PostgreSQL.LibPQ(ExecStatus(..))
import           Database.PostgreSQL.Simple.Ok
import           Database.PostgreSQL.Simple.Types (Query(..))
import           Database.PostgreSQL.Simple.TypeInfo.Types(TypeInfo)
import           Control.Monad.Trans.State.Strict
import           Control.Monad.Trans.Reader
import           Control.Monad.Trans.Class
import           GHC.IO.Exception
#if !defined(mingw32_HOST_OS)
import           Control.Concurrent(threadWaitRead, threadWaitWrite)
#endif
data Field = Field {
     result   :: !PQ.Result
   , column   ::  !PQ.Column
   , typeOid  ::  !PQ.Oid
     
     
   }
type TypeInfoCache = IntMap.IntMap TypeInfo
data Connection = Connection {
     connectionHandle  ::  !(MVar PQ.Connection)
   , connectionObjects ::  !(MVar TypeInfoCache)
   , connectionTempNameCounter ::  !(IORef Int64)
   }
data SqlError = SqlError {
     sqlState       :: ByteString
   , sqlExecStatus  :: ExecStatus
   , sqlErrorMsg    :: ByteString
   , sqlErrorDetail :: ByteString
   , sqlErrorHint   :: ByteString
   } deriving (Show, Typeable)
fatalError :: ByteString -> SqlError
fatalError msg = SqlError "" FatalError msg "" ""
instance Exception SqlError
data QueryError = QueryError {
      qeMessage :: String
    , qeQuery :: Query
    } deriving (Eq, Show, Typeable)
instance Exception QueryError
data ConnectInfo = ConnectInfo {
      connectHost :: String
    , connectPort :: Word16
    , connectUser :: String
    , connectPassword :: String
    , connectDatabase :: String
    } deriving (Eq,Read,Show,Typeable)
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = ConnectInfo {
                       connectHost = "127.0.0.1"
                     , connectPort = 5432
                     , connectUser = "postgres"
                     , connectPassword = ""
                     , connectDatabase = ""
                     }
connect :: ConnectInfo -> IO Connection
connect = connectPostgreSQL . postgreSQLConnectionString
connectPostgreSQL :: ByteString -> IO Connection
connectPostgreSQL connstr = do
    conn <- connectdb connstr
    stat <- PQ.status conn
    case stat of
      PQ.ConnectionOk -> do
          connectionHandle  <- newMVar conn
          connectionObjects <- newMVar (IntMap.empty)
          connectionTempNameCounter <- newIORef 0
          let wconn = Connection{..}
          version <- PQ.serverVersion conn
          let settings
                | version < 80200 = "SET datestyle TO ISO"
                | otherwise       = "SET standard_conforming_strings TO on;SET datestyle TO ISO"
          _ <- execute_ wconn settings
          return wconn
      _ -> do
          msg <- maybe "connectPostgreSQL error" id <$> PQ.errorMessage conn
          throwIO $ fatalError msg
connectdb :: ByteString -> IO PQ.Connection
#if defined(mingw32_HOST_OS)
connectdb = PQ.connectdb
#else
connectdb conninfo = do
    conn <- PQ.connectStart conninfo
    loop conn
  where
    funcName = "Database.PostgreSQL.Simple.connectPostgreSQL"
    loop conn = do
      status <- PQ.connectPoll conn
      case status of
        PQ.PollingFailed  -> throwLibPQError conn "connection failed"
        PQ.PollingReading -> do
                                mfd <- PQ.socket conn
                                case mfd of
                                  Nothing -> throwIO $! fdError funcName
                                  Just fd -> do
                                      threadWaitRead fd
                                      loop conn
        PQ.PollingWriting -> do
                                mfd <- PQ.socket conn
                                case mfd of
                                  Nothing -> throwIO $! fdError funcName
                                  Just fd -> do
                                      threadWaitWrite fd
                                      loop conn
        PQ.PollingOk      -> return conn
#endif
postgreSQLConnectionString :: ConnectInfo -> ByteString
postgreSQLConnectionString connectInfo = fromString connstr
  where
    connstr = str "host="     connectHost
            $ num "port="     connectPort
            $ str "user="     connectUser
            $ str "password=" connectPassword
            $ str "dbname="   connectDatabase
            $ []
    str name field
      | null value = id
      | otherwise  = showString name . quote value . space
        where value = field connectInfo
    num name field
      | value <= 0 = id
      | otherwise  = showString name . shows value . space
        where value = field connectInfo
    quote str rest = '\'' : foldr delta ('\'' : rest) str
       where
         delta c cs = case c of
                        '\\' -> '\\' : '\\' : cs
                        '\'' -> '\\' : '\'' : cs
                        _    -> c : cs
    space [] = []
    space xs = ' ':xs
oid2int :: Oid -> Int
oid2int (Oid x) = fromIntegral x
exec :: Connection
     -> ByteString
     -> IO PQ.Result
#if defined(mingw32_HOST_OS)
exec conn sql =
    withConnection conn $ \h -> do
        mres <- PQ.exec h sql
        case mres of
          Nothing  -> throwLibPQError h "PQexec returned no results"
          Just res -> return res
#else
exec conn sql =
    withConnection conn $ \h -> do
        success <- PQ.sendQuery h sql
        if success
        then awaitResult h Nothing
        else throwLibPQError h "PQsendQuery failed"
  where
    awaitResult h mres = do
        mfd <- PQ.socket h
        case mfd of
          Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec"
          Just fd -> do
             threadWaitRead fd
             _ <- PQ.consumeInput h  
             getResult h mres
    getResult h mres = do
        isBusy <- PQ.isBusy h
        if isBusy
        then awaitResult h mres
        else do
          mres' <- PQ.getResult h
          case mres' of
            Nothing -> case mres of
                         Nothing  -> throwLibPQError h "PQgetResult returned no results"
                         Just res -> return res
            Just res -> do
                status <- PQ.resultStatus res
                case status of
                   PQ.EmptyQuery    -> getResult h mres'
                   PQ.CommandOk     -> getResult h mres'
                   PQ.TuplesOk      -> getResult h mres'
                   PQ.CopyOut       -> return res
                   PQ.CopyIn        -> return res
                   PQ.BadResponse   -> getResult h mres'
                   PQ.NonfatalError -> getResult h mres'
                   PQ.FatalError    -> getResult h mres'
#endif
execute_ :: Connection -> Query -> IO Int64
execute_ conn q@(Query stmt) = do
  result <- exec conn stmt
  finishExecute conn q result
finishExecute :: Connection -> Query -> PQ.Result -> IO Int64
finishExecute _conn q result = do
    status <- PQ.resultStatus result
    case status of
      PQ.EmptyQuery -> throwIO $ QueryError "execute: Empty query" q
      PQ.CommandOk -> do
          ncols <- PQ.nfields result
          if ncols /= 0
          then throwIO $ QueryError ("execute resulted in " ++ show ncols ++
                                     "-column result") q
          else do
            nstr <- PQ.cmdTuples result
            return $ case nstr of
                       Nothing  -> 0   
                       Just str -> toInteger str
      PQ.TuplesOk -> do
          ncols <- PQ.nfields result
          throwIO $ QueryError ("execute resulted in " ++ show ncols ++
                                 "-column result") q
      PQ.CopyOut ->
          throwIO $ QueryError "execute: COPY TO is not supported" q
      PQ.CopyIn ->
          throwIO $ QueryError "execute: COPY FROM is not supported" q
      PQ.BadResponse   -> throwResultError "execute" result status
      PQ.NonfatalError -> throwResultError "execute" result status
      PQ.FatalError    -> throwResultError "execute" result status
    where
     toInteger str = B8.foldl' delta 0 str
                where
                  delta acc c =
                    if '0' <= c && c <= '9'
                    then 10 * acc + fromIntegral (ord c  ord '0')
                    else error ("finishExecute:  not an int: " ++ B8.unpack str)
throwResultError :: ByteString -> PQ.Result -> PQ.ExecStatus -> IO a
throwResultError _ result status = do
    errormsg  <- fromMaybe "" <$>
                 PQ.resultErrorField result PQ.DiagMessagePrimary
    detail    <- fromMaybe "" <$>
                 PQ.resultErrorField result PQ.DiagMessageDetail
    hint      <- fromMaybe "" <$>
                 PQ.resultErrorField result PQ.DiagMessageHint
    state     <- maybe "" id <$> PQ.resultErrorField result PQ.DiagSqlstate
    throwIO $ SqlError { sqlState = state
                       , sqlExecStatus = status
                       , sqlErrorMsg = errormsg
                       , sqlErrorDetail = detail
                       , sqlErrorHint = hint }
disconnectedError :: SqlError
disconnectedError = fatalError "connection disconnected"
withConnection :: Connection -> (PQ.Connection -> IO a) -> IO a
withConnection Connection{..} m = do
    withMVar connectionHandle $ \conn -> do
        if PQ.isNullConnection conn
          then throwIO disconnectedError
          else m conn
close :: Connection -> IO ()
close Connection{..} =
    mask $ \restore -> (do
            conn <- takeMVar connectionHandle
            restore (PQ.finish conn)
        `finally` do
            putMVar connectionHandle =<< PQ.newNullConnection
        )
newNullConnection :: IO Connection
newNullConnection = do
    connectionHandle  <- newMVar =<< PQ.newNullConnection
    connectionObjects <- newMVar IntMap.empty
    connectionTempNameCounter <- newIORef 0
    return Connection{..}
data Row = Row {
     row        ::  !PQ.Row
   , rowresult  :: !PQ.Result
   }
newtype RowParser a = RP { unRP :: ReaderT Row (StateT PQ.Column Conversion) a }
   deriving ( Functor, Applicative, Alternative, Monad )
liftRowParser :: IO a -> RowParser a
liftRowParser = RP . lift . lift . liftConversion
newtype Conversion a = Conversion { runConversion :: Connection -> IO (Ok a) }
liftConversion :: IO a -> Conversion a
liftConversion m = Conversion (\_ -> Ok <$> m)
instance Functor Conversion where
   fmap f m = Conversion $ \conn -> (fmap . fmap) f (runConversion m conn)
instance Applicative Conversion where
   pure a    = Conversion $ \_conn -> pure (pure a)
   mf <*> ma = Conversion $ \conn -> do
                   okf <- runConversion mf conn
                   case okf of
                     Ok f -> (fmap . fmap) f (runConversion ma conn)
                     Errors errs -> return (Errors errs)
instance Alternative Conversion where
   empty     = Conversion $ \_conn -> pure empty
   ma <|> mb = Conversion $ \conn -> do
                   oka <- runConversion ma conn
                   case oka of
                     Ok _     -> return oka
                     Errors _ -> (oka <|>) <$> runConversion mb conn
instance Monad Conversion where
   return a = Conversion $ \_conn -> return (return a)
   m >>= f = Conversion $ \conn -> do
                 oka <- runConversion m conn
                 case oka of
                   Ok a -> runConversion (f a) conn
                   Errors err -> return (Errors err)
instance MonadPlus Conversion where
   mzero = empty
   mplus = (<|>)
conversionMap :: (Ok a -> Ok b) -> Conversion a -> Conversion b
conversionMap f m = Conversion $ \conn -> f <$> runConversion m conn
conversionError :: Exception err => err -> Conversion a
conversionError err = Conversion $ \_ -> return (Errors [SomeException err])
newTempName :: Connection -> IO Query
newTempName Connection{..} = do
    !n <- atomicModifyIORef connectionTempNameCounter
          (\n -> let !n' = n+1 in (n', n'))
    return $! Query $ B8.pack $ "temp" ++ show n
fdError :: ByteString -> IOError
fdError funcName = IOError {
                     ioe_handle      = Nothing,
                     ioe_type        = ResourceVanished,
                     ioe_location    = B8.unpack funcName,
                     ioe_description = "failed to fetch file descriptor",
                     ioe_errno       = Nothing,
                     ioe_filename    = Nothing
                   }
libPQError :: ByteString -> IOError
libPQError desc = IOError {
                    ioe_handle      = Nothing,
                    ioe_type        = OtherError,
                    ioe_location    = "libpq",
                    ioe_description = B8.unpack desc,
                    ioe_errno       = Nothing,
                    ioe_filename    = Nothing
                  }
throwLibPQError :: PQ.Connection -> ByteString -> IO a
throwLibPQError conn default_desc = do
  msg <- maybe default_desc id <$> PQ.errorMessage conn
  throwIO $! libPQError msg