{-# LANGUAGE OverloadedStrings #-}
-- | This module contains support for the OpenID authentication standard. See
--   <http://www.openid.net> for details of the protocol. At the moment, only the
--   basic v2 authentication is supported. Also, we only support OpenID 2.0
--   HTML discovery, not Yadis nor XRI.
--
--   Only DH-SHA1 is used for the cryptography. This used to be SHA256, but
--   Yahoo doesn't support it (boo!)
--
--   References in this module are to the OpenID v2 spec
--      <http://openid.net/specs/openid-authentication-2_0.html>
module Network.MiniHTTP.OpenID
  ( -- * Types
    OpenIDDiscovery(..)
  , CheckIDType(..)
  , Handle
  , Key

    -- * Actions
  , findKey
  , discover
  , associate
  , checkID
  , processCheckIDReply
  ) where

import           Control.Monad (liftM)
import           Control.Concurrent.STM
import           Control.Exception (handle, throwIO)

import           Data.Bits (shiftR, xor, (.&.))
import qualified Data.Binary.Strict.Class as C
import qualified Data.Binary.Strict.Get as G
import qualified Data.ByteString as B
import           Data.ByteString.Char8 ()
import           Data.ByteString.Internal (w2c)
import qualified Data.Map as Map
import           Data.Maybe (maybe, fromJust)
import           Data.String (IsString(..))
import           Data.Time.Clock (UTCTime, getCurrentTime, addUTCTime)
import           Data.Word (Word8, Word32)

import qualified Network.Connection as C
import           Network.Socket (sClose)
import           Network.MiniHTTP.Client
import           Network.MiniHTTP.HTTPConnection
import           Network.MiniHTTP.Marshal
import qualified Network.MiniHTTP.URL as URL

import           System.IO.Unsafe (unsafePerformIO)

import           Text.HTML.TagSoup

import qualified OpenSSL.EVP.Base64 as Base64
import qualified OpenSSL.EVP.Digest as Digest
import qualified OpenSSL.BN as BN

-- | This is the result of Discovery: the OP local identity and the server HTTP
--   endpoint.
data OpenIDDiscovery =
  OpenIDDiscovery { discoveryProvider :: URL.URL
                  , discoveryLocalID :: Maybe String
                  } deriving (Show, Eq)

-- | This is the number of bytes of HTML that we'll try and read in order to
--   find the OpenID link elements
discoveryHTMLLimit :: Integral i => i
discoveryHTMLLimit = 4096

-- | Use HTML discovery to find the OpenID information for a given URL
discover :: URL.URL -> IO OpenIDDiscovery
discover uri = do
  r <- fetchBasic (emptyHeaders { httpRange = Just [RangeOf 0 discoveryHTMLLimit] }) uri
  case r of
       (conn, _, Nothing) -> do
         C.close conn
         fail "HTTP server returned no content"
       (conn, _, Just source) -> do
         let f d@(mprovider, mlocalid) (TagOpen "link" attrs) = maybe d id $ do
               rel <- "rel" `lookup` attrs
               href <- "href" `lookup` attrs
               case rel of
                    "openid2.provider" -> do
                      uri <- URL.parse $ fromString href
                      return (Just uri, mlocalid)
                    "openid2.local_id" -> return (mprovider, Just href)
                    _ -> return d

         payload <- sourceToBS 4096 source
         C.close conn
         case payload of
              Nothing -> fail "Error reading HTTP reply"
              Just payload -> do
                let (mprovider, mlocalid) =
                      foldl f (Nothing, Nothing) $ map head $
                        concatMap (sections (~== ("<link>" :: String))) $
                        sections (~== ("<head>" :: String)) $
                        parseTags $ map w2c $ B.unpack payload

                case mprovider of
                     Nothing -> fail "No provider discovered"
                     Just provider -> return $ OpenIDDiscovery provider mlocalid

-- | The default DH generator
dhG :: Integer
dhG = 2
-- | The default DH prime. See appendix B
dhP :: Integer
dhP = 0xDCF93A0B883972EC0E19989AC5A2CE310E1D37717E8D9571BB7623731866E61EF75A2E27898B057F9891C2E27A639C3F29B60814581CD3B2CA3986D2683705577D45C2E7E52DC81C7A171876E5CEA74B1448BFDFAF18828EFD2519F14E45E3826634AF1949E5B535CC829A483B8A76223E5D490A257F05BDFF16F2FB22C583AB

-- | Encode a URL of URL query key-value pairs into a query string (not
--   including the leading "?"). The string "openid." is prepended to all the
--   key names. See 4.1.2
postEncode :: Map.Map B.ByteString B.ByteString -> B.ByteString
postEncode = URL.serialiseArguments . Map.mapKeys (B.append "openid.")

-- | Encode a list of key-value pairs, in order, into the OpenID Key Value
--   format. See 4.1.1.
keyValueEncode :: [(B.ByteString, B.ByteString)] -> B.ByteString
keyValueEncode values = (B.intercalate (B.singleton 10) $ map f values) `B.append` B.singleton 10 where
  f (key, value) = key `B.append`
                   B.singleton 0x3a `B.append`
                   value

-- | A map of the default parameters included in all OpenID requests
defaultParams :: Map.Map B.ByteString B.ByteString
defaultParams = Map.fromList [("ns", "http://specs.openid.net/auth/2.0")]

-- | Convert an Integer to base64(btwoc) form. See 4.2
integerToBase64btwoc :: Integer -> IO B.ByteString
integerToBase64btwoc i = liftM (Base64.encodeBase64BS . B.drop 4) $ BN.integerToMPI i

-- | Convert a ByteString in btwoc form to an Integer. See 4.2
btwocToInteger :: B.ByteString -> IO Integer
btwocToInteger bs = do
  let len :: Word32
      len = fromIntegral $ B.length bs
      lengthbytes = B.pack $ map (\n -> fromIntegral $ (len `shiftR` n) .&. 0xff) [24, 16, 8, 0]
      mpi = lengthbytes `B.append` bs
  BN.mpiToInteger mpi

-- | The type of an OpenID handle. Handles are used to identify sessions
--   between the consumer and OP.
newtype Handle = Handle B.ByteString deriving (Show, Eq)
-- | The type of a key.
newtype Key = Key B.ByteString deriving (Show, Eq)

-- | This is the cache of association handles. It maps a string (hostname +
--   path of the OP) to a handle, key, expiry time triple.
associateCache :: TVar (Map.Map B.ByteString (Handle, Key, UTCTime))
associateCache = unsafePerformIO $ newTVarIO Map.empty

updateTVar :: TVar a -> (a -> a) -> STM ()
updateTVar var f = readTVar var >>= writeTVar var . f

-- | Lookup a key given the hostname of the OP and the handle. Generally used
--   after an indirect request to check a signature from an OP.
findKey :: B.ByteString -> Handle -> STM (Maybe Key)
findKey host (Handle handle) = do
  cache <- readTVar associateCache
  case Map.lookup host cache of
       Nothing -> return Nothing
       Just ((Handle handle'), key, _) ->
         if handle == handle'
            then return $ Just key
            else return Nothing

-- | Perform an association with a discovered OP and return either an error
--   message or a handle, a key and the number of seconds from now when the
--   handle will expire.
--
--   Internally this uses a cache so 'associate' may not actually involve an
--   HTTP request to the OP.
associate :: OpenIDDiscovery -> IO (Handle, Key)
associate discovery@(OpenIDDiscovery provider _) = do
  let cacheKey = fromString $ show provider
  currentTime <- getCurrentTime
  v <- atomically $ do
    cache <- readTVar associateCache
    case cacheKey `Map.lookup` cache of
         Nothing -> return Nothing
         Just v@(_, _, expiry) ->
           -- if the key expires in the next five minutes, dump it now
           if 300 `addUTCTime` expiry > currentTime
              then do writeTVar associateCache $ cacheKey `Map.delete` cache
                      return Nothing
              else return $ Just v
  case v of
       Just (h, k, _) -> return (h, k)
       Nothing -> do
         (h, k, secs) <- associateHTTP discovery
         let expiry = addUTCTime (fromIntegral secs) currentTime
         atomically $ do
           updateTVar associateCache (Map.insert cacheKey (h, k, expiry))
           return (h, k)

-- | Convert a URL to a Host header
urlToHost :: URL.URL -> Maybe B.ByteString
urlToHost (URL.URL {URL.urlHost = URL.Hostname h}) = Just h
urlToHost _ = Nothing

-- | An implementation of 'associate' which does an actual HTTP lookup
--   everytime. This is wrapped by 'associate', which handles caching of these
--   values. You can call this directly to bypass the cache.
associateHTTP :: OpenIDDiscovery -> IO (Handle, Key, Int)
associateHTTP (OpenIDDiscovery provider _) = do
  sock <- connection provider
  handle (\e -> sClose sock >> throwIO e) $ do
    conn <- transport provider sock
    (postbody, x) <- associateRequest
    postsource <- bsSource postbody
    r <- request conn (Request POST (URL.toRelative provider) 1 1 $
            emptyHeaders { httpHost = urlToHost provider
                         , httpContentType = Just (("application", "x-www-form-urlencoded"), [])
                         , httpContentLength = Just $ fromIntegral $ B.length postbody
                         }) $ Just postsource
    case r of
         (Just (Reply {replyStatus = 200}, Just source)) -> do
           mreplyBytes <- sourceToBS 4096 source
           C.close conn
           case mreplyBytes of
                Nothing -> fail "Error reading reply"
                Just replyBytes -> do
                  case G.runGet (parseKeyValue 10 0x3a) replyBytes of
                       (Right reply, _) -> processAssociateReply reply x
                       _ -> fail "Error parsing reply"
         _ -> print (fst $ fromJust r) >> fail "Bad HTTP reply code"

-- | Handle a reply from an associate call.
processAssociateReply :: Map.Map B.ByteString B.ByteString  -- ^ reply from the server
                      -> Integer  -- ^ the x value from associateRequest
                      -> IO (Handle, Key, Int)
                         -- ^ (handle, key, seconds to expiry)
processAssociateReply reply x =
  mapWrapper reply["assoc_handle", "dh_server_public", "enc_mac_key", "expires_in"] $
    \[handle, serverPublic, encKey, expiresStr] -> do
       Just sha1 <- Digest.getDigestByName "SHA1"
       gy <- btwocToInteger $ Base64.decodeBase64BS serverPublic
       let encKey' = Base64.decodeBase64BS encKey
           shared = BN.modexp gy x dhP
       sharedmpi <- BN.integerToMPI shared
       print encKey'
       print $ B.length encKey'
       let sharedbtwoc = B.drop 4 sharedmpi
           sharedkey = Digest.digestBS' sha1 sharedbtwoc
           key = B.pack $ B.zipWith (xor) sharedkey encKey'
       case maybeRead expiresStr of
            Nothing -> fail "Failed to parse expiry"
            Just expires -> return (Handle handle, Key key, expires)


-- | This creates an associate request and returns the URL query and the value
--   of @x@
associateRequest :: IO (B.ByteString, Integer)
associateRequest = do
  x <- BN.randIntegerUptoNMinusOneSuchThat (/= 0) dhP
  let gxmodp = BN.modexp dhG x dhP
  encoded <- integerToBase64btwoc gxmodp

  let m = Map.union defaultParams $ Map.fromList extras
      extras = [ ("mode", "associate")
               , ("assoc_type", "HMAC-SHA1")
               , ("session_type", "DH-SHA1")
               , ("dh_consumer_public", encoded)
               ]
  print encoded
  return (postEncode m, x)

-- | A helper function which extracts a number of keys from a list, calling a
--   continuation with those values if all are found and returning a Left if
--   any are missing
mapWrapper :: Map.Map B.ByteString B.ByteString  -- ^ the map of values
           -> [B.ByteString]  -- ^ the required keys
           -> ([B.ByteString] -> IO a)  -- ^ the continuation
           -> IO a  -- ^ the result
mapWrapper m keys f =
  case mapM (flip Map.lookup m) keys of
       Nothing -> fail "Map missing required value"
       Just values -> f values

parseKeyValue :: (C.BinaryParser m)
              => Word8  -- ^ byte which breaks 'lines' (e.g. \n)
              -> Word8  -- ^ byte which breaks pairs (e.g. ':')
              -> m (Map.Map B.ByteString B.ByteString)
parseKeyValue lineBreak pairBreak = do
  let parseLine = do
        key <- C.spanOf1 (/= pairBreak)
        C.word8 pairBreak
        value <- C.spanOf1 (/= lineBreak)
        if B.isPrefixOf "openid." key
           then return (B.drop 7 key, value)
           else return (key, value)
  line <- parseLine
  rest <- C.many (C.word8 lineBreak >> parseLine)
  C.optional $ C.word8 lineBreak
  donep <- C.isEmpty
  if not donep
     then fail "Trailing garbage found"
     else return $ Map.fromList $ line : rest

-- | There are two types of checkid calls.
data CheckIDType = CheckIDSetup | CheckIDImmediate deriving (Show, Eq)

typeToString :: CheckIDType -> B.ByteString
typeToString CheckIDSetup = "checkid_setup"
typeToString CheckIDImmediate = "checkid_immediate"

-- | Construct a checkid call
checkID :: CheckIDType
        -> URL.URL  -- ^ claimed id
        -> OpenIDDiscovery -- ^ OP-local id
        -> Handle  -- ^ assoc handle
        -> B.ByteString  -- ^ return to URL
        -> Maybe B.ByteString  -- ^ trust realm
        -> URL.URL  -- ^ URL
checkID ty claimed (OpenIDDiscovery provider mlocalid) (Handle handle) returnTo realm = r where
  r = provider { URL.urlArguments = foldl Map.union Map.empty
                   [ URL.urlArguments provider
                   , Map.mapKeys (B.append "openid.") defaultParams
                   , Map.mapKeys (B.append "openid.") $ Map.fromList rest
                   ] }
  common = [ ("mode", typeToString ty)
           , ("claimed_id", URL.serialise claimed)
           , ("identity", maybe (URL.serialise claimed) fromString mlocalid)
           , ("assoc_handle", handle)
           , ("return_to", returnTo)
           ]
  rest = case realm of
              Nothing -> common
              Just realm -> ("realm", realm) : common

processCheckIDReply :: Map.Map B.ByteString B.ByteString  -- ^ the arguments
                    -> IO (Either String B.ByteString)
processCheckIDReply args' = do
  -- first strip openid. from the front of all the keys
  let args = Map.mapKeys (\k -> if "openid." `B.isPrefixOf` k then B.drop 7 k else k) args'
  Just sha1 <- Digest.getDigestByName "SHA1"
  mapWrapper args ["assoc_handle", "claimed_id", "op_endpoint", "signed", "sig"] $
    \[handle, claimed, endpoint, signed, sig] -> do
    mkey <- atomically $ findKey endpoint $ Handle handle
    case mkey of
         Nothing -> return $ Left "Cannot find assoc key"
         Just (Key key) -> do
           let signedFields = B.split 0x2c signed
            in mapWrapper args signedFields $ \signedValues -> do
                 let kv = keyValueEncode $ zip signedFields signedValues
                     mySig = Base64.encodeBase64BS $ Digest.hmacBS sha1 key kv
                 if mySig == sig
                    then return $ Right claimed
                    else return $ Left "OpenID signature verification failed"