module Network.Wai.SAML2.Request (
AuthnRequest(..),
issueAuthnRequest,
renderBase64,
renderUrlEncodingDeflate,
renderXML,
) where
import Crypto.Random
import Data.Time.Clock
import Network.Wai.SAML2.NameIDFormat
import Network.Wai.SAML2.XML
import Text.XML
import qualified Codec.Compression.Zlib.Raw as Deflate
import qualified Data.ByteString as B
import qualified Data.ByteString.Base16 as Base16
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Network.HTTP.Types (urlEncode)
data AuthnRequest
= AuthnRequest {
AuthnRequest -> UTCTime
authnRequestTimestamp :: !UTCTime
, AuthnRequest -> Text
authnRequestID :: !T.Text
, AuthnRequest -> Text
authnRequestIssuer :: !T.Text
, AuthnRequest -> Maybe Text
authnRequestDestination :: !(Maybe T.Text)
, AuthnRequest -> Bool
authnRequestAllowCreate :: !Bool
, AuthnRequest -> NameIDFormat
authnRequestNameIDFormat :: !NameIDFormat
}
deriving (AuthnRequest -> AuthnRequest -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthnRequest -> AuthnRequest -> Bool
$c/= :: AuthnRequest -> AuthnRequest -> Bool
== :: AuthnRequest -> AuthnRequest -> Bool
$c== :: AuthnRequest -> AuthnRequest -> Bool
Eq, Int -> AuthnRequest -> ShowS
[AuthnRequest] -> ShowS
AuthnRequest -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AuthnRequest] -> ShowS
$cshowList :: [AuthnRequest] -> ShowS
show :: AuthnRequest -> String
$cshow :: AuthnRequest -> String
showsPrec :: Int -> AuthnRequest -> ShowS
$cshowsPrec :: Int -> AuthnRequest -> ShowS
Show)
issueAuthnRequest
:: T.Text
-> IO AuthnRequest
issueAuthnRequest :: Text -> IO AuthnRequest
issueAuthnRequest Text
authnRequestIssuer = do
UTCTime
authnRequestTimestamp <- IO UTCTime
getCurrentTime
Text
authnRequestID <- (Text
"id" forall a. Semigroup a => a -> a -> a
<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
T.decodeUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Base16.encode forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
16
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthnRequest{
authnRequestAllowCreate :: Bool
authnRequestAllowCreate = Bool
True
, authnRequestNameIDFormat :: NameIDFormat
authnRequestNameIDFormat = NameIDFormat
Transient
, authnRequestDestination :: Maybe Text
authnRequestDestination = forall a. Maybe a
Nothing
, UTCTime
Text
authnRequestID :: Text
authnRequestTimestamp :: UTCTime
authnRequestIssuer :: Text
authnRequestIssuer :: Text
authnRequestID :: Text
authnRequestTimestamp :: UTCTime
..
}
renderUrlEncodingDeflate :: AuthnRequest -> B.ByteString
renderUrlEncodingDeflate :: AuthnRequest -> ByteString
renderUrlEncodingDeflate AuthnRequest
request =
Bool -> ByteString -> ByteString
urlEncode Bool
True forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Base64.encode forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Deflate.compress forall a b. (a -> b) -> a -> b
$ AuthnRequest -> ByteString
renderXML AuthnRequest
request
renderBase64 :: AuthnRequest -> B.ByteString
renderBase64 :: AuthnRequest -> ByteString
renderBase64 AuthnRequest
request = ByteString -> ByteString
Base64.encode forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$ AuthnRequest -> ByteString
renderXML AuthnRequest
request
renderXML :: AuthnRequest -> BL.ByteString
renderXML :: AuthnRequest -> ByteString
renderXML AuthnRequest{Bool
Maybe Text
UTCTime
Text
NameIDFormat
authnRequestNameIDFormat :: NameIDFormat
authnRequestAllowCreate :: Bool
authnRequestDestination :: Maybe Text
authnRequestIssuer :: Text
authnRequestID :: Text
authnRequestTimestamp :: UTCTime
authnRequestNameIDFormat :: AuthnRequest -> NameIDFormat
authnRequestAllowCreate :: AuthnRequest -> Bool
authnRequestDestination :: AuthnRequest -> Maybe Text
authnRequestIssuer :: AuthnRequest -> Text
authnRequestID :: AuthnRequest -> Text
authnRequestTimestamp :: AuthnRequest -> UTCTime
..} =
RenderSettings -> Document -> ByteString
renderLBS forall a. Default a => a
def forall a b. (a -> b) -> a -> b
$
Document{
documentPrologue :: Prologue
documentPrologue = [Miscellaneous] -> Maybe Doctype -> [Miscellaneous] -> Prologue
Prologue [] forall a. Maybe a
Nothing []
, documentRoot :: Element
documentRoot = Element
root
, documentEpilogue :: [Miscellaneous]
documentEpilogue = []
}
where
timestamp :: Text
timestamp = UTCTime -> Text
showUTCTime UTCTime
authnRequestTimestamp
root :: Element
root = Name -> Map Name Text -> [Node] -> Element
Element
(Text -> Name
saml2pName Text
"AuthnRequest")
(forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
([ (Name
"xmlns:samlp", Text
"urn:oasis:names:tc:SAML:2.0:protocol")
, (Name
"xmlns:saml", Text
"urn:oasis:names:tc:SAML:2.0:assertion")
, (Name
"ID", Text
authnRequestID)
, (Name
"Version", Text
"2.0")
, (Name
"IssueInstant", Text
timestamp)
, (Name
"AssertionConsumerServiceIndex", Text
"1")
]
forall a. [a] -> [a] -> [a]
++ [(Name
"Destination", Text
uri) | let Just Text
uri = Maybe Text
authnRequestDestination] ))
[Element -> Node
NodeElement Element
issuer, Element -> Node
NodeElement Element
nameIdPolicy]
issuer :: Element
issuer = Name -> Map Name Text -> [Node] -> Element
Element
(Text -> Name
saml2Name Text
"Issuer")
forall a. Monoid a => a
mempty
[Text -> Node
NodeContent Text
authnRequestIssuer]
nameIdPolicy :: Element
nameIdPolicy = Name -> Map Name Text -> [Node] -> Element
Element
(Text -> Name
saml2pName Text
"NameIDPolicy")
(forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
[ (Name
"allowCreate"
, if Bool
authnRequestAllowCreate then Text
"true" else Text
"false")
, (Name
"Format", NameIDFormat -> Text
showNameIDFormat NameIDFormat
authnRequestNameIDFormat)
])
[]