{-# LANGUAGE ExistentialQuantification #-}
module Sound.MIDI.ALSA.Causal where

import Sound.MIDI.ALSA.Common (Time, TimeAbs, )
import qualified Sound.MIDI.ALSA.Common as Common

import qualified Sound.ALSA.Sequencer.Address as Addr
import qualified Sound.ALSA.Sequencer.Event as Event

import qualified Sound.MIDI.ALSA as MALSA
import qualified Sound.MIDI.Message.Channel as ChannelMsg
import qualified Sound.MIDI.Message.Channel.Voice as VoiceMsg

import Sound.MIDI.ALSA (normalNoteFromEvent, )
import Sound.MIDI.Message.Channel (Channel, )
import Sound.MIDI.Message.Channel.Voice (Controller, Program, )

import qualified Data.EventList.Relative.TimeBody as EventList
import qualified Data.EventList.Relative.MixedBody as EventListMB
import qualified Data.EventList.Absolute.TimeBody as EventListAbs

import qualified Data.Accessor.Monad.Trans.State as AccM
import qualified Data.Accessor.Tuple as AccTuple
import qualified Data.Accessor.Basic as Acc
import Data.Accessor.Basic ((^.), )

import Data.Tuple.HT (mapFst, mapSnd, mapPair, )
import Data.Ord.HT (limit, )
import qualified Data.List as List

import qualified Data.Map as Map

import qualified Control.Category as Cat
import qualified Control.Monad.Trans.State as State
import qualified Control.Monad.Trans.Reader as Reader
import Control.Monad.Trans.Reader (ReaderT, )
import Control.Monad (guard, )

import qualified Data.Monoid as Mn
import qualified Data.List as List
import Data.Word (Word8, )

import Prelude hiding (init, map, filter, )


{- |
The returned event list must be finite.
-}
data T a b =
   forall s c.
   Cons
      (TimeAbs -> Either c a ->
       State.State s (Maybe b, EventList.T Time c))
      s (EventList.T Time c)

{-
data T a b =
   forall s c.
   Cons (Time -> Either c a ->
         State.State s (Maybe b, Maybe (Time,c)))
-}
{-
This design allows to modify a trigger event until it fires.
However, when can we ship it?
We only know, if a later event comes in,
that the trigger would have been shipped already.

data T a b =
   forall s c.
   Cons (Time -> Either c a ->
         State.State (s, Maybe (Time,c)) (Maybe b))
-}
{-
data T a b =
   forall s c.
   Cons (Time -> Maybe a ->
         State.State (s, EventList.T Time c) (Maybe b))
-}


map :: (a -> b) -> T a b
map f = mapMaybe (Just . f)

mapMaybe :: (a -> Maybe b) -> T a b
mapMaybe f =
   Cons
      (\ _t ma ->
         return (either (const Nothing) f ma, EventList.empty))
      () EventList.empty

compose :: T b c -> T a b -> T a c
compose (Cons g sg tg) (Cons f sf tf) =
   Cons (\t ma -> State.state $ \(sf0,sg0) ->
      let ((mb,triggers), sf1) =
             case ma of
                Right a ->
                   mapFst (mapFst (fmap Right)) $
                   State.runState (f t (Right a)) sf0
                Left (Left et) ->
                   mapFst (mapFst (fmap Right)) $
                   State.runState (f t (Left et)) sf0
                Left (Right et) ->
                   ((Just (Left et), EventList.empty), sf0)
          etriggers = fmap Left triggers
      in  mapSnd (\sg1 -> (sf1,sg1)) $
          case mb of
             Nothing ->
                ((Nothing, etriggers), sg0)
             Just b ->
                mapFst (mapSnd
                   (EventList.mergeBy (\_ _ -> False) etriggers . fmap Right)) $
                State.runState (g t b) sg0)
      (sf,sg)
      (EventList.mergeBy (\_ _ -> False) (fmap Left tf) (fmap Right tg))

{- |
Run two stream processor in parallel.
We cannot use the @Arrow@ method @&&&@
since we cannot the @first@ method of the @Arrow@ class.
Consider @first :: arrow a b -> arrow (c,a) (c,b)@
and a trigger where @arrow a b@ generates an event of type @b@.
How could we generate additionally an event of type @c@
without having an input event?
-}
parallel ::
   (Mn.Monoid b) =>
   T a b -> T a b -> T a b
parallel (Cons f sf tf) (Cons g sg tg) =
   Cons (\t ma -> State.state $ \(sf0,sg0) ->
      case ma of
         Right a ->
            let ((b0,triggers0), sf1) =
                   State.runState (f t (Right a)) sf0
                ((b1,triggers1), sg1) =
                   State.runState (g t (Right a)) sg0
            in  ((Mn.mappend b0 b1,
                  EventList.mergeBy (\_ _ -> False)
                     (fmap Left  triggers0)
                     (fmap Right triggers1)),
                 (sf1,sg1))
         Left (Left et) ->
            mapPair
               (mapSnd (fmap Left),
                \sf1 -> (sf1,sg0)) $
            State.runState (f t (Left et)) sf0
         Left (Right et) ->
            mapPair
               (mapSnd (fmap Right),
                \sg1 -> (sf0,sg1)) $
            State.runState (g t (Left et)) sg0)
      (sf,sg)
      (EventList.mergeBy (\_ _ -> False) (fmap Left tf) (fmap Right tg))


instance Cat.Category T where
   id = map id
   (.) = compose


traverse :: s -> (a -> State.State s b) -> T a b
traverse s f =
   Cons
      (\ _t ma ->
         fmap (\r -> (r, EventList.empty)) $
         either (const $ return Nothing) (fmap Just . f) ma)
      s EventList.empty


process ::
   T Event.Data Common.EventDataBundle ->
   ReaderT Common.Handle IO ()
process (Cons f s initTriggers) = do
   Common.startQueue
   Reader.ReaderT $ \h ->
      {-
      Triggers maintains a priority queue parallelly to the queue of ALSA.
      We need this in order to associate Haskell values
      with the incoming trigger events.
      -}
      let outputTriggers triggers =
             EventListAbs.mapM_
                (\t ->
                   Event.output (Common.sequ h)
                      (Common.makeEcho h (Common.deconsTime t) (Event.Custom 0 0 0))
                    >> return ())
                (const $ return ())
                (EventList.toAbsoluteEventList 0 triggers)
          go s0 (lastTime,triggers0) = do
{-
             print (realToFrac lastTime :: Double,
                    List.map
                       ((realToFrac :: TimeAbs -> Double) . Common.deconsTime) $
                    EventList.getTimes triggers0)
-}
             ev <- Event.input (Common.sequ h)
             let time =
                    Common.deconsTime $
                    Common.timeFromStamp (Event.timestamp ev)
                 triggers1 =
                    EventList.decreaseStart
                       (Common.consTime "Causal.process.decreaseStart" (time-lastTime))
                       triggers0
                 (restTriggers1, ((mb,newTriggers), s1)) =
                    case Event.body ev of
                       Event.CustomEv Event.Echo _ ->
                          case (Event.source ev ==
                                   Addr.Cons (Common.client h) (Common.portOut h),
                                EventList.viewL triggers1) of
                             (True, Just ((_,c),restTriggers0)) ->
                                (restTriggers0,
                                 State.runState (f time (Left c)) s0)
                             _ ->
                                (EventList.empty,
                                 ((Nothing, EventList.empty), s0))
                       dat ->
                          (triggers1,
                           State.runState (f time (Right dat)) s0)

             case mb of
                Nothing -> return ()
                Just dats ->
                   flip mapM_ dats $ \(dt,dat) ->
                      Event.output (Common.sequ h)
                         (Common.makeEvent h (Common.incTime dt time) dat)
             outputTriggers
                (EventList.delay (Common.consTime "Causal.process.delay" time) $
                 newTriggers)
             Event.drainOutput (Common.sequ h)
             go s1 (time,
                    EventList.mergeBy (\_ _ -> False)
                       restTriggers1 newTriggers)
      in  outputTriggers initTriggers >>
          Event.drainOutput (Common.sequ h) >>
          go s (0,initTriggers)


transposeBundle :: Int -> T Event.Data Common.EventDataBundle
transposeBundle d =
   map (maybe [] Common.singletonBundle . Common.transpose d)

transpose :: Int -> T Event.Data Event.Data
transpose d =
   mapMaybe (Common.transpose d)

delayAdd ::
   Word8 -> Time -> T Event.Data Common.EventDataBundle
delayAdd decay d =
   map (Common.delayAdd decay d)


pattern ::
   (Common.Selector i, [i]) ->
   Time ->
   T Event.Data Common.EventDataBundle
pattern (select, ixs) dur =
   Cons
      (\ _t ee ->
         case ee of
            Left (n:ns) ->
               State.gets (\keys ->
                  (Just (select n dur $ Map.toAscList keys),
                   EventList.singleton dur ns))
            Left [] ->
               return (Nothing, EventList.empty)
            Right e ->
               fmap (\x -> (x, EventList.empty)) $
               case e of
                  Event.NoteEv notePart note -> do
                     State.modify (Common.updateChord notePart note)
                     return Nothing
                  _ -> return $ Just $ Common.singletonBundle e)
      Map.empty (EventList.singleton 0 ixs)

updateChordDur ::
   (Channel, Controller) ->
   (Time, Time) ->
   Event.Data ->
   State.State
      (Time, Common.KeySet)
      (Maybe Common.EventDataBundle, EventList.T time body)
updateChordDur chanCtrl minMaxDur e =
   case e of
      Event.NoteEv notePart note -> do
         AccM.modify AccTuple.second (Common.updateChord notePart note)
         return (Nothing, EventList.empty)
      Event.CtrlEv Event.Controller param |
            uncurry Common.controllerMatch chanCtrl param -> do
         AccM.set AccTuple.first (Common.updateDur param minMaxDur)
         return (Nothing, EventList.empty)
      _ -> return
         (Just (Common.singletonBundle e), EventList.empty)

patternTempo ::
   (Common.Selector i, [i]) ->
   ((Channel,Controller), (Time,Time,Time)) ->
   T Event.Data Common.EventDataBundle
patternTempo (select, ixs) ((chan,ctrl), (minDur, defltDur, maxDur)) =
   Cons
      (\ _t ee ->
         case ee of
            Left (n:ns) ->
               State.gets (\(dur,keys) ->
                  (Just (select n dur $ Map.toAscList keys),
                   EventList.singleton dur ns))
            Left [] ->
               return (Nothing, EventList.empty)
            Right e ->
               updateChordDur (chan,ctrl) (minDur,maxDur) e)
      (defltDur, Map.empty)
      (EventList.singleton 0 ixs)

patternMultiTempo ::
   (Common.Selector i, EventList.T Int [Common.IndexNote i]) ->
   ((Channel,Controller), (Time,Time,Time)) ->
   T Event.Data Common.EventDataBundle
patternMultiTempo (select, ixs) ((chan,ctrl), (minDur, defltDur, maxDur)) =
   let next dur rest =
          EventList.switchL
             EventList.empty
             (\(t,_) _ ->
                EventList.singleton (fromIntegral t * dur) rest)
             rest
   in  Cons
          (\ _t ee ->
             case ee of
                Left nt ->
                   EventList.switchL
                      (return (Nothing, EventList.empty))
                      (\(_,is) rest ->
                         State.gets (\(dur,keys) ->
                            (Just $
                             do Common.IndexNote d i <- is
                                select i (fromIntegral d * dur) $
                                   Map.toAscList keys,
                             next dur rest)))
                      nt
                Right e ->
                   updateChordDur (chan,ctrl) (minDur,maxDur) e)
          (defltDur, Map.empty)
          (next defltDur ixs)


updateSerialChord ::
   Int ->
   Event.NoteEv -> Event.Note ->
   Common.KeyQueue -> Common.KeyQueue
updateSerialChord maxNum notePart note chord =
   let key =
          (note ^. MALSA.notePitch,
           note ^. MALSA.noteChannel)
   in  case normalNoteFromEvent notePart note of
          (Event.NoteOn, vel) -> take maxNum $ (key, vel) : chord
          _ -> chord

updateSerialChordDur ::
   Int ->
   (Channel, Controller) ->
   (Time, Time) ->
   Event.Data ->
   State.State
      (Time, Common.KeyQueue)
      (Maybe Common.EventDataBundle, EventList.T time body)
updateSerialChordDur maxNum chanCtrl minMaxDur e =
   case e of
      Event.NoteEv notePart note -> do
         AccM.modify AccTuple.second (updateSerialChord maxNum notePart note)
         return (Nothing, EventList.empty)
      Event.CtrlEv Event.Controller param |
            uncurry Common.controllerMatch chanCtrl param -> do
         AccM.set AccTuple.first (Common.updateDur param minMaxDur)
         return (Nothing, EventList.empty)
      _ -> return
         (Just (Common.singletonBundle e), EventList.empty)

patternSerialTempo ::
   Int ->
   (Common.Selector i, [i]) ->
   ((Channel,Controller), (Time,Time,Time)) ->
   T Event.Data Common.EventDataBundle
patternSerialTempo
      maxNum (select, ixs) ((chan,ctrl), (minDur, defltDur, maxDur)) =
   Cons
      (\ _t ee ->
         case ee of
            Left (n:ns) ->
               State.gets (\(dur,keys) ->
                  (Just (select n dur keys),
                   EventList.singleton dur ns))
            Left [] ->
               return (Nothing, EventList.empty)
            Right e ->
               updateSerialChordDur maxNum (chan,ctrl) (minDur,maxDur) e)
      (defltDur, [])
      (EventList.singleton 0 ixs)


sweep ::
   Channel ->
   Time ->
   (Controller, (Time,Time)) ->
   Controller ->
   Controller ->
   (Double -> Double) ->
   T Event.Data Common.EventDataBundle
sweep chan dur (speedCtrl, (minSpeed, maxSpeed)) depthCtrl centerCtrl
      wave =
   Cons
      (\ _t ee ->
         case ee of
            Left () -> do
               ev <-
                  State.gets (\s ->
                     Event.CtrlEv Event.Controller $
                     Event.Ctrl {
                        Event.ctrlChannel = MALSA.fromChannel chan,
                        Event.ctrlParam = MALSA.fromController centerCtrl,
                        Event.ctrlValue =
                           round $ limit (0,127) $
                           Common.sweepCenter s + Common.sweepDepth s * wave (Common.sweepPhase s)
                     })
               State.modify (\s ->
                  s{Common.sweepPhase = Common.fraction (Common.sweepPhase s + Common.sweepSpeed s)})
               return $ (Just (Common.singletonBundle ev),
                         EventList.singleton dur ())
            Right e ->
               fmap (\ev -> (ev, EventList.empty)) $
               maybe (return $ Just $ Common.singletonBundle e)
                     (\f -> State.modify f >> return Nothing) $ do
                  Event.CtrlEv Event.Controller param <- Just e
                  let c = param ^. MALSA.ctrlChannel
                      ctrl = param ^. MALSA.ctrlController
                      x :: Num a => a
                      x = fromIntegral (Event.ctrlValue param)
                  guard (c==chan)
                  lookup ctrl $
                     (speedCtrl,
                      \s -> s{Common.sweepSpeed =
                         realToFrac $ Common.deconsTime $ (dur *) $
                         minSpeed + (maxSpeed-minSpeed) * x/127}) :
                     (depthCtrl,  \s -> s{Common.sweepDepth = x}) :
                     (centerCtrl, \s -> s{Common.sweepCenter = x}) :
                     [])
      (Common.SweepState {
         Common.sweepSpeed =
            realToFrac $ Common.deconsTime $
            dur*(minSpeed+maxSpeed)/2,
         Common.sweepDepth = 64,
         Common.sweepCenter = 64,
         Common.sweepPhase = 0
       })
      (EventList.singleton 0 ())

partition :: (a -> Bool) -> T a (Maybe a, Maybe a)
partition p =
   map (\a -> if p a then (Just a, Nothing) else (Nothing, Just a))

maybeIn :: T a b -> T (Maybe a) b
maybeIn (Cons f s0 trig) =
   Cons
      (\t e -> State.state $ \s ->
         case e of
            Left c -> State.runState (f t $ Left c) s
            Right (Just c) -> State.runState (f t $ Right c) s
            Right _ -> ((Nothing, EventList.empty), s))
      s0 trig

guide ::
   (Mn.Monoid b) =>
   (a -> Bool) -> T a b -> T a b -> T a b
guide p f g =
   compose
      (parallel
         (compose (maybeIn f) (map fst))
         (compose (maybeIn g) (map snd)))
      (partition p)

cyclePrograms :: [Program] -> T Event.Data Common.EventDataBundle
cyclePrograms pgms =
   traverse (cycle pgms)
      (Common.traverseProgramsSeek (length pgms))

{- |
> cycleProgramsDefer t

After a note that triggers a program change,
we won't change the program in the next 't' seconds.
This is in order to allow chords being played
and in order to skip accidentally played notes.
-}
cycleProgramsDefer :: Time -> [Program] -> T Event.Data Common.EventDataBundle
cycleProgramsDefer defer pgms =
   Cons
      (\ _t ->
         either
            (\() -> do
               AccM.set AccTuple.second False
               return (Nothing, EventList.empty))
            (\e -> do
               block <- State.gets snd
               case (block, e) of
                  (False, Event.NoteEv notePart note) ->
                     case fst $ normalNoteFromEvent notePart note of
                        Event.NoteOn -> do
                           AccM.set AccTuple.second True
                           fmap (\r -> (Just r, EventList.singleton defer ())) $
                              AccM.lift AccTuple.first $
                              Common.traverseProgramsSeek (length pgms) e
                        _ -> return (Just $ Common.singletonBundle e, EventList.empty)
                  _ -> return (Just $ Common.singletonBundle e, EventList.empty)))
      (cycle pgms, False) EventList.empty



main :: IO ()
main =
   Common.with $ Common.connectLLVM >>
   case 10::Int of
      0 -> process (transposeBundle 12)
      1 -> process (delayAdd 50 1)
      2 -> process (pattern (Common.cycleUp 4) 0.12)
      3 -> process (patternTempo (Common.cycleUp 4)
              (Common.defaultTempoCtrl, (0.05, 0.12, 0.25)))
      4 -> process (patternMultiTempo
              (Common.selectFromLimittedChord, Common.examplePatternMultiTempo1)
              (Common.defaultTempoCtrl, (0.05, 0.12, 0.25)))
      5 -> process (sweep (ChannelMsg.toChannel 1)
              0.01 (VoiceMsg.toController 72, (0.1, 1))
              (VoiceMsg.toController 73) (VoiceMsg.toController 91)
              (sin . (2*pi*)))
      6 -> process
              (guide
                 (\e ->
                     Common.checkPitch (VoiceMsg.toPitch 60 >) e ||
                     Common.checkController (snd Common.defaultTempoCtrl ==) e)
                 (patternTempo (Common.cycleUp 4)
                     (Common.defaultTempoCtrl, (0.05, 0.12, 0.25))
                  `compose`
                  transpose 12)
                 (map Common.singletonBundle))
      7 -> process (patternSerialTempo 4 (Common.cycleUp 4)
              (Common.defaultTempoCtrl, (0.05, 0.12, 0.25)))
      8 -> process $ cyclePrograms $
              List.map VoiceMsg.toProgram [16..20]
      9 -> process $ cycleProgramsDefer 0.1 $
              List.map VoiceMsg.toProgram [16..20]
      _ -> process (patternMultiTempo Common.binaryLegato
              (Common.defaultTempoCtrl, (0.05, 0.12, 0.25)))