{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Authenticator.Vault -- Description : Secrets and storage for OTP keys. -- Copyright : (c) Justin Le 2017 -- License : MIT -- Maintainer : justin@jle.im -- Stability : unstable -- Portability : portable -- -- Types for storing, serializing, accessing OTP keys. Gratuitous -- type-level programming here for no reason because I have issues. -- -- Based off of . -- module Authenticator.Vault ( Mode(..), SMode(..), withSMode, fromSMode , HashAlgo(..) , parseAlgo , Secret(..), OTPDigits(..), pattern OTPDigitsInt , ModeState(..) , SomeSecretState , Vault(..) , _Vault , hotp , totp , totp_ , otp , someSecret , vaultSecrets , describeSecret , secretURI , parseSecretURI ) where import Authenticator.Common import Control.Applicative import Control.Monad hiding (fail) import Crypto.Hash.Algorithms import Data.Bifunctor import Data.Bitraversable import Data.Char import Data.Dependent.Sum import Data.Function import Data.GADT.Show import Data.Kind import Data.Maybe import Data.Ord import Data.Some import Data.Time.Clock.POSIX import Data.Vinyl import Data.Void import Data.Word import GHC.Generics import Prelude.Compat import Text.Printf import Text.Read (readMaybe) import qualified Codec.Binary.Base32 as B32 import qualified Crypto.OTP as OTP import qualified Data.Aeson as J import qualified Data.Binary as B import qualified Data.ByteString as BS import qualified Data.Map as M import qualified Data.Set as S import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Network.URI.Encode as U import qualified Text.Megaparsec as P import qualified Text.Megaparsec.Char as P -- | OTP generation mode data Mode -- | Counter-based = HOTP -- | Time-based | TOTP deriving (Generic, Show) -- | Singleton for 'Mode' data SMode :: Mode -> Type where SHOTP :: SMode 'HOTP STOTP :: SMode 'TOTP deriving instance Show (SMode m) instance GShow SMode where gshowsPrec = showsPrec instance B.Binary Mode instance J.ToJSON Mode where toJSON HOTP = J.toJSON @T.Text "hotp" toJSON TOTP = J.toJSON @T.Text "totp" -- | Reify a 'Mode' to its singleton withSMode :: Mode -> (forall m. SMode m -> r) -> r withSMode = \case HOTP -> ($ SHOTP) TOTP -> ($ STOTP) -- | Reflect a 'SMode' to its value. fromSMode :: SMode m -> Mode fromSMode = \case SHOTP -> HOTP STOTP -> TOTP -- | A data family consisting of the state required by each mode. data family ModeState :: Mode -> Type -- | For 'HOTP' (counter-based) mode, the state is the current counter. data instance ModeState 'HOTP = HOTPState { hotpCounter :: Word64 } deriving (Generic, Show) -- | For 'TOTP' (time-based) mode, there is no state. data instance ModeState 'TOTP = TOTPState deriving (Generic, Show) instance B.Binary (ModeState 'HOTP) instance B.Binary (ModeState 'TOTP) instance J.ToJSON (ModeState 'HOTP) where toEncoding HOTPState{..} = J.pairs $ "counter" J..= hotpCounter toJSON HOTPState{..} = J.object [ "counter" J..= hotpCounter ] instance J.ToJSON (ModeState 'TOTP) modeStateBinary :: SMode m -> DictOnly B.Binary (ModeState m) modeStateBinary = \case SHOTP -> DictOnly STOTP -> DictOnly -- | Which OTP-approved hash algorithm to use? data HashAlgo = HASHA1 | HASHA256 | HASHA512 deriving (Generic, Show) instance B.Binary HashAlgo instance J.ToJSON HashAlgo where toJSON HASHA1 = J.toJSON @T.Text "sha1" toJSON HASHA256 = J.toJSON @T.Text "sha256" toJSON HASHA512 = J.toJSON @T.Text "sha512" -- | Generate the /cryptonite/ 'HashAlgorithm' instance. hashAlgo :: HashAlgo -> Some (Dict HashAlgorithm) hashAlgo HASHA1 = Some $ Dict SHA1 hashAlgo HASHA256 = Some $ Dict SHA256 hashAlgo HASHA512 = Some $ Dict SHA512 -- | Parse a hash algorithm string into the appropriate 'HashAlgo'. parseAlgo :: String -> Maybe HashAlgo parseAlgo = (`lookup` algos) . map toLower . unwords . words where algos = [("sha1" , HASHA1 ) ,("sha256", HASHA256) ,("sha512", HASHA512) ] -- | Newtype wrapper to provide 'Eq', 'Ord', 'B.Binary', and 'J.ToJSON' -- instances. You can convert to and from this and the 'Int' -- representation using 'OTPDigitsInt' newtype OTPDigits = OTPDigits { otpDigits :: OTP.OTPDigits } deriving Show instance Eq OTPDigits where (==) = (==) `on` show instance Ord OTPDigits where compare = comparing show otpDigitsSet :: S.Set OTPDigits otpDigitsSet = S.fromList $ OTPDigits <$> [OTP.OTP4, OTP.OTP5, OTP.OTP6, OTP.OTP7, OTP.OTP8, OTP.OTP9] pattern OTPDigitsInt :: OTPDigits -> Int pattern OTPDigitsInt o <- ((`safeElemAt` otpDigitsSet) . subtract 4->Just o) where OTPDigitsInt o = S.findIndex o otpDigitsSet + 4 instance B.Binary OTPDigits where get = do OTPDigitsInt o <- B.get pure o put = B.put . OTPDigitsInt instance J.ToJSON OTPDigits where toEncoding = J.toEncoding . OTPDigitsInt toJSON = J.toJSON . OTPDigitsInt -- | A standards-compliant secret key type. Well, almost. It doesn't -- include configuration for the time period if it's time-based. data Secret :: Mode -> Type where Sec :: { secAccount :: T.Text , secIssuer :: Maybe T.Text , secAlgo :: HashAlgo , secDigits :: OTPDigits , secKey :: BS.ByteString } -> Secret m deriving (Generic, Show) instance B.Binary (Secret m) instance J.ToJSON (Secret m) where toEncoding Sec{..} = J.pairs ( "account" J..= secAccount <> maybe mempty ("issuer" J..=) secIssuer <> "algorithm" J..= secAlgo <> "digits" J..= secDigits <> "key" J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey)) ) toJSON Sec{..} = J.object $ [ "account" J..= secAccount , "algorithm" J..= secAlgo , "digits" J..= secDigits , "key" J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey)) ] ++ maybe [] ((:[]) . ("issuer" J..=)) secIssuer formatKey :: Int -- ^ chunk size -> T.Text -> T.Text formatKey c = T.unwords . T.chunksOf c . T.map toLower . T.filter isAlphaNum -- | Print out the metadata (account name and issuer) of a 'Secret'. describeSecret :: Secret m -> T.Text describeSecret s = secAccount s <> case secIssuer s of Nothing -> "" Just i -> " / " <> i instance B.Binary SomeSecretState where get = do m <- B.get withSMode m $ \s -> case modeStateBinary s of DictOnly -> do sc <- B.get ms <- B.get return $ s :=> sc :*: ms put = \case s :=> sc :*: ms -> case modeStateBinary s of DictOnly -> do B.put $ fromSMode s B.put sc B.put ms instance J.ToJSON SomeSecretState where toEncoding (s :=> sc :*: ms) = J.pairs ( "type" J..= fromSMode s <> "secret" J..= sc <> (case s of SHOTP -> "state" J..= ms STOTP -> mempty ) ) toJSON (s :=> sc :*: ms) = J.object $ [ "type" J..= fromSMode s , "secret" J..= sc ] ++ case s of SHOTP -> ["state" J..= ms] STOTP -> [] -- | A 'Secret' coupled with its 'ModeState', existentially quantified over -- its 'Mode'. type SomeSecretState = DSum SMode (Secret :*: ModeState) -- | A list of secrets and their states, of various modes. newtype Vault = Vault { vaultList :: [SomeSecretState] } deriving Generic instance B.Binary Vault instance J.ToJSON Vault where toEncoding l = J.pairs $ "vault" J..= vaultList l toJSON l = J.object ["vault" J..= vaultList l] -- | Generate an HTOP (counter-based) code, returning a modified state. hotp :: Secret 'HOTP -> ModeState 'HOTP -> (T.Text, ModeState 'HOTP) hotp Sec{..} (HOTPState i) = (formatKey 3 . T.pack $ printf fmt p, HOTPState (i + 1)) where fmt = "%0" ++ show (OTPDigitsInt secDigits) ++ "d" p = withSome (hashAlgo secAlgo) $ \case Dict a -> OTP.hotp a (otpDigits secDigits) secKey i -- | (Purely) generate a TOTP (time-based) code, for a given time. totp_ :: Secret 'TOTP -> POSIXTime -> T.Text totp_ Sec{..} t = withSome (hashAlgo secAlgo) $ \case Dict a -> let Right tparam = OTP.mkTOTPParams a 0 30 (otpDigits secDigits) OTP.TwoSteps in formatKey 3 . T.pack $ printf fmt $ OTP.totp tparam secKey (round t) where fmt = "%0" ++ show (OTPDigitsInt secDigits) ++ "d" -- | Generate a TOTP (time-based) code in IO for the current time. totp :: Secret 'TOTP -> IO T.Text totp s = totp_ s <$> getPOSIXTime -- | Abstract over both 'hotp' and 'totp'. otp :: SMode m -> Secret m -> ModeState m -> IO (T.Text, ModeState m) otp = \case SHOTP -> curry $ return . uncurry hotp STOTP -> curry $ bitraverse totp return -- | Some sort of RankN lens and traversal over a 'SomeSecret'. Allows you -- to traverse (effectfully map) over the 'ModeState' in -- a 'SomeSecretState', with access to the 'Secret' as well. -- -- With this you can implement getters and setters. It's also used by the -- library to update the 'ModeState' in IO. someSecret :: Functor f => (forall m. SMode m -> Secret m -> ModeState m -> f (ModeState m)) -> SomeSecretState -> f SomeSecretState someSecret f = \case s :=> (sc :*: ms) -> (s :=>) . (sc :*:) <$> f s sc ms -- | A RankN traversal over all of the 'Secret's and 'ModeState's in -- a 'Vault'. vaultSecrets :: Applicative f => (forall m. SMode m -> Secret m -> ModeState m -> f (ModeState m)) -> Vault -> f Vault vaultSecrets f = (_Vault . traverse) (someSecret f) -- | A lens into the list of 'SomeSecretState's in a 'Vault'. Should be an -- Iso but we don't want a lens dependency now, do we. _Vault :: Functor f => ([SomeSecretState] -> f [SomeSecretState]) -> Vault -> f Vault _Vault f s = Vault <$> f (vaultList s) type Parser = P.Parsec Void String -- | A parser for a otpauth URI. secretURI :: Parser SomeSecretState secretURI = do _ <- P.string "otpauth://" m <- otpMode _ <- P.char '/' (a,i) <- otpLabel ps <- M.fromList <$> P.try param `P.sepBy` P.char '&' sec <- case M.lookup "secret" ps of Nothing -> fail "Required parameter 'secret' not present" Just s -> case decodePad s of Just s' -> return s' Nothing -> fail $ "Not a valid base-32 string: " ++ T.unpack s let dig = fromMaybe (OTPDigits OTP.OTP6) $ do d <- M.lookup "digits" ps OTPDigitsInt o <- readMaybe $ T.unpack d pure o i' = i <|> M.lookup "issuer" ps alg = fromMaybe HASHA1 $ do al <- M.lookup "algorithm" ps parseAlgo . T.unpack . T.map toLower $ al secr :: forall m. Secret m secr = Sec a i' alg dig sec withSMode m $ \case SHOTP -> case M.lookup "counter" ps of Nothing -> fail "Paramater 'counter' required for hotp mode" Just (T.unpack->c) -> case readMaybe c of Nothing -> fail $ "Could not parse 'counter' parameter: " ++ c Just c' -> return $ SHOTP :=> secr :*: HOTPState c' STOTP -> return $ STOTP :=> secr :*: TOTPState where otpMode :: Parser Mode otpMode = HOTP <$ P.string "hotp" <|> HOTP <$ P.string "HOTP" <|> TOTP <$ P.string "totp" <|> TOTP <$ P.string "TOTP" otpLabel :: Parser (T.Text, Maybe T.Text) otpLabel = do x <- P.some (P.try (mfilter (/= ':') uriChar)) rest <- Just <$> ( colon *> P.manyTill (P.try uriChar <|> uriSpace) (P.char '?') ) <|> Nothing <$ P.char '?' return $ case rest of Nothing -> (T.pack . U.decode $ x, Nothing) Just y -> (T.pack . U.decode $ y, Just . T.pack . U.decode $ x) param :: Parser (T.Text, T.Text) param = do k <- T.map toLower . T.pack <$> P.some (P.try uriChar) _ <- P.char '=' v <- T.pack <$> P.some (P.try uriChar) return (k, v) uriChar = P.try (P.satisfy U.isAllowed) <|> P.char '@' <|> (do x <- U.decode <$> sequence [P.char '%', P.hexDigitChar, P.hexDigitChar] case x of [y] -> return y _ -> fail "Invalid URI escape code" ) colon = void (P.char ':') <|> void (P.string "%3A") uriSpace = ' ' <$ (void P.space <|> void (P.string "%20")) -- | Parse a valid otpauth URI and initialize its state. -- -- See parseSecretURI :: String -> Either String SomeSecretState parseSecretURI s = first P.errorBundlePretty $ P.parse secretURI "secret URI" s safeElemAt :: Int -> S.Set a -> Maybe a safeElemAt i s | i < S.size s = Just (S.elemAt i s) | otherwise = Nothing