{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-|
Module      : Data.Password.Scrypt
Copyright   : (c) Dennis Gosnell, 2019; Felix Paulusma, 2020
License     : BSD-style (see LICENSE file)
Maintainer  : cdep.illabout@gmail.com
Stability   : experimental
Portability : POSIX

= scrypt

The @scrypt@ algorithm is a fairly new one. First published
in 2009, but published by the IETF in 2016 as <https://tools.ietf.org/html/rfc7914 RFC 7914>.
Originally used for the Tarsnap backup service, it is
designed to be costly by requiring large amounts of memory.

== Other algorithms

@scrypt@ does increase the memory requirement in contrast to
@'Data.Password.Bcrypt.Bcrypt'@ and @'Data.Password.PBKDF2.PBKDF2'@, but it
turns out it is not as optimal as it could be, and thus others have set out
to search for other algorithms that do fulfill on their promises.
@'Data.Password.Argon2.Argon2'@ seems to be the winner in that search.

That is not to say using @scrypt@ somehow means your passwords
won't be properly protected. The cryptography is sound and
thus is fine for protection against brute-force attacks.
Because of the memory cost, it is generally advised to use
@'Data.Password.Bcrypt.Bcrypt'@ if you're not sure this might be a
problem on your system.
-}

module Data.Password.Scrypt (
  -- * Algorithm
  Scrypt
  -- * Plain-text Password
  , Password
  , mkPassword
  -- * Hash Passwords (scrypt)
  , hashPassword
  , PasswordHash(..)
  -- * Verify Passwords (scrypt)
  , checkPassword
  , PasswordCheck(..)
  -- * Hashing Manually (scrypt)
  , hashPasswordWithParams
  , defaultParams
  , ScryptParams(..)
  -- ** Hashing with salt (DISADVISED)
  --
  -- | Hashing with a set 'Salt' is almost never what you want
  -- to do. Use 'hashPassword' or 'hashPasswordWithParams' to have
  -- automatic generation of randomized salts.
  , hashPasswordWithSalt
  , newSalt
  , Salt(..)
  -- * Unsafe debugging function to show a Password
  , unsafeShowPassword
  , -- * Setup for doctests.
    -- $setup
  ) where

import Control.Monad (guard)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Crypto.KDF.Scrypt as Scrypt
import Data.ByteArray (Bytes, constEq, convert)
import Data.ByteString (ByteString)
import Data.ByteString.Base64 (encodeBase64)
import qualified Data.ByteString.Char8 as C8 (length)
import Data.Maybe (fromMaybe)
import qualified Data.Text as T (intercalate, split)
import Data.Word (Word32)

import Data.Password (
         PasswordCheck(..)
       , PasswordHash(..)
       , Salt(..)
       , mkPassword
       , unsafeShowPassword
       )
import Data.Password.Internal (Password(..), from64, readT, showT, toBytes)
import qualified Data.Password.Internal (newSalt)

-- | Phantom type for __Argon2__
--
-- @since 2.0.0.0
data Scrypt

-- $setup
-- >>> :set -XFlexibleInstances
-- >>> :set -XOverloadedStrings
--
-- Import needed libraries.
--
-- >>> import Data.Password
-- >>> import Data.ByteString (pack)
-- >>> import Test.QuickCheck (Arbitrary(arbitrary), Blind(Blind), vector)
-- >>> import Test.QuickCheck.Instances.Text ()
--
-- >>> instance Arbitrary (Salt a) where arbitrary = Salt . pack <$> vector 32
-- >>> instance Arbitrary Password where arbitrary = fmap Password arbitrary
-- >>> let salt = Salt "abcdefghijklmnopqrstuvwxyz012345"
-- >>> let testParams = defaultParams {scryptRounds = 10}

-- -- >>> instance Arbitrary (PasswordHash Scrypt) where arbitrary = hashPasswordWithSalt testParams <$> arbitrary <*> arbitrary

-- | Hash the 'Password' using the 'Scrypt' hash algorithm
--
-- >>> hashPassword $ mkPassword "foobar"
-- PasswordHash {unPasswordHash = "14|8|1|...|..."}
hashPassword :: MonadIO m => Password -> m (PasswordHash Scrypt)
hashPassword :: Password -> m (PasswordHash Scrypt)
hashPassword = ScryptParams -> Password -> m (PasswordHash Scrypt)
forall (m :: * -> *).
MonadIO m =>
ScryptParams -> Password -> m (PasswordHash Scrypt)
hashPasswordWithParams ScryptParams
defaultParams

