-- | -- Module : Data.ASN1.Prim -- License : BSD-style -- Maintainer : Vincent Hanquez -- Stability : experimental -- Portability : unknown -- -- Tools to read ASN1 primitive (e.g. boolean, int) -- module Data.ASN1.Prim ( -- * ASN1 high level algebraic type ASN1(..) , ASN1ConstructionType(..) , encodeHeader , encodePrimitiveHeader , encodePrimitive , decodePrimitive , encodeConstructed , encodeList , encodeOne , mkSmallestLength -- * marshall an ASN1 type from a val struct or a bytestring , getBoolean , getInteger , getBitString , getOctetString , getUTF8String , getNumericString , getPrintableString , getT61String , getVideoTexString , getIA5String , getNull , getOID , getUTCTime , getGeneralizedTime , getGraphicString , getVisibleString , getGeneralString , getUniversalString , getCharacterString , getBMPString -- * marshall an ASN1 type to a bytestring , putUTCTime , putGeneralizedTime , putInteger , putBitString , putString , putOID ) where import Data.ASN1.Internal import Data.ASN1.Raw import Data.ASN1.Stream import Data.Bits import Data.Word import Data.List (unfoldr) import Data.ByteString (ByteString) import Data.Char (ord) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import Data.Text.Lazy (Text) import qualified Data.Text.Lazy as T import Data.Text.Lazy.Encoding (decodeASCII, decodeUtf8, decodeUtf32BE, encodeUtf8, encodeUtf32BE) encodeUCS2BE :: Text -> L.ByteString encodeUCS2BE t = L.pack $ concatMap (\c -> let (d,m) = (fromEnum c) `divMod` 256 in [fromIntegral m,fromIntegral d] ) $ T.unpack t decodeUCS2BE :: L.ByteString -> Text decodeUCS2BE l = T.pack $ loop l where loop x | L.null x = [] | otherwise = let (h, r) = L.splitAt 2 l in case L.unpack h of [a,b] -> (toEnum $ (fromIntegral a) + (fromIntegral b) * 256) : loop r _ -> loop r encodeHeader :: Bool -> ASN1Length -> ASN1 -> ASN1Header encodeHeader pc len (Boolean _) = ASN1Header Universal 0x1 pc len encodeHeader pc len (IntVal _) = ASN1Header Universal 0x2 pc len encodeHeader pc len (BitString _ _) = ASN1Header Universal 0x3 pc len encodeHeader pc len (OctetString _) = ASN1Header Universal 0x4 pc len encodeHeader pc len Null = ASN1Header Universal 0x5 pc len encodeHeader pc len (OID _) = ASN1Header Universal 0x6 pc len encodeHeader pc len (Real _) = ASN1Header Universal 0x9 pc len encodeHeader pc len Enumerated = ASN1Header Universal 0xa pc len encodeHeader pc len (UTF8String _) = ASN1Header Universal 0xc pc len encodeHeader pc len (NumericString _) = ASN1Header Universal 0x12 pc len encodeHeader pc len (PrintableString _) = ASN1Header Universal 0x13 pc len encodeHeader pc len (T61String _) = ASN1Header Universal 0x14 pc len encodeHeader pc len (VideoTexString _) = ASN1Header Universal 0x15 pc len encodeHeader pc len (IA5String _) = ASN1Header Universal 0x16 pc len encodeHeader pc len (UTCTime _) = ASN1Header Universal 0x17 pc len encodeHeader pc len (GeneralizedTime _) = ASN1Header Universal 0x18 pc len encodeHeader pc len (GraphicString _) = ASN1Header Universal 0x19 pc len encodeHeader pc len (VisibleString _) = ASN1Header Universal 0x1a pc len encodeHeader pc len (GeneralString _) = ASN1Header Universal 0x1b pc len encodeHeader pc len (UniversalString _) = ASN1Header Universal 0x1c pc len encodeHeader pc len (CharacterString _) = ASN1Header Universal 0x1d pc len encodeHeader pc len (BMPString _) = ASN1Header Universal 0x1e pc len encodeHeader pc len (Start Sequence) = ASN1Header Universal 0x10 pc len encodeHeader pc len (Start Set) = ASN1Header Universal 0x11 pc len encodeHeader pc len (Start (Container tc tag)) = ASN1Header tc tag pc len encodeHeader pc len (Other tc tag _) = ASN1Header tc tag pc len encodeHeader _ _ (End _) = error "this should not happen" encodePrimitiveHeader :: ASN1Length -> ASN1 -> ASN1Header encodePrimitiveHeader = encodeHeader False encodePrimitiveData :: ASN1 -> ByteString encodePrimitiveData (Boolean b) = B.singleton (if b then 0xff else 0) encodePrimitiveData (IntVal i) = putInteger i encodePrimitiveData (BitString i bits) = putBitString i bits encodePrimitiveData (OctetString b) = putString b encodePrimitiveData Null = B.empty encodePrimitiveData (OID oid) = putOID oid encodePrimitiveData (Real _) = B.empty -- not implemented encodePrimitiveData Enumerated = B.empty -- not implemented encodePrimitiveData (UTF8String b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (NumericString b) = putString b encodePrimitiveData (PrintableString b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (T61String b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (VideoTexString b) = putString b encodePrimitiveData (IA5String b) = putString $ encodeUtf8 $ T.pack b encodePrimitiveData (UTCTime t) = putUTCTime t encodePrimitiveData (GeneralizedTime t) = putGeneralizedTime t encodePrimitiveData (GraphicString b) = putString b encodePrimitiveData (VisibleString b) = putString b encodePrimitiveData (GeneralString b) = putString b encodePrimitiveData (UniversalString b) = putString $ encodeUtf32BE $ T.pack b encodePrimitiveData (CharacterString b) = putString b encodePrimitiveData (BMPString b) = putString $ encodeUCS2BE $ T.pack b encodePrimitiveData (Other _ _ b) = b encodePrimitiveData o = error ("not a primitive " ++ show o) encodePrimitive :: ASN1 -> (Int, [ASN1Event]) encodePrimitive a = let b = encodePrimitiveData a in let blen = B.length b in let len = makeLength blen in let hdr = encodePrimitiveHeader len a in (B.length (putHeader hdr) + blen, [Header hdr, Primitive b]) where makeLength len | len < 0x80 = LenShort len | otherwise = LenLong (nbBytes len) len nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1 encodeOne :: ASN1 -> (Int, [ASN1Event]) encodeOne (Start _) = error "encode one cannot do start" encodeOne t = encodePrimitive t encodeList :: [ASN1] -> (Int, [ASN1Event]) encodeList [] = (0, []) encodeList (End _:xs) = encodeList xs encodeList (t@(Start _):xs) = let (ys, zs) = getConstructedEnd 0 xs in let (llen, lev) = encodeList zs in let (len, ev) = encodeConstructed t ys in (llen + len, ev ++ lev) encodeList (x:xs) = let (llen, lev) = encodeList xs in let (len, ev) = encodeOne x in (llen + len, ev ++ lev) encodeConstructed :: ASN1 -> [ASN1] -> (Int, [ASN1Event]) encodeConstructed c@(Start _) children = let (clen, events) = encodeList children in let len = mkSmallestLength clen in let h = encodeHeader True len c in let tlen = B.length (putHeader h) + clen in (tlen, Header h : ConstructionBegin : events ++ [ConstructionEnd]) encodeConstructed _ _ = error "not a start node" mkSmallestLength :: Int -> ASN1Length mkSmallestLength i | i < 0x80 = LenShort i | otherwise = LenLong (nbBytes i) i where nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1 type ASN1Ret = Either ASN1Err ASN1 decodePrimitive :: ASN1Header -> B.ByteString -> ASN1Ret decodePrimitive (ASN1Header Universal 0x1 _ _) p = getBoolean False p decodePrimitive (ASN1Header Universal 0x2 _ _) p = getInteger p decodePrimitive (ASN1Header Universal 0x3 _ _) p = getBitString p decodePrimitive (ASN1Header Universal 0x4 _ _) p = getOctetString p decodePrimitive (ASN1Header Universal 0x5 _ _) p = getNull p decodePrimitive (ASN1Header Universal 0x6 _ _) p = getOID p decodePrimitive (ASN1Header Universal 0x7 _ _) _ = Left $ ASN1NotImplemented "Object Descriptor" decodePrimitive (ASN1Header Universal 0x8 _ _) _ = Left $ ASN1NotImplemented "External" decodePrimitive (ASN1Header Universal 0x9 _ _) _ = Left $ ASN1NotImplemented "real" decodePrimitive (ASN1Header Universal 0xa _ _) _ = Left $ ASN1NotImplemented "enumerated" decodePrimitive (ASN1Header Universal 0xb _ _) _ = Left $ ASN1NotImplemented "EMBEDDED PDV" decodePrimitive (ASN1Header Universal 0xc _ _) p = getUTF8String p decodePrimitive (ASN1Header Universal 0xd _ _) _ = Left $ ASN1NotImplemented "RELATIVE-OID" decodePrimitive (ASN1Header Universal 0x10 _ _) _ = error "sequence not a primitive" decodePrimitive (ASN1Header Universal 0x11 _ _) _ = error "set not a primitive" decodePrimitive (ASN1Header Universal 0x12 _ _) p = getNumericString p decodePrimitive (ASN1Header Universal 0x13 _ _) p = getPrintableString p decodePrimitive (ASN1Header Universal 0x14 _ _) p = getT61String p decodePrimitive (ASN1Header Universal 0x15 _ _) p = getVideoTexString p decodePrimitive (ASN1Header Universal 0x16 _ _) p = getIA5String p decodePrimitive (ASN1Header Universal 0x17 _ _) p = getUTCTime p decodePrimitive (ASN1Header Universal 0x18 _ _) p = getGeneralizedTime p decodePrimitive (ASN1Header Universal 0x19 _ _) p = getGraphicString p decodePrimitive (ASN1Header Universal 0x1a _ _) p = getVisibleString p decodePrimitive (ASN1Header Universal 0x1b _ _) p = getGeneralString p decodePrimitive (ASN1Header Universal 0x1c _ _) p = getUniversalString p decodePrimitive (ASN1Header Universal 0x1d _ _) p = getCharacterString p decodePrimitive (ASN1Header Universal 0x1e _ _) p = getBMPString p decodePrimitive (ASN1Header tc tag _ _) p = Right $ Other tc tag p getBoolean :: Bool -> ByteString -> Either ASN1Err ASN1 getBoolean isDer s = if B.length s == 1 then case B.head s of 0 -> Right (Boolean False) 0xff -> Right (Boolean True) _ -> if isDer then Left $ ASN1PolicyFailed "DER" "boolean value not canonical" else Right (Boolean True) else Left $ ASN1Misc "boolean: length not within bound" {- | getInteger, parse a value bytestring and get the integer out of the two complement encoded bytes -} getInteger :: ByteString -> Either ASN1Err ASN1 getInteger s | B.length s == 0 = Left $ ASN1Misc "integer: null encoding" | B.length s == 1 = Right $ IntVal $ snd $ intOfBytes s | otherwise = if (v1 == 0xff && testBit v2 7) || (v1 == 0x0 && (not $ testBit v2 7)) then Left $ ASN1Misc "integer: not shortest encoding" else Right $ IntVal $ snd $ intOfBytes s where v1 = s `B.index` 0 v2 = s `B.index` 1 getBitString :: ByteString -> Either ASN1Err ASN1 getBitString s = let toSkip = B.head s in let toSkip' = if toSkip >= 48 && toSkip <= 48 + 7 then toSkip - (fromIntegral $ ord '0') else toSkip in let xs = B.tail s in if toSkip' >= 0 && toSkip' <= 7 then Right $ BitString (fromIntegral toSkip') (L.fromChunks [xs]) else Left $ ASN1Misc ("bitstring: skip number not within bound " ++ show toSkip' ++ " " ++ show s) getString :: (ByteString -> Maybe ASN1Err) -> ByteString -> Either ASN1Err L.ByteString getString check s = case check s of Nothing -> Right $ L.fromChunks [s] Just err -> Left err getOctetString :: ByteString -> Either ASN1Err ASN1 getOctetString = either Left (Right . OctetString) . getString (\_ -> Nothing) getNumericString :: ByteString -> Either ASN1Err ASN1 getNumericString = either Left (Right . NumericString) . getString (\_ -> Nothing) getPrintableString :: ByteString -> Either ASN1Err ASN1 getPrintableString = either Left (Right . PrintableString . T.unpack . decodeASCII) . getString (\_ -> Nothing) getUTF8String :: ByteString -> Either ASN1Err ASN1 getUTF8String = either Left (Right . UTF8String . T.unpack . decodeUtf8) . getString (\_ -> Nothing) getT61String :: ByteString -> Either ASN1Err ASN1 getT61String = either Left (Right . T61String . T.unpack . decodeASCII) . getString (\_ -> Nothing) getVideoTexString :: ByteString -> Either ASN1Err ASN1 getVideoTexString = either Left (Right . VideoTexString) . getString (\_ -> Nothing) getIA5String :: ByteString -> Either ASN1Err ASN1 getIA5String = either Left (Right . IA5String . T.unpack . decodeASCII) . getString (\_ -> Nothing) getGraphicString :: ByteString -> Either ASN1Err ASN1 getGraphicString = either Left (Right . GraphicString) . getString (\_ -> Nothing) getVisibleString :: ByteString -> Either ASN1Err ASN1 getVisibleString = either Left (Right . VisibleString) . getString (\_ -> Nothing) getGeneralString :: ByteString -> Either ASN1Err ASN1 getGeneralString = either Left (Right . GeneralString) . getString (\_ -> Nothing) getUniversalString :: ByteString -> Either ASN1Err ASN1 getUniversalString = either Left (Right . UniversalString . T.unpack . decodeUtf32BE) . getString (\_ -> Nothing) getCharacterString :: ByteString -> Either ASN1Err ASN1 getCharacterString = either Left (Right . CharacterString) . getString (\_ -> Nothing) getBMPString :: ByteString -> Either ASN1Err ASN1 getBMPString = either Left (Right . BMPString . T.unpack . decodeUCS2BE) . getString (\_ -> Nothing) getNull :: ByteString -> Either ASN1Err ASN1 getNull s = if B.length s == 0 then Right Null else Left $ ASN1Misc "Null: data length not within bound" {- | return an OID -} getOID :: ByteString -> Either ASN1Err ASN1 getOID s = Right $ OID $ (fromIntegral (x `div` 40) : fromIntegral (x `mod` 40) : groupOID xs) where (x:xs) = B.unpack s groupOID :: [Word8] -> [Integer] groupOID = map (foldl (\acc n -> (acc `shiftL` 7) + fromIntegral n) 0) . groupSubOID groupSubOIDHelper [] = Nothing groupSubOIDHelper l = Just $ spanSubOIDbound l groupSubOID :: [Word8] -> [[Word8]] groupSubOID = unfoldr groupSubOIDHelper spanSubOIDbound [] = ([], []) spanSubOIDbound (a:as) = if testBit a 7 then (clearBit a 7 : ys, zs) else ([a], as) where (ys, zs) = spanSubOIDbound as getUTCTime :: ByteString -> Either ASN1Err ASN1 getUTCTime s = case B.unpack s of [y1, y2, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, z] -> let y = integerise y1 y2 in let year = 1900 + (if y <= 50 then y + 100 else y) in let month = integerise m1 m2 in let day = integerise d1 d2 in let hour = integerise h1 h2 in let minute = integerise mi1 mi2 in let second = integerise s1 s2 in Right $ UTCTime (year, month, day, hour, minute, second, z == 90) _ -> Left $ ASN1Misc "utctime unexpected format" where integerise a b = ((fromIntegral a) - (ord '0')) * 10 + ((fromIntegral b) - (ord '0')) getGeneralizedTime :: ByteString -> Either ASN1Err ASN1 getGeneralizedTime s = case B.unpack s of [y1, y2, y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, z] -> let year = (integerise y1 y2) * 100 + (integerise y3 y4) in let month = integerise m1 m2 in let day = integerise d1 d2 in let hour = integerise h1 h2 in let minute = integerise mi1 mi2 in let second = integerise s1 s2 in Right $ GeneralizedTime (year, month, day, hour, minute, second, z == 90) _ -> Left $ ASN1Misc "utctime unexpected format" where integerise a b = ((fromIntegral a) - (ord '0')) * 10 + ((fromIntegral b) - (ord '0')) putTime :: Bool -> (Int, Int, Int, Int, Int, Int, Bool) -> ByteString putTime generalized (y,m,d,h,mi,s,z) = B.pack etime where etime = if generalized then [y1, y2, y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, if z then 90 else 0 ] else [y3, y4, m1, m2, d1, d2, h1, h2, mi1, mi2, s1, s2, if z then 90 else 0 ] split2 n = (fromIntegral $ n `div` 10 + ord '0', fromIntegral $ n `mod` 10 + ord '0') ((y1,y2),(y3,y4)) = (split2 (y `div` 100), split2 (y `mod` 100)) (m1, m2) = split2 m (d1, d2) = split2 d (h1, h2) = split2 h (mi1, mi2) = split2 mi (s1, s2) = split2 s putUTCTime :: (Int, Int, Int, Int, Int, Int, Bool) -> ByteString putUTCTime time = putTime False time putGeneralizedTime :: (Int, Int, Int, Int, Int, Int, Bool) -> ByteString putGeneralizedTime time = putTime True time putInteger :: Integer -> ByteString putInteger i = B.pack $ bytesOfInt i putBitString :: Int -> L.ByteString -> ByteString putBitString i bits = B.concat $ B.singleton (fromIntegral i) : L.toChunks bits putString :: L.ByteString -> ByteString putString l = B.concat $ L.toChunks l {- no enforce check that oid1 is between [0..2] and oid2 is between [0..39] -} putOID :: [Integer] -> ByteString putOID oids = B.cons eoidclass subeoids where (oid1:oid2:suboids) = oids eoidclass = fromIntegral (oid1 * 40 + oid2) encode x | x == 0 = B.singleton 0 | otherwise = putVarEncodingIntegral x subeoids = B.concat $ map encode suboids