{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}

{-|
Module      : Database.MySQL.Protocol.Auth
Description : MySQL Auth Packets
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

Auth related packet.

-}

module Database.MySQL.Protocol.Auth where

import           Control.Applicative
import           Control.Monad
import           Data.Binary
import           Data.Binary.Get
import           Data.Binary.Parser
import           Data.Binary.Put
import qualified Data.ByteString                as B
import           Data.ByteString.Char8          as BC
import           Data.Bits
import           Database.MySQL.Protocol.Packet

--------------------------------------------------------------------------------
-- Authentications

#define CLIENT_LONG_PASSWORD                  0x00000001
#define CLIENT_FOUND_ROWS                     0x00000002
#define CLIENT_LONG_FLAG                      0x00000004
#define CLIENT_CONNECT_WITH_DB                0x00000008
#define CLIENT_NO_SCHEMA                      0x00000010
#define CLIENT_COMPRESS                       0x00000020
#define CLIENT_ODBC                           0x00000040
#define CLIENT_LOCAL_FILES                    0x00000080
#define CLIENT_IGNORE_SPACE                   0x00000100
#define CLIENT_PROTOCOL_41                    0x00000200
#define CLIENT_INTERACTIVE                    0x00000400
#define CLIENT_SSL                            0x00000800
#define CLIENT_IGNORE_SIGPIPE                 0x00001000
#define CLIENT_TRANSACTIONS                   0x00002000
#define CLIENT_RESERVED                       0x00004000
#define CLIENT_SECURE_CONNECTION              0x00008000
#define CLIENT_MULTI_STATEMENTS               0x00010000
#define CLIENT_MULTI_RESULTS                  0x00020000
#define CLIENT_PS_MULTI_RESULTS               0x00040000
#define CLIENT_PLUGIN_AUTH                    0x00080000
#define CLIENT_CONNECT_ATTRS                  0x00100000
#define CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 0x00200000

data Greeting = Greeting
    { Greeting -> Word8
greetingProtocol :: !Word8
    , Greeting -> ByteString
greetingVersion  :: !B.ByteString
    , Greeting -> Word32
greetingConnId   :: !Word32
    , Greeting -> ByteString
greetingSalt1    :: !B.ByteString
    , Greeting -> Word32
greetingCaps     :: !Word32
    , Greeting -> Word8
greetingCharset  :: !Word8
    , Greeting -> Word16
greetingStatus   :: !Word16
    , Greeting -> ByteString
greetingSalt2    :: !B.ByteString
    , Greeting -> ByteString
greetingAuthPlugin :: !B.ByteString
    } deriving (Int -> Greeting -> ShowS
[Greeting] -> ShowS
Greeting -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Greeting] -> ShowS
$cshowList :: [Greeting] -> ShowS
show :: Greeting -> String
$cshow :: Greeting -> String
showsPrec :: Int -> Greeting -> ShowS
$cshowsPrec :: Int -> Greeting -> ShowS
Show, Greeting -> Greeting -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Greeting -> Greeting -> Bool
$c/= :: Greeting -> Greeting -> Bool
== :: Greeting -> Greeting -> Bool
$c== :: Greeting -> Greeting -> Bool
Eq)

putGreeting :: Greeting -> Put
putGreeting :: Greeting -> Put
putGreeting (Greeting Word8
pv ByteString
sv Word32
cid ByteString
salt1 Word32
cap Word8
charset Word16
st ByteString
salt2 ByteString
authPlugin) = do
    Word8 -> Put
putWord8 Word8
pv
    ByteString -> Put
putByteString ByteString
sv
    Word8 -> Put
putWord8 Word8
0x00
    Word32 -> Put
putWord32le Word32
cid
    ByteString -> Put
putByteString ByteString
salt1
    let capL :: Word16
capL = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
cap forall a. Bits a => a -> a -> a
.|. Word16
0xFF
        capH :: Word16
capH = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
cap forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.|. Word16
0xFF
    Word16 -> Put
putWord16le Word16
capL
    Word8 -> Put
putWord8 Word8
charset
    Word16 -> Put
putWord16le Word16
st
    Word16 -> Put
putWord16le Word16
capH
    Word8 -> Put
putWord8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
salt2)
    forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
