-- |
-- Module      : Crypto.Store.CMS.Util
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- CMS and ASN.1 utilities
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Crypto.Store.CMS.Util
    (
    -- * Testing ASN.1 types
      nullOrNothing
    , intOrNothing
    , dateTimeOrNothing
    -- * Object Identifiers
    , OIDTable
    , lookupOID
    , Enumerable(..)
    , OIDNameableWrapper(..)
    , withObjectID
    -- * Parsing and encoding ASN.1 objects
    , ASN1Event
    , ASN1ObjectExact(..)
    , ProduceASN1Object(..)
    , encodeASN1Object
    , ParseASN1Object(..)
    , decodeASN1Object
    -- * Algorithm Identifiers
    , AlgorithmId(..)
    , algorithmASN1S
    , algorithmMaybeASN1S
    , parseAlgorithm
    , parseAlgorithmMaybe
    -- * Miscellaneous functions
    , orElse
    ) where

import           Data.ASN1.BinaryEncoding
import           Data.ASN1.BinaryEncoding.Raw
import           Data.ASN1.Encoding
import           Data.ASN1.OID
import           Data.ASN1.Types
import           Data.ByteString (ByteString)
import           Data.List (find)
import           Data.X509

import Time.Types (DateTime)

import Crypto.Store.ASN1.Generate
import Crypto.Store.ASN1.Parse
import Crypto.Store.Error

-- | Try to parse a 'Null' ASN.1 value.
nullOrNothing :: ASN1 -> Maybe ()
nullOrNothing :: ASN1 -> Maybe ()
nullOrNothing ASN1
Null = forall a. a -> Maybe a
Just ()
nullOrNothing ASN1
_    = forall a. Maybe a
Nothing

-- | Try to parse an 'IntVal' ASN.1 value.
intOrNothing :: ASN1 -> Maybe Integer
intOrNothing :: ASN1 -> Maybe Integer
intOrNothing (IntVal Integer
i) = forall a. a -> Maybe a
Just Integer
i
intOrNothing ASN1
_          = forall a. Maybe a
Nothing

-- | Try to parse a 'DateTime' ASN.1 value.
dateTimeOrNothing :: ASN1 -> Maybe DateTime
dateTimeOrNothing :: ASN1 -> Maybe DateTime
dateTimeOrNothing (ASN1Time ASN1TimeType
_ DateTime
t Maybe TimezoneOffset
_) = forall a. a -> Maybe a
Just DateTime
t
dateTimeOrNothing ASN1
_                = forall a. Maybe a
Nothing

-- | Mapping between values and OIDs.
type OIDTable a = [(a, OID)]

-- | Find the value associated to an OID.
lookupByOID :: OIDTable a -> OID -> Maybe a
lookupByOID :: forall a. OIDTable a -> OID -> Maybe a
lookupByOID OIDTable a
table OID
oid = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (forall a. Eq a => a -> a -> Bool
(==) OID
oid forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) OIDTable a
table

-- | Find the OID associated to a value.
lookupOID :: Eq a => OIDTable a -> a -> Maybe OID
lookupOID :: forall a. Eq a => OIDTable a -> a -> Maybe OID
lookupOID OIDTable a
table a
a = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup a
a OIDTable a
table

-- | Types with a finite set of values.
class Enumerable a where
    -- | Return all possible values for the given type.
    values :: [a]

-- | Type used to transform a 'Enumerable' instance to an 'OIDNameable'
-- instance.
newtype OIDNameableWrapper a = OIDNW { forall a. OIDNameableWrapper a -> a
unOIDNW :: a }
    deriving (Int -> OIDNameableWrapper a -> ShowS
forall a. Show a => Int -> OIDNameableWrapper a -> ShowS
forall a. Show a => [OIDNameableWrapper a] -> ShowS
forall a. Show a => OIDNameableWrapper a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OIDNameableWrapper a] -> ShowS
$cshowList :: forall a. Show a => [OIDNameableWrapper a] -> ShowS
show :: OIDNameableWrapper a -> String
$cshow :: forall a. Show a => OIDNameableWrapper a -> String
showsPrec :: Int -> OIDNameableWrapper a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> OIDNameableWrapper a -> ShowS
Show,OIDNameableWrapper a -> OIDNameableWrapper a -> Bool
forall a.
Eq a =>
OIDNameableWrapper a -> OIDNameableWrapper a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OIDNameableWrapper a -> OIDNameableWrapper a -> Bool
$c/= :: forall a.
Eq a =>
OIDNameableWrapper a -> OIDNameableWrapper a -> Bool
== :: OIDNameableWrapper a -> OIDNameableWrapper a -> Bool
$c== :: forall a.
Eq a =>
OIDNameableWrapper a -> OIDNameableWrapper a -> Bool
Eq)