-- TODO: Add way to parse the following. From [https://hashcat.net/wiki/doku.php?id=example_hashes]
-- SCRYPT:1024:1:1:MDIwMzMwNTQwNDQyNQ==:5FW+zWivLxgCWj7qLiQbeC8zaNQ+qdO0NUinvqyFcfo=

-- | Parameters used in the 'Scrypt' hashing algorithm.
--
-- @since 2.0.0.0
data ScryptParams = ScryptParams {
  ScryptParams -> Word32
scryptSalt :: Word32,
  -- ^ Bytes to randomly generate as a unique salt, default is __32__
  ScryptParams -> Word32
scryptRounds :: Word32,
  -- ^ log2(N) rounds to hash, default is __14__ (i.e. 2^14 rounds)
  ScryptParams -> Word32
scryptBlockSize :: Word32,
  -- ^ Block size, default is __8__
  --
  -- Limits are min: @1@, and max: @scryptBlockSize * scryptParallelism < 2 ^ 30@
  ScryptParams -> Word32
scryptParallelism :: Word32,
  -- ^ Parallelism factor, default is __1__
  --
  -- Limits are min: @0@, and max: @scryptBlockSize * scryptParallelism < 2 ^ 30@
  ScryptParams -> Word32
scryptOutputLength :: Word32
  -- ^ Output key length in bytes, default is __64__
} deriving (ScryptParams -> ScryptParams -> Bool
(ScryptParams -> ScryptParams -> Bool)
-> (ScryptParams -> ScryptParams -> Bool) -> Eq ScryptParams
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScryptParams -> ScryptParams -> Bool
$c/= :: ScryptParams -> ScryptParams -> Bool
== :: ScryptParams -> ScryptParams -> Bool
$c== :: ScryptParams -> ScryptParams -> Bool
Eq, Int -> ScryptParams -> ShowS
[ScryptParams] -> ShowS
ScryptParams -> String
(Int -> ScryptParams -> ShowS)
-> (ScryptParams -> String)
-> ([ScryptParams] -> ShowS)
-> Show ScryptParams
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScryptParams] -> ShowS
$cshowList :: [ScryptParams] -> ShowS
show :: ScryptParams -> String
$cshow :: ScryptParams -> String
showsPrec :: Int -> ScryptParams -> ShowS
$cshowsPrec :: Int -> ScryptParams -> ShowS
Show)

-- | Default parameters for the 'Scrypt' algorithm.
--
-- >>> defaultParams
-- ScryptParams {scryptSalt = 32, scryptRounds = 14, scryptBlockSize = 8, scryptParallelism = 1, scryptOutputLength = 64}
--
-- @since 2.0.0.0
defaultParams :: ScryptParams
defaultParams :: ScryptParams
defaultParams = ScryptParams :: Word32 -> Word32 -> Word32 -> Word32 -> Word32 -> ScryptParams
ScryptParams {
  scryptSalt :: Word32
scryptSalt = Word32
32,
  scryptRounds :: Word32
scryptRounds = Word32
14,
  scryptBlockSize :: Word32
scryptBlockSize = Word32
8,
  scryptParallelism :: Word32
scryptParallelism = Word32
1,
  scryptOutputLength :: Word32
scryptOutputLength = Word32
64
}

-- | Hash a password with the given 'ScryptParams' and also with the given 'Salt'
-- instead of a randomly generated salt using 'scryptSalt' from 'ScryptParams'.
-- Using 'hashPasswordWithSalt' is strongly __disadvised__ and 'hashPasswordWithParams'
-- should be used instead. /Never use a static salt in production applications!/
--
-- The resulting 'PasswordHash' has the parameters used to hash it, as well as the
-- 'Salt' appended to it, separated by @|@.
--
-- The input 'Salt' and resulting 'PasswordHash' are both base64 encoded.
--
-- >>> let salt = Salt "abcdefghijklmnopqrstuvwxyz012345"
-- >>> hashPasswordWithSalt defaultParams salt (mkPassword "foobar")
-- PasswordHash {unPasswordHash = "14|8|1|YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXowMTIzNDU=|nENDaqWBmPKapAqQ3//H0iBImweGjoTqn5SvBS8Mc9FPFbzq6w65maYPZaO+SPamVZRXQjARQ8Y+5rhuDhjIhw=="}
--
-- (Note that we use an explicit 'Salt' in the example above.  This is so that the
-- example is reproducible, but in general you should use 'hashPassword'. 'hashPassword'
-- generates a new 'Salt' everytime it is called.)
hashPasswordWithSalt :: ScryptParams -> Salt Scrypt -> Password -> PasswordHash Scrypt
hashPasswordWithSalt :: ScryptParams -> Salt Scrypt -> Password -> PasswordHash Scrypt
hashPasswordWithSalt params :: ScryptParams
params@ScryptParams{Word32
scryptOutputLength :: Word32
scryptParallelism :: Word32
scryptBlockSize :: Word32
scryptRounds :: Word32
scryptSalt :: Word32
scryptOutputLength :: ScryptParams -> Word32
scryptParallelism :: ScryptParams -> Word32
scryptBlockSize :: ScryptParams -> Word32
scryptRounds :: ScryptParams -> Word32
scryptSalt :: ScryptParams -> Word32
..} s :: Salt Scrypt
s@(Salt ByteString
salt) Password
pass =
  Text -> PasswordHash Scrypt
