{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE OverloadedStrings #-}

module OTP.Commons
  ( -- * Auxiliary
    OTP (..)
  , Digits
  , Algorithm (..)
  , mkDigits
  , digitsToWord32
  , totpCounter
  , counterRange
  , totpCounterRange
  ) where

import Chronos (Time (..), Timespan (..), asSeconds, sinceEpoch)
import Data.Int (Int64)
import Data.Text.Display
import Data.Text.Lazy.Builder (Builder)
import Data.Text.Lazy.Builder qualified as Text
import Data.Word
import Text.Printf (printf)

-- $setup
-- >>> import Chronos qualified
-- >>> import Chronos (DatetimeFormat(..))
-- >>> import Torsor qualified
-- >>> import Data.Maybe (fromJust)
-- >>> :set -XOverloadedStrings
-- >>> let format = DatetimeFormat (Just '-') (Just ' ') (Just ':')
-- >>> let decode txt = Chronos.datetimeToTime $ fromJust $ Chronos.decode_YmdHMS format txt

-- |
--
-- @since 3.0.0.0
data Algorithm
  = HMAC_SHA1
  | HMAC_SHA256
  | HMAC_SHA512
  deriving stock
    ( Algorithm -> Algorithm -> Bool
(Algorithm -> Algorithm -> Bool)
-> (Algorithm -> Algorithm -> Bool) -> Eq Algorithm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Algorithm -> Algorithm -> Bool
== :: Algorithm -> Algorithm -> Bool
$c/= :: Algorithm -> Algorithm -> Bool
/= :: Algorithm -> Algorithm -> Bool
Eq
      -- ^ @since 3.0.0.0
    , Eq Algorithm
Eq Algorithm =>
(Algorithm -> Algorithm -> Ordering)
-> (Algorithm -> Algorithm -> Bool)
-> (Algorithm -> Algorithm -> Bool)
-> (Algorithm -> Algorithm -> Bool)
-> (Algorithm -> Algorithm -> Bool)
-> (Algorithm -> Algorithm -> Algorithm)
-> (Algorithm -> Algorithm -> Algorithm)
-> Ord Algorithm
Algorithm -> Algorithm -> Bool
Algorithm -> Algorithm -> Ordering
Algorithm -> Algorithm -> Algorithm
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Algorithm -> Algorithm -> Ordering
compare :: Algorithm -> Algorithm -> Ordering
$c< :: Algorithm -> Algorithm -> Bool
< :: Algorithm -> Algorithm -> Bool
$c<= :: Algorithm -> Algorithm -> Bool
<= :: Algorithm -> Algorithm -> Bool
$c> :: Algorithm -> Algorithm -> Bool
> :: Algorithm -> Algorithm -> Bool
$c>= :: Algorithm -> Algorithm -> Bool
>= :: Algorithm -> Algorithm -> Bool
$cmax :: Algorithm -> Algorithm -> Algorithm
max :: Algorithm -> Algorithm -> Algorithm
$cmin :: Algorithm -> Algorithm -> Algorithm
min :: Algorithm -> Algorithm -> Algorithm
Ord
      -- ^ @since 3.0.0.0
    , Int -> Algorithm -> ShowS
[Algorithm] -> ShowS
Algorithm -> String
(Int -> Algorithm -> ShowS)
-> (Algorithm -> String)
-> ([Algorithm] -> ShowS)
-> Show Algorithm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Algorithm -> ShowS
showsPrec :: Int -> Algorithm -> ShowS
$cshow :: Algorithm -> String
show :: Algorithm -> String
$cshowList :: [Algorithm] -> ShowS
showList :: [Algorithm] -> ShowS
Show
      -- ^ @since 3.0.0.0
    )

-- |
--
-- @since 3.0.0.0
instance Display Algorithm where
  displayBuilder :: Algorithm -> Builder
displayBuilder Algorithm
HMAC_SHA1 = Builder
"SHA1"
  displayBuilder Algorithm
HMAC_SHA256 = Builder
"SHA256"
  displayBuilder Algorithm
HMAC_SHA512 = Builder
"SHA512"

-- |
--
-- @since 3.0.0.0
data OTP = OTP
  { OTP -> Word32
digits :: Word32
  , OTP -> Word32
code :: Word32
  }
  deriving stock
    ( OTP -> OTP -> Bool
(OTP -> OTP -> Bool) -> (OTP -> OTP -> Bool) -> Eq OTP
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OTP -> OTP -> Bool
== :: OTP -> OTP -> Bool
$c/= :: OTP -> OTP -> Bool
/= :: OTP -> OTP -> Bool
Eq
      -- ^ @since 3.0.0.0
    , Eq OTP
Eq OTP =>
(OTP -> OTP -> Ordering)
-> (OTP -> OTP -> Bool)
-> (OTP -> OTP -> Bool)
-> (OTP -> OTP -> Bool)
-> (OTP -> OTP -> Bool)
-> (OTP -> OTP -> OTP)
-> (OTP -> OTP -> OTP)
-> Ord OTP
OTP -> OTP -> Bool
OTP -> OTP -> Ordering
OTP -> OTP -> OTP
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: OTP -> OTP -> Ordering
compare :: OTP -> OTP -> Ordering
$c< :: OTP -> OTP -> Bool
< :: OTP -> OTP -> Bool
$c<= :: OTP -> OTP -> Bool
<= :: OTP -> OTP -> Bool
$c> :: OTP -> OTP -> Bool
> :: OTP -> OTP -> Bool
$c>= :: OTP -> OTP -> Bool
>= :: OTP -> OTP -> Bool
$cmax :: OTP -> OTP -> OTP
max :: OTP -> OTP -> OTP
$cmin :: OTP -> OTP -> OTP
min :: OTP -> OTP -> OTP
Ord
      -- ^ @since 3.0.0.0
    , Int -> OTP -> ShowS
[OTP] -> ShowS
OTP -> String
(Int -> OTP -> ShowS)
-> (OTP -> String) -> ([OTP] -> ShowS) -> Show OTP
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OTP -> ShowS
showsPrec :: Int -> OTP -> ShowS
$cshow :: OTP -> String
show :: OTP -> String
$cshowList :: [OTP] -> ShowS
showList :: [OTP] -> ShowS
Show
      -- ^ @since 3.0.0.0
    )

-- |
--
-- @since 3.0.0.0
instance Display OTP where
  displayBuilder :: OTP -> Builder
displayBuilder OTP{Word32
digits :: OTP -> Word32
digits :: Word32
digits, Word32
code :: OTP -> Word32
code :: Word32
code} = Word32 -> Word32 -> Builder
displayWord32AsOTP Word32
digits Word32
code

displayWord32AsOTP :: Word32 -> Word32 -> Builder
displayWord32AsOTP :: Word32 -> Word32 -> Builder
displayWord32AsOTP Word32
digits Word32
code = String -> Builder
Text.fromString (String -> Builder) -> String -> Builder
forall a b. (a -> b) -> a -> b
$ String -> Word32 -> String
forall r. PrintfType r => String -> r
printf (String
"%0" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word32 -> String
forall a. Show a => a -> String
show Word32
digits String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"u") Word32
code

-- |
--
-- @since 3.0.0.0
newtype Digits = Digits Word32
  deriving newtype (Digits -> Digits -> Bool
(Digits -> Digits -> Bool)
-> (Digits -> Digits -> Bool) -> Eq Digits
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Digits -> Digits -> Bool
== :: Digits -> Digits -> Bool
$c/= :: Digits -> Digits -> Bool
/= :: Digits -> Digits -> Bool
Eq, Int -> Digits -> ShowS
[Digits] -> ShowS
Digits -> String
(Int -> Digits -> ShowS)
-> (Digits -> String) -> ([Digits] -> ShowS) -> Show Digits
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Digits -> ShowS
showsPrec :: Int -> Digits -> ShowS
$cshow :: Digits -> String
show :: Digits -> String
$cshowList :: [Digits] -> ShowS
showList :: [Digits] -> ShowS
Show, Eq Digits
Eq Digits =>
(Digits -> Digits -> Ordering)
-> (Digits -> Digits -> Bool)
-> (Digits -> Digits -> Bool)
-> (Digits -> Digits -> Bool)
-> (Digits -> Digits -> Bool)
-> (Digits -> Digits -> Digits)
-> (Digits -> Digits -> Digits)
-> Ord Digits
Digits -> Digits -> Bool
Digits -> Digits -> Ordering
Digits -> Digits -> Digits
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Digits -> Digits -> Ordering
compare :: Digits -> Digits -> Ordering
$c< :: Digits -> Digits -> Bool
< :: Digits -> Digits -> Bool
$c<= :: Digits -> Digits -> Bool
<= :: Digits -> Digits -> Bool
$c> :: Digits -> Digits -> Bool
> :: Digits -> Digits -> Bool
$c>= :: Digits -> Digits -> Bool
>= :: Digits -> Digits -> Bool
$cmax :: Digits -> Digits -> Digits
max :: Digits -> Digits -> Digits
$cmin :: Digits -> Digits -> Digits
min :: Digits -> Digits -> Digits
Ord)
  deriving
    (Int -> Digits -> Builder
[Digits] -> Builder
Digits -> Builder
(Digits -> Builder)
-> ([Digits] -> Builder)
-> (Int -> Digits -> Builder)
-> Display Digits
forall a.
(a -> Builder)
-> ([a] -> Builder) -> (Int -> a -> Builder) -> Display a
$cdisplayBuilder :: Digits -> Builder
displayBuilder :: Digits -> Builder
$cdisplayList :: [Digits] -> Builder
displayList :: [Digits] -> Builder
$cdisplayPrec :: Int -> Digits -> Builder
displayPrec :: Int -> Digits -> Builder
Display)
    via ShowInstance Digits

digitsToWord32 :: Digits -> Word32
digitsToWord32 :: Digits -> Word32
digitsToWord32 (Digits Word32
digits) = Word32
digits

-- |
--
-- RFC 4226 §5.3 says "Implementations MUST extract a 6-digit code at a minimum and possibly 7 and 8-digit code".
--
-- This function validates that the number of desired digits is equal or greater than 6.
mkDigits
  :: Word32
  -> Maybe Digits
mkDigits :: Word32 -> Maybe Digits
mkDigits Word32
userDigits
  | Word32
userDigits Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word32
6 = Digits -> Maybe Digits
forall a. a -> Maybe a
Just (Word32 -> Digits
Digits Word32
userDigits)
  | Bool
otherwise = Maybe Digits
forall a. Maybe a
Nothing

-- | Calculate HOTP counter using time. Starting time (T0
-- according to RFC6238) is 0 (begining of UNIX epoch)
-- >>> let timestamp = decode "2010-10-10 00:00:30"
-- >>> let timespan = Torsor.scale 30 Chronos.second
-- >>> totpCounter timestamp timespan
-- 42888961
--
-- >>> let timestamp2 = decode "2010-10-10 00:00:45"
-- >>> totpCounter timestamp2 timespan
-- 42888961
--
-- >>> let timestamp3 = decode "2010-10-10 00:01:00"
-- >>> totpCounter timestamp3 timespan
-- 42888962
--
-- @since 3.0.0.0
totpCounter
  :: Time
  -- ^ Time of totp
  -> Timespan
  -- ^ Time range in seconds
  -> Word64
  -- ^ Resulting counter
totpCounter :: Time -> Timespan -> Word64
totpCounter Time
time Timespan
period =
  Int64 -> Word64
ts2word (Timespan -> Int64
asSeconds (Time -> Timespan
sinceEpoch Time
time)) Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`quot` Int64 -> Word64
ts2word (Timespan -> Int64
asSeconds Timespan
period)
  where
    ts2word :: Int64 -> Word64
    ts2word :: Int64 -> Word64
ts2word = Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral

-- | Make a sequence of acceptable counters, protected from
-- arithmetic overflow. Maximum range is limited to 1000 due to huge
-- counter ranges being insecure.
--
-- >>> counterRange (0, 0) 9000
-- [9000]
--
-- >>> counterRange (1, 0) 9000
-- [8999,9000]
--
-- >>> length $ counterRange (5000, 0) 9000
-- 501
--
-- >>> length $ counterRange (5000, 5000) 9000
-- 1000
--
-- >>> counterRange (2, 2) maxBound
-- [18446744073709551613,18446744073709551614,18446744073709551615]
--
-- >>> counterRange (2, 2) minBound
-- [0,1,2]
--
-- >>> counterRange (2, 2) (maxBound `div` 2)
-- [9223372036854775805,9223372036854775806,9223372036854775807,9223372036854775808,9223372036854775809]
--
-- >>> counterRange (5, 5) 9000
-- [8995,8996,8997,8998,8999,9000,9001,9002,9003,9004,9005]
--
-- RFC recommends avoiding excessively large values for counter ranges.
--
-- @since 3.0.0.0
counterRange
  :: (Word64, Word64)
  -- ^ Number of counters before and after ideal
  -> Word64
  -- ^ Ideal counter value
  -> [Word64]
counterRange :: (Word64, Word64) -> Word64 -> [Word64]
counterRange (Word64
tolow', Word64
tohigh') Word64
ideal =
  let tolow :: Word64
tolow = Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
min Word64
500 Word64
tolow'
      tohigh :: Word64
tohigh = Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
min Word64
499 Word64
tohigh'
      l :: Word64
l = Word64 -> Word64 -> Word64 -> Word64
forall {c}. Ord c => c -> c -> c -> c
trim Word64
0 Word64
ideal (Word64
ideal Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
tolow)
      h :: Word64
h = Word64 -> Word64 -> Word64 -> Word64
forall {c}. Ord c => c -> c -> c -> c
trim Word64
ideal Word64
forall a. Bounded a => a
maxBound (Word64
ideal Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
tohigh)
   in [Word64
l .. Word64
h]
  where
    trim :: c -> c -> c -> c
trim c
l c
h = c -> c -> c
forall a. Ord a => a -> a -> a
max c
l (c -> c) -> (c -> c) -> c -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. c -> c -> c
forall a. Ord a => a -> a -> a
min c
h

-- | Make a sequence of acceptable periods.
--
-- >>> let time = decode "2010-10-10 00:00:30"
-- >>> let timespan = Torsor.scale 30 Chronos.second
-- >>> totpCounterRange (1, 1) time timespan
-- [42888960,42888961,42888962]
--
-- @since 3.0.0.0
totpCounterRange
  :: (Word64, Word64)
  -> Time
  -> Timespan
  -> [Word64]
totpCounterRange :: (Word64, Word64) -> Time -> Timespan -> [Word64]
totpCounterRange (Word64, Word64)
range Time
time Timespan
period =
  (Word64, Word64) -> Word64 -> [Word64]
counterRange (Word64, Word64)
range (Word64 -> [Word64]) -> Word64 -> [Word64]
forall a b. (a -> b) -> a -> b
$ Time -> Timespan -> Word64
totpCounter Time
time Timespan
period