module Data.ASN1.Raw (
GetErr,
runGetErr,
runGetErrInGet,
ASN1Err(..),
CheckFn,
TagClass(..),
TagNumber,
ValLength(..),
ValStruct(..),
Value(..),
getValueCheck,
getValue,
putValuePolicy,
putValue
) where
import Data.Bits
import Data.ASN1.Internal
import Data.Binary.Get
import Data.Binary.Put
import Data.ByteString (ByteString)
import Data.Word
import Control.Monad.Error
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
data TagClass =
Universal
| Application
| Context
| Private
deriving (Show, Eq)
data ValLength =
LenShort Int
| LenLong Int Int
| LenIndefinite
deriving (Show, Eq)
type TagNumber = Int
type TagConstructed = Bool
type Identifier = (TagClass, TagConstructed, TagNumber)
data ValStruct =
Primitive ByteString
| Constructed [Value]
deriving (Show, Eq)
data Value = Value TagClass TagNumber ValStruct
deriving (Show, Eq)
data ASN1Err =
ASN1LengthDecodingLongContainsZero
| ASN1PolicyFailed String String
| ASN1NotImplemented String
| ASN1Multiple [ASN1Err]
| ASN1Misc String
deriving (Show, Eq)
type CheckFn = (TagClass, Bool, TagNumber) -> ValLength -> Maybe ASN1Err
instance Error ASN1Err where
noMsg = ASN1Misc ""
strMsg = ASN1Misc
newtype GetErr a = GE { runGE :: ErrorT ASN1Err Get a }
deriving (Monad, MonadError ASN1Err)
instance Functor GetErr where
fmap f = GE . fmap f . runGE
runGetErr :: GetErr a -> L.ByteString -> Either ASN1Err a
runGetErr f b = runGet (runErrorT (runGE f)) b
runGetErrInGet :: GetErr a -> Get (Either ASN1Err a)
runGetErrInGet f = runErrorT (runGE f)
liftGet :: Get a -> GetErr a
liftGet = GE . lift
geteWord8 :: GetErr Word8
geteWord8 = liftGet getWord8
geteBytes :: Int -> GetErr ByteString
geteBytes = liftGet . getBytes
getTagNumberLong :: Bool -> GetErr TagNumber
getTagNumberLong nz = do
t <- geteWord8
let tval = fromIntegral (t .&. 0x7f)
when (nz && tval == 0) $ error "long tag encoding failure: first value is 0"
if (t .&. 0x80) > 0
then do
trest <- getTagNumberLong False
return ((tval `shiftL` 7) + trest)
else
return tval
putTagNumberLong :: TagNumber -> Put
putTagNumberLong i = do
if i > 0x7f
then do
putWord8 $ fromIntegral (0x80 .|. (i .&. 0x7f))
putTagNumberLong (i `shiftR` 7)
else
putWord8 $ fromIntegral (i .&. 0x7f)
getIdentifier :: GetErr Identifier
getIdentifier = do
w <- geteWord8
let cl =
case (w `shiftR` 7) .&. 3 of
0 -> Universal
1 -> Application
2 -> Context
3 -> Private
_ -> Universal
let pc = (w .&. 0x20) > 0
let val = fromIntegral (w .&. 0x1f)
vencoded <- if val < 0x1f then return val else getTagNumberLong True
return $ (cl, pc, vencoded)
putIdentifier :: Identifier -> Put
putIdentifier (cl, pc, val) = do
let cli = case cl of
Universal -> 0
Application -> 1
Context -> 2
Private -> 3
let pcval = if pc then 0x20 else 0x00
if val < 0x1f
then
putWord8 $ fromIntegral $ (cli `shiftL` 7) .|. pcval .|. (val)
else do
putWord8 $ fromIntegral $ (cli `shiftL` 7) .|. pcval .|. 0x1f
putTagNumberLong val
getLength :: GetErr ValLength
getLength = do
l1 <- geteWord8
if testBit l1 7
then do
case fromIntegral (clearBit l1 7) of
0 -> return LenIndefinite
len -> do
lw <- geteBytes len
return $ LenLong len (fromIntegral $ snd $ uintOfBytes lw)
else
return $ LenShort $ fromIntegral l1
putLength :: ValLength -> Put
putLength (LenShort i) = do
when (i < 0 || i > 0x7f) (error "putLength: short length is not between 0x0 and 0x80")
putWord8 (fromIntegral i)
putLength (LenLong _ i) = do
when (i < 0) (error "putLength: long length is negative")
let lw = bytesOfUInt $ fromIntegral i
let lenbytes = fromIntegral (length lw .|. 0x80)
putWord8 lenbytes
mapM_ putWord8 lw
putLength (LenIndefinite) = putWord8 0x80
getValueConstructed :: CheckFn -> GetErr [Value]
getValueConstructed check = do
remain <- liftGet $ remaining
if remain > 0
then do
o <- getValueCheck check
l <- getValueConstructed check
return (o : l)
else
return []
getValueConstructedUntilEOC :: CheckFn -> GetErr [Value]
getValueConstructedUntilEOC check = do
o <- getValueCheck check
case o of
Value Universal 0 _ -> return []
_ -> do
l <- getValueConstructedUntilEOC check
return (o : l)
getValueOfLength :: CheckFn -> Int -> Bool -> GetErr ValStruct
getValueOfLength check len pc = do
b <- geteBytes len
case pc of
True -> do
case runGetErr (getValueConstructed check) (L.fromChunks [b]) of
Right x -> return $ Constructed x
Left err -> throwError err
False -> return $ Primitive b
getValueCheck :: CheckFn -> GetErr Value
getValueCheck check = do
(tc, pc, tn) <- getIdentifier
vallen <- getLength
case check (tc, pc, tn) vallen of
Just err -> throwError err
Nothing -> return ()
struct <- case vallen of
LenIndefinite -> do
when (not pc) $ throwError $ ASN1Misc "lenght indefinite not allowed with primitive"
vs <- getValueConstructedUntilEOC check
return $ Constructed vs
(LenShort len) -> getValueOfLength check len pc
(LenLong _ len) -> getValueOfLength check len pc
return $ Value tc tn struct
getValue :: GetErr Value
getValue = getValueCheck (\_ _ -> Nothing)
putValStruct :: ValStruct -> Put
putValStruct (Primitive x) = putByteString x
putValStruct (Constructed l) = mapM_ putValue l
putValuePolicy :: (Value -> Int -> ValLength) -> Value -> Put
putValuePolicy policy v@(Value tc tn struct) = do
let pc =
case struct of
Primitive _ -> False
Constructed _ -> True
putIdentifier (tc, pc, tn)
let content = runPut (putValStruct struct)
let len = fromIntegral $ L.length content
let lenEncoded = policy v len
putLength lenEncoded
putLazyByteString content
case lenEncoded of
LenIndefinite -> putValue $ Value Universal 0x0 (Primitive B.empty)
_ -> return ()
putValue :: Value -> Put
putValue = putValuePolicy (\_ len -> if len < 0x80 then LenShort len else LenLong 0 len)