module Network.Wai.SAML2.Request (
AuthnRequest(..),
issueAuthnRequest,
renderBase64,
renderUrlEncodingDeflate,
renderXML,
) where
import Crypto.Random
import Data.Time.Clock
import Network.Wai.SAML2.NameIDFormat
import Network.Wai.SAML2.XML
import Text.XML
import qualified Codec.Compression.Zlib.Raw as Deflate
import qualified Data.ByteString as B
import qualified Data.ByteString.Base16 as Base16
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Network.HTTP.Types (urlEncode)
data AuthnRequest
= AuthnRequest {
AuthnRequest -> UTCTime
authnRequestTimestamp :: !UTCTime
, AuthnRequest -> Text
authnRequestID :: !T.Text
, AuthnRequest -> Text
authnRequestIssuer :: !T.Text
, AuthnRequest -> Maybe Text
authnRequestDestination :: !(Maybe T.Text)
, AuthnRequest -> Bool
authnRequestAllowCreate :: !Bool
, AuthnRequest -> NameIDFormat
authnRequestNameIDFormat :: !NameIDFormat
}
deriving (AuthnRequest -> AuthnRequest -> Bool
(AuthnRequest -> AuthnRequest -> Bool)
-> (AuthnRequest -> AuthnRequest -> Bool) -> Eq AuthnRequest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AuthnRequest -> AuthnRequest -> Bool
== :: AuthnRequest -> AuthnRequest -> Bool
$c/= :: AuthnRequest -> AuthnRequest -> Bool
/= :: AuthnRequest -> AuthnRequest -> Bool
Eq, Int -> AuthnRequest -> ShowS
[AuthnRequest] -> ShowS
AuthnRequest -> String
(Int -> AuthnRequest -> ShowS)
-> (AuthnRequest -> String)
-> ([AuthnRequest] -> ShowS)
-> Show AuthnRequest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AuthnRequest -> ShowS
showsPrec :: Int -> AuthnRequest -> ShowS
$cshow :: AuthnRequest -> String
show :: AuthnRequest -> String
$cshowList :: [AuthnRequest] -> ShowS
showList :: [AuthnRequest] -> ShowS
Show)
issueAuthnRequest
:: T.Text
-> IO AuthnRequest
issueAuthnRequest :: Text -> IO AuthnRequest
issueAuthnRequest Text
authnRequestIssuer = do
UTCTime
authnRequestTimestamp <- IO UTCTime
getCurrentTime
Text
authnRequestID <- (Text
"id" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> Text) -> (ByteString -> Text) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
T.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Base16.encode (ByteString -> Text) -> IO ByteString -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
16
AuthnRequest -> IO AuthnRequest
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthnRequest{
authnRequestAllowCreate :: Bool
authnRequestAllowCreate = Bool
True
, authnRequestNameIDFormat :: NameIDFormat
authnRequestNameIDFormat = NameIDFormat
Transient
, authnRequestDestination :: Maybe Text
authnRequestDestination = Maybe Text
forall a. Maybe a
Nothing
, Text
UTCTime
authnRequestTimestamp :: UTCTime
authnRequestID :: Text
authnRequestIssuer :: Text
authnRequestIssuer :: Text
authnRequestTimestamp :: UTCTime
authnRequestID :: Text
..
}
renderUrlEncodingDeflate :: AuthnRequest -> B.ByteString
renderUrlEncodingDeflate :: AuthnRequest -> ByteString
renderUrlEncodingDeflate AuthnRequest
request =
Bool -> ByteString -> ByteString
urlEncode Bool
True (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Base64.encode (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Deflate.compress (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ AuthnRequest -> ByteString
renderXML AuthnRequest
request
renderBase64 :: AuthnRequest -> B.ByteString
renderBase64 :: AuthnRequest -> ByteString
renderBase64 AuthnRequest
request = ByteString -> ByteString
Base64.encode (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ AuthnRequest -> ByteString
renderXML AuthnRequest
request
renderXML :: AuthnRequest -> BL.ByteString
renderXML :: AuthnRequest -> ByteString
renderXML AuthnRequest{Bool
Maybe Text
Text
UTCTime
NameIDFormat
authnRequestTimestamp :: AuthnRequest -> UTCTime
authnRequestID :: AuthnRequest -> Text
authnRequestIssuer :: AuthnRequest -> Text
authnRequestDestination :: AuthnRequest -> Maybe Text
authnRequestAllowCreate :: AuthnRequest -> Bool
authnRequestNameIDFormat :: AuthnRequest -> NameIDFormat
authnRequestTimestamp :: UTCTime
authnRequestID :: Text
authnRequestIssuer :: Text
authnRequestDestination :: Maybe Text
authnRequestAllowCreate :: Bool
authnRequestNameIDFormat :: NameIDFormat
..} =
RenderSettings -> Document -> ByteString
renderLBS RenderSettings
forall a. Default a => a
def (Document -> ByteString) -> Document -> ByteString
forall a b. (a -> b) -> a -> b
$
Document{
documentPrologue :: Prologue
documentPrologue = [Miscellaneous] -> Maybe Doctype -> [Miscellaneous] -> Prologue
Prologue [] Maybe Doctype
forall a. Maybe a
Nothing []
, documentRoot :: Element
documentRoot = Element
root
, documentEpilogue :: [Miscellaneous]
documentEpilogue = []
}
where
timestamp :: Text
timestamp = UTCTime -> Text
showUTCTime UTCTime
authnRequestTimestamp
root :: Element
root = Name -> Map Name Text -> [Node] -> Element
Element
(Text -> Name
saml2pName Text
"AuthnRequest")
([(Name, Text)] -> Map Name Text
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
([ (Name
"xmlns:samlp", Text
"urn:oasis:names:tc:SAML:2.0:protocol")
, (Name
"xmlns:saml", Text
"urn:oasis:names:tc:SAML:2.0:assertion")
, (Name
"ID", Text
authnRequestID)
, (Name
"Version", Text
"2.0")
, (Name
"IssueInstant", Text
timestamp)
, (Name
"AssertionConsumerServiceIndex", Text
"1")
]
[(Name, Text)] -> [(Name, Text)] -> [(Name, Text)]
forall a. [a] -> [a] -> [a]
++ [(Name
"Destination", Text
uri) | let Just Text
uri = Maybe Text
authnRequestDestination] ))
[Element -> Node
NodeElement Element
issuer, Element -> Node
NodeElement Element
nameIdPolicy]
issuer :: Element
issuer = Name -> Map Name Text -> [Node] -> Element
Element
(Text -> Name
saml2Name Text
"Issuer")
Map Name Text
forall a. Monoid a => a
mempty
[Text -> Node
NodeContent Text
authnRequestIssuer]
nameIdPolicy :: Element
nameIdPolicy = Name -> Map Name Text -> [Node] -> Element
Element
(Text -> Name
saml2pName Text
"NameIDPolicy")
([(Name, Text)] -> Map Name Text
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
[ (Name
"allowCreate"
, if Bool
authnRequestAllowCreate then Text
"true" else Text
"false")
, (Name
"Format", NameIDFormat -> Text
showNameIDFormat NameIDFormat
authnRequestNameIDFormat)
])
[]