{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE NoImplicitPrelude #-} -- | -- Module : Authenticator.Vault -- Description : Secrets and storage for OTP keys. -- Copyright : (c) Justin Le 2017 -- License : MIT -- Maintainer : justin@jle.im -- Stability : unstable -- Portability : portable -- -- Types for storing, serializing, accessing OTP keys. Gratuitous -- type-level programming here for no reason because I have issues. -- -- Based off of . module Authenticator.Vault ( Mode (..), SMode (..), withSMode, fromSMode, HashAlgo (..), parseAlgo, Secret (..), OTPDigits (..), pattern OTPDigitsInt, ModeState (..), SomeSecretState, Vault (..), _Vault, hotp, totp, totp_, otp, someSecret, vaultSecrets, describeSecret, secretURI, parseSecretURI, ) where import Authenticator.Common import qualified Codec.Binary.Base32 as B32 import Control.Applicative import Control.Monad hiding (fail) import Crypto.Hash.Algorithms import qualified Crypto.OTP as OTP import qualified Data.Aeson as J import Data.Bifunctor import qualified Data.Binary as B import Data.Bitraversable import qualified Data.ByteString as BS import Data.Char import Data.Dependent.Sum import Data.Function import Data.GADT.Show import Data.Kind import qualified Data.Map as M import Data.Maybe import Data.Ord import qualified Data.Set as S import Data.Some import qualified Data.Text as T import qualified Data.Text.Encoding as T import Data.Time.Clock.POSIX import Data.Vinyl import Data.Void import Data.Word import GHC.Generics import qualified Network.URI.Encode as U import Prelude.Compat import qualified Text.Megaparsec as P import qualified Text.Megaparsec.Char as P import Text.Printf import Text.Read (readMaybe) -- | OTP generation mode data Mode = -- | Counter-based HOTP | -- | Time-based TOTP deriving (Generic, Show) -- | Singleton for 'Mode' data SMode :: Mode -> Type where SHOTP :: SMode 'HOTP STOTP :: SMode 'TOTP deriving instance Show (SMode m) instance GShow SMode where gshowsPrec = showsPrec instance B.Binary Mode instance J.ToJSON Mode where toJSON HOTP = J.toJSON @T.Text "hotp" toJSON TOTP = J.toJSON @T.Text "totp" -- | Reify a 'Mode' to its singleton withSMode :: Mode -> (forall m. SMode m -> r) -> r withSMode = \case HOTP -> ($ SHOTP) TOTP -> ($ STOTP) -- | Reflect a 'SMode' to its value. fromSMode :: SMode m -> Mode fromSMode = \case SHOTP -> HOTP STOTP -> TOTP -- | A data family consisting of the state required by each mode. data family ModeState :: Mode -> Type -- | For 'HOTP' (counter-based) mode, the state is the current counter. data instance ModeState 'HOTP = HOTPState {hotpCounter :: Word64} deriving (Generic, Show) -- | For 'TOTP' (time-based) mode, there is no state. data instance ModeState 'TOTP = TOTPState deriving (Generic, Show) instance B.Binary (ModeState 'HOTP) instance B.Binary (ModeState 'TOTP) instance J.ToJSON (ModeState 'HOTP) where toEncoding HOTPState {..} = J.pairs $ "counter" J..= hotpCounter toJSON HOTPState {..} = J.object ["counter" J..= hotpCounter] instance J.ToJSON (ModeState 'TOTP) modeStateBinary :: SMode m -> DictOnly B.Binary (ModeState m) modeStateBinary = \case SHOTP -> DictOnly STOTP -> DictOnly -- | Which OTP-approved hash algorithm to use? data HashAlgo = HASHA1 | HASHA256 | HASHA512 deriving (Generic, Show) instance B.Binary HashAlgo instance J.ToJSON HashAlgo where toJSON HASHA1 = J.toJSON @T.Text "sha1" toJSON HASHA256 = J.toJSON @T.Text "sha256" toJSON HASHA512 = J.toJSON @T.Text "sha512" -- | Generate the /cryptonite/ 'HashAlgorithm' instance. hashAlgo :: HashAlgo -> Some (Dict HashAlgorithm) hashAlgo HASHA1 = Some $ Dict SHA1 hashAlgo HASHA256 = Some $ Dict SHA256 hashAlgo HASHA512 = Some $ Dict SHA512 -- | Parse a hash algorithm string into the appropriate 'HashAlgo'. parseAlgo :: String -> Maybe HashAlgo parseAlgo = (`lookup` algos) . map toLower . unwords . words where algos = [ ("sha1", HASHA1), ("sha256", HASHA256), ("sha512", HASHA512) ] -- | Newtype wrapper to provide 'Eq', 'Ord', 'B.Binary', and 'J.ToJSON' -- instances. You can convert to and from this and the 'Int' -- representation using 'OTPDigitsInt' newtype OTPDigits = OTPDigits {otpDigits :: OTP.OTPDigits} deriving (Show) instance Eq OTPDigits where (==) = (==) `on` show instance Ord OTPDigits where compare = comparing show otpDigitsSet :: S.Set OTPDigits otpDigitsSet = S.fromList $ OTPDigits <$> [OTP.OTP4, OTP.OTP5, OTP.OTP6, OTP.OTP7, OTP.OTP8, OTP.OTP9] pattern OTPDigitsInt :: OTPDigits -> Int pattern OTPDigitsInt o <- ((`safeElemAt` otpDigitsSet) . subtract 4 -> Just o) where OTPDigitsInt o = S.findIndex o otpDigitsSet + 4 instance B.Binary OTPDigits where get = do OTPDigitsInt o <- B.get pure o put = B.put . OTPDigitsInt instance J.ToJSON OTPDigits where toEncoding = J.toEncoding . OTPDigitsInt toJSON = J.toJSON . OTPDigitsInt -- | A standards-compliant secret key type. Well, almost. It doesn't -- include configuration for the time period if it's time-based. data Secret :: Mode -> Type where Sec :: { secAccount :: T.Text, secIssuer :: Maybe T.Text, secAlgo :: HashAlgo, secDigits :: OTPDigits, secKey :: BS.ByteString } -> Secret m deriving (Generic, Show) instance B.Binary (Secret m) instance J.ToJSON (Secret m) where toEncoding Sec {..} = J.pairs ( "account" J..= secAccount <> maybe mempty ("issuer" J..=) secIssuer <> "algorithm" J..= secAlgo <> "digits" J..= secDigits <> "key" J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey)) ) toJSON Sec {..} = J.object $ [ "account" J..= secAccount, "algorithm" J..= secAlgo, "digits" J..= secDigits, "key" J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey)) ] ++ maybe [] ((: []) . ("issuer" J..=)) secIssuer formatKey :: -- | chunk size Int -> T.Text -> T.Text formatKey c = T.unwords . T.chunksOf c . T.map toLower . T.filter isAlphaNum -- | Print out the metadata (account name and issuer) of a 'Secret'. describeSecret :: Secret m -> T.Text describeSecret s = secAccount s <> case secIssuer s of Nothing -> "" Just i -> " / " <> i instance B.Binary SomeSecretState where get = do m <- B.get withSMode m $ \s -> case modeStateBinary s of DictOnly -> do sc <- B.get ms <- B.get return $ s :=> sc :*: ms put = \case s :=> sc :*: ms -> case modeStateBinary s of DictOnly -> do B.put $ fromSMode s B.put sc B.put ms instance J.ToJSON SomeSecretState where toEncoding (s :=> sc :*: ms) = J.pairs ( "type" J..= fromSMode s <> "secret" J..= sc <> ( case s of SHOTP -> "state" J..= ms STOTP -> mempty ) ) toJSON (s :=> sc :*: ms) = J.object $ [ "type" J..= fromSMode s, "secret" J..= sc ] ++ case s of SHOTP -> ["state" J..= ms] STOTP -> [] -- | A 'Secret' coupled with its 'ModeState', existentially quantified over -- its 'Mode'. type SomeSecretState = DSum SMode (Secret :*: ModeState) -- | A list of secrets and their states, of various modes. newtype Vault = Vault {vaultList :: [SomeSecretState]} deriving (Generic) instance B.Binary Vault instance J.ToJSON Vault where toEncoding l = J.pairs $ "vault" J..= vaultList l toJSON l = J.object ["vault" J..= vaultList l] -- | Generate an HTOP (counter-based) code, returning a modified state. hotp :: Secret 'HOTP -> ModeState 'HOTP -> (T.Text, ModeState 'HOTP) hotp Sec {..} (HOTPState i) = (formatKey 3 . T.pack $ printf fmt p, HOTPState (i + 1)) where fmt = "%0" ++ show (OTPDigitsInt secDigits) ++ "d" p = withSome (hashAlgo secAlgo) $ \case Dict a -> OTP.hotp a (otpDigits secDigits) secKey i -- | (Purely) generate a TOTP (time-based) code, for a given time. totp_ :: Secret 'TOTP -> POSIXTime -> T.Text totp_ Sec {..} t = withSome (hashAlgo secAlgo) $ \case Dict a -> let tparam = case OTP.mkTOTPParams a 0 30 (otpDigits secDigits) OTP.TwoSteps of Left e -> error $ "totp_: " <> e Right x -> x in formatKey 3 . T.pack $ printf fmt $ OTP.totp tparam secKey (round t) where fmt = "%0" ++ show (OTPDigitsInt secDigits) ++ "d" -- | Generate a TOTP (time-based) code in IO for the current time. totp :: Secret 'TOTP -> IO T.Text totp s = totp_ s <$> getPOSIXTime -- | Abstract over both 'hotp' and 'totp'. otp :: SMode m -> Secret m -> ModeState m -> IO (T.Text, ModeState m) otp = \case SHOTP -> curry $ return . uncurry hotp STOTP -> curry $ bitraverse totp return -- | Some sort of RankN lens and traversal over a 'SomeSecret'. Allows you -- to traverse (effectfully map) over the 'ModeState' in -- a 'SomeSecretState', with access to the 'Secret' as well. -- -- With this you can implement getters and setters. It's also used by the -- library to update the 'ModeState' in IO. someSecret :: (Functor f) => (forall m. SMode m -> Secret m -> ModeState m -> f (ModeState m)) -> SomeSecretState -> f SomeSecretState someSecret f = \case s :=> (sc :*: ms) -> (s :=>) . (sc :*:) <$> f s sc ms -- | A RankN traversal over all of the 'Secret's and 'ModeState's in -- a 'Vault'. vaultSecrets :: (Applicative f) => (forall m. SMode m -> Secret m -> ModeState m -> f (ModeState m)) -> Vault -> f Vault vaultSecrets f = (_Vault . traverse) (someSecret f) -- | A lens into the list of 'SomeSecretState's in a 'Vault'. Should be an -- Iso but we don't want a lens dependency now, do we. _Vault :: (Functor f) => ([SomeSecretState] -> f [SomeSecretState]) -> Vault -> f Vault _Vault f s = Vault <$> f (vaultList s) type Parser = P.Parsec Void String -- | A parser for a otpauth URI. secretURI :: Parser SomeSecretState secretURI = do _ <- P.string "otpauth://" m <- otpMode _ <- P.char '/' (a, i) <- otpLabel ps <- M.fromList <$> P.try param `P.sepBy` P.char '&' sec <- case M.lookup "secret" ps of Nothing -> fail "Required parameter 'secret' not present" Just s -> case decodePad s of Just s' -> return s' Nothing -> fail $ "Not a valid base-32 string: " ++ T.unpack s let dig = fromMaybe (OTPDigits OTP.OTP6) $ do d <- M.lookup "digits" ps OTPDigitsInt o <- readMaybe $ T.unpack d pure o i' = i <|> M.lookup "issuer" ps alg = fromMaybe HASHA1 $ do al <- M.lookup "algorithm" ps parseAlgo . T.unpack . T.map toLower $ al secr :: forall m. Secret m secr = Sec a i' alg dig sec withSMode m $ \case SHOTP -> case M.lookup "counter" ps of Nothing -> fail "Paramater 'counter' required for hotp mode" Just (T.unpack -> c) -> case readMaybe c of Nothing -> fail $ "Could not parse 'counter' parameter: " ++ c Just c' -> return $ SHOTP :=> secr :*: HOTPState c' STOTP -> return $ STOTP :=> secr :*: TOTPState where otpMode :: Parser Mode otpMode = HOTP <$ P.string "hotp" <|> HOTP <$ P.string "HOTP" <|> TOTP <$ P.string "totp" <|> TOTP <$ P.string "TOTP" otpLabel :: Parser (T.Text, Maybe T.Text) otpLabel = do x <- P.some (P.try (mfilter (/= ':') uriChar)) rest <- Just <$> ( colon *> P.manyTill (P.try uriChar <|> uriSpace) (P.char '?') ) <|> Nothing <$ P.char '?' return $ case rest of Nothing -> (T.pack . U.decode $ x, Nothing) Just y -> (T.pack . U.decode $ y, Just . T.pack . U.decode $ x) param :: Parser (T.Text, T.Text) param = do k <- T.map toLower . T.pack <$> P.some (P.try uriChar) _ <- P.char '=' v <- T.pack <$> P.some (P.try uriChar) return (k, v) uriChar = P.try (P.satisfy U.isAllowed) <|> P.char '@' <|> ( do x <- U.decode <$> sequence [P.char '%', P.hexDigitChar, P.hexDigitChar] case x of [y] -> return y _ -> fail "Invalid URI escape code" ) colon = void (P.char ':') <|> void (P.string "%3A") uriSpace = ' ' <$ (void P.space <|> void (P.string "%20")) -- | Parse a valid otpauth URI and initialize its state. -- -- See parseSecretURI :: String -> Either String SomeSecretState parseSecretURI s = first P.errorBundlePretty $ P.parse secretURI "secret URI" s safeElemAt :: Int -> S.Set a -> Maybe a safeElemAt i s | i < S.size s = Just (S.elemAt i s) | otherwise = Nothing