module Music.LilyPond.Light.Analysis where

import Data.Function
import Data.List
import Data.Maybe
import Data.Monoid
import Data.Ratio
import qualified Music.LilyPond.Light as L
import Music.LilyPond.Light.Model
import Music.Theory.Duration
import Music.Theory.Pitch
import qualified Music.Theory.Spelling as T

type R = Double

-- * Basic traversal

-- | Apply a function to all elements and collect results in a list.
traverse :: (Music -> a) -> Music -> [a]
traverse fn =
    let fn' xs m = (fn m : xs, m)
    in reverse . fst . transform_st fn' []

-- | Collect all elements of a given type.
collect_entries :: (Music -> Bool) -> Music -> [Music]
collect_entries fn = filter fn . traverse id

-- * Basic statistical analysis

count_entries :: (Music -> Bool) -> Music -> Integer
count_entries fn = genericLength . filter id . traverse fn

count_notes :: Music -> Integer
count_notes = count_entries L.is_note

count_chords :: Music -> Integer
count_chords = count_entries L.is_chord

count_ts :: Music -> Integer
count_ts = count_entries L.is_time

-- * Basic pitch analysis

-- | Does music element contain one or more pitches?
has_pitch :: Music -> Bool
has_pitch x = L.is_note x ||
              L.is_chord x ||
              L.is_grace x ||
              L.is_after_grace x

-- | Collect pitches from a note or chord or join of such.
collect_pitches_no_grace :: Music -> [Pitch]
collect_pitches_no_grace m =
    case m of
      (Note x _ _) -> [x]
      (Chord xs _ _) -> concatMap collect_pitches xs
      (Skip _) -> []
      (Join xs) -> concatMap collect_pitches xs
      _ -> error ("collect_pitches_no_grace: " ++ L.ly_music_elem m)

-- | Collect pitches from a note, chord, or grace note.
collect_pitches :: Music -> [Pitch]
collect_pitches m =
    case m of
      (Grace x) -> collect_pitches x
      (AfterGrace x0 x1) -> collect_pitches x0 ++ collect_pitches x1
      _ -> collect_pitches_no_grace m

-- | Collect note sequence, filters tied notes.
note_seq :: Music -> [Music]
note_seq = filter (not . L.is_tied) . collect_entries L.is_note

-- * Frequency analysis utilites

freq_anal_by :: (Ord a) => (a -> a -> Ordering) -> [a] -> [(Int,a)]
freq_anal_by c =
    let fn xs = (length xs, head xs)
    in reverse . sort . map fn . group . sortBy c

freq_anal :: (Ord a) => [a] -> [(Int,a)]
freq_anal = freq_anal_by compare

-- * Durations

type TempoMarking = (Duration,Integer)

-- | Apply `d' dots to the rational pulse count `n'.
apply_dots :: Rational -> Integer -> Rational
apply_dots n d =
    let m = map (\x -> n * (1 / fromInteger (2 ^ x))) [1..d]
    in n + sum m

-- | Convert a duration to a pulse count in relation to the indicated
--   time signature.
dur_pulses :: TimeSignature -> Duration -> Rational
dur_pulses (_, b) (Duration dv dt ml) =
    let n = b % dv
    in apply_dots n dt * ml

-- | The duration, in seconds, of a pulse at the indicated time
--   signaure and tempo marking.
pulse_duration :: TimeSignature -> TempoMarking -> R
pulse_duration t (x,i) =
    let j = recip (dur_pulses t x)
        s = 60 / fromIntegral i
    in fromRational j * s

-- | The duration, in seconds, of a measure at the indicated time
--   signaure and tempo marking.
measure_duration :: TimeSignature -> TempoMarking -> R
measure_duration (n,d) t = pulse_duration (n,d) t * fromIntegral n

-- * Temporal map

type TimeSignature_Map = [(Measure,TimeSignature)]
type TempoMarking_Map = [(Measure,TempoMarking)]
type Temporal_Map = (TimeSignature_Map,TempoMarking_Map)

temporal_map :: [Music] -> Temporal_Map
temporal_map xs =
    let ts_m = ts_map xs
        tm_m = tempo_map xs
    in (ts_m,tm_m)

-- | Return duration (in seconds) and pulse counts for n measures.
mm_durations :: Temporal_Map -> Integer -> [(R,Integer)]
mm_durations (ts_m,tm_m) n =
    let get = map_lookup
        fn xs i = if i == n
                  then xs
                  else let t = get ts_m i
                           j = measure_duration t (get tm_m i)
                       in fn ((j,fst t):xs) (i+1)
    in reverse (fn [] 0)

