module Web.Csrf (
  Csrf (..)
, CsrfCheckResult (..)
, mkCsrf
, getCsrf
, runCheck
, unMkToken
) where

import           Codec.Crypto.SimpleAES
import           Data.ByteString
import           Data.ByteString.Base64.URL
import           Data.String.Conversions

newtype Token = MkToken ByteString
  deriving (Eq, Show)

data CsrfCheckResult = Invalid | Valid
  deriving (Eq, Show)

data Csrf = Csrf {
  cookie :: Token
, formToken :: Token
, validationResult :: Maybe CsrfCheckResult
} deriving (Eq, Show)

runCheck :: ByteString -> Csrf -> Csrf
runCheck keyBS csrf@(Csrf (MkToken c) (MkToken t) _) = csrf { validationResult = Just result }
  where
    key :: ByteString
    key = decodeLenient keyBS

    result :: CsrfCheckResult
    result = case (decode c, decode t) of
      (Right c', Right t') ->
        if decryptMsg ECB key (cs c') == decryptMsg ECB key (cs t')
          then Valid
          else Invalid
      _ -> Invalid

getCsrf :: ByteString -> IO Csrf
getCsrf key = do
  secret <- randomKey
  c <- encryptMsg ECB (decodeLenient key) (cs secret)
  t <- encryptMsg ECB (decodeLenient key) (cs secret)
  return $ Csrf
    (MkToken (encode . cs $ c))
    (MkToken (encode . cs $ t))
    Nothing

mkCsrf :: ByteString -> ByteString -> ByteString -> Csrf
mkCsrf keyBS c t
  = runCheck key (Csrf (MkToken c) (MkToken t) Nothing)
  where
    key :: ByteString
    key = decodeLenient keyBS

unMkToken :: Token -> ByteString
unMkToken (MkToken x) = x