{-# LANGUAGE DeriveFunctor        #-}
{-# LANGUAGE DeriveGeneric        #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE LambdaCase           #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE RankNTypes           #-}
{-# LANGUAGE RecordWildCards      #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE StandaloneDeriving   #-}
{-# LANGUAGE TemplateHaskell      #-}
{-# LANGUAGE TupleSections        #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeInType           #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE ViewPatterns         #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Authenticator.Vault (
    Mode(..)
  , Sing(SHOTP, STOTP)
  , SMode, HOTPSym0, TOTPSym0
  , HashAlgo(..)
  , parseAlgo
  , Secret(..)
  , ModeState(..)
  , Vault(..)
  , _Vault
  , hotp
  , totp
  , totp_
  , otp
  , someotp
  , someSecret
  , vaultSecrets
  , describeSecret
  , secretURI
  , parseSecretURI
  ) where

import           Authenticator.Common
import           Control.Applicative
import           Control.Monad
import           Crypto.Hash.Algorithms
import           Data.Bitraversable
import           Data.Char
import           Data.Dependent.Sum
import           Data.Kind
import           Data.Maybe
import           Data.Semigroup
import           Data.Singletons
import           Data.Singletons.TH
import           Data.Time.Clock
import           Data.Type.Combinator
import           Data.Type.Conjunction
import           Data.Word
import           GHC.Generics           (Generic)
import           Text.Printf
import           Text.Read              (readMaybe)
import           Type.Class.Higher
import           Type.Class.Witness
import qualified Codec.Binary.Base32    as B32
import qualified Data.Aeson             as J
import qualified Data.Binary            as B
import qualified Data.ByteString        as BS
import qualified Data.Map               as M
import qualified Data.OTP               as OTP
import qualified Data.Text              as T
import qualified Data.Text.Encoding     as T
import qualified Network.URI.Encode     as U
import qualified Text.Trifecta          as P

$(singletons [d|
  data Mode = HOTP | TOTP
    deriving (Generic, Show)
  |])

instance B.Binary Mode
instance J.ToJSON Mode where
    toJSON HOTP = J.toJSON @T.Text "hotp"
    toJSON TOTP = J.toJSON @T.Text "totp"

data family ModeState :: Mode -> Type
data instance ModeState 'HOTP =
    HOTPState { hotpCounter :: Word64
              }
  deriving (Generic, Show)
data instance ModeState 'TOTP = TOTPState
  deriving (Generic, Show)

instance B.Binary (ModeState 'HOTP)
instance B.Binary (ModeState 'TOTP)
instance J.ToJSON (ModeState 'HOTP) where
    toEncoding (HOTPState{..}) = J.pairs $ "counter" J..= hotpCounter
    toJSON (HOTPState{..}) = J.object
      [ "counter" J..= hotpCounter ]

instance J.ToJSON (ModeState 'TOTP)

modeStateBinary :: Sing m -> Wit1 B.Binary (ModeState m)
modeStateBinary = \case
    SHOTP -> Wit1
    STOTP -> Wit1

data HashAlgo = HASHA1 | HASHA256 | HASHA512
  deriving (Generic, Show)

instance B.Binary HashAlgo
instance J.ToJSON HashAlgo where
    toJSON HASHA1   = J.toJSON @T.Text "sha1"
    toJSON HASHA256 = J.toJSON @T.Text "sha256"
    toJSON HASHA512 = J.toJSON @T.Text "sha512"

hashAlgo :: HashAlgo -> SomeC HashAlgorithm I
hashAlgo HASHA1   = SomeC (I SHA1  )
hashAlgo HASHA256 = SomeC (I SHA256)
hashAlgo HASHA512 = SomeC (I SHA512)

parseAlgo :: String -> Maybe HashAlgo
parseAlgo = (`lookup` algos) . map toLower . unwords . words
  where
    algos = [("sha1", HASHA1)
            ,("sha256", HASHA256)
            ,("sha512", HASHA512)
            ]

-- TODO: add period?
data Secret :: Mode -> Type where
    Sec :: { secAccount :: T.Text
           , secIssuer  :: Maybe T.Text
           , secAlgo    :: HashAlgo
           , secDigits  :: Word
           , secKey     :: BS.ByteString
           }
        -> Secret m
  deriving (Generic, Show)

instance B.Binary (Secret m)
instance J.ToJSON (Secret m) where
    toEncoding (Sec{..}) = J.pairs
        ( "account"   J..= secAccount
       <> maybe mempty ("issuer" J..=) secIssuer
       <> "algorithm" J..= secAlgo
       <> "digits"    J..= secDigits
       <> "key"       J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey))
        )
    toJSON (Sec{..}) = J.object $
        [ "account"   J..= secAccount
        , "algorithm" J..= secAlgo
        , "digits"    J..= secDigits
        , "key"       J..= formatKey 4 (T.decodeUtf8 (B32.encode secKey))
        ] ++ maybe [] ((:[]) . ("issuer" J..=)) secIssuer

formatKey :: Int -> T.Text -> T.Text
formatKey c = T.unwords
          . T.chunksOf c
          . T.map toLower
          . T.filter isAlphaNum

describeSecret :: Secret m -> T.Text
describeSecret s = secAccount s <> case secIssuer s of
                                     Nothing -> ""
                                     Just i  -> " / " <> i

instance B.Binary (DSum Sing (Secret :&: ModeState)) where
    get = do
      m <- B.get
      withSomeSing m $ \s -> modeStateBinary s // do
        sc <- B.get
        ms <- B.get
        return $ s :=> sc :&: ms
    put = \case
      s :=> sc :&: ms -> modeStateBinary s // do
        B.put $ fromSing s
        B.put sc
        B.put ms

instance J.ToJSON (DSum Sing (Secret :&: ModeState)) where
    toEncoding (s :=> sc :&: ms) = J.pairs
        ( "type"   J..= fromSing s
       <> "secret" J..= sc
       <> (case s of SHOTP -> "state" J..= ms
                     STOTP -> mempty
          )
        )
    toJSON (s :=> sc :&: ms) = J.object $
        [ "type"   J..= fromSing s
        , "secret" J..= sc
        ] ++ case s of SHOTP -> ["state" J..= ms]
                       STOTP -> []

data Vault = Vault { vaultList :: [DSum Sing (Secret :&: ModeState)] }
  deriving Generic

instance B.Binary Vault
instance J.ToJSON Vault where
    toEncoding l = J.pairs $ "vault" J..= vaultList l
    toJSON l     = J.object ["vault" J..= vaultList l]

hotp :: Secret 'HOTP -> ModeState 'HOTP -> (T.Text, ModeState 'HOTP)
hotp Sec{..} (HOTPState i) =
    (formatKey 3 . T.pack $ printf fmt p, HOTPState (i + 1))
  where
    fmt = "%0" ++ show secDigits ++ "d"
    p = hashAlgo secAlgo >>~ \(I a) -> OTP.hotp a secKey i secDigits

totp_ :: Secret 'TOTP -> UTCTime -> T.Text
totp_ Sec{..} t = hashAlgo secAlgo >>~ \(I a) -> formatKey 3 . T.pack $
    printf fmt $ OTP.totp a secKey (30 `addUTCTime` t) 30 secDigits
  where
    fmt = "%0" ++ show secDigits ++ "d"

totp :: Secret 'TOTP -> IO T.Text
totp s = totp_ s <$> getCurrentTime

otp :: forall m. SingI m => Secret m -> ModeState m -> IO (T.Text, ModeState m)
otp = case sing @_ @m of
    SHOTP -> curry $ return . uncurry hotp
    STOTP -> curry $ bitraverse totp return

someotp :: DSum Sing (Secret :&: ModeState) -> IO (T.Text, DSum Sing (Secret :&: ModeState))
someotp = getComp . someSecret (\s -> Comp . otp s)

someSecret
    :: Functor f
    => (forall m. SingI m => Secret m -> ModeState m -> f (ModeState m))
    -> DSum Sing (Secret :&: ModeState)
    -> f (DSum Sing (Secret :&: ModeState))
someSecret f = \case
    s :=> (sc :&: ms) -> withSingI s $ ((s :=>) . (sc :&:)) <$> f sc ms

deriving instance (Functor f, Functor g) => Functor (f :.: g)

vaultSecrets
    :: Applicative f
    => (forall m. SingI m => Secret m -> ModeState m -> f (ModeState m))
    -> Vault
    -> f Vault
vaultSecrets f = (_Vault . traverse) (someSecret f)

_Vault
    :: Functor f
    => ([DSum Sing (Secret :&: ModeState)] -> f [DSum Sing (Secret :&: ModeState)])
    -> Vault
    -> f Vault
_Vault f s = Vault <$> f (vaultList s)

secretURI :: P.Parser (DSum Sing (Secret :&: ModeState))
secretURI = do
    _ <- P.string "otpauth://"
    m <- otpMode
    _ <- P.char '/'
    (a,i) <- otpLabel
    ps <- M.fromList <$> param `P.sepBy` P.char '&'
    sec <- case M.lookup "secret" ps of
      Nothing -> fail "Required parameter 'secret' not present"
      Just s ->
        case decodePad s of
          Just s' -> return s'
          Nothing -> fail $ "Not a valid base-32 string: " ++ T.unpack s
    let dig = fromMaybe 6 $ do
          d <- M.lookup "digits" ps
          readMaybe @Word $ T.unpack d
        i' = i <|> M.lookup "issuer" ps
        alg = fromMaybe HASHA1 $ do
          al <- M.lookup "algorithm" ps
          parseAlgo . T.unpack . T.map toLower $ al
        secr :: forall m. Secret m
        secr = Sec a i' alg dig sec

    withSomeSing m $ \case
      SHOTP -> case M.lookup "counter" ps of
          Nothing -> fail "Paramater 'counter' required for hotp mode"
          Just (T.unpack->c) -> case readMaybe c of
            Nothing -> fail $ "Could not parse 'counter' parameter: " ++ c
            Just c' -> return $ SHOTP :=> secr :&: HOTPState c'
      STOTP -> return $ STOTP :=> secr :&: TOTPState
  where
    otpMode :: P.Parser Mode
    otpMode = HOTP <$ P.string "hotp"
          <|> HOTP <$ P.string "HOTP"
          <|> TOTP <$ P.string "totp"
          <|> TOTP <$ P.string "TOTP"
    otpLabel :: P.Parser (T.Text, Maybe T.Text)
    otpLabel = do
      x <- P.some (P.try (mfilter (/= ':') uriChar))
      rest <- Just <$> (colon
                     *> P.many (P.try uriSpace)
                     *> P.some (P.try uriChar)
                     <* P.char '?'
                       )
          <|> Nothing <$ P.char '?'
      return $ case rest of
        Nothing -> (T.pack . U.decode $ x, Nothing)
        Just y  -> (T.pack . U.decode $ y, Just . T.pack . U.decode $ x)
    param :: P.Parser (T.Text, T.Text)
    param = do
      k <- T.map toLower . T.pack <$> P.some (P.try uriChar)
      _ <- P.char '='
      v <- T.pack <$> P.some (P.try uriChar)
      return (k, v)
    uriChar = P.satisfy U.isAllowed
          <|> P.char '@'
          <|> (do x <- U.decode <$> sequence [P.char '%', P.hexDigit, P.hexDigit]
                  case x of
                    [y] -> return y
                    _   -> fail "Invalid URI escape code"
              )
    colon    = void (P.char ':') <|> void (P.string "%3A")
    uriSpace = void P.space      <|> void (P.string "%20")

parseSecretURI
    :: String
    -> Either String (DSum Sing (Secret :&: ModeState))
parseSecretURI s = case P.parseString secretURI mempty s of
    P.Success r -> Right r
    P.Failure e -> Left (show e)