{- |
Parser which limits the input data to a given number of bytes.
We need this for parsing MIDI tracks and some MetaEvents,
where the length of a part is fixed by a length specification.
-}
module Sound.MIDI.Parser.Restricted
   (T(..), run, ) where

import qualified Sound.MIDI.Parser.Class as Parser

import Control.Monad.State
   (StateT(StateT, runStateT), mapStateT,
    get, put, liftM, lift, when, )

import qualified Numeric.NonNegative.Wrapper as NonNeg

import Prelude hiding (replicate, until, )


run :: Parser.C parser =>
   NonNeg.Integer -> T parser a -> parser a
run maxLen p =
   do (x,remaining) <- runStateT (decons p) maxLen
      Parser.force $ when
         (remaining>0)
         (Parser.warn ("unparsed bytes left in part (" ++ show remaining ++ " bytes)"))
      return x



newtype T parser a =
   Cons {decons :: StateT NonNeg.Integer parser a}

instance Monad parser => Monad (T parser) where
   return = Cons . return
   x >>= y = Cons $ decons . y =<< decons x

instance Parser.C parser => Parser.C (T parser) where
   isEnd =
     Cons $ get >>= \remaining ->
       if remaining==0 then return True else lift Parser.isEnd
   getByte =
     Cons $ get >>= \remaining ->
       do when (remaining==0)
             (lift $ Parser.giveUp "unexpected end of part")
{- in principle not necessary, because Parser.getByte must check for remaining bytes
          end <- lift Parser.isEnd
          when end
             (lift $ Parser.giveUp "part longer than container")
-}
          put (remaining-1)
          lift Parser.getByte
   skip n =
     Cons $ get >>= \remaining ->
       if n>remaining
         then lift $ Parser.giveUp "skip beyond end of part"
         else put (remaining-n) >> lift (Parser.skip n)
   warn   = Cons . lift . Parser.warn
   -- giveUp = Cons . lift . giveUp
   giveUp errMsg =
      Cons $ StateT $ \remain ->
         Parser.skip remain >> Parser.giveUp errMsg
   try (Cons st) =
      Cons $ StateT $ \remain0 ->
         liftM (either
                 (\errMsg -> (Left errMsg, 0))
                 (\(x,remain1) -> (Right x, remain1))) $
         Parser.try (runStateT st remain0)
   force (Cons st) =
      Cons $ mapStateT Parser.force st