module Data.ASN1.Serialize (getHeader, putHeader) where
import qualified Data.ByteString as B
import Data.ASN1.Get
import Data.ASN1.Internal
import Data.ASN1.Types
import Data.ASN1.Types.Lowlevel
import Data.Bits
import Data.Word
import Control.Applicative ((<$>))
import Control.Monad
getHeader :: Get ASN1Header
getHeader = do
    (cl,pc,t1) <- parseFirstWord <$> getWord8
    tag        <- if t1 == 0x1f then getTagLong else return t1
    len        <- getLength
    return $ ASN1Header cl tag pc len
parseFirstWord :: Word8 -> (ASN1Class, Bool, ASN1Tag)
parseFirstWord w = (cl,pc,t1)
  where cl = toEnum $ fromIntegral $ (w `shiftR` 6)
        pc = testBit w 5
        t1 = fromIntegral (w .&. 0x1f)
getTagLong :: Get ASN1Tag
getTagLong = do
    t <- fromIntegral <$> getWord8
    when (t == 0x80) $ fail "non canonical encoding of long tag"
    if testBit t 7
        then loop (clearBit t 7)
        else return t
  where loop n = do
            t <- fromIntegral <$> getWord8
            if testBit t 7
                then loop (n `shiftL` 7 + clearBit t 7)
                else return (n `shiftL` 7 + t)
getLength :: Get ASN1Length
getLength = do
    l1 <- fromIntegral <$> getWord8
    if testBit l1 7
        then case clearBit l1 7 of
            0   -> return LenIndefinite
            len -> do
                lw <- getBytes len
                return (LenLong len $ uintbs lw)
        else
            return (LenShort l1)
  where
        
        uintbs = B.foldl (\acc n -> (acc `shiftL` 8) + fromIntegral n) 0
putHeader :: ASN1Header -> B.ByteString
putHeader (ASN1Header cl tag pc len) = B.concat
    [ B.singleton word1
    , if tag < 0x1f then B.empty else tagBS
    , lenBS]
  where cli   = shiftL (fromIntegral $ fromEnum cl) 6
        pcval = shiftL (if pc then 0x1 else 0x0) 5
        tag0  = if tag < 0x1f then fromIntegral tag else 0x1f
        word1 = cli .|. pcval .|. tag0
        lenBS = B.pack $ putLength len
        tagBS = putVarEncodingIntegral tag
putLength :: ASN1Length -> [Word8]
putLength (LenShort i)
    | i < 0 || i > 0x7f = error "putLength: short length is not between 0x0 and 0x80"
    | otherwise         = [fromIntegral i]
putLength (LenLong _ i)
    | i < 0     = error "putLength: long length is negative"
    | otherwise = lenbytes : lw
        where
            lw       = bytesOfUInt $ fromIntegral i
            lenbytes = fromIntegral (length lw .|. 0x80)
putLength (LenIndefinite) = [0x80]