integrate :: (Num a) => [a] -> [a]
integrate [] = []
integrate (x:xs) =
    let f p c = (p + c, p + c)
    in x : snd (mapAccumL f x xs)

-- | Return start time and duration (in seconds) and pulse counts for
--   i measures.
mm_start_times :: Temporal_Map -> Integer -> [(R,R,Integer)]
mm_start_times tm i =
    let (x,y) = unzip (mm_durations tm i)
    in zip3 (0 : integrate x) x y

location_to_rt :: [(R,R,Integer)] -> Location -> R
location_to_rt mm l =
    let (x,y,z) = mm `genericIndex` measure l
    in x + ((fromRational (pulse l) / fromIntegral z) * y)

locate_rt :: [Music] -> [(R,Music)]
locate_rt xs =
    let l = locate' xs
        m = temporal_map xs
        t = mm_start_times m (lv_last_measure l + 1)
        fn (i,x) = (location_to_rt t i,x)
    in map fn l

-- * Temporal annotations

data Locate_Mode = LM_Normal
                 | LM_In_Tuplet
                   deriving (Show, Eq, Ord)

type Measure = Integer
type Pulse = Rational
type Part_ID = Integer

-- | Data type representing the location of a musical element.
data Location = Location { measure :: Measure
                         , pulse :: Pulse
                         , part :: Part_ID
                         , mode :: Locate_Mode }
                deriving (Show, Eq, Ord)

-- | Convert a location to normal form under given time signature.
location_nf :: TimeSignature -> Location -> Location
location_nf ts l =
    let (Location b p v m) = l
        (n, _) = ts
        p' = numerator p `div` denominator p
    in if p' >= n
       then location_nf ts (Location (b + 1) (p - fromIntegral n) v m)
       else l

-- | Type to thread state through location calculations.
type Locate_ST = (TimeSignature, Location)

-- | Update state part number.
st_set_part :: Locate_ST -> Part_ID -> Locate_ST
st_set_part st v =
    let (ts, (Location b p _ m)) = st
    in (ts, Location b p v m)

-- | Update state part number.
st_set_mode :: Locate_ST -> Locate_Mode -> Locate_ST
st_set_mode st m =
    let (ts, Location b p v _) = st
    in (ts, Location b p v m)

