{-# language CPP #-}

module Network.HaskellNet.Auth
where

import Crypto.Hash.MD5
import Data.Text.Encoding.Base64 as B64

#if MIN_VERSION_base64(0,5,0)
import Data.Base64.Types as B64
#endif

import Data.Word
import Data.List
import Data.Bits
import Data.Array
import qualified Data.ByteString as B
import qualified Data.Text as T

type UserName = String
type Password = String

-- | Authorization types supported by the <https://www.ietf.org/rfc/rfc4954.txt RFC5954>
data AuthType = PLAIN
              | LOGIN
              | CRAM_MD5
              | XOAUTH2
                deriving AuthType -> AuthType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthType -> AuthType -> Bool
$c/= :: AuthType -> AuthType -> Bool
== :: AuthType -> AuthType -> Bool
$c== :: AuthType -> AuthType -> Bool
Eq

instance Show AuthType where
    showsPrec :: Int -> AuthType -> ShowS
showsPrec Int
d AuthType
at = Bool -> ShowS -> ShowS
showParen (Int
dforall a. Ord a => a -> a -> Bool
>Int
app_prec) forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString forall a b. (a -> b) -> a -> b
$ AuthType -> String
showMain AuthType
at
        where app_prec :: Int
app_prec = Int
10
              showMain :: AuthType -> String
showMain AuthType
PLAIN    = String
"PLAIN"
              showMain AuthType
LOGIN    = String
"LOGIN"
              showMain AuthType
CRAM_MD5 = String
"CRAM-MD5"
              showMain AuthType
XOAUTH2  = String
"XOAUTH2"

b64Encode :: String -> String
b64Encode :: ShowS
b64Encode = Text -> String
T.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
encode forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack
    where encode :: Text -> Text
encode =
#if MIN_VERSION_base64(0,5,0)
              B64.extractBase64 . B64.encodeBase64
#else
              Text -> Text
B64.encodeBase64
#endif



b64Decode :: String -> String
b64Decode :: ShowS
b64Decode = Text -> String
T.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
decode forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack
    where decode :: Text -> Text
decode =
#if MIN_VERSION_base64(0,5,0)
              B64.decodeBase64Lenient . B64.assertBase64
#else
              Text -> Text
B64.decodeBase64Lenient
#endif

showOctet :: [Word8] -> String
showOctet :: [Word8] -> String
showOctet = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Word8 -> String
hexChars
    where hexChars :: Word8 -> String
hexChars Word8
c = [Array Word8 Char
arr forall i e. Ix i => Array i e -> i -> e
! (Word8
c forall a. Integral a => a -> a -> a
`div` Word8
16), Array Word8 Char
arr forall i e. Ix i => Array i e -> i -> e
! (Word8
c forall a. Integral a => a -> a -> a
`mod` Word8
16)]
          arr :: Array Word8 Char
arr = forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Word8
0, Word8
15) String
"0123456789abcdef"

hashMD5 :: [Word8] -> [Word8]
hashMD5 :: [Word8] -> [Word8]
hashMD5 = ByteString -> [Word8]
B.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
hash forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
B.pack

hmacMD5 :: String -> String -> [Word8]
hmacMD5 :: String -> String -> [Word8]
hmacMD5 String
text String
key = [Word8] -> [Word8]
hashMD5 forall a b. (a -> b) -> a -> b
$ [Word8]
okey forall a. [a] -> [a] -> [a]
++ [Word8] -> [Word8]
hashMD5 ([Word8]
ikey forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Enum a => Int -> a
toEnumforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a. Enum a => a -> Int
fromEnum) String
text)
    where koc :: [Word8]
koc = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Enum a => Int -> a
toEnumforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a. Enum a => a -> Int
fromEnum) String
key
          key' :: [Word8]
key' = if forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
koc forall a. Ord a => a -> a -> Bool
> Int
64
                 then [Word8] -> [Word8]
hashMD5 forall a b. (a -> b) -> a -> b
$ [Word8]
koc forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
48 Word8
0
                 else [Word8]
koc forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
64forall a. Num a => a -> a -> a
-forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
koc) Word8
0
          ipad :: [Word8]
ipad = forall a. Int -> a -> [a]
replicate Int
64 Word8
0x36
          opad :: [Word8]
opad = forall a. Int -> a -> [a]
replicate Int
64 Word8
0x5c
          ikey :: [Word8]
ikey = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor [Word8]
key' [Word8]
ipad
          okey :: [Word8]
okey = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor [Word8]
key' [Word8]
opad

plain :: UserName -> Password -> String
plain :: String -> ShowS
plain String
user String
pass = ShowS
b64Encode forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [[a]] -> [a]
intercalate String
"\0" [String
"", String
user, String
pass]

login :: UserName -> Password -> (String, String)
login :: String -> String -> (String, String)
login String
user String
pass = (ShowS
b64Encode String
user, ShowS
b64Encode String
pass)

cramMD5 :: String -> UserName -> Password -> String
cramMD5 :: String -> String -> ShowS
cramMD5 String
challenge String
user String
pass =
    ShowS
b64Encode (String
user forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ [Word8] -> String
showOctet (String -> String -> [Word8]
hmacMD5 String
challenge String
pass))

auth :: AuthType -> String -> UserName -> Password -> String
auth :: AuthType -> String -> String -> ShowS
auth AuthType
PLAIN    String
_ String
u String
p = String -> ShowS
plain String
u String
p
auth AuthType
LOGIN    String
_ String
u String
p = let (String
u', String
p') = String -> String -> (String, String)
login String
u String
p in [String] -> String
unwords [String
u', String
p']
auth AuthType
CRAM_MD5 String
c String
u String
p = String -> String -> ShowS
cramMD5 String
c String
u String
p
auth AuthType
XOAUTH2  String
_ String
u String
p = ShowS
b64Encode forall a b. (a -> b) -> a -> b
$ String
"user=" forall a. [a] -> [a] -> [a]
++ String
u forall a. [a] -> [a] -> [a]
++ String
"\001auth=" forall a. [a] -> [a] -> [a]
++ String
p forall a. [a] -> [a] -> [a]
++ String
"\001\001"