10 (Word8 -> Put
putWord8 Word8
0x00)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
cap forall a. Bits a => a -> a -> a
.&. CLIENT_SECURE_CONNECTION /= 0)
        (ByteString -> Put
putByteString ByteString
salt2)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
cap forall a. Bits a => a -> a -> a
.&. CLIENT_PLUGIN_AUTH /= 0)
        (ByteString -> Put
putByteString ByteString
authPlugin)

getGreeting :: Get Greeting
getGreeting :: Get Greeting
getGreeting = do
    Word8
pv <- Get Word8
getWord8
    ByteString
sv <- Get ByteString
getByteStringNul
    Word32
cid <- Get Word32
getWord32le
    ByteString
salt1 <- Int -> Get ByteString
getByteString Int
8
    Int -> Get ()
skipN Int
1  -- 0x00
    Word16
capL <- Get Word16
getWord16le
    Word8
charset <- Get Word8
getWord8
    Word16
status <- Get Word16
getWord16le
    Word16
capH <- Get Word16
getWord16le
    let cap :: Word32
cap = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
capH forall a. Bits a => a -> Int -> a
`shiftL` Int
16 forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
capL
    Word8
authPluginLen <- Get Word8
getWord8   -- this will issue an unused warning, see the notes below
    Int -> Get ()
skipN Int
10 -- 10 * 0x00
    ByteString
salt2 <- if (Word32
cap forall a. Bits a => a -> a -> a
.&. CLIENT_SECURE_CONNECTION) == 0
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
B.empty
        else Get ByteString
getByteStringNul   -- This is different with the MySQL document here
                                -- The doc said we should expect a MAX(13, length of auth-plugin-data - 8)
                                -- length bytes, but doing so stop us from login
                                -- anyway 'getByteStringNul' works perfectly here.

    ByteString
authPlugin <- if (Word32
cap forall a. Bits a => a -> a -> a
.&. CLIENT_PLUGIN_AUTH) == 0
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
B.empty
        else Get ByteString
getByteStringNul

    forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
-> ByteString
-> Word32
-> ByteString
-> Word32
-> Word8
-> Word16
-> ByteString
-> ByteString
-> Greeting
Greeting Word8
pv ByteString
sv Word32
cid ByteString
salt1 Word32
cap Word8
charset Word16
status ByteString
salt2 ByteString
authPlugin)

instance Binary Greeting where
    get :: Get Greeting
get = Get Greeting
getGreeting
    put :: Greeting -> Put
put = Greeting -> Put
putGreeting

data Auth = Auth
    { Auth -> Word32
authCaps      :: !Word32
    , Auth -> Word32
authMaxPacket :: !Word32
    , Auth -> Word8
authCharset   :: !Word8
    , Auth -> ByteString
authName      :: !ByteString
    , Auth -> ByteString
authPassword  :: !ByteString
    , Auth -> ByteString
authSchema    :: !ByteString
    } deriving (Int -> Auth -> ShowS
[Auth] -> ShowS
Auth -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Auth] -> ShowS
$cshowList :: [Auth] -> ShowS
show :: Auth -> String
$cshow :: Auth -> String
showsPrec :: Int -> Auth -> ShowS
$cshowsPrec :: Int -> Auth -> ShowS
Show, Auth -> Auth -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Auth -> Auth -> Bool
$c/= :: Auth -> Auth -> Bool
== :: Auth -> Auth -> Bool
$c== :: Auth -> Auth -> Bool
Eq)

getAuth :: Get Auth
getAuth :: Get Auth
getAuth = do
    Word32
a <- Get Word32
getWord32le
    Word32
m <- Get Word32
getWord32le
    Word8
c <- Get Word8
getWord8
    Int -> Get ()
skipN Int
23
    ByteString
n <- Get ByteString
getByteStringNul
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Word32
-> Word32
-> Word8
-> ByteString
-> ByteString
-> ByteString
-> Auth
Auth Word32
a Word32
m Word8
c ByteString
n ByteString
B.empty ByteString
B.empty

putAuth :: Auth -> Put
putAuth :: Auth -> Put
putAuth (Auth Word32
cap Word32
m Word8
c ByteString
n ByteString
p ByteString
s) = do
    Word32 -> Put
putWord32le Word32
cap
    Word32 -> Put
