{-# LANGUAGE ScopedTypeVariables #-}

-- | One-time password implementation as defined by the
-- <http://tools.ietf.org/html/rfc4226 HOTP> and <http://tools.ietf.org/html/rfc6238 TOTP>
-- specifications.
--
-- Both implementations use a shared key between the client and the server. HOTP passwords
-- are based on a synchronized counter. TOTP passwords use the same approach but calculate
-- the counter as a number of time steps from the Unix epoch to the current time, thus
-- requiring that both client and server have synchronized clocks.
--
-- Probably the best-known use of TOTP is in Google's 2-factor authentication.
--
-- The TOTP API doesn't depend on any particular time package, so the user needs to supply
-- the current @OTPTime@ value, based on the system time. For example, using the @hourglass@
-- package, you could create a @getOTPTime@ function:
--
-- >>> import Time.System
-- >>> import Time.Types
-- >>>
-- >>> let getOTPTime = timeCurrent >>= \(Elapsed t) -> return (fromIntegral t :: OTPTime)
--
-- Or if you prefer, the @time@ package could be used:
--
-- >>> import Data.Time.Clock.POSIX
-- >>>
-- >>> let getOTPTime = getPOSIXTime >>= \t -> return (floor t :: OTPTime)
--

module Crypto.OTP
    ( OTP
    , OTPDigits (..)
    , OTPTime
    , hotp
    , resynchronize
    , totp
    , totpVerify
    , TOTPParams
    , ClockSkew (..)
    , defaultTOTPParams
    , mkTOTPParams
    )
where

import           Data.Bits (shiftL, (.&.), (.|.))
import           Data.ByteArray.Mapping (fromW64BE)
import           Data.List (elemIndex)
import           Data.Word
import           Control.Monad (unless)
import           Crypto.Hash (HashAlgorithm, SHA1(..))
import           Crypto.MAC.HMAC
import           Crypto.Internal.ByteArray (ByteArrayAccess, Bytes)
import qualified Crypto.Internal.ByteArray as B


-- | A one-time password which is a sequence of 4 to 9 digits.
type OTP = Word32

-- | The strength of the calculated HOTP value, namely
-- the number of digits (between 4 and 9) in the extracted value.
data OTPDigits = OTP4 | OTP5 | OTP6 | OTP7 | OTP8 | OTP9 deriving (Int -> OTPDigits -> ShowS
[OTPDigits] -> ShowS
OTPDigits -> String
(Int -> OTPDigits -> ShowS)
-> (OTPDigits -> String)
-> ([OTPDigits] -> ShowS)
-> Show OTPDigits
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OTPDigits] -> ShowS
$cshowList :: [OTPDigits] -> ShowS
show :: OTPDigits -> String
$cshow :: OTPDigits -> String
showsPrec :: Int -> OTPDigits -> ShowS
$cshowsPrec :: Int -> OTPDigits -> ShowS
Show)

-- | An integral time value in seconds.
type OTPTime = Word64

hotp :: forall hash key. (HashAlgorithm hash, ByteArrayAccess key)
    => hash
    -> OTPDigits
    -- ^ Number of digits in the HOTP value extracted from the calculated HMAC
    -> key
    -- ^ Shared secret between the client and server
    -> Word64
    -- ^ Counter value synchronized between the client and server
    -> OTP
    -- ^ The HOTP value
hotp :: hash -> OTPDigits -> key -> Word64 -> OTP
hotp hash
_ OTPDigits
d key
k Word64
c = OTP
dt OTP -> OTP -> OTP
forall a. Integral a => a -> a -> a
`mod` OTPDigits -> OTP
digitsPower OTPDigits
d
  where
    mac :: HMAC hash
mac = key -> Bytes -> HMAC hash
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac key
k (Word64 -> Bytes
forall ba. ByteArray ba => Word64 -> ba
fromW64BE Word64
c :: Bytes) :: HMAC hash
    offset :: Int
offset = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HMAC hash -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index HMAC hash
mac (HMAC hash -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length HMAC hash
mac Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0xf)
    dt :: OTP
