{- |
Loading MIDI Files

This module loads and parses a MIDI File.
It can convert it into a 'MIDIFile.T' data type object or
simply print out the contents of the file.
-}

{-
The MIDI file format is quite similar to the Interchange File Format (IFF)
of Electronic Arts.
But it seems to be not sensible
to re-use functionality from the @iff@ package.
-}
module Sound.MIDI.File.Load (fromFile, fromStream, maybeFromStream, showFile)
 where

import           Sound.MIDI.File
import qualified Sound.MIDI.File as MIDIFile
import qualified Sound.MIDI.Event as MIDIEvent
import qualified Data.EventList.Relative.TimeBody as EventList
import qualified Numeric.NonNegative.Wrapper as NonNeg

import Sound.MIDI.IO (ByteString, readBinaryFile, stringCharFromByte)
import qualified Sound.MIDI.Bit as Bit
import Data.Bits (testBit, clearBit)
import Data.Word (Word8)
import Data.Maybe (mapMaybe, fromMaybe)
import           Sound.MIDI.String (unlinesS)
import qualified Sound.MIDI.Parser as Parser
import qualified Sound.MIDI.ParserState as ParserState
import qualified Control.Monad.State as State
import Control.Monad (replicateM, liftM, liftM2)
import Control.Monad.State (StateT(StateT, runStateT), evalStateT, lift)

{- |
The main load function.
-}
fromFile :: FilePath -> IO MIDIFile.T
fromFile = liftM fromStream . readBinaryFile

fromStream :: ByteString -> MIDIFile.T
fromStream contents =
   case maybeFromStream contents of
      Right (mf,[]) -> mf
      Right _       -> error "Garbage left over." -- return mf
      Left msg      -> error ("MIDI.Load.fromStream: " ++ msg)
      -- error "Error reading midi file: unfamiliar format or file corrupt."

maybeFromStream :: ByteString -> Either String (MIDIFile.T, ByteString)
maybeFromStream = evalParser parse

evalParser :: ByteParser a -> ByteString -> Either String (a, ByteString)
evalParser = runStateT


{- |
A MIDI file is made of /chunks/, each of which is either a /header chunk/
or a /track chunk/.  To be correct, it must consist of one header chunk
followed by any number of track chunks, but for robustness's sake we ignore
any non-header chunks that come before a header chunk.  The header tells us
the number of tracks to come, which is passed to 'getTracks'.
-}
parse :: ByteParser MIDIFile.T
parse =
   getChunk >>= \ chunk ->
      case chunk of
         Header (format, nTracks, division) ->
            liftM
               (MIDIFile.Cons format division .
                map removeEndOfTrack .
                mapMaybe trackFromChunk)
               (replicateM (NonNeg.toNumber nTracks) getChunk)
         _ -> parse

{- |
Check if a chunk contains a track.
Like 'parse', if a chunk is not a track chunk, it is just ignored.
-}
trackFromChunk :: Chunk -> Maybe Track
trackFromChunk (Track t) = Just t
trackFromChunk  _        = Nothing

{- |
There are two ways to mark the end of the track:
The end of the event list and the meta event 'EndOfTrack'.
Thus the end marker is redundant and we remove a 'EndOfTrack'
at the end of the track
and complain about all 'EndOfTrack's within the event list.
-}
removeEndOfTrack :: Track -> Track
removeEndOfTrack xs =
   fromMaybe
      (error "Empty track, missing EndOfTrack")
      (do (initEvents, lastEvent) <- EventList.viewR xs
          let (eots, track) =
                 EventList.partition isEndOfTrack initEvents
          if EventList.null eots
            then return ()
            else error "EndOfTrack inside a track"
          if isEndOfTrack (snd lastEvent)
            then return ()
            else error "Track does not end with EndOfTrack"
          return track)

isEndOfTrack :: Event -> Bool
isEndOfTrack ev =
   case ev of
      MetaEvent EndOfTrack -> True
      _ -> False

