module Crypto.MAC.TOTP.Factory ( Factory (..) , initialize , initializeIO , initGrace , epochEq , authenticate , authenticateBS , roundTime , setTime , validUntil , shouldRefresh , refresh , refreshIO , tryRefreshEvery , startRefreshThread , getNext , getMessages , getNextIO , getMessagesIO ) where import Crypto.Hash hiding (hmac) import Crypto.MAC.HMAC import qualified Data.ByteString as BS import Data.ByteString (ByteString) import Data.Int import Data.Word import Data.Bits import System.Posix.Time import System.Posix.Types (EpochTime) import Foreign.C.Types (CTime (..)) import Control.Concurrent import Data.IORef import Data.Serialize instance Integral CTime where quot (CTime a) (CTime b) = CTime (a `quot` b) rem (CTime a) (CTime b) = CTime (a `rem` b) div (CTime a) (CTime b) = CTime (a `div` b) mod (CTime a) (CTime b) = CTime (a `mod` b) quotRem (CTime n) (CTime d) = (\(d,m) -> (CTime d, CTime m)) (quotRem n d) divMod (CTime n) (CTime d) = (\(d,m) -> (CTime d, CTime m)) (divMod n d) toInteger (CTime t) = toInteger t data Factory = Factory { secret :: ByteString , secretInit :: ByteString , count :: Int64 , validSeconds :: CTime , refreshEpoch :: EpochTime , hashMethod :: ByteString -> ByteString , blockSize :: Int , prefix :: ByteString -> ByteString } initialize :: (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> Factory initialize hashMethod blockSize tokenBytes secretInit validSeconds = if validSeconds < 1 then error "validSeconds must be >= 1" else Factory { secret = BS.empty , secretInit , count = 0 , validSeconds , refreshEpoch = 0 , hashMethod , blockSize , prefix = BS.take tokenBytes } initializeIO :: (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> IO (Factory) initializeIO hashMethod blockSize tokenBytes secretInit validSeconds = do time <- epochTime return $ refresh time (initialize hashMethod blockSize tokenBytes secretInit validSeconds) initGrace :: Factory -> CTime -> Factory initGrace (Factory _ secretInit _ validSeconds refreshEpoch hashMethod blockSize prefix) graceSeconds = let time = refreshEpoch - graceSeconds * validSeconds in refresh time $ Factory { secret = BS.empty , secretInit , count = 0 , validSeconds , refreshEpoch = 0 , hashMethod , blockSize , prefix } epochEq :: Factory -> CTime -> Factory -> Bool epochEq baseF n f = refreshEpoch f == refreshEpoch baseF - n * (validSeconds baseF) incr :: Factory -> Factory incr f = f {count = count f + 1} authenticate :: Serialize b => Factory -> b -> ByteString authenticate factory = authenticateBS factory encode authenticateBS :: Factory -> (b -> ByteString) -> b -> ByteString authenticateBS factory encodeFun message = (prefix factory) $ hmac (hashMethod factory) (blockSize factory) (secret factory) (encodeFun message) hashCount :: Factory -> ByteString hashCount f = authenticate f (count f) roundTime :: CTime -> CTime -> CTime roundTime t r = (t `div` r) * r setTime :: CTime -> Factory -> Factory setTime t f = let ct'@(CTime t') = roundTime t (validSeconds f) timeBytes = encode t' in f { refreshEpoch = ct', secret = (hashMethod f) $ BS.concat [secretInit f, timeBytes] } validUntil :: Factory -> EpochTime validUntil f = refreshEpoch f + validSeconds f shouldRefresh :: Factory -> EpochTime -> Bool shouldRefresh f t = t >= validUntil f refresh :: EpochTime -> Factory -> Factory refresh time factory = if shouldRefresh factory time then (setTime time factory) { count = 0 } else factory refreshIO :: Factory -> IO (Factory) refreshIO factory = do time <- epochTime return $ refresh time factory tryRefreshEvery :: Int -- ^ The delay according to Control.Concurrent.threadDelay before refresh attempts. -> IORef (Factory) -- ^ The current factory. -> IO () tryRefreshEvery delay factoryRef = do threadDelay delay time <- epochTime atomicModifyIORef factoryRef (\f -> let f' = if validUntil f <= time then refresh time f else f in (f', ())) tryRefreshEvery delay factoryRef startRefreshThread :: Int -> Factory -> IO (ThreadId, IORef (Factory)) startRefreshThread delay factory = do time <- epochTime let factory' = refresh time factory factoryRef <- newIORef factory' t <- forkIO (tryRefreshEvery delay factoryRef) return (t, factoryRef) getNext :: Factory -> (Factory, ByteString) getNext f = (incr f, hashCount f) getMessages :: Int -> Factory -> (Factory, [ByteString]) getMessages 0 f = (f, []) getMessages n f = let (f', keys) = getMessages (n-1) f (f'', key) = getNext f' in (f'', key:keys) getNextIO :: IORef (Factory) -> IO (ByteString) getNextIO factoryRef = atomicModifyIORef factoryRef getNext getMessagesIO ::IORef (Factory) -> Int -> IO [ByteString] getMessagesIO _ n | n < 1 = return [] getMessagesIO factoryRef n = atomicModifyIORef factoryRef (getMessages n)