dt = (Word8 -> OTP
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HMAC hash -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index HMAC hash
mac Int
offset       Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0x7f) OTP -> Int -> OTP
forall a. Bits a => a -> Int -> a
`shiftL` Int
24) OTP -> OTP -> OTP
forall a. Bits a => a -> a -> a
.|.
         (Word8 -> OTP
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HMAC hash -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index HMAC hash
mac (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0xff) OTP -> Int -> OTP
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) OTP -> OTP -> OTP
forall a. Bits a => a -> a -> a
.|.
         (Word8 -> OTP
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HMAC hash -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index HMAC hash
mac (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0xff) OTP -> Int -> OTP
forall a. Bits a => a -> Int -> a
`shiftL`  Int
8) OTP -> OTP -> OTP
forall a. Bits a => a -> a -> a
.|.
         Word8 -> OTP
forall a b. (Integral a, Num b) => a -> b
fromIntegral  (HMAC hash -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index HMAC hash
mac (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0xff)

-- | Attempt to resynchronize the server's counter value
-- with the client, given a sequence of HOTP values.
resynchronize :: (HashAlgorithm hash, ByteArrayAccess key)
    => hash
    -> OTPDigits
    -> Word16
    -- ^ The look-ahead window parameter. Up to this many values will
    -- be calculated and checked against the value(s) submitted by the client
    -> key
    -- ^ The shared secret
    -> Word64
    -- ^ The current server counter value
    -> (OTP, [OTP])
    -- ^ The first OTP submitted by the client and a list of additional
    -- sequential OTPs (which may be empty)
    -> Maybe Word64
    -- ^ The new counter value, synchronized with the client's current counter
    -- or Nothing if the submitted OTP values didn't match anywhere within the window
resynchronize :: hash
-> OTPDigits
-> Word16
-> key
-> Word64
-> (OTP, [OTP])
-> Maybe Word64
resynchronize hash
h OTPDigits
d Word16
s key
k Word64
c (OTP
p1, [OTP]
extras) = do
    Word64
offBy <- (Int -> Word64) -> Maybe Int -> Maybe Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (OTP -> [OTP] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex OTP
p1 [OTP]
range)
    Word64 -> [OTP] -> Maybe Word64
checkExtraOtps (Word64
c Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
offBy Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) [OTP]
extras
  where
    checkExtraOtps :: Word64 -> [OTP] -> Maybe Word64
checkExtraOtps Word64
ctr [] = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just Word64
ctr
    checkExtraOtps Word64
ctr (OTP
p:[OTP]
ps)
        | hash -> OTPDigits -> key -> Word64 -> OTP
forall hash key.
(HashAlgorithm hash, ByteArrayAccess key) =>
hash -> OTPDigits -> key -> Word64 -> OTP
hotp hash
h OTPDigits
d key
k Word64
ctr OTP -> OTP -> Bool
forall a. Eq a => a -> a -> Bool
/= OTP
p = Maybe Word64
forall a. Maybe a
Nothing
        | Bool
otherwise           = Word64 -> [OTP] -> Maybe Word64
checkExtraOtps (Word64
ctr Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) [OTP]
ps

    range :: [OTP]
range = (Word64 -> OTP) -> [Word64] -> [OTP]
forall a b. (a -> b) -> [a] -> [b]
map (hash -> OTPDigits -> key -> Word64 -> OTP
forall hash key.
(HashAlgorithm hash, ByteArrayAccess key) =>
hash -> OTPDigits -> key -> Word64 -> OTP
hotp hash
h OTPDigits
d key
k)[Word64
c..Word64
c Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word16 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
s]

digitsPower :: OTPDigits -> Word32
digitsPower :: OTPDigits -> OTP
digitsPower OTPDigits
OTP4 = OTP
10000
digitsPower OTPDigits
OTP5 = OTP
100000
digitsPower OTPDigits
OTP6 = OTP
1000000
digitsPower OTPDigits
OTP7 = OTP
10000000
digitsPower OTPDigits
OTP8 = OTP
100000000
digitsPower OTPDigits
OTP9 = OTP
1000000000


data TOTPParams h = TP !h !OTPTime !Word16 !OTPDigits !ClockSkew deriving (Int -> TOTPParams h -> ShowS
[TOTPParams h] -> ShowS
TOTPParams h -> String
(Int -> TOTPParams h -> ShowS)
-> (TOTPParams h -> String)
-> ([TOTPParams h] -> ShowS)
-> Show (TOTPParams h)
forall h. Show h => Int -> TOTPParams h -> ShowS
forall h. Show h => [TOTPParams h] -> ShowS
forall h. Show h => TOTPParams h -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TOTPParams h] -> ShowS
$cshowList :: forall h. Show h => [TOTPParams h] -> ShowS
show :: TOTPParams h -> String
$cshow :: forall h. Show h => TOTPParams h -> String
showsPrec :: Int -> TOTPParams h -> ShowS
$cshowsPrec :: forall h. Show h => Int -> TOTPParams h -> ShowS
Show)