{-
removeEndOfTrack :: Track -> Track
removeEndOfTrack =
   maybe
      (error "Track does not end with EndOfTrack")
      (\(ev,evs) ->
          case snd ev of
             MetaEvent EndOfTrack ->
                if EventList.null evs
                  then evs
                  else error "EndOfTrack inside a track"
             _ -> uncurry EventList.cons ev (removeEndOfTrack evs)) .
      EventList.viewL
-}

{- |
Parse a chunk, whether a header chunk, a track chunk, or otherwise.
A chunk consists of a four-byte type code
(a header is @MThd@; a track is @MTrk@),
four bytes for the size of the coming data,
and the data itself.
-}
getChunk :: ByteParser Chunk
getChunk =
   do
      (ty, body) <- getPlainChunk
      case ty of
        "MThd" ->
           return $ Header $
              case evalParser getHeader body of
                 Right (hd,[]) -> hd
                 Right (_,_)   -> error "header chunk too large"
                 Left msg      -> error ("getChunk header: " ++ msg)
        "MTrk" ->
           return $ Track $
           either (\msg -> error ("getChunk track: " ++ msg)) id $
           evalStateT (evalStateT getTrack initReadEvent) body
        _ -> return (AlienChunk ty body)

data Chunk =
     Header (MIDIFile.Type, NonNeg.Int, Division)
   | Track Track
   | AlienChunk String ByteString
  deriving Eq

{- |
Parse a Header Chunk.  A header consists of a format (0, 1, or 2),
the number of track chunks to come, and the smallest time division
to be used in reading the rest of the file.
-}
getHeader :: ByteParser (MIDIFile.Type, NonNeg.Int, Division)
getHeader =
   do
      format   <- liftM toEnum get2
      nTracks  <- liftM (NonNeg.fromNumberMsg "MIDI.Load.getHeader") get2
      division <- getDivision
      return (format, nTracks, division)

{- |
The division is implemented thus: the most significant bit is 0 if it's
in ticks per quarter note; 1 if it's an SMPTE value.
-}
getDivision :: ByteParser Division
getDivision =
   do
      x <- get1
      y <- get1
      return (if x < 128
                then Ticks (NonNeg.fromNumberMsg "MIDI.Load.getDivision" (x*256+y))
                else SMPTE (256-x) y)

{- |
A track is a series of events.  Parse a track, stopping when the size
is zero.
-}
getTrack :: TrackParser MIDIFile.Track
getTrack =
   liftM
      EventList.fromPairList
      (ParserState.zeroOrMore getSchedEvent)

{- |
Each event is preceded by the delta time: the time in ticks between the
last event and the current event.  Parse a time and an event, ignoring
System Exclusive messages.
-}
getSchedEvent :: TrackParser MIDIFile.SchedEvent
getSchedEvent  =  liftM2 (,) (lift getVar) getEvent