instance (Enumerable a, OIDable a) => OIDNameable (OIDNameableWrapper a) where
    fromObjectID :: OID -> Maybe (OIDNameableWrapper a)
fromObjectID = forall a. OIDTable a -> OID -> Maybe a
lookupByOID [(OIDNameableWrapper a, OID)]
table
      where table :: [(OIDNameableWrapper a, OID)]
table = [ (forall a. a -> OIDNameableWrapper a
OIDNW a
val, forall a. OIDable a => a -> OID
getObjectID a
val) | a
val <- forall a. Enumerable a => [a]
values ]

-- | Convert the specified OID and apply a parser to the result.
withObjectID :: OIDNameable a
             => String -> OID -> (a -> ParseASN1 e b) -> ParseASN1 e b
withObjectID :: forall a e b.
OIDNameable a =>
String -> OID -> (a -> ParseASN1 e b) -> ParseASN1 e b
withObjectID String
name OID
oid a -> ParseASN1 e b
fn =
    case forall a. OIDNameable a => OID -> Maybe a
fromObjectID OID
oid of
        Just a
val -> a -> ParseASN1 e b
fn a
val
        Maybe a
Nothing  ->
            forall e a. String -> ParseASN1 e a
throwParseError (String
"Unsupported " forall a. [a] -> [a] -> [a]
++ String
name forall a. [a] -> [a] -> [a]
++ String
": OID " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show OID
oid)

-- | Objects that can produce an ASN.1 stream.
class ProduceASN1Object e obj where
    asn1s :: obj -> ASN1Stream e

instance ProduceASN1Object e obj => ProduceASN1Object e [obj] where
    asn1s :: [obj] -> ASN1Stream e
asn1s [obj]
l [e]
r = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall e obj. ProduceASN1Object e obj => obj -> ASN1Stream e
asn1s [e]
r [obj]
l

instance ASN1Elem e => ProduceASN1Object e DistinguishedName where
    asn1s :: DistinguishedName -> ASN1Stream e
asn1s = forall e.
ASN1Elem e =>
ASN1ConstructionType -> ASN1Stream e -> ASN1Stream e
asn1Container ASN1ConstructionType
Sequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. ASN1Elem e => DistinguishedName -> ASN1Stream e
inner
      where
        inner :: DistinguishedName -> [e] -> [e]
inner (DistinguishedName [(OID, ASN1CharacterString)]
dn) [e]
cont = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall {e}.
ASN1Elem e =>
(OID, ASN1CharacterString) -> ASN1Stream e
dnSet [e]
cont [(OID, ASN1CharacterString)]
dn
        dnSet :: (OID, ASN1CharacterString) -> ASN1Stream e
dnSet (OID
oid, ASN1CharacterString
cs) =
            forall e.
ASN1Elem e =>
ASN1ConstructionType -> ASN1Stream e -> ASN1Stream e
asn1Container ASN1ConstructionType
Set forall a b. (a -> b) -> a -> b
$
                forall e.
ASN1Elem e =>
ASN1ConstructionType -> ASN1Stream e -> ASN1Stream e
asn1Container ASN1ConstructionType
Sequence (forall e. ASN1Elem e => OID -> ASN1Stream e
gOID OID
oid forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. ASN1Elem e => ASN1CharacterString -> ASN1Stream e
gASN1String ASN1CharacterString
cs)

instance (Show a, Eq a, ASN1Object a) => ProduceASN1Object ASN1P (SignedExact a) where
    asn1s :: SignedExact a -> ASN1Stream ASN1P
asn1s = ByteString -> ASN1Stream ASN1P
gEncoded forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
(Show a, Eq a, ASN1Object a) =>
SignedExact a -> ByteString
encodeSignedObject

-- | Encode the ASN.1 object to DER format.
encodeASN1Object :: ProduceASN1Object ASN1P obj => obj -> ByteString
encodeASN1Object :: forall obj. ProduceASN1Object ASN1P obj => obj -> ByteString
encodeASN1Object = ASN1Stream ASN1P -> ByteString
encodeASN1S forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e obj. ProduceASN1Object e obj => obj -> ASN1Stream e
asn1s