-- | Step location state by duration.
location_step :: Locate_ST -> Duration -> Locate_ST
location_step st d =
    let (ts, l) = st
        (Location b p v m) = l
        p' = p + dur_pulses ts d
        l' = location_nf ts (Location b p' v m)
    in (ts, l')

-- | Located music
type LM = (Location, Music)

-- | Located value
type LV a = (Location, a)

-- | State threading form of location calculations.
--   Currently, nested polyphonic parts generate duplicate IDs (?)
locate_st :: Locate_ST -> Music -> (Locate_ST, [LM])
locate_st st m =
    let (ts, l) = st
        this = (l,m)
    in case m of
         Note _ (Just d) _ -> (location_step st d, [this])
         Note _ _ _ -> error "locate_st: note without duration"
         Chord _ d _ -> (location_step st d, [this])
         Tremolo _ _ -> error "locate_st: Tremolo"
         Rest d _ -> (location_step st d, [this])
         MMRest i (j,k) _ -> let d = Duration k 0 ((i * j) % 1)
                             in (location_step st d, [this])
         Skip d -> (location_step st d, [this])
         Repeat _ x -> locate_st st x
         Tuplet _ _ x ->
             let st' = (ts, l { mode = LM_In_Tuplet })
                 ((ts', l''), r) = locate_st st' x
             in ((ts', l'' { mode = LM_Normal } ), this:r)
         Grace _ -> (st, [this])
         AfterGrace x _ ->
             let (st', _) = locate_st st x
             in (st', [this])
         Join xs ->
             let (st', r) = mapAccumL locate_st st xs
             in (st', concat r)
         Clef _ _ -> (st, [this])
         Time ts' -> ((ts', l), [this])
         Key _ _ _ -> (st, [this])
         Tempo _ _ -> (st, [this])
         Command _ -> (st, [this])
         Polyphony x0 x1 ->
              let (Location _ _ v _) = l
                  st' = map (st_set_part st) [v ..]
                  r = zipWith locate_st st' [x0,x1]
                  ((st'',_):_) = r
              in (st'', this : concat (map snd r))
         Empty -> (st, [this]) -- error "locate_st: empty"

-- | Run location calculations.
locate :: Music -> [LM]
locate = snd . locate_st ((4,4), Location 0 0 0 LM_Normal)

locate' :: [Music] -> [LM]
locate' = locate . mconcat

-- | Extract list of part identifiers.
lv_located_parts :: [LV a] -> [Part_ID]
lv_located_parts = nub . sort . map (part . fst)

lv_group_parts :: [LV a] -> [[LV a]]
lv_group_parts = kv_group_by part

-- | Drop `n' measures.
lv_from_measure :: Integer -> [LV a] -> [LV a]
lv_from_measure n = dropWhile ((< n) . measure . fst)

lv_group_measures :: [LV a] -> [[LV a]]
lv_group_measures = kv_group_by measure

lv_extract_part :: Part_ID -> [LV a] -> [LV a]
lv_extract_part n = filter ((== n) . part . fst)

lv_extract_measure :: Measure -> [LV a] -> [LV a]
lv_extract_measure n = filter ((== n) . measure . fst)

lm_pitches :: [LM] -> [Pitch]
lm_pitches l =
    let m = map snd l
        p = concat (map collect_pitches (filter has_pitch m))
    in (nub . sort) p

lm_pcset :: [LM] -> [PitchClass]
lm_pcset = nub . sort . map pitch_to_pc . lm_pitches

lm_pitches_per_measure :: [LM] -> [[Pitch]]
lm_pitches_per_measure = map lm_pitches . lv_group_measures

lm_pcset_per_measure :: [LM] -> [[PitchClass]]
lm_pcset_per_measure = map lm_pcset . lv_group_measures

unlocate_p :: Music -> Bool
unlocate_p m =
    case m of
      Note _ Nothing _ -> False
      Repeat _ _ -> False
      Join _ -> False
      Polyphony _ _ -> error "unlocate_p: Polyphony"
      _ -> True

normal_mode_p :: Location -> Bool
normal_mode_p l = mode l == LM_Normal

lm_unlocate :: [LM] -> [Music]
lm_unlocate = filter unlocate_p . map snd . filter (normal_mode_p . fst)

location_time :: Location -> (Measure,Pulse)
location_time (Location i j _ _) = (i,j)

lv_sort :: [LV a] -> [LV a]
lv_sort = sortBy (compare `on` (location_time . fst))

located_pitches :: [[Music]] -> [(Location, [Pitch])]
located_pitches xs =
    let xs' = map (lm_discard_tied_notes . locate') xs
        set_vc i (j,x) = (j {part = i },x)
        xs'' = concat (zipWith (\i x -> map (set_vc i) x) [1..] xs')
        xs''' = lv_sort (filter (has_pitch . snd) xs'')
    in kv_map id collect_pitches xs'''

-- * Time-signature structure analysis.

-- | Rewrite time signature to indicated denominator.
ts_rewrite :: Integer -> TimeSignature -> TimeSignature
ts_rewrite d' =
    let dv i j = let (x,y) = i `divMod` j
                 in if y == 0 then x else error "ts_rewrite"
        go (n,d) = case compare d d' of
                     EQ -> (n,d)
                     GT -> go (n `dv` 2, d `dv` 2)
                     LT -> go (n * 2, d * 2)
    in go

-- | Sum time signatures (ie. 3/16 and 1/2 sum to 11/16).
ts_sum :: [TimeSignature] -> TimeSignature
ts_sum xs =
    let i = maximum (map snd xs)
        xs' = map (ts_rewrite i) xs
        j = sum (map fst xs')
    in (j,i)

measure_diff :: Location -> Location -> Integer
measure_diff l1 l2 = measure l2 - measure l1

lv_last_measure :: [LV a] -> Measure
lv_last_measure =
    let fn = compare `on` measure
    in measure . head . reverse . sortBy fn . map fst

time_unpack :: Music -> TimeSignature
time_unpack (Time t) = t
time_unpack _ = error "time_unpack"

-- | Time signature structure of music.
ts_structure :: Music -> [[(TimeSignature, Integer)]]
ts_structure m =
    let m' = locate m
        ts = filter (L.is_time . snd) m'
        ts' = kv_map id time_unpack ts
        ps = lv_group_parts ts'
        e xs = let (l,t) = last xs
               in (t, lv_last_measure m' - measure l)
        f ((l1, t1), (l2, _)) = (t1, measure_diff l1 l2)
    in map (\ys -> map f (zip ys (tail ys)) ++ [e ys]) ps

ts_structure' :: [Music] -> [[(TimeSignature, Integer)]]
ts_structure' = ts_structure . mconcat

structure_unfold' :: (Integral i) => [(a,i)] -> [a]
structure_unfold' =
    let repl = genericReplicate
    in concatMap (\(x,n) -> repl n x)

structure_unfold :: (Integral i) => [(a,i)] -> [Maybe a]
structure_unfold =
    let repl = genericReplicate
    in concatMap (\(x,n) -> Just x : repl (n - 1) Nothing)

lm_ts_map :: [LM] -> TimeSignature_Map
lm_ts_map xs =
    let xs' = filter (L.is_time . snd) xs
    in map (\(t,x) -> (measure t, time_unpack x)) xs'

ts_map :: [Music] -> TimeSignature_Map
ts_map = lm_ts_map . locate'

-- | Keys are in ascending order, the value retrieved is the that with
--   the greatest key less than or equal to the key requested.
map_lookup :: Ord i => [(i,a)] -> i -> a
map_lookup mp i =
    let fn pr xs =
            case xs of
              ((j,x):xs') -> if j > i then pr else fn x xs'
              [] -> pr
    in case mp of
         [(j,x)] -> if i >= j then x else error "map_lookup"
         _ -> fn (error "map_lookup") mp

ts_lookup :: [(Measure,TimeSignature)] -> Measure -> TimeSignature
ts_lookup = map_lookup

-- * Tempo

lm_tempo_map :: [LM] -> [(Measure,(Duration,Integer))]
lm_tempo_map xs =
    let xs' = filter (L.is_tempo . snd) xs
    in map (\(t,Tempo d x) -> (measure t, (d,x))) xs'

tempo_map :: [Music] -> [(Measure,(Duration,Integer))]
tempo_map = lm_tempo_map . locate'

tempo_lookup :: [(Measure,(Duration,Integer))] -> Measure -> (Duration,Integer)
tempo_lookup = map_lookup

-- * Key/value utilties

-- Group keys equal under 'fn'.
kv_group_by :: (Ord c) => (a -> c) -> [(a, b)] -> [[(a, b)]]
kv_group_by fn =
    let mk_f op = (op `on` (fn . fst))
    in groupBy (mk_f (==)) . sortBy (mk_f compare)

kv_collate :: (Ord k) => (a -> k) -> (a -> v) -> [a] -> [(k,[v])]
kv_collate k v =
    map (\xs -> (k (head xs), map v xs)) .
    groupBy ((==) `on` k) .
    sortBy (compare `on` k)

kv_collate' :: (Ord k) => [(k,v)] -> [(k,[v])]
kv_collate' = kv_collate fst snd

-- | Filter with predicates at key and value.
kv_filter :: (k -> Bool) -> (v -> Bool) -> [(k,v)] -> [(k,v)]
kv_filter k v = filter (\x -> k (fst x) && v (snd x))

-- | Apply functions to keys and values.
kv_map :: (k -> k') -> (v -> v') -> [(k,v)] -> [(k',v')]
kv_map f g = map (\(k,v) -> (f k, g v))

-- * Measure collation

measure_collate :: (Music -> Bool) -> Music -> [[(Integer, [Music])]]
measure_collate p m =
    let m' = filter (p . snd) (locate m)
    in map (kv_collate' . kv_map measure id) (lv_group_parts m')

collation_unfold :: [(Integer, a)] -> [Maybe a]
collation_unfold =
    let repl = genericReplicate
        go _ [] = []
        go n ((m,x):r) =
            let s = m - n
            in repl s Nothing ++ [Just x] ++ go (n + s + 1) r
    in go 0

-- * Transformation

type ST_r st = (st, Music)
type ST_f st = (st -> Music -> ST_r st)

transform_st :: ST_f st -> st -> Music -> ST_r st
transform_st fn st m =
    let rc = transform_st fn
    in case m of
         Chord xs d a ->
             let (st',xs') = mapAccumL rc st xs
             in fn st' (Chord xs' d a)
         Tremolo (x0,x1) n ->
             let (st',x0') = rc st x0
                 (st'',x1') = rc st' x1
             in fn st'' (Tremolo (x0', x1') n)
         Repeat n x ->
             let (st',x') = rc st x
             in fn st' (Repeat n x')
         Tuplet o t x ->
             let (st',x') = rc st x
             in fn st' (Tuplet o t x')
         Grace x ->
             let (st',x') = rc st x
             in fn st' (Grace x')
         AfterGrace x0 x1 ->
             let (st',x0') = rc st x0
                 (st'',x1') = rc st' x1
             in fn st'' (AfterGrace x0' x1')
         Join xs ->
             let (st',xs') = mapAccumL rc st xs
             in fn st' (Join xs')
         Polyphony x0 x1 ->
             let (st',x0') = rc st x0
                 (st'',x1') = rc st' x1
             in fn st'' (Polyphony x0' x1')
         _ -> fn st m

transform :: (Music -> Music) -> Music -> Music
transform fn =
    let fn' _ m = ((), fn m)
    in snd . transform_st fn' ()

-- * Repeats

write_out_repeats :: Music -> Music
write_out_repeats =
    let fn m = case m of
                 Repeat n x -> Join (genericReplicate n x)
                 _ -> m
    in transform fn

-- * Replace

-- | Replace the pitch of note element n1 with that of n0.
note_replace_pitch :: Pitch -> Music -> Music
note_replace_pitch x n = n { note_pitch = x }

note_replace_pitch_m :: Music -> Music -> Music
note_replace_pitch_m n0 n1 = n1 { note_pitch = note_pitch n0 }

replace_notes_fn :: (a -> Pitch) -> [a] -> Music -> ([a], Music)
replace_notes_fn fn xs m =
    case xs of
      [] -> ([], m)
      (n:ns) -> if L.is_note m
                then let m' = note_replace_pitch (fn n) m
                     in if L.is_tied m
                        then (xs, m')
                        else (ns, m')
                else (xs, m)

-- | Replaces notes with indicated pitches, rhythms and annotations
--   are not replaced.  Tied notes do not use multiple pitches from
--   the input sequence.
replace_notes_p :: [Pitch] -> Music -> Music
replace_notes_p ns = snd . transform_st (replace_notes_fn id) ns

replace_notes :: [Music] -> Music -> Music
replace_notes ns = snd . transform_st (replace_notes_fn note_pitch) ns

-- * Insert

insert_after_notes_fn :: [Maybe Music] -> Music -> ([Maybe Music], Music)
insert_after_notes_fn xs m =
    case xs of
      [] -> ([], m)
      (y:ys) -> if L.is_note m && not (L.is_tied m)
                then (ys, case y of
                            Nothing -> m
                            Just y' -> Join [m,y'])
                else (xs, m)

-- | Inserts a value after each note as indicated.
insert_after_notes :: [Maybe Music] -> Music -> Music
insert_after_notes xs = snd . transform_st insert_after_notes_fn xs

-- * Tied notes

discard_tied_notes_pr :: (a -> Bool) -> (a -> Bool) -> [a] -> [a]
discard_tied_notes_pr i_pr f_pr =
    let fn keep ns =
            case ns of
              (x:xs) -> let i = if keep then [x] else []
                        in if i_pr x || (not keep && f_pr x)
                           then i ++ fn False xs
                           else i ++ fn True xs
              _ -> ns
    in fn True

discard_tied_notes :: [Music] -> [Music]
discard_tied_notes =
    let i_pr = L.is_tied
        f_pr = not . has_pitch
    in discard_tied_notes_pr i_pr f_pr

lm_discard_tied_notes :: [LM] -> [LM]
lm_discard_tied_notes =
    let i_pr = L.is_tied . snd
        f_pr = not . has_pitch . snd
    in discard_tied_notes_pr i_pr f_pr

-- * Spelling

spell_ks :: (Octave, PitchClass) -> Music
spell_ks (o,pc) =
    let (n,a) = T.pc_spell_ks pc
    in Note (Pitch n a o) Nothing []

spell_sharp :: (Octave, PitchClass) -> Music
spell_sharp (o,pc) =
    let (n,a) = T.pc_spell_sharp pc
    in Note (Pitch n a o) Nothing []

spell_flat :: (Octave, PitchClass) -> Music
spell_flat (o,pc) =
    let (n,a) = T.pc_spell_flat pc
    in Note (Pitch n a o) Nothing []

-- * Validation

v_assert :: String -> (Music -> Bool) -> (Music -> Maybe String)
v_assert str fn m =
    if fn m then Nothing else Just str

-- | Notes in chords must not have duration.
v_chord_note_valid :: Music -> Maybe String
v_chord_note_valid =
    let fn (Note _ Nothing _) = True
        fn _ = False
    in v_assert "v_chord_note_valid" fn

validate :: Music -> [String]
validate m =
    case m of
      Chord [] _ _ -> ["empty chord"]
      Chord xs _ _ -> mapMaybe v_chord_note_valid xs
      _ -> []