data ClockSkew = NoSkew | OneStep | TwoSteps | ThreeSteps | FourSteps deriving (Int -> ClockSkew
ClockSkew -> Int
ClockSkew -> [ClockSkew]
ClockSkew -> ClockSkew
ClockSkew -> ClockSkew -> [ClockSkew]
ClockSkew -> ClockSkew -> ClockSkew -> [ClockSkew]
(ClockSkew -> ClockSkew)
-> (ClockSkew -> ClockSkew)
-> (Int -> ClockSkew)
-> (ClockSkew -> Int)
-> (ClockSkew -> [ClockSkew])
-> (ClockSkew -> ClockSkew -> [ClockSkew])
-> (ClockSkew -> ClockSkew -> [ClockSkew])
-> (ClockSkew -> ClockSkew -> ClockSkew -> [ClockSkew])
-> Enum ClockSkew
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: ClockSkew -> ClockSkew -> ClockSkew -> [ClockSkew]
$cenumFromThenTo :: ClockSkew -> ClockSkew -> ClockSkew -> [ClockSkew]
enumFromTo :: ClockSkew -> ClockSkew -> [ClockSkew]
$cenumFromTo :: ClockSkew -> ClockSkew -> [ClockSkew]
enumFromThen :: ClockSkew -> ClockSkew -> [ClockSkew]
$cenumFromThen :: ClockSkew -> ClockSkew -> [ClockSkew]
enumFrom :: ClockSkew -> [ClockSkew]
$cenumFrom :: ClockSkew -> [ClockSkew]
fromEnum :: ClockSkew -> Int
$cfromEnum :: ClockSkew -> Int
toEnum :: Int -> ClockSkew
$ctoEnum :: Int -> ClockSkew
pred :: ClockSkew -> ClockSkew
$cpred :: ClockSkew -> ClockSkew
succ :: ClockSkew -> ClockSkew
$csucc :: ClockSkew -> ClockSkew
Enum, Int -> ClockSkew -> ShowS
[ClockSkew] -> ShowS
ClockSkew -> String
(Int -> ClockSkew -> ShowS)
-> (ClockSkew -> String)
-> ([ClockSkew] -> ShowS)
-> Show ClockSkew
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ClockSkew] -> ShowS
$cshowList :: [ClockSkew] -> ShowS
show :: ClockSkew -> String
$cshow :: ClockSkew -> String
showsPrec :: Int -> ClockSkew -> ShowS
$cshowsPrec :: Int -> ClockSkew -> ShowS
Show)

-- | The default TOTP configuration.
defaultTOTPParams :: TOTPParams SHA1
defaultTOTPParams :: TOTPParams SHA1
defaultTOTPParams = SHA1
-> Word64 -> Word16 -> OTPDigits -> ClockSkew -> TOTPParams SHA1
forall h.
h -> Word64 -> Word16 -> OTPDigits -> ClockSkew -> TOTPParams h
TP SHA1
SHA1 Word64
0 Word16
30 OTPDigits
OTP6 ClockSkew
TwoSteps

-- | Create a TOTP configuration with customized parameters.
mkTOTPParams :: (HashAlgorithm hash)
    => hash
    -> OTPTime
    -- ^ The T0 parameter in seconds. This is the Unix time from which to start
    -- counting steps (default 0). Must be before the current time.
    -> Word16
    -- ^ The time step parameter X in seconds (default 30, maximum allowed 300)
    -> OTPDigits
    -- ^ Number of required digits in the OTP (default 6)
    -> ClockSkew
    -- ^ The number of time steps to check either side of the current value
    -- to allow for clock skew between client and server and or delay in
    -- submitting the value. The default is two time steps.
    -> Either String (TOTPParams hash)