-- | Objects that can be parsed from an ASN.1 stream.
class Monoid e => ParseASN1Object e obj where
    parse :: ParseASN1 e obj

instance ParseASN1Object e obj => ParseASN1Object e [obj] where
    parse :: ParseASN1 e [obj]
parse = forall e a. ParseASN1 e a -> ParseASN1 e [a]
getMany forall e obj. ParseASN1Object e obj => ParseASN1 e obj
parse

instance Monoid e => ParseASN1Object e DistinguishedName where
    parse :: ParseASN1 e DistinguishedName
parse = [(OID, ASN1CharacterString)] -> DistinguishedName
DistinguishedName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e a.
Monoid e =>
ASN1ConstructionType -> ParseASN1 e a -> ParseASN1 e a
onNextContainer ASN1ConstructionType
Sequence ParseASN1 e [(OID, ASN1CharacterString)]
inner
      where
        inner :: ParseASN1 e [(OID, ASN1CharacterString)]
inner = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e a. ParseASN1 e a -> ParseASN1 e [a]
getMany ParseASN1 e [(OID, ASN1CharacterString)]
parseOne
        parseOne :: ParseASN1 e [(OID, ASN1CharacterString)]
parseOne =
            forall e a.
Monoid e =>
ASN1ConstructionType -> ParseASN1 e a -> ParseASN1 e a
onNextContainer ASN1ConstructionType
Set forall a b. (a -> b) -> a -> b
$ forall e a. ParseASN1 e a -> ParseASN1 e [a]
getMany forall a b. (a -> b) -> a -> b
$
                forall e a.
Monoid e =>
ASN1ConstructionType -> ParseASN1 e a -> ParseASN1 e a
onNextContainer ASN1ConstructionType
Sequence forall a b. (a -> b) -> a -> b
$ do
                    OID OID
oid <- forall e. Monoid e => ParseASN1 e ASN1
getNext
                    ASN1String ASN1CharacterString
cs <- forall e. Monoid e => ParseASN1 e ASN1
getNext
                    forall (m :: * -> *) a. Monad m => a -> m a
return (OID
oid, ASN1CharacterString
cs)

instance (Show a, Eq a, ASN1Object a) => ParseASN1Object [ASN1Event] (SignedExact a) where
    parse :: ParseASN1 [ASN1Event] (SignedExact a)
parse = forall e a. Monoid e => ParseASN1 e a -> ParseASN1 e (a, e)
withAnnotations ParseASN1 [ASN1Event] [ASN1]
parseSequence forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {a} {a} {e}.
(Show a, Eq a, ASN1Object a) =>
(a, [ASN1Event]) -> ParseASN1 e (SignedExact a)
finish
      where
        parseSequence :: ParseASN1 [ASN1Event] [ASN1]
parseSequence = forall e a.
Monoid e =>
ASN1ConstructionType -> ParseASN1 e a -> ParseASN1 e a
onNextContainer ASN1ConstructionType
Sequence (forall e a. ParseASN1 e a -> ParseASN1 e [a]
getMany forall e. Monoid e => ParseASN1 e ASN1
getNext)
        finish :: (a, [ASN1Event]) -> ParseASN1 e (SignedExact a)
finish (a
_, [ASN1Event]
events) =
            case forall a.
(Show a, Eq a, ASN1Object a) =>
ByteString -> Either String (SignedExact a)
decodeSignedObject ([ASN1Event] -> ByteString
toByteString [ASN1Event]
events) of
                Right SignedExact a
se -> forall (m :: * -> *) a. Monad m => a -> m a
return SignedExact a
se
                Left String
err -> forall e a. String -> ParseASN1 e a
throwParseError (String
"SignedExact: " forall a. [a] -> [a] -> [a]
++ String
err)

-- | Create an ASN.1 object from a bytearray in BER format.
decodeASN1Object :: ParseASN1Object [ASN1Event] obj => ByteString -> Either StoreError obj
decodeASN1Object :: forall obj.
ParseASN1Object [ASN1Event] obj =>
ByteString -> Either StoreError obj
decodeASN1Object ByteString
bs =
    case forall a.
ASN1DecodingRepr a =>
a -> ByteString -> Either ASN1Error [ASN1Repr]
decodeASN1Repr' BER
BER ByteString
bs of
        Left ASN1Error
e     -> forall a b. a -> Either a b
Left (ASN1Error -> StoreError
DecodingError ASN1Error
e)
        Right [ASN1Repr]