{- |
Parse an event.  Note that in the case of a regular MIDI Event, the tag is
the status, and we read the first byte of data before we call 'getMIDIEvent'.
In the case of a MIDIEvent with running status, we find out the status from
the parser (it's been nice enough to keep track of it for us), and the tag
that we've already gotten is the first byte of data.
-}
getEvent :: TrackParser MIDIFile.Event
getEvent =
   do
      tag <- lift get1
      case tag of
        240 -> liftM SysExStart $ lift (getBigN =<< getVar)
        247 -> liftM SysExCont  $ lift (getBigN =<< getVar)
        255 -> lift $
           do
              code <- get1
              size <- getVar
              liftM MetaEvent (getMetaEvent code size)
        x -> if x>127
               then let parseEv = decodeStatus tag
                    in  putEventParser parseEv >> lift (get1 >>= parseEvent parseEv)
               else -- running status
                    lift . flip parseEvent tag =<< getEventParser

{- |
Simpler version of 'getTrack', used in the Show functions.
-}
getPlainTrack :: TrackParser MIDIFile.Track
getPlainTrack =
   liftM
      EventList.fromPairList
      (ParserState.oneOrMore getSchedEvent)


newtype EventParser =
   EventParser {parseEvent :: Int -> ByteParser MIDIFile.Event}

{- |
Find out the status (MIDIEvent type and channel) given a byte of data.
-}
decodeStatus :: Int -> EventParser
decodeStatus tag =
   let (code, channel) = Bit.splitAt 4 tag
   in  EventParser (getMIDIEvent code (MIDIEvent.toChannel channel))

{- |
Parse a MIDI Event.
Note that since getting the first byte is a little complex
(there are issues with running status),
it has already been handled for us by 'getEvent'.
-}
getMIDIEvent :: Int -> MIDIEvent.Channel -> Int -> ByteParser MIDIFile.Event
getMIDIEvent code channel firstData =
   let pitch  = MIDIEvent.toPitch firstData
       getVel = liftM MIDIEvent.toVelocity get1
       getME =
          case code of
            08 -> liftM (MIDIEvent.NoteOff   pitch) getVel
            09 -> liftM (MIDIEvent.NoteOn    pitch) getVel
            10 -> liftM (MIDIEvent.PolyAfter pitch) get1
            11 -> liftM (MIDIEvent.Control (toEnum firstData)) get1
            12 -> return (MIDIEvent.ProgramChange (MIDIEvent.toProgram firstData))
            13 -> return (MIDIEvent.MonoAfter  firstData)
            14 -> liftM (\msb -> MIDIEvent.PitchBend (firstData+128*msb)) get1
            _  -> fail ("invalid MIDIEvent code:" ++ show code)
   in  liftM (MIDIEvent channel) getME

{- |
Parse a MetaEvent.
-}
getMetaEvent :: Int -> NonNeg.Integer -> ByteParser MetaEvent
getMetaEvent code size =
   case code of
      000 -> liftM SequenceNum get2
      001 -> getText size TextEvent
      002 -> getText size Copyright
      003 -> getText size TrackName
      004 -> getText size InstrName
      005 -> getText size Lyric
      006 -> getText size Marker
      007 -> getText size CuePoint

      032 -> liftM (MIDIPrefix . MIDIEvent.toChannel) get1
      047 -> return EndOfTrack
      081 -> liftM (SetTempo . NonNeg.fromNumberMsg "MIDI.Load.getMetaEvent") get3

      084 -> do {hrs    <- get1 ; mins <- get1 ; secs <- get1;
                 frames <- get1 ; bits <- get1 ;
                 return (SMPTEOffset hrs mins secs frames bits)}

      088 -> do
                n <- get1
                d <- get1
                c <- get1
                b <- get1
                return (TimeSig n d c b)

      089 -> do
                sf <- get1
                mi <- get1
                return (KeySig (toKeyName sf) (toEnum mi))

      127 -> liftM SequencerSpecific (getBigN size)

      _   -> liftM (Unknown code) (getBigN size)

getText :: NonNeg.Integer -> (String -> MetaEvent) -> ByteParser MetaEvent
getText size c  =  liftM c (getString size)

toKeyName :: Int -> Key
toKeyName sf = toEnum (mod (sf+7) 15)

{- |
'getByte' gets a single byte from the input.
-}
getByte :: ByteParser Word8
getByte = StateT $ maybe (Left "reached end of file") Right . viewL

viewL :: [a] -> Maybe (a, [a])
viewL xs =
   case xs of
      []     -> Nothing
      (c:cs) -> Just (c,cs)

{- |
@getN n@ returns n characters (bytes) from the input.
-}
getN :: NonNeg.Int -> ByteParser ByteString
getN n = replicateM (NonNeg.toNumber n) getByte

getString :: NonNeg.Integer -> ByteParser String
getString n = liftM stringCharFromByte (getBigN n)

getBigN :: NonNeg.Integer -> ByteParser ByteString
getBigN n =
   sequence $
   Bit.replicateBig
      (succ (fromIntegral (maxBound :: NonNeg.Int)))
      (NonNeg.toNumber n)
      getByte


{- |
'get1', 'get2', 'get3', and 'get4' take 1-, 2-, 3-, or
4-byte numbers from the input (respectively), convert the base-256 data
into a single number, and return.
-}
get1 :: ByteParser Int
get1 = liftM fromIntegral getByte

getNByteInt :: NonNeg.Int -> ByteParser Int
getNByteInt n =
   liftM Bit.fromBytes (replicateM (NonNeg.toNumber n) get1)

get2, get3, get4 :: ByteParser Int
get2 = getNByteInt 2
get3 = getNByteInt 3
get4 = getNByteInt 4

{- |
/Variable-length quantities/ are used often in MIDI notation.
They are represented in the following way:
Each byte (containing 8 bits) uses the 7 least significant bits to store information.
The most significant bit is used to signal whether or not more information is coming.
If it's @1@, another byte is coming.
If it's @0@, that byte is the last one.
'getVar' gets a variable-length quantity from the input.
-}
getVar :: ByteParser NonNeg.Integer
getVar =
   liftM (Bit.fromBase (2^(7::Int)) . map fromIntegral) getVarBytes

{- |
The returned list contains only bytes with the most significant bit cleared.
These are digits of a 128-ary number.
-}
getVarBytes :: ByteParser [Word8]
getVarBytes =
   do
      digit <- getByte
      if flip testBit 7 digit            -- if it's the last byte
        then liftM (flip clearBit 7 digit :) getVarBytes
        else return [digit]

{- |
Functions to show the decoded contents of a MIDI file in an easy-to-read format.
-}
showFile :: FilePath -> IO ()
showFile fileName = putStr . showChunks =<< readBinaryFile fileName

showChunks :: ByteString -> String
showChunks mf = showMR getChunks (unlinesS . map pp) mf ""
 where
  pp :: (String, ByteString) -> ShowS
  pp ("MThd",contents) =
    showString "Header: " .
    showMR getHeader shows contents
  pp ("MTrk",contents) =
    showString "Track:\n" .
    showMR (evalStateT getPlainTrack initReadEvent)
        (\track str ->
            EventList.foldr
               MIDIFile.showTime
               (\e -> MIDIFile.showEvent e . showString "\n")
               str track)
        contents
  pp (ty,contents) =
    showString "Chunk: " .
    showString ty .
    showString " " .
    shows contents .
    showString "\n"

showMR :: ByteParser a -> (a->ShowS) -> ByteString -> ShowS
showMR m pp contents =
  case evalParser m contents of
    Left msg ->
       showString "Parse failed: " . showString (msg++"\n") . shows contents
    Right (a,[]  ) -> pp a
    Right (a,junk) -> pp a . showString "Junk: " . shows junk


{- |
These two functions, the 'getPlainChunk' and 'getChunks' parsers,
do not combine directly into a single master parser.
Rather, they should be used to chop parts of a midi file
up into chunks of bytes which can be outputted separately.

Chop a MIDI file into chunks returning:

* list of /chunk-type/-contents pairs; and
* leftover slop (should be empty in correctly formatted file)

-}
getChunks :: ByteParser [(String, ByteString)]
getChunks = Parser.zeroOrMore getPlainChunk

getPlainChunk :: ByteParser (String, ByteString)
getPlainChunk =
   liftM2 (,)
      (getString 4)  -- chunk type: header or track
      (getN . NonNeg.fromNumberMsg "getPlainChunk" =<< get4)
                  -- chunk body


type ByteParser a = Parser.T ByteString a

{- |
The 'TrackParser' monad parses a track of a MIDI File.
In MIDI, a shortcut is used for long strings of similar MIDI events:
If a stream of consecutive events all have the same type and channel,
the type and channel can be omitted for all but the first event.
To implement this /feature/,
the parser must keep track of the type and channel of the most recent MIDI Event.
This is done by setting a 'EventParser' parser,
which parses the data bytes according to the currently active event type.
-}
type TrackParser a = ParserState.T EventParser ByteString a


putEventParser  :: EventParser -> TrackParser ()
putEventParser = State.put

getEventParser :: TrackParser EventParser
getEventParser = State.get

initReadEvent :: EventParser
initReadEvent =
   EventParser (const (return (error "At beginning, no event type set so far")))