module Crypto.MAC.TOTP.Verifier ( Verifier (..) , initializeIO , initialize , tryRefreshEvery , startRefreshThread , refresh , getNext , getMessages , getNextIO , getMessagesIO , isAuthentic , isAuthenticIO ) where import qualified Data.ByteString as BS import Data.ByteString (ByteString) import qualified Data.List as List import qualified Data.Set as Set import System.Posix.Time import System.Posix.Types (EpochTime) import Foreign.C.Types (CTime (..)) import Control.Concurrent import Data.IORef import Control.Exception (evaluate) import qualified Crypto.MAC.TOTP.Factory as Factory data Verifier = Verifier { factory :: Factory.Factory , usedTokens :: Set.Set ByteString , grace :: [GraceVerifier] , graceSeconds :: CTime } data GraceVerifier = GraceVerifier { graceFactory :: Factory.Factory , graceUsedTokens :: Set.Set ByteString } graceEq :: Verifier -> CTime -> GraceVerifier -> Bool graceEq v n gv = Factory.epochEq (factory v) n (graceFactory gv) toGrace :: Verifier -> GraceVerifier toGrace (Verifier factory usedTokens _ _) = GraceVerifier { graceFactory = factory, graceUsedTokens = usedTokens } initializeIO :: (ByteString -> ByteString)-> Int -> Int -> ByteString -> CTime -> CTime -> IO Verifier initializeIO hashMethod blockSize tokenBytes secret validSeconds graceSeconds = do time <- epochTime return $ initialize time hashMethod blockSize tokenBytes secret validSeconds graceSeconds initialize :: EpochTime -> (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> CTime -> Verifier initialize time hashMethod blockSize tokenBytes secret validSeconds graceSeconds = if graceSeconds < 0 then error "graceSeconds must be >= 0" else let factory = Factory.initialize hashMethod blockSize tokenBytes secret validSeconds factory' = Factory.refresh time factory in Verifier { factory = factory' , usedTokens = Set.empty , grace = initGrace graceSeconds validSeconds factory' , graceSeconds } tryRefreshEvery :: Verifier -> Int -> IORef Verifier -> IO () tryRefreshEvery v delay verifierRef = do threadDelay delay time <- epochTime if shouldRefresh time v then let v' = refresh time v in do writeIORef verifierRef v' tryRefreshEvery v' delay verifierRef else tryRefreshEvery v delay verifierRef startRefreshThread :: Int -> Verifier -> IO (ThreadId, IORef Verifier) startRefreshThread delay verifier = do time <- epochTime let verifier' = refresh time verifier verifierRef <- newIORef verifier' t <- forkIO (tryRefreshEvery verifier' delay verifierRef) return (t, verifierRef) graceCount :: CTime -> CTime -> CTime graceCount graceSeconds validSeconds = let (d, m) = graceSeconds `divMod` validSeconds in d + if m > 0 then 1 else 0 graceCountV :: Verifier -> CTime graceCountV v = graceCount (graceSeconds v) (Factory.validSeconds . factory $ v) initGrace :: CTime -> CTime -> Factory.Factory -> [GraceVerifier] initGrace graceSeconds validSeconds factory' = let initGraceAux n accum = let gf = Factory.initGrace factory' n in (GraceVerifier gf Set.empty):accum in foldr initGraceAux [] [1..graceCount graceSeconds validSeconds] refreshGrace :: Verifier -> Verifier -> Verifier refreshGrace verifierOld verifierNew = let grace' = (toGrace verifierOld):(grace verifierOld) refreshGraceAux n accum = case List.find (graceEq verifierNew n) grace' of Nothing -> let gf = Factory.initGrace (factory verifierNew) n in (GraceVerifier gf Set.empty):accum Just f -> f:accum in verifierNew { grace = foldr refreshGraceAux [] [1..graceCountV verifierNew] } shouldRefresh :: EpochTime -> Verifier -> Bool shouldRefresh time v = Factory.shouldRefresh (factory v) time refresh :: EpochTime -> Verifier -> Verifier refresh t v = let f = factory v f' = Factory.refresh t f in if shouldRefresh t v then let v' = Verifier { factory = f' , usedTokens = Set.empty , grace = [] , graceSeconds = graceSeconds v } in refreshGrace v v' else v getNextUnsafe :: Verifier -> (Verifier, ByteString) getNextUnsafe v = let (f', m) = Factory.getNext (factory v) v' = v { factory = f' , usedTokens = Set.insert m (usedTokens v) } in (v', m) getNext :: Verifier -> (Verifier, ByteString) getNext v = let (v', m) = getNextUnsafe v in if Set.member m (usedTokens v) then getNext v' else (v', m) getMessages :: Int -> Verifier -> (Verifier, [ByteString]) getMessages 0 v = (v, []) getMessages n f = let (v', keys) = getMessages (n-1) f (v'', key) = getNext v' in (v'', key:keys) getNextIO :: IORef Verifier -> IO ByteString getNextIO vRef = do atomicModifyIORef vRef getNext getMessagesIO :: Int -> IORef Verifier -> IO [ByteString] getMessagesIO n _ | n < 1 = return [] getMessagesIO n vRef = do atomicModifyIORef vRef (getMessages n) isAuthenticGrace :: EpochTime -> ByteString -> ByteString -> Verifier -> ([GraceVerifier], Bool) isAuthenticGrace currentTime message token v = isAuthenticGraceAux [] (grace v) where isAuthenticGraceAux left [] = (left, False) isAuthenticGraceAux left (g:gs) = let used = graceUsedTokens g token' = Factory.authenticateBS (graceFactory g) id message areEq = token == token' isNotExpired = (Factory.validUntil . graceFactory $ g) + graceSeconds v > currentTime nextStep = isAuthenticGraceAux (g:left) gs in if not (Set.member token used) && areEq && isNotExpired then let g' = g { graceUsedTokens = Set.insert token used } in (g':(gs ++ left), True) else nextStep isAuthentic :: EpochTime -> ByteString -> ByteString -> Verifier -> (Verifier, Bool) isAuthentic currentTime message token v = let v' = refresh currentTime v used = usedTokens v' token' = Factory.authenticateBS (factory v') id message areEq = token == token' in if Set.member token used || not areEq then case isAuthenticGrace currentTime message token v' of (g', True) -> (v' { grace = g' }, True) (_, False) -> (v', False) else if areEq then let v'' = v' { usedTokens = Set.insert token used } in (v'', True) else (v', False) isAuthenticIO :: ByteString -> ByteString -> IORef Verifier -> IO Bool isAuthenticIO message token vRef = do t <- epochTime atomicModifyIORef vRef (isAuthentic t message token)