asn1 ->
            case forall e a.
Monoid e =>
ParseASN1 e a -> [(ASN1, e)] -> Either String (a, [(ASN1, e)])
runParseASN1State_ forall e obj. ParseASN1Object e obj => ParseASN1 e obj
parse [ASN1Repr]
asn1 of
                Right (obj
obj, []) -> forall a b. b -> Either a b
Right obj
obj
                Right (obj, [ASN1Repr])
_         -> forall a b. a -> Either a b
Left (String -> StoreError
ParseFailure String
"Incomplete parse")
                Left String
e          -> forall a b. a -> Either a b
Left (String -> StoreError
ParseFailure String
e)

-- | An ASN.1 object associated with the raw data it was parsed from.
data ASN1ObjectExact a = ASN1ObjectExact
    { forall a. ASN1ObjectExact a -> a
exactObject    :: a           -- ^ The wrapped ASN.1 object
    , forall a. ASN1ObjectExact a -> ByteString
exactObjectRaw :: ByteString  -- ^ The raw representation of this object
    } deriving Int -> ASN1ObjectExact a -> ShowS
forall a. Show a => Int -> ASN1ObjectExact a -> ShowS
forall a. Show a => [ASN1ObjectExact a] -> ShowS
forall a. Show a => ASN1ObjectExact a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ASN1ObjectExact a] -> ShowS
$cshowList :: forall a. Show a => [ASN1ObjectExact a] -> ShowS
show :: ASN1ObjectExact a -> String
$cshow :: forall a. Show a => ASN1ObjectExact a -> String
showsPrec :: Int -> ASN1ObjectExact a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> ASN1ObjectExact a -> ShowS
Show

instance Eq a => Eq (ASN1ObjectExact a)
    where ASN1ObjectExact a
a == :: ASN1ObjectExact a -> ASN1ObjectExact a -> Bool
== ASN1ObjectExact a
b = forall a. ASN1ObjectExact a -> a
exactObject ASN1ObjectExact a
a forall a. Eq a => a -> a -> Bool
== forall a. ASN1ObjectExact a -> a
exactObject ASN1ObjectExact a
b

instance ProduceASN1Object ASN1P a => ProduceASN1Object ASN1P (ASN1ObjectExact a) where
    asn1s :: ASN1ObjectExact a -> ASN1Stream ASN1P
asn1s = ByteString -> ASN1Stream ASN1P
gEncoded forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ASN1ObjectExact a -> ByteString
exactObjectRaw

instance ParseASN1Object [ASN1Event] a => ParseASN1Object [ASN1Event] (ASN1ObjectExact a) where
    parse :: ParseASN1 [ASN1Event] (ASN1ObjectExact a)
parse = do
        (a
obj, [ASN1Event]
events) <- forall e a. Monoid e => ParseASN1 e a -> ParseASN1 e (a, e)
withAnnotations forall e obj. ParseASN1Object e obj => ParseASN1 e obj
parse
        let objRaw :: ByteString
objRaw = [ASN1Event] -> ByteString
toByteString [ASN1Event]
events
        forall (m :: * -> *) a. Monad m => a -> m a
return ASN1ObjectExact { exactObject :: a
exactObject = a
obj, exactObjectRaw :: ByteString
exactObjectRaw = ByteString
objRaw }

-- | Algorithm identifier with associated parameter.
class AlgorithmId param where
    type AlgorithmType param
    algorithmName  :: param -> String
    algorithmType  :: param -> AlgorithmType param
    parameterASN1S :: ASN1Elem e => param -> ASN1Stream e
    parseParameter :: Monoid e => AlgorithmType param -> ParseASN1 e param

-- | Transform the algorithm identifier to ASN.1 stream.
algorithmASN1S :: (ASN1Elem e, AlgorithmId param, OIDable (AlgorithmType param))
               => ASN1ConstructionType -> param -> ASN1Stream e
algorithmASN1S :: forall e param.
(ASN1Elem e, AlgorithmId param, OIDable (AlgorithmType param)) =>
ASN1ConstructionType -> param -> ASN1Stream e
algorithmASN1S ASN1ConstructionType
ty param
p = forall e.
ASN1Elem e =>
ASN1ConstructionType -> ASN1Stream e -> ASN1Stream e
asn1Container ASN1ConstructionType
ty ([e] -> [e]
oid forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall param e.
(AlgorithmId param, ASN1Elem e) =>
param -> ASN1Stream e
parameterASN1S param
p)
  where typ :: AlgorithmType param