forall a. Text -> PasswordHash a
PasswordHash (Text -> PasswordHash Scrypt) -> Text -> PasswordHash Scrypt
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
"|"
    [ Word32 -> Text
forall a. Show a => a -> Text
showT Word32
scryptRounds
    , Word32 -> Text
forall a. Show a => a -> Text
showT Word32
scryptBlockSize
    , Word32 -> Text
forall a. Show a => a -> Text
showT Word32
scryptParallelism
    , ByteString -> Text
encodeBase64 ByteString
salt
    , ByteString -> Text
encodeBase64 ByteString
key
    ]
  where
    key :: ByteString
key = ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' ScryptParams
params Salt Scrypt
s Password
pass

-- | Only for internal use
hashPasswordWithSalt' :: ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' :: ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' ScryptParams{Word32
scryptOutputLength :: Word32
scryptParallelism :: Word32
scryptBlockSize :: Word32
scryptRounds :: Word32
scryptSalt :: Word32
scryptOutputLength :: ScryptParams -> Word32
scryptParallelism :: ScryptParams -> Word32
scryptBlockSize :: ScryptParams -> Word32
scryptRounds :: ScryptParams -> Word32
scryptSalt :: ScryptParams -> Word32
..} (Salt ByteString
salt) (Password Text
pass) =
    Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (Bytes
scryptHash :: Bytes)
  where
    scryptHash :: Bytes
scryptHash = Parameters -> Bytes -> Bytes -> Bytes
forall password salt output.
(ByteArrayAccess password, ByteArrayAccess salt,
 ByteArray output) =>
Parameters -> password -> salt -> output
Scrypt.generate Parameters
params (Text -> Bytes
toBytes Text
pass) (ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
salt :: Bytes)
    params :: Parameters
params = Parameters :: Word64 -> Int -> Int -> Int -> Parameters
Scrypt.Parameters {
        n :: Word64
n = Word64
2 Word64 -> Word32 -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^ Word32
scryptRounds,
        r :: Int
r = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
scryptBlockSize,
        p :: Int
p = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
scryptParallelism,
        outputLength :: Int
outputLength = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
scryptOutputLength
      }


-- | Hash a password using the 'Scrypt' algorithm with the given 'ScryptParams'.
--
-- __N.B.__: If you have any doubt in your knowledge of cryptography and/or the
-- 'Scrypt' algorithm, please just use 'hashPassword'.
--
-- Advice for setting the parameters:
--
-- * Memory used is about: @(2 ^ 'scryptRounds') * 'scryptBlockSize' * 128@
-- * Increasing 'scryptBlockSize' and 'scryptRounds' will increase CPU time
--   and memory used.
-- * Increasing 'scryptParallelism' will increase CPU time. (since this
--   implementation, like most, runs the 'scryptParallelism' parameter in
--   sequence, not in parallel)
--
-- @since 2.0.0.0
hashPasswordWithParams :: MonadIO m => ScryptParams -> Password -> m (PasswordHash Scrypt)
hashPasswordWithParams :: ScryptParams -> Password -> m (PasswordHash Scrypt)
hashPasswordWithParams ScryptParams
params Password
pass = IO (PasswordHash Scrypt) -> m (PasswordHash Scrypt)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PasswordHash Scrypt) -> m (PasswordHash Scrypt))
-> IO (PasswordHash Scrypt) -> m (PasswordHash Scrypt)
forall a b. (a -> b) -> a -> b
$ do
    Salt Scrypt
salt <- Int -> IO (Salt Scrypt)
forall (m :: * -> *) a. MonadIO m => Int -> m (Salt a)
Data.Password.Internal.newSalt Int
saltLength
    PasswordHash Scrypt -> IO (PasswordHash Scrypt)
