module Data.ASN1.Raw
	( GetErr
	
	, runGetErr
	, runGetErrInGet
	
	, ASN1Err(..)
	, CheckFn
	, TagClass(..)
	, TagNumber
	, ValLength(..)
	, ValStruct(..)
	, Value(..)
	
	, getValueCheck
	, getValue
	
	, putValuePolicy
	, putValue
	) where
import Data.Bits
import Data.ASN1.Internal
import Data.Binary.Get
import Data.Binary.Put
import Data.ByteString (ByteString)
import Data.Word
import Control.Monad.Error
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
data TagClass =
	  Universal
	| Application
	| Context
	| Private
	deriving (Show, Eq)
data ValLength =
	  LenShort Int      
	| LenLong Int Int   
	| LenIndefinite     
	deriving (Show, Eq)
type TagNumber = Int
type TagConstructed = Bool
type Identifier = (TagClass, TagConstructed, TagNumber)
data ValStruct =
	  Primitive ByteString 
	| Constructed [Value]  
	deriving (Show, Eq)
data Value = Value TagClass TagNumber ValStruct
	deriving (Show, Eq)
data ASN1Err =
	  ASN1LengthDecodingLongContainsZero
	| ASN1PolicyFailed String String
	| ASN1NotImplemented String
	| ASN1Multiple [ASN1Err]
	| ASN1Misc String
	deriving (Show, Eq)
type CheckFn = (TagClass, Bool, TagNumber) -> ValLength -> Maybe ASN1Err
instance Error ASN1Err where
	noMsg = ASN1Misc ""
	strMsg = ASN1Misc
newtype GetErr a = GE { runGE :: ErrorT ASN1Err Get a }
	deriving (Monad, MonadError ASN1Err)
instance Functor GetErr where
	fmap f = GE . fmap f . runGE
runGetErr :: GetErr a -> L.ByteString -> Either ASN1Err a
runGetErr = runGet . runErrorT . runGE
runGetErrInGet :: GetErr a -> Get (Either ASN1Err a)
runGetErrInGet = runErrorT . runGE
liftGet :: Get a -> GetErr a
liftGet = GE . lift
geteWord8 :: GetErr Word8
geteWord8 = liftGet getWord8
geteBytes :: Int -> GetErr ByteString
geteBytes = liftGet . getBytes
getTagNumberLong :: GetErr TagNumber
getTagNumberLong = getNext 0 True
	where getNext n nz = do
		t <- fromIntegral `fmap` geteWord8
		when (nz && t == 0x80) $ throwError ASN1LengthDecodingLongContainsZero
		if testBit t 7
			then getNext (n `shiftL` 7 + clearBit t 7) False
			else return (n `shiftL` 7 + t)
putTagNumberLong :: TagNumber -> Put
putTagNumberLong n = mapM_ putWord8 $ revSethighbits $ split7bits n
	where
		revSethighbits :: [Word8] -> [Word8]
		revSethighbits []     = []
		revSethighbits (x:xs) = reverse $ (x : map (\i -> setBit i 7) xs)
		split7bits i
			| i == 0    = []
			| i <= 0x7f = [ fromIntegral i ]
			| otherwise = fromIntegral (i .&. 0x7f) : split7bits (i `shiftR` 7)
		
getIdentifier :: GetErr Identifier
getIdentifier = do
	w <- geteWord8
	let cl =
		case (w `shiftR` 6) .&. 3 of
			0 -> Universal
			1 -> Application 
			2 -> Context
			3 -> Private
			_ -> Universal 
	let pc = (w .&. 0x20) > 0
	let val = fromIntegral (w .&. 0x1f)
	vencoded <- if val < 0x1f then return val else getTagNumberLong
	return (cl, pc, vencoded)
putIdentifier :: Identifier -> Put
putIdentifier (cl, pc, val) = do
	let cli = case cl of
		Universal   -> 0
		Application -> 1
		Context     -> 2
		Private     -> 3
	let pcval = if pc then 0x20 else 0x00
	if val < 0x1f
		then
			putWord8 $ fromIntegral $ (cli `shiftL` 6) .|. pcval .|. val
		else do
			putWord8 $ fromIntegral $ (cli `shiftL` 6) .|. pcval .|. 0x1f
			putTagNumberLong val
getLength :: GetErr ValLength
getLength = do
	l1 <- geteWord8
	if testBit l1 7
		then case fromIntegral (clearBit l1 7) of
			0   -> return LenIndefinite
			len -> do
				lw <- geteBytes len
				return $ LenLong len (fromIntegral $ snd $ uintOfBytes lw)
		else
			return $ LenShort $ fromIntegral l1
putLength :: ValLength -> Put
putLength (LenShort i)
	| i < 0 || i > 0x7f = error "putLength: short length is not between 0x0 and 0x80"
	| otherwise         = putWord8 $ fromIntegral i
putLength (LenLong _ i)
	| i < 0     = error "putLength: long length is negative"
	| otherwise = putWord8 lenbytes >> mapM_ putWord8 lw
		where
			lw       = bytesOfUInt $ fromIntegral i
			lenbytes = fromIntegral (length lw .|. 0x80)
	
putLength (LenIndefinite) = putWord8 0x80
getValueConstructed :: CheckFn -> GetErr [Value]
getValueConstructed check = do
	remain <- liftGet remaining
	if remain > 0
		then liftM2 (:) (getValueCheck check) (getValueConstructed check)
		else return []
getValueConstructedUntilEOC :: CheckFn -> GetErr [Value]
getValueConstructedUntilEOC check = do
	o <- getValueCheck check
	case o of
		
		Value Universal 0 _ -> return []
		_                   -> liftM (o :) (getValueConstructedUntilEOC check)
getValueOfLength :: CheckFn -> Int -> Bool -> GetErr ValStruct
getValueOfLength check len pc = do
	b <- geteBytes len
	if pc
		then case runGetErr (getValueConstructed check) (L.fromChunks [b]) of
			Right x  -> return $ Constructed x
			Left err -> throwError err
		else
			return $ Primitive b
getValueCheck :: CheckFn -> GetErr Value
getValueCheck check = do
	(tc, pc, tn) <- getIdentifier
	vallen <- getLength
	
	case check (tc, pc, tn) vallen of
		Just err -> throwError err
		Nothing  -> return ()
	struct <- case vallen of
		LenIndefinite -> do
			unless pc $ throwError $ ASN1Misc "lenght indefinite not allowed with primitive"
			vs <- getValueConstructedUntilEOC check
			return $ Constructed vs
		(LenShort len)  -> getValueOfLength check len pc
		(LenLong _ len) -> getValueOfLength check len pc
	return $ Value tc tn struct
getValue :: GetErr Value
getValue = getValueCheck (\_ _ -> Nothing)
putValStruct :: ValStruct -> Put
putValStruct (Primitive x)   = putByteString x
putValStruct (Constructed l) = mapM_ putValue l
putValuePolicy :: (Value -> Int -> ValLength) -> Value -> Put
putValuePolicy policy v@(Value tc tn struct) = do
	let pc =
		case struct of
			Primitive _   -> False
			Constructed _ -> True
	putIdentifier (tc, pc, tn)
	let content = runPut (putValStruct struct)
	let len = fromIntegral $ L.length content 
	let lenEncoded = policy v len
	putLength lenEncoded
	putLazyByteString content
	case lenEncoded of
		LenIndefinite -> putValue $ Value Universal 0x0 (Primitive B.empty)
		_             -> return ()
putValue :: Value -> Put
putValue = putValuePolicy (\_ len -> if len < 0x80 then LenShort len else LenLong 0 len)