typ = forall param. AlgorithmId param => param -> AlgorithmType param
algorithmType param
p
        oid :: [e] -> [e]
oid = forall e. ASN1Elem e => OID -> ASN1Stream e
gOID (forall a. OIDable a => a -> OID
getObjectID AlgorithmType param
typ)

-- | Transform the optional algorithm identifier to ASN.1 stream.
algorithmMaybeASN1S :: (ASN1Elem e, AlgorithmId param, OIDable (AlgorithmType param))
                    => ASN1ConstructionType -> Maybe param -> ASN1Stream e
algorithmMaybeASN1S :: forall e param.
(ASN1Elem e, AlgorithmId param, OIDable (AlgorithmType param)) =>
ASN1ConstructionType -> Maybe param -> ASN1Stream e
algorithmMaybeASN1S ASN1ConstructionType
_  Maybe param
Nothing  = forall a. a -> a
id
algorithmMaybeASN1S ASN1ConstructionType
ty (Just param
p) = forall e param.
(ASN1Elem e, AlgorithmId param, OIDable (AlgorithmType param)) =>
ASN1ConstructionType -> param -> ASN1Stream e
algorithmASN1S ASN1ConstructionType
ty param
p

-- | Parse an algorithm identifier from an ASN.1 stream.
parseAlgorithm :: forall e param . (Monoid e, AlgorithmId param, OIDNameable (AlgorithmType param))
               => ASN1ConstructionType -> ParseASN1 e param
parseAlgorithm :: forall e param.
(Monoid e, AlgorithmId param, OIDNameable (AlgorithmType param)) =>
ASN1ConstructionType -> ParseASN1 e param
parseAlgorithm ASN1ConstructionType
ty = forall e a.
Monoid e =>
ASN1ConstructionType -> ParseASN1 e a -> ParseASN1 e a
onNextContainer ASN1ConstructionType
ty forall a b. (a -> b) -> a -> b
$ do
    OID OID
oid <- forall e. Monoid e => ParseASN1 e ASN1
getNext
    forall a e b.
OIDNameable a =>
String -> OID -> (a -> ParseASN1 e b) -> ParseASN1 e b
withObjectID (param -> String
getName forall a. HasCallStack => a
undefined) OID
oid forall param e.
(AlgorithmId param, Monoid e) =>
AlgorithmType param -> ParseASN1 e param
parseParameter
  where
    getName :: param -> String
    getName :: param -> String
getName = forall param. AlgorithmId param => param -> String
algorithmName

-- | Parse an optional algorithm identifier from an ASN.1 stream.
parseAlgorithmMaybe :: forall e param . (Monoid e, AlgorithmId param, OIDNameable (AlgorithmType param))
                    => ASN1ConstructionType -> ParseASN1 e (Maybe param)
parseAlgorithmMaybe :: forall e param.
(Monoid e, AlgorithmId param, OIDNameable (AlgorithmType param)) =>
ASN1ConstructionType -> ParseASN1 e (Maybe param)
parseAlgorithmMaybe ASN1ConstructionType
ty = forall e a.
Monoid e =>
ASN1ConstructionType -> ParseASN1 e a -> ParseASN1 e (Maybe a)
onNextContainerMaybe ASN1ConstructionType
ty forall a b. (a -> b) -> a -> b
$ do
    OID OID
oid <- forall e. Monoid e => ParseASN1 e ASN1
getNext
    forall a e b.
OIDNameable a =>
String -> OID -> (a -> ParseASN1 e b) -> ParseASN1 e b
withObjectID (param -> String
getName forall a. HasCallStack => a
undefined) OID
oid forall param e.
(AlgorithmId param, Monoid e) =>
AlgorithmType param -> ParseASN1 e param
parseParameter
  where
    getName :: param -> String
    getName :: param -> String
getName = forall param. AlgorithmId param => param -> String
algorithmName

-- | Execute the second action only if the first action produced 'Nothing'.
orElse :: Monad m => m (Maybe a) -> m (Maybe a) -> m (Maybe a)
orElse :: forall (m :: * -> *) a.
Monad m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
orElse m (Maybe a)
pa m (Maybe a)
pb = do
    Maybe a
va <- m (Maybe a)
pa
    case Maybe a
va of
        Maybe a
Nothing -> m (Maybe a)
pb
        Maybe a
_       -> forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
va