forall (m :: * -> *) a. Monad m => a -> m a
return (PasswordHash Scrypt -> IO (PasswordHash Scrypt))
-> PasswordHash Scrypt -> IO (PasswordHash Scrypt)
forall a b. (a -> b) -> a -> b
$ ScryptParams -> Salt Scrypt -> Password -> PasswordHash Scrypt
hashPasswordWithSalt ScryptParams
params Salt Scrypt
salt Password
pass
  where
    saltLength :: Int
saltLength = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ ScryptParams -> Word32
scryptSalt ScryptParams
params

-- | Check a 'Password' against a 'PasswordHash' 'Scrypt'.
--
-- Returns 'PasswordCheckSuccess' on success.
--
-- >>> let pass = mkPassword "foobar"
-- >>> passHash <- hashPassword pass
-- >>> checkPassword pass passHash
-- PasswordCheckSuccess
--
-- Returns 'PasswordCheckFail' if an incorrect 'Password' or 'PasswordHash' 'Scrypt' is used.
--
-- >>> let badpass = mkPassword "incorrect-password"
-- >>> checkPassword badpass passHash
-- PasswordCheckFail
--
-- This should always fail if an incorrect password is given.
--
-- prop> \(Blind badpass) -> let correctPasswordHash = hashPasswordWithSalt testParams salt "foobar" in checkPassword badpass correctPasswordHash == PasswordCheckFail
checkPassword :: Password -> PasswordHash Scrypt -> PasswordCheck
checkPassword :: Password -> PasswordHash Scrypt -> PasswordCheck
checkPassword Password
pass (PasswordHash Text
passHash) =
  PasswordCheck -> Maybe PasswordCheck -> PasswordCheck
forall a. a -> Maybe a -> a
fromMaybe PasswordCheck
PasswordCheckFail (Maybe PasswordCheck -> PasswordCheck)
-> Maybe PasswordCheck -> PasswordCheck
forall a b. (a -> b) -> a -> b
$ do
    let paramList :: [Text]
paramList = (Char -> Bool) -> Text -> [Text]
T.split (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'|') Text
passHash
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Text] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
paramList Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5
    let [ Text
scryptRoundsT,
          Text
scryptBlockSizeT,
          Text
scryptParallelismT,
          Text
salt64,
          Text
hashedKey64 ] = [Text]
paramList
    Word32
scryptRounds <- Text -> Maybe Word32
forall a. Read a => Text -> Maybe a
readT Text
scryptRoundsT
    Word32
scryptBlockSize <- Text -> Maybe Word32
forall a. Read a => Text -> Maybe a
readT Text
scryptBlockSizeT
    Word32
scryptParallelism <- Text -> Maybe Word32
forall a. Read a => Text -> Maybe a
readT Text
scryptParallelismT
    ByteString
salt <- Text -> Maybe ByteString
from64 Text
salt64
    ByteString
hashedKey <- Text -> Maybe ByteString
from64 Text
hashedKey64
    let scryptOutputLength :: Word32
scryptOutputLength = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
C8.length ByteString
hashedKey
        producedKey :: ByteString
producedKey = ScryptParams -> Salt Scrypt -> Password -> ByteString
hashPasswordWithSalt' ScryptParams :: Word32 -> Word32 -> Word32 -> Word32 -> Word32 -> ScryptParams
ScryptParams{Word32
scryptSalt :: Word32
scryptOutputLength :: Word32
scryptParallelism :: Word32
scryptBlockSize :: Word32
scryptRounds :: Word32
scryptOutputLength :: Word32
scryptParallelism :: Word32
scryptBlockSize :: Word32
scryptRounds :: Word32
scryptSalt :: Word32
..} (ByteString -> Salt Scrypt
forall a. ByteString -> Salt a
Salt ByteString
salt) Password
pass
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ ByteString
hashedKey ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
producedKey
    PasswordCheck -> Maybe PasswordCheck
forall (m :: * -> *) a. Monad m => a -> m a
return PasswordCheck
PasswordCheckSuccess
  where
    scryptSalt :: Word32
scryptSalt = Word32
32 -- only here because of warnings

-- | Generate a random 32-byte @scrypt@ salt
--
-- @since 2.0.0.0
newSalt :: MonadIO m => m (Salt Scrypt)
newSalt :: m (Salt Scrypt)
newSalt = Int -> m (Salt Scrypt)
forall (m :: * -> *) a. MonadIO m => Int -> m (Salt a)
Data.Password.Internal.newSalt Int
32