{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Authenticator.Vault (
Mode(..), SMode(..), withSMode, fromSMode
, HashAlgo(..)
, parseAlgo
, Secret(..), OTPDigits(..), pattern OTPDigitsInt
, ModeState(..)
, SomeSecretState
, Vault(..)
, _Vault
, hotp
, totp
, totp_
, otp
, someSecret
, vaultSecrets
, describeSecret
, secretURI
, parseSecretURI
) where
import Authenticator.Common
import Control.Applicative
import Control.Monad hiding (fail)
import Crypto.Hash.Algorithms
import Data.Bifunctor
import Data.Bitraversable
import Data.Char
import Data.Dependent.Sum
import Data.Function
import Data.GADT.Show
import Data.Kind
import Data.Maybe
import Data.Ord
import Data.Some
import Data.Time.Clock.POSIX
import Data.Vinyl
import Data.Void
import Data.Word
import GHC.Generics
import Prelude.Compat
import Text.Printf
import Text.Read (readMaybe)
import qualified Codec.Binary.Base32 as B32
import qualified Crypto.OTP as OTP
import qualified Data.Aeson as J
import qualified Data.Binary as B
import qualified Data.ByteString as BS
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Network.URI.Encode as U
import qualified Text.Megaparsec as P
import qualified Text.Megaparsec.Char as P
data Mode
= HOTP
| TOTP
deriving (Generic, Show)
data SMode :: Mode -> Type where
SHOTP :: SMode 'HOTP
STOTP :: SMode 'TOTP
deriving instance Show (SMode m)
instance GShow SMode where
gshowsPrec = showsPrec
instance B.Binary Mode
instance J.ToJSON Mode where
toJSON HOTP = J.toJSON @T.Text "hotp"
toJSON TOTP = J.toJSON @T.Text "totp"
withSMode
:: Mode
-> (forall m. SMode m -> r)
-> r
withSMode = \case
HOTP -> ($ SHOTP)
TOTP -> ($ STOTP)
fromSMode :: SMode m -> Mode
fromSMode = \case
SHOTP -> HOTP
STOTP -> TOTP
data family ModeState :: Mode -> Type
data instance ModeState 'HOTP = HOTPState { hotpCounter :: Word64 }
deriving (Generic, Show)
data instance ModeState 'TOTP = TOTPState
deriving (Generic, Show)
instance B.Binary (ModeState 'HOTP)
instance B.Binary (ModeState 'TOTP)
instance J.ToJSON (ModeState 'HOTP) where
toEncoding HOTPState{..} = J.pairs $ "counter" J..= hotpCounter
toJSON HOTPState{..} = J.object
[ "counter" J..= hotpCounter ]
instance J.ToJSON (ModeState 'TOTP)
modeStateBinary :: SMode m -> DictOnly B.Binary (ModeState m)
modeStateBinary = \case
SHOTP -> DictOnly
STOTP -> DictOnly
data HashAlgo = HASHA1 | HASHA256 | HASHA512
deriving (Generic, Show)
instance B.Binary HashAlgo
instance J.ToJSON HashAlgo where
toJSON HASHA1 = J.toJSON @T.Text "sha1"
toJSON HASHA256 = J.toJSON @T.Text "sha256"
toJSON HASHA512 = J.toJSON @T.Text "sha512"
hashAlgo :: HashAlgo -> Some (Dict HashAlgorithm)
hashAlgo HASHA1 = Some $ Dict SHA1
hashAlgo HASHA256 = Some $ Dict SHA256
hashAlgo HASHA512 = Some $ Dict SHA512
parseAlgo :: String -> Maybe HashAlgo
parseAlgo = (`lookup` algos) . map toLower . unwords . words
where
algos = [("sha1" , HASHA1 )
,("sha256", HASHA256)
,("sha512", HASHA512)
]
newtype OTPDigits = OTPDigits { otpDigits :: OTP.OTPDigits }
deriving Show
instance Eq OTPDigits where
(==) = (==) `on` show
instance Ord OTPDigits where
compare = comparing show
otpDigitsSet :: S.Set OTPDigits
otpDigitsSet = S.fromList $
OTPDigits <$> [OTP.OTP4, OTP.OTP5, OTP.OTP6, OTP.OTP7, OTP.OTP8, OTP.OTP9]
pattern OTPDigitsInt :: OTPDigits -> Int
pattern OTPDigitsInt o <- ((`safeElemAt` otpDigitsSet) . subtract 4->Just o)
where
OTPDigitsInt o = S.findIndex o otpDigitsSet + 4
instance B.Binary OTPDigits where
get = do
OTPDigitsInt o <- B.get
pure o
put = B.put . OTPDigitsInt
instance J.ToJSON OTPDigits where
toEncoding = J.toEncoding . OTPDigitsInt
toJSON = J.toJSON . OTPDigitsInt
data Secret :: Mode -> Type where
Sec :: { secAccount :: T.Text
, secIssuer :: Maybe T.Text
, secAlgo :: HashAlgo
, secDigits :: OTPDigits
, secKey :: BS.ByteString
}
-> Secret m
deriving (Generic, Show)
instance B.Binary (Secret m)
instance J.ToJSON (Secret m) where
toEncoding Sec{..} = J.pairs
( "account" J..= secAccount
<> maybe mempty ("issuer" J..=) secIssuer
<> "algorithm" J..= secAlgo
<> "digits" J..= secDigits
<> "key" J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey))
)
toJSON Sec{..} = J.object $
[ "account" J..= secAccount
, "algorithm" J..= secAlgo
, "digits" J..= secDigits
, "key" J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey))
] ++ maybe [] ((:[]) . ("issuer" J..=)) secIssuer
formatKey
:: Int
-> T.Text
-> T.Text
formatKey c = T.unwords
. T.chunksOf c
. T.map toLower
. T.filter isAlphaNum
describeSecret
:: Secret m
-> T.Text
describeSecret s = secAccount s <> case secIssuer s of
Nothing -> ""
Just i -> " / " <> i
instance B.Binary SomeSecretState where
get = do
m <- B.get
withSMode m $ \s -> case modeStateBinary s of
DictOnly -> do
sc <- B.get
ms <- B.get
return $ s :=> sc :*: ms
put = \case
s :=> sc :*: ms -> case modeStateBinary s of
DictOnly -> do
B.put $ fromSMode s
B.put sc
B.put ms
instance J.ToJSON SomeSecretState where
toEncoding (s :=> sc :*: ms) = J.pairs
( "type" J..= fromSMode s
<> "secret" J..= sc
<> (case s of SHOTP -> "state" J..= ms
STOTP -> mempty
)
)
toJSON (s :=> sc :*: ms) = J.object $
[ "type" J..= fromSMode s
, "secret" J..= sc
] ++ case s of SHOTP -> ["state" J..= ms]
STOTP -> []
type SomeSecretState = DSum SMode (Secret :*: ModeState)
newtype Vault = Vault { vaultList :: [SomeSecretState] }
deriving Generic
instance B.Binary Vault
instance J.ToJSON Vault where
toEncoding l = J.pairs $ "vault" J..= vaultList l
toJSON l = J.object ["vault" J..= vaultList l]
hotp :: Secret 'HOTP -> ModeState 'HOTP -> (T.Text, ModeState 'HOTP)
hotp Sec{..} (HOTPState i) =
(formatKey 3 . T.pack $ printf fmt p, HOTPState (i + 1))
where
fmt = "%0" ++ show (OTPDigitsInt secDigits) ++ "d"
p = withSome (hashAlgo secAlgo) $ \case
Dict a -> OTP.hotp a (otpDigits secDigits) secKey i
totp_ :: Secret 'TOTP -> POSIXTime -> T.Text
totp_ Sec{..} t = withSome (hashAlgo secAlgo) $ \case
Dict a ->
let Right tparam =
OTP.mkTOTPParams a 0 30 (otpDigits secDigits) OTP.TwoSteps
in formatKey 3 . T.pack $
printf fmt $
OTP.totp tparam secKey (round t)
where
fmt = "%0" ++ show (OTPDigitsInt secDigits) ++ "d"
totp :: Secret 'TOTP -> IO T.Text
totp s = totp_ s <$> getPOSIXTime
otp :: SMode m -> Secret m -> ModeState m -> IO (T.Text, ModeState m)
otp = \case
SHOTP -> curry $ return . uncurry hotp
STOTP -> curry $ bitraverse totp return
someSecret
:: Functor f
=> (forall m. SMode m -> Secret m -> ModeState m -> f (ModeState m))
-> SomeSecretState
-> f SomeSecretState
someSecret f = \case
s :=> (sc :*: ms) -> (s :=>) . (sc :*:) <$> f s sc ms
vaultSecrets
:: Applicative f
=> (forall m. SMode m -> Secret m -> ModeState m -> f (ModeState m))
-> Vault
-> f Vault
vaultSecrets f = (_Vault . traverse) (someSecret f)
_Vault
:: Functor f
=> ([SomeSecretState] -> f [SomeSecretState])
-> Vault
-> f Vault
_Vault f s = Vault <$> f (vaultList s)
type Parser = P.Parsec Void String
secretURI :: Parser SomeSecretState
secretURI = do
_ <- P.string "otpauth://"
m <- otpMode
_ <- P.char '/'
(a,i) <- otpLabel
ps <- M.fromList <$> P.try param `P.sepBy` P.char '&'
sec <- case M.lookup "secret" ps of
Nothing -> fail "Required parameter 'secret' not present"
Just s ->
case decodePad s of
Just s' -> return s'
Nothing -> fail $ "Not a valid base-32 string: " ++ T.unpack s
let dig = fromMaybe (OTPDigits OTP.OTP6) $ do
d <- M.lookup "digits" ps
OTPDigitsInt o <- readMaybe $ T.unpack d
pure o
i' = i <|> M.lookup "issuer" ps
alg = fromMaybe HASHA1 $ do
al <- M.lookup "algorithm" ps
parseAlgo . T.unpack . T.map toLower $ al
secr :: forall m. Secret m
secr = Sec a i' alg dig sec
withSMode m $ \case
SHOTP -> case M.lookup "counter" ps of
Nothing -> fail "Paramater 'counter' required for hotp mode"
Just (T.unpack->c) -> case readMaybe c of
Nothing -> fail $ "Could not parse 'counter' parameter: " ++ c
Just c' -> return $ SHOTP :=> secr :*: HOTPState c'
STOTP -> return $ STOTP :=> secr :*: TOTPState
where
otpMode :: Parser Mode
otpMode = HOTP <$ P.string "hotp"
<|> HOTP <$ P.string "HOTP"
<|> TOTP <$ P.string "totp"
<|> TOTP <$ P.string "TOTP"
otpLabel :: Parser (T.Text, Maybe T.Text)
otpLabel = do
x <- P.some (P.try (mfilter (/= ':') uriChar))
rest <- Just <$> ( colon
*> P.manyTill (P.try uriChar <|> uriSpace) (P.char '?')
)
<|> Nothing <$ P.char '?'
return $ case rest of
Nothing -> (T.pack . U.decode $ x, Nothing)
Just y -> (T.pack . U.decode $ y, Just . T.pack . U.decode $ x)
param :: Parser (T.Text, T.Text)
param = do
k <- T.map toLower . T.pack <$> P.some (P.try uriChar)
_ <- P.char '='
v <- T.pack <$> P.some (P.try uriChar)
return (k, v)
uriChar = P.try (P.satisfy U.isAllowed)
<|> P.char '@'
<|> (do x <- U.decode <$> sequence [P.char '%', P.hexDigitChar, P.hexDigitChar]
case x of
[y] -> return y
_ -> fail "Invalid URI escape code"
)
colon = void (P.char ':') <|> void (P.string "%3A")
uriSpace = ' ' <$ (void P.space <|> void (P.string "%20"))
parseSecretURI
:: String
-> Either String SomeSecretState
parseSecretURI s = first P.errorBundlePretty $
P.parse secretURI "secret URI" s
safeElemAt :: Int -> S.Set a -> Maybe a
safeElemAt i s
| i < S.size s = Just (S.elemAt i s)
| otherwise = Nothing