mkTOTPParams :: hash
-> Word64
-> Word16
-> OTPDigits
-> ClockSkew
-> Either String (TOTPParams hash)
mkTOTPParams hash
h Word64
t0 Word16
x OTPDigits
d ClockSkew
skew = do
    Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Word16
x Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
> Word16
0) (String -> Either String ()
forall a b. a -> Either a b
Left String
"Time step must be greater than zero")
    Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Word16
x Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word16
300) (String -> Either String ()
forall a b. a -> Either a b
Left String
"Time step cannot be greater than 300 seconds")
    TOTPParams hash -> Either String (TOTPParams hash)
forall (m :: * -> *) a. Monad m => a -> m a
return (hash
-> Word64 -> Word16 -> OTPDigits -> ClockSkew -> TOTPParams hash
forall h.
h -> Word64 -> Word16 -> OTPDigits -> ClockSkew -> TOTPParams h
TP hash
h Word64
t0 Word16
x OTPDigits
d ClockSkew
skew)

-- | Calculate a totp value for the given time.
totp :: (HashAlgorithm hash, ByteArrayAccess key)
    => TOTPParams hash
    -> key
    -- ^ The shared secret
    -> OTPTime
    -- ^ The time for which the OTP should be calculated.
    -- This is usually the current time as returned by @Data.Time.Clock.POSIX.getPOSIXTime@
    -> OTP
totp :: TOTPParams hash -> key -> Word64 -> OTP
totp (TP hash
h Word64
t0 Word16
x OTPDigits
d ClockSkew
_) key
k Word64
now = hash -> OTPDigits -> key -> Word64 -> OTP
forall hash key.
(HashAlgorithm hash, ByteArrayAccess key) =>
hash -> OTPDigits -> key -> Word64 -> OTP
hotp hash
h OTPDigits
d key
k (Word64 -> Word64 -> Word16 -> Word64
timeToCounter Word64
now Word64
t0 Word16
x)

-- | Check a supplied TOTP value is valid for the given time,
-- within the window defined by the skew parameter.
totpVerify :: (HashAlgorithm hash, ByteArrayAccess key)
    => TOTPParams hash
    -> key
    -> OTPTime
    -> OTP
    -> Bool
totpVerify :: TOTPParams hash -> key -> Word64 -> OTP -> Bool
totpVerify (TP hash
h Word64
t0 Word16
x OTPDigits
d ClockSkew
skew) key
k Word64
now OTP
otp = OTP
otp OTP -> [OTP] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Word64 -> OTP) -> [Word64] -> [OTP]
forall a b. (a -> b) -> [a] -> [b]
map (hash -> OTPDigits -> key -> Word64 -> OTP
forall hash key.
(HashAlgorithm hash, ByteArrayAccess key) =>
hash -> OTPDigits -> key -> Word64 -> OTP
hotp hash
h OTPDigits
d key
k) (Word64 -> [Word64] -> [Word64]
range Word64
window [])
  where
    t :: Word64
t = Word64 -> Word64 -> Word16 -> Word64
timeToCounter Word64
now Word64
t0 Word16
x
    window :: Word64
window = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ClockSkew -> Int
forall a. Enum a => a -> Int
fromEnum ClockSkew
skew)
    range :: Word64 -> [Word64] -> [Word64]
range Word64
0 [Word64]
acc = Word64
t Word64 -> [Word64] -> [Word64]
forall a. a -> [a] -> [a]
: [Word64]
acc
    range Word64
n [Word64]
acc = Word64 -> [Word64] -> [Word64]
range (Word64
nWord64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
-Word64
1) ((Word64
tWord64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
-Word64
n) Word64 -> [Word64] -> [Word64]
forall a. a -> [a] -> [a]
: (Word64
tWord64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+Word64
n) Word64 -> [Word64] -> [Word64]
forall a. a -> [a] -> [a]
: [Word64]
acc)

timeToCounter :: Word64 -> Word64 -> Word16 -> Word64
timeToCounter :: Word64 -> Word64 -> Word16 -> Word64
timeToCounter Word64
now Word64
t0 Word16
x = (Word64
now Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
t0) Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word16 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
x