-- |Code that is common to Web Authentication
-- ("Network.WindowsLive.Login") and Delegated Authentication
-- ("Network.WindowsLive.ConsentToken")
module Network.WindowsLive.Token
    ( -- * Application State
      App(..)
    , AppID
    , Secret
    , newApp

    -- * Decryption and validation (internal)
    , decodeToken
    , validateToken

    -- * Generate a signed application verifier
    , appVerifier
    )
where
import qualified Codec.Binary.Base64 as Base64
import qualified Codec.Encryption.AES as AES
import Codec.Encryption.Modes ( unCbc )
import Codec.Text.Raw ( hexdump )
import Codec.Utils ( Octet, fromOctets, toOctets, listFromOctets )
import Control.Monad ( when, replicateM )
import Control.Monad.Error ( MonadError )
import qualified Data.Digest.SHA256 as SHA256
import Data.HMAC ( hmac, HashMethod(..) )
import Data.LargeWord ( Word128 )
import Data.List.Split ( splitOn )
import Data.Monoid ( mconcat, mappend )
import Data.Time.Clock.POSIX ( POSIXTime )
import Network.URI ( unEscapeString )
import Network.WindowsLive.Query ( (%=) )
import qualified Network.WindowsLive.Query as Query
import qualified Text.Parsec as Parsec
import Text.PrettyPrint.HughesPJ ( text, (<+>), char )

-- |Visit
-- <https://lx.azure.microsoft.com/Cloud/Provisioning/Default.aspx> to
-- get your application's Application ID and Secret key
data App = App { appId :: AppID
               , secret :: Secret
               }

type AppID = String

newtype Secret = Secret [Octet]

instance Show Secret where
    showsPrec _ (Secret bs) =
        shows $ text "Secret<" <+> hexdump 24 bs <+> char '>'

-- |Create a new 'App', validating the Application ID and Secret key
newApp :: MonadError e m => String -> String -> m App
newApp appIdStr secretStr = do
  validateAppId appIdStr
  validateSecret secretStr
  let sec = Secret $ map (toEnum . fromEnum) secretStr
  return $ App appIdStr $ sec

validateAppId :: MonadError e m => String -> m ()
validateAppId = either (fail . show) (const $ return ()) .
                Parsec.parse (replicateM 16 Parsec.hexDigit) "appid"

validateSecret :: MonadError e m => String -> m ()
validateSecret s = when (null s) $ fail "Empty secret"

data KeyType = Signature | Encryption deriving Show

keyPrefix :: KeyType -> [Octet]
keyPrefix kt = map (toEnum . fromEnum) $
               case kt of
                 Signature -> "SIGNATURE"
                 Encryption -> "ENCRYPTION"

-- |Generate a cryptographic key from the secret and the key type
derive :: Secret -> KeyType -> [Octet]
derive (Secret bytes) kt = take 16 $ SHA256.hash $ keyPrefix kt ++ bytes

-- |Decrypt a token (failing if it cannot be decrypted)
decodeToken :: MonadError e m => App -> String -> m String
decodeToken app tokStr = do
  -- First, the string is URL-unescaped and base64 decoded
  encryptedBytes <- u64 tokStr
  when (null encryptedBytes) $ fail "Missing initialization vector"
  when ((length encryptedBytes `mod` 16) /= 0) $
       fail "Attempted to decode invalid token"

  -- Second, the IV is extracted from the first 16 bytes of the string
  let initVector:encryptedBlocks = toBlocks encryptedBytes

      -- Finally, the string is decrypted using the encryption key
      key = fromOctets (256::Integer) $ derive (secret app) Encryption :: Word128
      decryptedBlocks = unCbc AES.decrypt initVector key encryptedBlocks

  return $ stripEOT $ toString decryptedBlocks

-- |decode a Base64 encoded, URL-escaped string into a sequence of bytes
u64 :: MonadError e m => String -> m [Octet]
u64 str =
    case Base64.decode $ unEscapeString str of
      Nothing -> fail "Data was not valid base64"
      Just bs -> return bs

-- |Check the signature of this token (failing if it is not valid)
validateToken :: MonadError e m => App -> String -> m ()
validateToken app tok = do
  (body, sig) <- case splitOn "&sig=" tok of
                   [b, s] -> return (b, s)
                   [_] -> fail $ "No sig found: " ++ show tok
                   unexpected ->
                       fail $ "More than one sig found: " ++ show unexpected

  extractedSig <- u64 sig
  let calculatedSig = signToken (secret app) body
  when (extractedSig /= calculatedSig) $
       fail $ "Signature did not match: extracted=" ++ show extractedSig
                ++ " /= calculated=" ++ show calculatedSig

signToken :: Secret -> String -> [Octet]
signToken sec =
    hmac (HashMethod SHA256.hash 512) (derive sec Signature) . toBytes

stripEOT :: String -> String
stripEOT = reverse . dropWhile (== '\EOT') . reverse

toBytes :: String -> [Octet]
toBytes = map (toEnum . fromEnum)

toString :: [Word128] -> String
toString = map (toEnum . fromEnum) . concatMap (toOctets (256::Integer))

toBlocks :: [Octet] -> [Word128]
toBlocks = reverse . listFromOctets . reverse

-- |Generate an application verifier to prove to the server that we
-- know the secret and application ID
appVerifier :: App -> POSIXTime -> Query.Query
appVerifier app ts =
    let q = mconcat [ "appid" %= appId app
                    , "ts" %= show (round ts :: Integer)
                    ]
        token = Query.toQueryString q
        sig = Base64.encode $ signToken (secret app) token
    in q `mappend` ("sig" %= sig)