module Network.HaskellNet.Auth
where

import Crypto.Hash.MD5
import Data.Text.Encoding.Base64 as B64

import Data.Word
import Data.List
import Data.Bits
import Data.Array
import qualified Data.ByteString as B
import qualified Data.Text as T

type UserName = String
type Password = String

-- | Authorization types supported by the <https://www.ietf.org/rfc/rfc4954.txt RFC5954>
data AuthType = PLAIN
              | LOGIN
              | CRAM_MD5
                deriving AuthType -> AuthType -> Bool
(AuthType -> AuthType -> Bool)
-> (AuthType -> AuthType -> Bool) -> Eq AuthType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthType -> AuthType -> Bool
$c/= :: AuthType -> AuthType -> Bool
== :: AuthType -> AuthType -> Bool
$c== :: AuthType -> AuthType -> Bool
Eq

instance Show AuthType where
    showsPrec :: Int -> AuthType -> ShowS
showsPrec Int
d AuthType
at = Bool -> ShowS -> ShowS
showParen (Int
dInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
app_prec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString (String -> ShowS) -> String -> ShowS
forall a b. (a -> b) -> a -> b
$ AuthType -> String
showMain AuthType
at
        where app_prec :: Int
app_prec = Int
10
              showMain :: AuthType -> String
showMain AuthType
PLAIN    = String
"PLAIN"
              showMain AuthType
LOGIN    = String
"LOGIN"
              showMain AuthType
CRAM_MD5 = String
"CRAM-MD5"

b64Encode :: String -> String
b64Encode :: ShowS
b64Encode = Text -> String
T.unpack (Text -> String) -> (String -> Text) -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
B64.encodeBase64 (Text -> Text) -> (String -> Text) -> String -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack

b64Decode :: String -> String
b64Decode :: ShowS
b64Decode = Text -> String
T.unpack (Text -> String) -> (String -> Text) -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
B64.decodeBase64Lenient (Text -> Text) -> (String -> Text) -> String -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack

showOctet :: [Word8] -> String
showOctet :: [Word8] -> String
showOctet = (Word8 -> String) -> [Word8] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Word8 -> String
hexChars
    where hexChars :: Word8 -> String
hexChars Word8
c = [Array Word8 Char
arr Array Word8 Char -> Word8 -> Char
forall i e. Ix i => Array i e -> i -> e
! (Word8
c Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`div` Word8
16), Array Word8 Char
arr Array Word8 Char -> Word8 -> Char
forall i e. Ix i => Array i e -> i -> e
! (Word8
c Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`mod` Word8
16)]
          arr :: Array Word8 Char
arr = (Word8, Word8) -> String -> Array Word8 Char
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Word8
0, Word8
15) String
"0123456789abcdef"

hashMD5 :: [Word8] -> [Word8]
hashMD5 :: [Word8] -> [Word8]
hashMD5 = ByteString -> [Word8]
B.unpack (ByteString -> [Word8])
-> ([Word8] -> ByteString) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
hash (ByteString -> ByteString)
-> ([Word8] -> ByteString) -> [Word8] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
B.pack

hmacMD5 :: String -> String -> [Word8]
hmacMD5 :: String -> String -> [Word8]
hmacMD5 String
text String
key = [Word8] -> [Word8]
hashMD5 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Word8]
okey [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8] -> [Word8]
hashMD5 ([Word8]
ikey [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a. Enum a => Int -> a
toEnum(Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum) String
text)
    where koc :: [Word8]
koc = (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a. Enum a => Int -> a
toEnum(Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum) String
key
          key' :: [Word8]
key' = if [Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
koc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
64
                 then [Word8] -> [Word8]
hashMD5 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Word8]
koc [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
48 Word8
0
                 else [Word8]
koc [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate (Int
64Int -> Int -> Int
forall a. Num a => a -> a -> a
-[Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
koc) Word8
0
          ipad :: [Word8]
ipad = Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
64 Word8
0x36
          opad :: [Word8]
opad = Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
64 Word8
0x5c
          ikey :: [Word8]
ikey = (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor [Word8]
key' [Word8]
ipad
          okey :: [Word8]
okey = (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor [Word8]
key' [Word8]
opad

plain :: UserName -> Password -> String
plain :: String -> ShowS
plain String
user String
pass = ShowS
b64Encode ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\0" [String
"", String
user, String
pass]

login :: UserName -> Password -> (String, String)
login :: String -> String -> (String, String)
login String
user String
pass = (ShowS
b64Encode String
user, ShowS
b64Encode String
pass)

cramMD5 :: String -> UserName -> Password -> String
cramMD5 :: String -> String -> ShowS
cramMD5 String
challenge String
user String
pass =
    ShowS
b64Encode (String
user String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Word8] -> String
showOctet (String -> String -> [Word8]
hmacMD5 String
challenge String
pass))

auth :: AuthType -> String -> UserName -> Password -> String
auth :: AuthType -> String -> String -> ShowS
auth AuthType
PLAIN    String
_ String
u String
p = String -> ShowS
plain String
u String
p
auth AuthType
LOGIN    String
_ String
u String
p = let (String
u', String
p') = String -> String -> (String, String)
login String
u String
p in [String] -> String
unwords [String
u', String
p']
auth AuthType
CRAM_MD5 String
c String
u String
p = String -> String -> ShowS
cramMD5 String
c String
u String
p