{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}

-- |
-- SQL Server client library implemented in Haskell
--
-- [Usage Example](https://github.com/mitsuji/mssql-simple-example/blob/master/app/Main.hs)


module Database.MSSQLServer.Connection
  (
    -- * Connect with the SQL Server
    -- $use

      ConnectInfo(..)
    , defaultConnectInfo
    , Connection(..)
    , connect
    , connectWithoutEncryption
    , close
    , ProtocolError(..)
    , AuthError(..)
    ) where

import qualified Network.Socket as Socket
import Network.Socket (AddrInfo(..),SocketType(..),Socket(..))
import Network.Socket.ByteString (recv)
import Network.Socket.ByteString.Lazy (sendAll)

import Data.Monoid ((<>),mempty)

import qualified Data.ByteString as B

import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Text.Encoding as T

import qualified Data.Binary.Put as Put
import qualified Data.Binary.Get as Get

import Control.Monad (when)
import Control.Exception (Exception(..),throwIO)

import qualified Network.TLS as TLS
import Network.HostName (getHostName)

import Database.Tds.Message
import Database.Tds.Transport (contextNew)

import Data.Word (Word8,Word32)
import Data.Int (Int32)
import Data.Typeable(Typeable)

data ProtocolError = ProtocolError String
                   deriving (Show,Typeable)
instance Exception ProtocolError

data AuthError = AuthError !Info
               deriving (Show,Typeable)
instance Exception AuthError




data ConnectInfo = ConnectInfo { connectHost :: !String
                               , connectPort :: !String
                               , connectDatabase :: !String
                               , connectUser :: !String
                               , connectPassword :: !String
                               , connectEncryption :: !Word8
                               , connectPacketSize :: !Word32
                               , connectOptionFlags1 :: !Word8
                               , connectOptionFlags2 :: !Word8
                               , connectOptionFlags3 :: !Word8
                               , connectTypeFlags :: !Word8
                               , connectTimeZone :: !Int32
                               , connectCollation :: !Collation32
                               , connectLanguage :: !String
                               , connectAppName :: !String
                               , connectServerName :: !String
                               }

defaultConnectInfo :: ConnectInfo
defaultConnectInfo =
  let
    l7 = defaultLogin7
  in ConnectInfo { connectHost = mempty
                 , connectPort = mempty
                 , connectDatabase = T.unpack $ l7Database l7
                 , connectUser = T.unpack $ l7UserName l7
                 , connectPassword = T.unpack $ l7Password l7
                 , connectEncryption = 0x00 -- 0x00: ENCRYPT_OFF (Encrypt login packet only), 0x02: ENCRYPT_NOT_SUP (No encryption)
                 , connectPacketSize = l7PacketSize l7
                 , connectOptionFlags1 = l7OptionFlags1 l7
                 , connectOptionFlags2 = l7OptionFlags2 l7
                 , connectOptionFlags3 = l7OptionFlags3 l7
                 , connectTypeFlags = l7TypeFlags l7
                 , connectTimeZone = l7TimeZone l7
                 , connectCollation = l7Collation l7
                 , connectLanguage = T.unpack $ l7Language l7
                 , connectAppName = T.unpack $ l7AppName l7
                 , connectServerName = T.unpack $ l7ServerName l7
                 }


data Connection = Connection Socket Word32


connect :: ConnectInfo -> IO Connection
connect ci@(ConnectInfo host port _ _ _ encrypt ps _ _ _ _ _ _ _ _ _) = do
  addr <- resolve host port
  sock <- connect' addr

  Prelogin plResOpts <- performPrelogin sock ps encrypt

  PLOEncryption modeEnc:_  <- case filter isPLOEncryption plResOpts of
                                [] -> throwIO $ ProtocolError "connect: PLOEncryption is necessary"
                                xs -> return xs
  PLOMars modeMars:_ <- case filter isPLOMars plResOpts of
                          [] -> throwIO $ ProtocolError "connect: PLOMars is necessary"
                          xs -> return xs
  when (modeEnc/=encrypt)  $ throwIO $ ProtocolError "connect: Server reported unsupported encryption mode"
  when (modeMars/=0) $ throwIO $ ProtocolError "connect: Server reported unsupported mars mode"

  login7 <- newLogin7 ci

  tss <- case encrypt of
    0x00 -> do
      ---
      --- TLS handshake
      ---
      tlsContext <- contextNew sock host
      TLS.handshake tlsContext

      --- 
      --- Login with encrypted packet
      --- 
      TLS.sendData tlsContext $ Put.runPut $ putClientMessage ps $ CMLogin7 login7
      readMessage sock $ Get.runGetIncremental getServerMessage
    0x02 -> do
      --- 
      --- Login without encryipted packet
      --- 
      sendAll sock $ Put.runPut $ putClientMessage ps $ CMLogin7 login7
      readMessage sock $ Get.runGetIncremental getServerMessage

  --- 
  --- Verify Ack
  --- 
  validLoginAck login7 tss

  return $ Connection sock ps



connectWithoutEncryption :: ConnectInfo -> IO Connection
connectWithoutEncryption ci = connect $ ci {connectEncryption = 0x02}


close :: Connection -> IO ()
close (Connection sock _ ) = Socket.close sock




performPrelogin :: Socket -> Word32 -> Word8 -> IO Prelogin
performPrelogin sock ps enc = do
  -- https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/60f56408-0188-4cd5-8b90-25c6f2423868
  --
  -- Prelogin
  --
  -- [TODO] Threadid support
  -- [TODO] Mars support
  let clientPrelogin = Prelogin [ PLOVersion 8 0 341 0
                                , PLOEncryption enc
                                , PLOInstopt "MSSQLServer"
                                , PLOThreadid (Just 1000) -- [TODO]
                                , PLOMars 0 -- [TODO]
                                ]
  sendAll sock $ Put.runPut $ putClientMessage ps $ CMPrelogin clientPrelogin
  serverPrelogin <- readMessage sock $ Get.runGetIncremental getServerMessage

  return serverPrelogin



newLogin7 :: ConnectInfo -> IO Login7
newLogin7 (ConnectInfo _ _ database user pass _ _ optf1 optf2 optf3 typef tz coll lang app serv) = do
  ---
  --- Login7
  ---
  -- [TODO] process ID support
  -- [TODO] MAC address support
  hostname <- getHostName
  let login7 = defaultLogin7 { l7ClientProgVer = 1
                             , l7OptionFlags1 = optf1
                             , l7OptionFlags2 = optf2
                             , l7OptionFlags3 = optf3
                             , l7TypeFlags = typef
                             , l7TimeZone = tz
                             , l7Collation = coll
                             , l7CltIntName = T.pack "mssql-simple"
                             , l7Language = T.pack lang
                             , l7ClientPID = 1 -- [TODO]
                             , l7ClientMacAddr = B.pack [0x00,0x00,0x00,0x00,0x00,0x00] -- [TODO]
                             , l7ClientHostName = T.pack hostname
                             , l7AppName = T.pack app
                             , l7ServerName = T.pack serv
                             , l7UserName = T.pack user
                             , l7Password = T.pack pass
                             , l7Database = T.pack database
                             }
  return login7



validLoginAck :: Login7 -> TokenStreams -> IO ()
validLoginAck login7 (TokenStreams loginResTokenStreams) = do

  let loginAcks   = filter isTSLoginAck loginResTokenStreams
  when (null loginAcks) $ do
    TSError info:_ <- case filter isTSError loginResTokenStreams of
                        [] -> throwIO $ ProtocolError "validLoginAck: TSError is necessary"
                        xs -> return xs
    throwIO $ AuthError info

  let (TSLoginAck _ tdsVersion' _ _):_ = loginAcks
  when (tdsVersion /= tdsVersion') $ throwIO $ ProtocolError "validLoginAck: Server reported unsupported tds version"

  return ()
  where

    isTSLoginAck :: TokenStream -> Bool
    isTSLoginAck (TSLoginAck{}) = True
    isTSLoginAck _ = False

    isTSError :: TokenStream -> Bool
    isTSError (TSError{}) = True
    isTSError _ = False

    printEnvChange :: TokenStream -> IO ()
    printEnvChange (TSEnvChange t o n) = do
      putStr "TSEnvChange: "
      case t of
        1 -> T.putStr $ "Database: "   <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
        2 -> T.putStr $ "Language: "   <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
        3 -> T.putStr $ "Charset: "    <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
        4 -> T.putStr $ "PacketSize: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
        5 -> T.putStr $ "DSLID:      " <> T.decodeUtf16LE n
        6 -> T.putStr $ "DSCFlags: "   <> T.decodeUtf16LE n
        7 -> putStr $ "Collaction: " <> show o <> " -> " <> show n
        8 -> putStr $ "BeginTran: " <> show n
        9 -> putStr $ "CommitTran: " <> show o
        10 -> putStr $ "RollbackTran: " <> show o
        11 -> putStr $ "EnlistDTCTran: " <> show o
        12 -> putStr $ "DefactTran: " <> show n
        13 -> T.putStr $ "MirrorPartner: " <> T.decodeUtf16LE n
        15 -> putStr $ "PromoteTran: " <> show n
        16 -> putStr $ "TranManAddr: " <> show n
        17 -> putStr $ "TranEndedr: " <> show o
        18 -> putStr $ "ResetAck: "
        19 -> T.putStr $ "SendsBackInfo: " <> T.decodeUtf16LE n
        20 -> putStr $ "Routing: " <> show n
      putStrLn mempty



isPLOEncryption :: PreloginOption -> Bool
isPLOEncryption (PLOEncryption{}) = True
isPLOEncryption _ = False

isPLOMars :: PreloginOption -> Bool
isPLOMars (PLOMars{}) = True
isPLOMars _ = False

resolve host port = do
  let hints = Socket.defaultHints { addrSocketType = Stream }
  addr:_ <- Socket.getAddrInfo (Just hints) (Just host) (Just port)
  return addr

connect' addr = do
  sock <- Socket.socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
  Socket.connect sock $ addrAddress addr
  return sock

readMessage :: Socket -> Get.Decoder a -> IO a
readMessage sock decoder = do
  bs <- recv sock 512 -- [TODO] optimize
  case Get.pushChunk decoder bs of
    Get.Done _ _ msg -> return msg
    decoder' -> readMessage sock decoder'




-- $use
-- 'connect' and 'close' function could be used as follows.
--
-- > {-# LANGUAGE OverloadedStrings #-}
-- > module Main where
-- >
-- > import Network.Socket (withSocketsDo)
-- > import Control.Exception (bracket)
-- >
-- > import Database.MSSQLServer.Connection
-- > import Database.MSSQLServer.Query
-- >
-- > main :: IO ()
-- > main = do
-- >   let info = defaultConnectInfo { connectHost = "192.168.0.1"
-- >                                 , connectPort = "1433"
-- >                                 , connectDatabase = "some_database"
-- >                                 , connectUser = "some_user"
-- >                                 , connectPassword = "some_password"
-- >                                 }
-- >   withSocketsDo $
-- >     bracket (connect info) close $ \conn -> do
-- >     rs <- sql conn "SELECT 2 + 2" :: IO [Only Int]
-- >     print rs