module Crypto.MAC.TOTP.Verifier
( Verifier (..)
, initializeIO
, initialize
, tryRefreshEvery
, startRefreshThread
, refresh
, getNext
, getMessages
, getNextIO
, getMessagesIO
, isAuthentic
, isAuthenticIO
) where

import qualified Data.ByteString as BS
import Data.ByteString (ByteString)
import qualified Data.List as List
import qualified Data.Set as Set
import System.Posix.Time
import System.Posix.Types (EpochTime)
import Foreign.C.Types (CTime (..))
import Control.Concurrent
import Data.IORef
import Control.Exception (evaluate)

import qualified Crypto.MAC.TOTP.Factory as Factory

data Verifier = Verifier { factory :: Factory.Factory
                         , usedTokens :: Set.Set ByteString
                         , grace :: [GraceVerifier]
                         , graceSeconds :: CTime
                         }

data GraceVerifier = GraceVerifier { graceFactory :: Factory.Factory
                                   , graceUsedTokens :: Set.Set ByteString
                                   }

graceEq :: Verifier -> CTime -> GraceVerifier -> Bool
graceEq v n gv =
    Factory.epochEq (factory v) n (graceFactory gv)

toGrace :: Verifier -> GraceVerifier
toGrace (Verifier factory usedTokens _ _) =
    GraceVerifier { graceFactory = factory, graceUsedTokens = usedTokens }

initializeIO :: (ByteString -> ByteString)-> Int -> Int -> ByteString -> CTime -> CTime -> IO Verifier 
initializeIO hashMethod blockSize tokenBytes secret validSeconds graceSeconds = do
  time <- epochTime
  return $ initialize time hashMethod blockSize tokenBytes secret validSeconds graceSeconds

initialize :: EpochTime -> (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> CTime -> Verifier
initialize time hashMethod blockSize tokenBytes secret validSeconds graceSeconds =
  if graceSeconds < 0
  then error "graceSeconds must be >= 0"
  else let factory = Factory.initialize hashMethod blockSize tokenBytes secret validSeconds
           factory' = Factory.refresh time factory in
       Verifier { factory = factory'
                , usedTokens = Set.empty
                , grace = initGrace graceSeconds validSeconds factory'
                , graceSeconds
                }

tryRefreshEvery :: Verifier -> Int -> IORef Verifier -> IO ()
tryRefreshEvery v delay verifierRef = do
  threadDelay delay
  time <- epochTime
  if shouldRefresh time v
    then let v' = refresh time v in
         do writeIORef verifierRef v'
            tryRefreshEvery v' delay verifierRef
    else tryRefreshEvery v delay verifierRef

startRefreshThread :: Int -> Verifier -> IO (ThreadId, IORef Verifier)
startRefreshThread delay verifier = do
  time <- epochTime
  let verifier' = refresh time verifier
  verifierRef <- newIORef verifier'
  t <- forkIO (tryRefreshEvery verifier' delay verifierRef)
  return (t, verifierRef)

graceCount :: CTime -> CTime -> CTime
graceCount graceSeconds validSeconds =
    let (d, m) = graceSeconds `divMod` validSeconds in
    d + if m > 0 then 1 else 0

graceCountV :: Verifier -> CTime
graceCountV v =
    graceCount (graceSeconds v) (Factory.validSeconds . factory $ v)

initGrace :: CTime -> CTime -> Factory.Factory -> [GraceVerifier]
initGrace graceSeconds validSeconds factory' =
    let initGraceAux n accum =
                  let gf = Factory.initGrace factory' n in
                  (GraceVerifier gf Set.empty):accum
    in foldr initGraceAux [] [1..graceCount graceSeconds validSeconds]

refreshGrace :: Verifier -> Verifier -> Verifier
refreshGrace verifierOld verifierNew =
    let grace' = (toGrace verifierOld):(grace verifierOld)
        refreshGraceAux n accum =
            case List.find (graceEq verifierNew n) grace' of
              Nothing ->
                  let gf = Factory.initGrace (factory verifierNew) n in
                  (GraceVerifier gf Set.empty):accum
              Just f ->
                  f:accum
    in verifierNew { grace = foldr refreshGraceAux [] [1..graceCountV verifierNew] }

shouldRefresh :: EpochTime -> Verifier -> Bool
shouldRefresh time v =
    Factory.shouldRefresh (factory v) time

refresh :: EpochTime -> Verifier -> Verifier
refresh t v =
    let f = factory v
        f' = Factory.refresh t f in
    if shouldRefresh t v
    then let v' = Verifier { factory = f'
                           , usedTokens = Set.empty
                           , grace = []
                           , graceSeconds = graceSeconds v
                           }
         in refreshGrace v v'
    else v

getNextUnsafe :: Verifier -> (Verifier, ByteString)
getNextUnsafe v =
    let (f', m) = Factory.getNext (factory v)
        v' = v { factory = f'
               , usedTokens = Set.insert m (usedTokens v)
               } in
    (v', m)

getNext :: Verifier -> (Verifier, ByteString)
getNext v =
    let (v', m) = getNextUnsafe v in
    if Set.member m (usedTokens v)
    then getNext v'
    else (v', m)

getMessages :: Int -> Verifier -> (Verifier, [ByteString])
getMessages 0 v = (v, [])
getMessages n f =
    let (v', keys) = getMessages (n-1) f
        (v'', key) = getNext v'
    in (v'', key:keys)

getNextIO :: IORef Verifier -> IO ByteString
getNextIO vRef = do
  atomicModifyIORef vRef getNext

getMessagesIO :: Int -> IORef Verifier -> IO [ByteString]
getMessagesIO n _ | n < 1 = return []
getMessagesIO n vRef = do
  atomicModifyIORef vRef (getMessages n)

isAuthenticGrace :: EpochTime -> ByteString -> ByteString -> Verifier -> ([GraceVerifier], Bool)
isAuthenticGrace currentTime message token v = isAuthenticGraceAux [] (grace v)
    where isAuthenticGraceAux left [] = (left, False)
          isAuthenticGraceAux left (g:gs) =
              let used = graceUsedTokens g
                  token' = Factory.authenticateBS (graceFactory g) id message
                  areEq = token == token'
                  isNotExpired = (Factory.validUntil . graceFactory $ g) + graceSeconds v > currentTime
                  nextStep = isAuthenticGraceAux (g:left) gs in
              if not (Set.member token used) && areEq && isNotExpired
              then let g' = g { graceUsedTokens = Set.insert token used } in
                   (g':(gs ++ left), True)
              else nextStep

isAuthentic :: EpochTime -> ByteString -> ByteString -> Verifier -> (Verifier, Bool)
isAuthentic currentTime message token v =
    let v' = refresh currentTime v
        used = usedTokens v'
        token' = Factory.authenticateBS (factory v') id message
        areEq = token == token' in
    if Set.member token used || not areEq
    then case isAuthenticGrace currentTime message token v' of
           (g', True) -> (v' { grace = g' }, True)
           (_, False) -> (v', False)
    else if areEq
         then let v'' = v' { usedTokens = Set.insert token used } in
              (v'', True)
         else (v', False)

isAuthenticIO :: ByteString -> ByteString -> IORef Verifier -> IO Bool
isAuthenticIO message token vRef = do
  t <- epochTime
  atomicModifyIORef vRef (isAuthentic t message token)