{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-|
Module      : Data.Password.Internal
Copyright   : (c) Dennis Gosnell, 2019; Felix Paulusma, 2020
License     : BSD-style (see LICENSE file)
Maintainer  : cdep.illabout@gmail.com
Stability   : experimental
Portability : POSIX
-}

module Data.Password.Internal (
  -- * Global types
    PasswordCheck(..)
  , newSalt
  -- * Utility
  , toBytes
  , fromBytes
  , from64
  , unsafePad64
  , unsafeRemovePad64
  , readT
  , showT
  ) where

import Control.Monad.IO.Class (MonadIO(liftIO))
import Crypto.Random (getRandomBytes)
import Data.ByteArray (Bytes, convert)
import Data.ByteString (ByteString)
import Data.ByteString.Base64 (decodeBase64)
#if !MIN_VERSION_base(4,13,0)
import Data.Semigroup ((<>))
#endif
import Data.Text as T (
    Text,
    dropEnd,
    length,
    pack,
    replicate,
    unpack,
 )
import Data.Password.Types (Salt(..))
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Text.Read (readMaybe)


-- | Generate a random x-byte-long salt.
--
-- @since 2.0.0.0
newSalt :: MonadIO m => Int -> m (Salt a)
newSalt :: Int -> m (Salt a)
newSalt Int
i = IO (Salt a) -> m (Salt a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Salt a) -> m (Salt a)) -> IO (Salt a) -> m (Salt a)
forall a b. (a -> b) -> a -> b
$ ByteString -> Salt a
forall a. ByteString -> Salt a
Salt (ByteString -> Salt a) -> IO ByteString -> IO (Salt a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
i
{-# INLINE newSalt #-}

-- | The result of checking a password against a hashed version. This is
-- returned by the @checkPassword@ functions.
data PasswordCheck
  = PasswordCheckSuccess
  -- ^ The password check was successful. The plain-text password matches the
  -- hashed password.
  | PasswordCheckFail
  -- ^ The password check failed. The plain-text password does not match the
  -- hashed password.
  deriving (PasswordCheck -> PasswordCheck -> Bool
(PasswordCheck -> PasswordCheck -> Bool)
-> (PasswordCheck -> PasswordCheck -> Bool) -> Eq PasswordCheck
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PasswordCheck -> PasswordCheck -> Bool
$c/= :: PasswordCheck -> PasswordCheck -> Bool
== :: PasswordCheck -> PasswordCheck -> Bool
$c== :: PasswordCheck -> PasswordCheck -> Bool
Eq, ReadPrec [PasswordCheck]
ReadPrec PasswordCheck
Int -> ReadS PasswordCheck
ReadS [PasswordCheck]
(Int -> ReadS PasswordCheck)
-> ReadS [PasswordCheck]
-> ReadPrec PasswordCheck
-> ReadPrec [PasswordCheck]
-> Read PasswordCheck
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [PasswordCheck]
$creadListPrec :: ReadPrec [PasswordCheck]
readPrec :: ReadPrec PasswordCheck
$creadPrec :: ReadPrec PasswordCheck
readList :: ReadS [PasswordCheck]
$creadList :: ReadS [PasswordCheck]
readsPrec :: Int -> ReadS PasswordCheck
$creadsPrec :: Int -> ReadS PasswordCheck
Read, Int -> PasswordCheck -> ShowS
[PasswordCheck] -> ShowS
PasswordCheck -> String
(Int -> PasswordCheck -> ShowS)
-> (PasswordCheck -> String)
-> ([PasswordCheck] -> ShowS)
-> Show PasswordCheck
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PasswordCheck] -> ShowS
$cshowList :: [PasswordCheck] -> ShowS
show :: PasswordCheck -> String
$cshow :: PasswordCheck -> String
showsPrec :: Int -> PasswordCheck -> ShowS
$cshowsPrec :: Int -> PasswordCheck -> ShowS
Show)

-- | Converting 'Text' to 'Bytes'
toBytes :: Text -> Bytes
toBytes :: Text -> Bytes
toBytes = ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ByteString -> Bytes) -> (Text -> ByteString) -> Text -> Bytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8
{-# INLINE toBytes #-}

-- | Converting 'Bytes' to 'Text'
fromBytes :: Bytes -> Text
fromBytes :: Bytes -> Text
fromBytes = ByteString -> Text
decodeUtf8 (ByteString -> Text) -> (Bytes -> ByteString) -> Bytes -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert
{-# INLINE fromBytes #-}

-- | Decodes a base64 'Text' to a regular 'ByteString' (if possible)
from64 :: Text -> Maybe ByteString
from64 :: Text -> Maybe ByteString
from64 = Either Text ByteString -> Maybe ByteString
forall b a. Either b a -> Maybe a
toMaybe (Either Text ByteString -> Maybe ByteString)
-> (Text -> Either Text ByteString) -> Text -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either Text ByteString
decodeBase64 (ByteString -> Either Text ByteString)
-> (Text -> ByteString) -> Text -> Either Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8
  where
    toMaybe :: Either b a -> Maybe a
toMaybe = (b -> Maybe a) -> (a -> Maybe a) -> Either b a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> b -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just
{-# INLINE from64 #-}

-- | Same as 'read' but works on 'Text'
readT :: Read a => Text -> Maybe a
readT :: Text -> Maybe a
readT = String -> Maybe a
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe a) -> (Text -> String) -> Text -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack
{-# INLINE readT #-}

-- | Same as 'show' but works on 'Text'
showT :: Show a => a -> Text
showT :: a -> Text
showT = String -> Text
T.pack (String -> Text) -> (a -> String) -> a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> String
forall a. Show a => a -> String
show
{-# INLINE showT #-}

-- | (UNSAFE) Pad a base64 text to "length `rem` 4 == 0" with "="
unsafePad64 :: Text -> Text
unsafePad64 :: Text -> Text
unsafePad64 Text
t
    | Int
remains Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Text
t
    | Bool
otherwise = Text
t Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
pad
  where
    remains :: Int
remains = Text -> Int
T.length Text
t Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
4
    pad :: Text
pad = Int -> Text -> Text
T.replicate (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
remains) Text
"="

-- | (UNSAFE) Removes the "=" padding from a base64 text
-- given the length of the original bytestring.
unsafeRemovePad64 :: Int -> Text -> Text
unsafeRemovePad64 :: Int -> Text -> Text
unsafeRemovePad64 Int
bsLen = Int -> Text -> Text
T.dropEnd Int
drops
  where
    drops :: Int
drops = case Int
bsLen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
3 of
        -- 1 extra byte results in 2 characters (4 - 2 = 2)
        Int
1 -> Int
2
        -- 2 extra bytes results in 3 characters (4 - 3 = 1)
        Int
2 -> Int
1
        -- This will just be 0
        Int
other -> Int
other