putWord32le Word32
m
    Word8 -> Put
putWord8 Word8
c
    forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
23 (Word8 -> Put
putWord8 Word8
0x00)
    ByteString -> Put
putByteString ByteString
n forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word8 -> Put
putWord8 Word8
0x00
    Word8 -> Put
putWord8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
p)
    ByteString -> Put
putByteString ByteString
p
    ByteString -> Put
putByteString ByteString
s
    Word8 -> Put
putWord8 Word8
0x00

instance Binary Auth where
    get :: Get Auth
get = Get Auth
getAuth
    put :: Auth -> Put
put = Auth -> Put
putAuth

data SSLRequest = SSLRequest
    { SSLRequest -> Word32
sslReqCaps      :: !Word32
    , SSLRequest -> Word32
sslReqMaxPacket :: !Word32
    , SSLRequest -> Word8
sslReqCharset   :: !Word8
    } deriving (Int -> SSLRequest -> ShowS
[SSLRequest] -> ShowS
SSLRequest -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SSLRequest] -> ShowS
$cshowList :: [SSLRequest] -> ShowS
show :: SSLRequest -> String
$cshow :: SSLRequest -> String
showsPrec :: Int -> SSLRequest -> ShowS
$cshowsPrec :: Int -> SSLRequest -> ShowS
Show, SSLRequest -> SSLRequest -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SSLRequest -> SSLRequest -> Bool
$c/= :: SSLRequest -> SSLRequest -> Bool
== :: SSLRequest -> SSLRequest -> Bool
$c== :: SSLRequest -> SSLRequest -> Bool
Eq)

getSSLRequest :: Get SSLRequest
getSSLRequest :: Get SSLRequest
getSSLRequest = Word32 -> Word32 -> Word8 -> SSLRequest
SSLRequest forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
getWord32le forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Word32
getWord32le forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Word8
getWord8 forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> Get ()
skipN Int
23

putSSLRequest :: SSLRequest -> Put
putSSLRequest :: SSLRequest -> Put
putSSLRequest (SSLRequest Word32
cap Word32
m Word8
c) = do
    Word32 -> Put
putWord32le Word32
cap
    Word32 -> Put
putWord32le Word32
m
    Word8 -> Put
putWord8 Word8
c
    forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
23 (Word8 -> Put
putWord8 Word8
0x00)

instance Binary SSLRequest where
    get :: Get SSLRequest
get = Get SSLRequest
getSSLRequest
    put :: SSLRequest -> Put
put = SSLRequest -> Put
putSSLRequest

--------------------------------------------------------------------------------
-- default Capability Flags

clientCap :: Word32
clientCap :: Word32
clientCap =  CLIENT_LONG_PASSWORD
                forall a. Bits a => a -> a -> a
.|. CLIENT_LONG_FLAG
                forall a. Bits a => a -> a -> a
.|. CLIENT_CONNECT_WITH_DB
                forall a. Bits a => a -> a -> a
.|. CLIENT_IGNORE_SPACE
                forall a. Bits a => a -> a -> a
.|. CLIENT_PROTOCOL_41
                forall a. Bits a => a -> a -> a
.|. CLIENT_TRANSACTIONS
                forall a. Bits a => a -> a -> a
.|. CLIENT_MULTI_STATEMENTS
                forall a. Bits a => a -> a -> a
.|. CLIENT_MULTI_RESULTS
                forall a. Bits a => a -> a -> a
.|. CLIENT_SECURE_CONNECTION

clientMaxPacketSize :: Word32
clientMaxPacketSize :: Word32
clientMaxPacketSize = Word32
0x00ffffff :: Word32


supportTLS :: Word32 -> Bool
supportTLS :: Word32 -> Bool
supportTLS Word32
x = (Word32
x forall a. Bits a => a -> a -> a
.&. Word32
CLIENT_SSL) forall a. Eq a => a -> a -> Bool
/= Word32
0

sslRequest :: Word8 -> SSLRequest
sslRequest :: Word8 -> SSLRequest
sslRequest Word8
charset = Word32 -> Word32 -> Word8 -> SSLRequest
SSLRequest (Word32
clientCap forall a. Bits a => a -> a -> a
.|. Word32
CLIENT_SSL) Word32
clientMaxPacketSize Word8
charset