-- | -- Module : Data.ASN1.Prim -- License : BSD-style -- Maintainer : Vincent Hanquez -- Stability : experimental -- Portability : unknown -- -- Tools to read ASN1 primitive (e.g. boolean, int) -- module Data.ASN1.Prim ( -- * ASN1 high level algebraic type ASN1(..) , ASN1ConstructionType(..) , encodeHeader , encodePrimitiveHeader , encodePrimitive , decodePrimitive , encodeConstructed , encodeList , encodeOne , mkSmallestLength -- * marshall an ASN1 type from a val struct or a bytestring , getBoolean , getInteger , getBitString , getOctetString , getUTF8String , getNumericString , getPrintableString , getT61String , getVideoTexString , getIA5String , getNull , getOID , getUTCTime , getGeneralizedTime , getGraphicString , getVisibleString , getGeneralString , getUniversalString , getCharacterString , getBMPString -- * marshall an ASN1 type to a bytestring , putUTCTime , putGeneralizedTime , putInteger , putBitString , putString , putOID ) where import Data.ASN1.Internal import Data.ASN1.Stream import Data.ASN1.BitArray import Data.ASN1.Types import Data.ASN1.Serialize import Data.Serialize.Put (runPut) import Data.Bits import Data.Word import Data.List (unfoldr) import Data.ByteString (ByteString) import Data.Char (ord) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import Data.Text.Lazy (Text) import qualified Data.Text.Lazy as T import Data.Text.Lazy.Encoding (decodeASCII, decodeUtf8, decodeUtf32BE, encodeUtf8, encodeUtf32BE) import Control.Applicative encodeUCS2BE :: Text -> L.ByteString encodeUCS2BE t = L.pack $ concatMap (\c -> let (d,m) = (fromEnum c) `divMod` 256 in [fromIntegral d,fromIntegral m] ) $ T.unpack t decodeUCS2BE :: L.ByteString -> Text decodeUCS2BE lbs = T.pack $ loop lbs where loop bs | L.null bs = [] | otherwise = let (h, r) = L.splitAt 2 bs in case L.length h of 2 -> (toEnum $ fromIntegral $ be16 h) : loop r _ -> loop r be16 :: L.ByteString -> Word16 be16 b = fromIntegral (L.index b 0) `shiftL` 8 + fromIntegral (L.index b 1) encodeHeader :: Bool -> ASN1Length -> ASN1 -> ASN1Header encodeHeader pc len (Boolean _) = ASN1Header Universal 0x1 pc len encodeHeader pc len (IntVal _) = ASN1Header Universal 0x2 pc len encodeHeader pc len (BitString _) = ASN1Header Universal 0x3 pc len encodeHeader pc len (OctetString _) = ASN1Header Universal 0x4 pc len encodeHeader pc len Null = ASN1Header Universal 0x5 pc len encodeHeader pc len (OID _) = ASN1Header Universal 0x6 pc len encodeHeader pc len (Real _) = ASN1Header Universal 0x9 pc len encodeHeader pc len Enumerated = ASN1Header Universal 0xa pc len encodeHeader pc len (UTF8String _) = ASN1Header Universal 0xc pc len encodeHeader pc len (NumericString _) = ASN1Header Universal 0x12 pc len encodeHeader pc len (PrintableString _) = ASN1Header Universal 0x13 pc len encodeHeader pc len (T61String _) = ASN1Header Universal 0x14 pc len encodeHeader pc len (VideoTexString _) = ASN1Header Universal 0x15 pc len encodeHeader pc len (IA5String _) = ASN1Header Universal 0x16 pc len encodeHeader pc len (UTCTime _) = ASN1Header Universal 0x17 pc len encodeHeader pc len (GeneralizedTime _) = ASN1Header Universal 0x18 pc len encodeHeader pc len (GraphicString _) = ASN1Header Universal 0x19 pc len encodeHeader pc len (VisibleString _) = ASN1Header Universal 0x1a pc len encodeHeader pc len (GeneralString _) = ASN1Header Universal 0x1b pc len encodeHeader pc len (UniversalString _) = ASN1Header Universal 0x1c pc len encodeHeader pc len (CharacterString _) = ASN1Header Universal 0x1d pc len encodeHeader pc len (BMPString _) = ASN1Header Universal 0x1e pc len encodeHeader pc len (Start Sequence) = ASN1Header Universal 0x10 pc len encodeHeader pc len (Start Set) = ASN1Header Universal 0x11 pc len encodeHeader pc len (Start (Container tc tag)) = ASN1Header tc tag pc len encodeHeader pc len (Other tc tag _) = ASN1Header tc tag pc len encodeHeader _ _ (End _) = error "this should not happen" encodePrimitiveHeader :: ASN1Length -> ASN1 -> ASN1Header encodePrimitiveHeader = encodeHeader False encodePrimitiveData :: ASN1 -> ByteString encodePrimitiveData (Boolean b) = B.singleton (if b then 0xff else 0) encodePrimitiveData (IntVal i) = putInteger i encodePrimitiveData (BitString bits) = putBitString bits encodePrimitiveData (OctetString b) = putString b encodePrimitiveData Null = B.empty encodePrimitiveData (OID oid) = putOID oid encodePrimitiveData (Real _) = B.empty -- not implemented encodePrimitiveData Enumerated = B.empty -- not implemented encodePrimitiveData (UTF8String b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (NumericString b) = putString b encodePrimitiveData (PrintableString b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (T61String b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (VideoTexString b) = putString b encodePrimitiveData (IA5String b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (UTCTime t) = putUTCTime t encodePrimitiveData (GeneralizedTime t) = putGeneralizedTime t encodePrimitiveData (GraphicString b) = putString b encodePrimitiveData (VisibleString b) = putString b encodePrimitiveData (GeneralString b) = putString b encodePrimitiveData (UniversalString b) = putString $ encodeUtf32BE $ T.pack b encodePrimitiveData (CharacterString b) = putString b encodePrimitiveData (BMPString b) = putString $ encodeUCS2BE $ T.pack b encodePrimitiveData (Other _ _ b) = b encodePrimitiveData o = error ("not a primitive " ++ show o) encodePrimitive :: ASN1 -> (Int, [ASN1Event]) encodePrimitive a = let b = encodePrimitiveData a in let blen = B.length b in let len = makeLength blen in let hdr = encodePrimitiveHeader len a in (B.length (runPut $ putHeader hdr) + blen, [Header hdr, Primitive b]) where makeLength len | len < 0x80 = LenShort len | otherwise = LenLong (nbBytes len) len nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1 encodeOne :: ASN1 -> (Int, [ASN1Event]) encodeOne (Start _) = error "encode one cannot do start" encodeOne t = encodePrimitive t encodeList :: [ASN1] -> (Int, [ASN1Event]) encodeList [] = (0, []) encodeList (End _:xs) = encodeList xs encodeList (t@(Start _):xs) = let (ys, zs) = getConstructedEnd 0 xs in let (llen, lev) = encodeList zs in let (len, ev) = encodeConstructed t ys in (llen + len, ev ++ lev) encodeList (x:xs) = let (llen, lev) = encodeList xs in let (len, ev) = encodeOne x in (llen + len, ev ++ lev) encodeConstructed :: ASN1 -> [ASN1] -> (Int, [ASN1Event]) encodeConstructed c@(Start _) children = let (clen, events) = encodeList children in let len = mkSmallestLength clen in let h = encodeHeader True len c in let tlen = B.length (runPut $ putHeader h) + clen in (tlen, Header h : ConstructionBegin : events ++ [ConstructionEnd]) encodeConstructed _ _ = error "not a start node" mkSmallestLength :: Int -> ASN1Length mkSmallestLength i | i < 0x80 = LenShort i | otherwise = LenLong (nbBytes i) i where nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1 type ASN1Ret = Either ASN1Error ASN1 decodePrimitive :: ASN1Header -> B.ByteString -> ASN1Ret decodePrimitive (ASN1Header Universal 0x1 _ _) p = getBoolean False p decodePrimitive (ASN1Header Universal 0x2 _ _) p = getInteger p decodePrimitive (ASN1Header Universal 0x3 _ _) p = getBitString p decodePrimitive (ASN1Header Universal 0x4 _ _) p = getOctetString p decodePrimitive (ASN1Header Universal 0x5 _ _) p = getNull p decodePrimitive (ASN1Header Universal 0x6 _ _) p = getOID p decodePrimitive (ASN1Header Universal 0x7 _ _) _ = Left $ TypeNotImplemented "Object Descriptor" decodePrimitive (ASN1Header Universal 0x8 _ _) _ = Left $ TypeNotImplemented "External" decodePrimitive (ASN1Header Universal 0x9 _ _) _ = Left $ TypeNotImplemented "real" decodePrimitive (ASN1Header Universal 0xa _ _) _ = Left $ TypeNotImplemented "enumerated" decodePrimitive (ASN1Header Universal 0xb _ _) _ = Left $ TypeNotImplemented "EMBEDDED PDV" decodePrimitive (ASN1Header Universal 0xc _ _) p = getUTF8String p decodePrimitive (ASN1Header Universal 0xd _ _) _ = Left $ TypeNotImplemented "RELATIVE-OID" decodePrimitive (ASN1Header Universal 0x10 _ _) _ = error "sequence not a primitive" decodePrimitive (ASN1Header Universal 0x11 _ _) _ = error "set not a primitive" decodePrimitive (ASN1Header Universal 0x12 _ _) p = getNumericString p decodePrimitive (ASN1Header Universal 0x13 _ _) p = getPrintableString p decodePrimitive (ASN1Header Universal 0x14 _ _) p = getT61String p decodePrimitive (ASN1Header Universal 0x15 _ _) p = getVideoTexString p decodePrimitive (ASN1Header Universal 0x16 _ _) p = getIA5String p decodePrimitive (ASN1Header Universal 0x17 _ _) p = getUTCTime p decodePrimitive (ASN1Header Universal 0x18 _ _) p = getGeneralizedTime p decodePrimitive (ASN1Header Universal 0x19 _ _) p = getGraphicString p decodePrimitive (ASN1Header Universal 0x1a _ _) p = getVisibleString p decodePrimitive (ASN1Header Universal 0x1b _ _) p = getGeneralString p decodePrimitive (ASN1Header Universal 0x1c _ _) p = getUniversalString p decodePrimitive (ASN1Header Universal 0x1d _ _) p = getCharacterString p decodePrimitive (ASN1Header Universal 0x1e _ _) p = getBMPString p decodePrimitive (ASN1Header tc tag _ _) p = Right $ Other tc tag p getBoolean :: Bool -> ByteString -> Either ASN1Error ASN1 getBoolean isDer s = if B.length s == 1 then case B.head s of 0 -> Right (Boolean False) 0xff -> Right (Boolean True) _ -> if isDer then Left $ PolicyFailed "DER" "boolean value not canonical" else Right (Boolean True) else Left $ TypeDecodingFailed "boolean: length not within bound" {- | getInteger, parse a value bytestring and get the integer out of the two complement encoded bytes -} getInteger :: ByteString -> Either ASN1Error ASN1 getInteger s | B.length s == 0 = Left $ TypeDecodingFailed "integer: null encoding" | B.length s == 1 = Right $ IntVal $ snd $ intOfBytes s | otherwise = if (v1 == 0xff && testBit v2 7) || (v1 == 0x0 && (not $ testBit v2 7)) then Left $ TypeDecodingFailed "integer: not shortest encoding" else Right $ IntVal $ snd $ intOfBytes s where v1 = s `B.index` 0 v2 = s `B.index` 1 getBitString :: ByteString -> Either ASN1Error ASN1 getBitString s = let toSkip = B.head s in let toSkip' = if toSkip >= 48 && toSkip <= 48 + 7 then toSkip - (fromIntegral $ ord '0') else toSkip in let xs = B.tail s in if toSkip' >= 0 && toSkip' <= 7 then Right $ BitString $ toBitArray (L.fromChunks [xs]) (fromIntegral toSkip') else Left $ TypeDecodingFailed ("bitstring: skip number not within bound " ++ show toSkip' ++ " " ++ show s) getString :: (ByteString -> Maybe ASN1Error) -> ByteString -> Either ASN1Error L.ByteString getString check s = case check s of Nothing -> Right $ L.fromChunks [s] Just err -> Left err getOctetString :: ByteString -> Either ASN1Error ASN1 getOctetString = (OctetString <$>) . getString (\_ -> Nothing) getNumericString :: ByteString -> Either ASN1Error ASN1 getNumericString = (NumericString <$>) . getString (\_ -> Nothing) getPrintableString :: ByteString -> Either ASN1Error ASN1 getPrintableString = (PrintableString . T.unpack . decodeASCII <$>) . getString (\_ -> Nothing) getUTF8String :: ByteString -> Either ASN1Error ASN1 getUTF8String = (UTF8String . T.unpack . decodeUtf8 <$>) . getString (\_ -> Nothing) getT61String :: ByteString -> Either ASN1Error ASN1 getT61String = (T61String . T.unpack . decodeASCII <$>) . getString (\_ -> Nothing) getVideoTexString :: ByteString -> Either ASN1Error ASN1 getVideoTexString = (VideoTexString <$>) . getString (\_ -> Nothing) getIA5String :: ByteString -> Either ASN1Error ASN1 getIA5String = (IA5String . T.unpack . decodeASCII <$>) . getString (\_ -> Nothing) getGraphicString :: ByteString -> Either ASN1Error ASN1 getGraphicString = (GraphicString <$>) . getString (\_ -> Nothing) getVisibleString :: ByteString -> Either ASN1Error ASN1 getVisibleString = (VisibleString <$>) . getString (\_ -> Nothing) getGeneralString :: ByteString -> Either ASN1Error ASN1 getGeneralString = (GeneralString <$>) . getString (\_ -> Nothing) getUniversalString :: ByteString -> Either ASN1Error ASN1 getUniversalString = (UniversalString . T.unpack . decodeUtf32BE <$>) . getString (\_ -> Nothing) getCharacterString :: ByteString -> Either ASN1Error ASN1 getCharacterString = (CharacterString <$>) . getString (\_ -> Nothing) getBMPString :: ByteString -> Either ASN1Error ASN1 getBMPString = (BMPString . T.unpack . decodeUCS2BE <$>) . getString (\_ -> Nothing) getNull :: ByteString -> Either ASN1Error ASN1 getNull s | B.length s == 0 = Right Null | otherwise = Left $ TypeDecodingFailed "Null: data length not within bound" {- | return an OID -} getOID :: ByteString -> Either ASN1Error ASN1 getOID s = Right $ OID $ (fromIntegral (x `div` 40) : fromIntegral (x `mod` 40) : groupOID xs) where (x:xs) = B.unpack s groupOID :: [Word8] -> [Integer] groupOID = map (foldl (\acc n -> (acc `shiftL` 7) + fromIntegral n) 0) . groupSubOID groupSubOIDHelper [] = Nothing groupSubOIDHelper l = Just $ spanSubOIDbound l groupSubOID :: [Word8] -> [[Word8]] groupSubOID = unfoldr groupSubOIDHelper spanSubOIDbound [] = ([], []) spanSubOIDbound (a:as) = if testBit a 7 then (clearBit a 7 : ys, zs) else ([a], as) where (ys, zs) = spanSubOIDbound as getUTCTime :: ByteString -> Either ASN1Error ASN1 getUTCTime s = case B.unpack s of [y1, y2, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, z] -> let y = integerise y1 y2 in let year = 1900 + (if y <= 50 then y + 100 else y) in let month = integerise m1 m2 in let day = integerise d1 d2 in let hour = integerise h1 h2 in let minute = integerise mi1 mi2 in let second = integerise s1 s2 in Right $ UTCTime (year, month, day, hour, minute, second, z == 90) _ -> Left $ TypeDecodingFailed "utctime unexpected format" where integerise a b = ((fromIntegral a) - (ord '0')) * 10 + ((fromIntegral b) - (ord '0')) getGeneralizedTime :: ByteString -> Either ASN1Error ASN1 getGeneralizedTime s = case B.unpack s of [y1, y2, y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, z] -> let year = (integerise y1 y2) * 100 + (integerise y3 y4) in let month = integerise m1 m2 in let day = integerise d1 d2 in let hour = integerise h1 h2 in let minute = integerise mi1 mi2 in let second = integerise s1 s2 in Right $ GeneralizedTime (year, month, day, hour, minute, second, z == 90) _ -> Left $ TypeDecodingFailed "utctime unexpected format" where integerise a b = ((fromIntegral a) - (ord '0')) * 10 + ((fromIntegral b) - (ord '0')) putTime :: Bool -> (Int, Int, Int, Int, Int, Int, Bool) -> ByteString putTime generalized (y,m,d,h,mi,s,z) = B.pack etime where etime = if generalized then [y1, y2, y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, if z then 90 else 0 ] else [y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, if z then 90 else 0 ] split2 n = (fromIntegral $ n `div` 10 + ord '0', fromIntegral $ n `mod` 10 + ord '0') ((y1,y2),(y3,y4)) = (split2 (y `div` 100), split2 (y `mod` 100)) (m1, m2) = split2 m (d1, d2) = split2 d (h1, h2) = split2 h (mi1, mi2) = split2 mi (s1, s2) = split2 s putUTCTime :: (Int, Int, Int, Int, Int, Int, Bool) -> ByteString putUTCTime time = putTime False time putGeneralizedTime :: (Int, Int, Int, Int, Int, Int, Bool) -> ByteString putGeneralizedTime time = putTime True time putInteger :: Integer -> ByteString putInteger i = B.pack $ bytesOfInt i putBitString :: BitArray -> ByteString putBitString (BitArray n bits) = B.concat $ B.singleton (fromIntegral i) : L.toChunks bits where i = (8 - (n `mod` 8)) .&. 0x7 putString :: L.ByteString -> ByteString putString l = B.concat $ L.toChunks l {- no enforce check that oid1 is between [0..2] and oid2 is between [0..39] -} putOID :: [Integer] -> ByteString putOID oids = B.cons eoidclass subeoids where (oid1:oid2:suboids) = oids eoidclass = fromIntegral (oid1 * 40 + oid2) encode x | x == 0 = B.singleton 0 | otherwise = putVarEncodingIntegral x subeoids = B.concat $ map encode suboids