module Music.LilyPond.Light.Analysis where

import Control.Arrow {- base -}
import Data.Function {- base -}
import Data.List {- base -}
import Data.Maybe {- base -}
import Data.Ratio {- base -}

import qualified Music.Theory.Duration as T {- hmt -}
import qualified Music.Theory.Pitch as T {- hmt -}
import qualified Music.Theory.Pitch.Spelling.Table as T {- hmt -}
import qualified Music.Theory.Tempo_Marking as T {- hmt -}
import qualified Music.Theory.Time_Signature as T {- hmt -}

import qualified Music.LilyPond.Light.Model as L
import qualified Music.LilyPond.Light.Notation as L
import qualified Music.LilyPond.Light.Output.LilyPond as L (ly_music_elem)

-- * Basic traversal

-- | Apply a function to all elements and collect results in a list.
m_traverse :: (L.Music -> a) -> L.Music -> [a]
m_traverse fn =
    let fn' xs m = (fn m : xs, m)
    in reverse . fst . transform_st fn' []

-- | Collect all elements of a given type.
collect_entries :: (L.Music -> Bool) -> L.Music -> [L.Music]
collect_entries fn = filter fn . m_traverse id

-- * Basic statistical analysis

count_entries :: (L.Music -> Bool) -> L.Music -> Integer
count_entries fn = genericLength . filter id . m_traverse fn

count_notes :: L.Music -> Integer
count_notes = count_entries L.is_note

count_chords :: L.Music -> Integer
count_chords = count_entries L.is_chord

count_ts :: L.Music -> Integer
count_ts = count_entries L.is_time

-- * Basic pitch analysis

-- | Does music element contain one or more pitches?
has_pitch :: L.Music -> Bool
has_pitch x = L.is_note x ||
              L.is_chord x ||
              L.is_grace x ||
              L.is_after_grace x

-- | Collect pitches from a note or chord or join of such.
collect_pitches_no_grace :: L.Music -> [T.Pitch]
collect_pitches_no_grace m =
    case m of
      (L.Note x _ _) -> [x]
      (L.Chord xs _ _) -> concatMap collect_pitches xs
      (L.Skip _ _) -> []
      (L.Join xs) -> concatMap collect_pitches xs
      _ -> error ("collect_pitches_no_grace: " ++ L.ly_music_elem m)

-- | Collect pitches from a note, chord, or grace note.
collect_pitches :: L.Music -> [T.Pitch]
collect_pitches m =
    case m of
      (L.Grace x) -> collect_pitches x
      (L.AfterGrace x0 x1) -> collect_pitches x0 ++ collect_pitches x1
      _ -> collect_pitches_no_grace m

-- | Collect note sequence, filters tied notes.
note_seq :: L.Music -> [L.Music]
note_seq = filter (not . L.is_tied) . collect_entries L.is_note

-- * Frequency analysis utilites

freq_anal_by :: (Ord a) => (a -> a -> Ordering) -> [a] -> [(Int,a)]
freq_anal_by c =
    let fn xs = (length xs, head xs)
    in reverse . sort . map fn . group . sortBy c

freq_anal :: (Ord a) => [a] -> [(Int,a)]
freq_anal = freq_anal_by compare

-- * Temporal map

type Measure = Integer
type Time_Signature_Map = [(Measure,T.Time_Signature)]
type Tempo_Marking_Map = [(Measure,T.Tempo_Marking)]
type Temporal_Map = (Time_Signature_Map,Tempo_Marking_Map)

temporal_map :: [L.Music] -> Temporal_Map
temporal_map xs =
    let ts_m = ts_map xs
        tm_m = tempo_map xs
    in (ts_m,tm_m)

-- | Return duration (in seconds) and pulse counts for n measures.
mm_durations :: Temporal_Map -> Integer -> [(Rational,Integer)]
mm_durations (ts_m,tm_m) n =
    let get = map_lookup
        fn xs i = if i == n
                  then xs
                  else let t = get ts_m i
                           j = T.measure_duration t (get tm_m i)
                       in fn ((j,fst t):xs) (i+1)
    in reverse (fn [] 0)

integrate :: (Num a) => [a] -> [a]
integrate [] = []
integrate (x:xs) =
    let f p c = (p + c, p + c)
    in x : snd (mapAccumL f x xs)

-- | Return start time and duration (in seconds) and pulse counts for
--   i measures.
mm_start_times :: Temporal_Map -> Integer -> [(Rational,Rational,Integer)]
mm_start_times tm i =
    let (x,y) = unzip (mm_durations tm i)
    in zip3 (0 : integrate x) x y

location_to_rt :: [(Rational,Rational,Integer)] -> Location -> Rational
location_to_rt mm l =
    let (x,y,z) = mm `genericIndex` measure l
    in x + ((fromRational (pulse l) / fromIntegral z) * y)

locate_rt :: [L.Music] -> [(Rational,L.Music)]
locate_rt xs =
    let l = locate' xs
        m = temporal_map xs
        t = mm_start_times m (lv_last_measure l + 1)
        fn (i,x) = (location_to_rt t i,x)
    in map fn l

-- * Temporal annotations

data Locate_Mode = LM_Normal
                 | LM_In_Tuplet
                   deriving (Show, Eq, Ord)

type Pulse = Rational
type Part_ID = Integer

-- | Data type representing the location of a musical element.
data Location = Location { measure :: Measure
                         , pulse :: Pulse
                         , part :: Part_ID
                         , mode :: Locate_Mode }
                deriving (Show, Eq, Ord)

-- | Convert a location to normal form under given time signature.
location_nf :: T.Time_Signature -> Location -> Location
location_nf ts l =
    let (Location b p v m) = l
        (n, _) = ts
        p' = numerator p `div` denominator p
    in if p' >= n
       then location_nf ts (Location (b + 1) (p - fromIntegral n) v m)
       else l

-- | Type to thread state through location calculations.
type Locate_ST = (T.Time_Signature, Location)

-- | Update state part number.
st_set_part :: Locate_ST -> Part_ID -> Locate_ST
st_set_part st v =
    let (ts, Location b p _ m) = st
    in (ts, Location b p v m)

-- | Update state part number.
st_set_mode :: Locate_ST -> Locate_Mode -> Locate_ST
st_set_mode st m =
    let (ts, Location b p v _) = st
    in (ts, Location b p v m)

-- | Step location state by duration.
location_step :: Locate_ST -> T.Duration -> Locate_ST
location_step st d =
    let (ts, l) = st
        (Location b p v m) = l
        p' = p + T.ts_duration_pulses ts d
        l' = location_nf ts (Location b p' v m)
    in (ts, l')

-- | Located music
type LM = (Location, L.Music)

-- | Located value
type LV a = (Location, a)

-- | State threading form of location calculations.
--   Currently, nested polyphonic parts generate duplicate IDs (?)
locate_st :: Locate_ST -> L.Music -> (Locate_ST, [LM])
locate_st st m =
    let (ts, l) = st
        this = (l,m)
    in case m of
         L.Note _ (Just d) _ -> (location_step st d, [this])
         L.Note {} -> error "locate_st: note without duration"
         L.Chord _ d _ -> (location_step st d, [this])
         L.Tremolo _ _ -> error "locate_st: Tremolo"
         L.Rest _ d _ -> (location_step st d, [this])
         L.MMRest i (j,k) _ ->
             let d = T.Duration k 0 ((i * j) % 1)
             in (location_step st d, [this])
         L.Skip d _ -> (location_step st d, [this])
         L.Repeat _ x -> locate_st st x
         L.Tuplet _ _ x ->
             let st' = (ts, l { mode = LM_In_Tuplet })
                 ((ts', l''), r) = locate_st st' x
             in ((ts', l'' { mode = LM_Normal } ), this:r)
         L.Grace _ -> (st, [this])
         L.AfterGrace x _ ->
             let (st', _) = locate_st st x
             in (st', [this])
         L.Join xs ->
             let (st', r) = mapAccumL locate_st st xs
             in (st', concat r)
         L.Clef _ -> (st, [this])
         L.Time ts' -> ((ts', l), [this])
         L.Key {} -> (st, [this])
         L.Tempo _ _ _ -> (st, [this])
         L.Command _ _ -> (st, [this])
         L.Polyphony x0 x1 ->
              let (Location _ _ v _) = l
                  st' = map (st_set_part st) [v ..]
                  r = zipWith locate_st st' [x0,x1]
                  ((st'',_):_) = r
              in (st'', this : concatMap snd r)
         L.Empty -> (st, [this]) -- error "locate_st: empty"

-- | Run location calculations.
locate :: L.Music -> [LM]
locate = snd . locate_st ((4,4), Location 0 0 0 LM_Normal)

locate' :: [L.Music] -> [LM]
locate' = locate . mconcat

-- | Extract list of part identifiers.
lv_located_parts :: [LV a] -> [Part_ID]
lv_located_parts = nub . sort . map (part . fst)

lv_group_parts :: [LV a] -> [[LV a]]
lv_group_parts = kv_group_by part

-- | Drop `n' measures.
lv_from_measure :: Integer -> [LV a] -> [LV a]
lv_from_measure n = dropWhile ((< n) . measure . fst)

lv_group_measures :: [LV a] -> [[LV a]]
lv_group_measures = kv_group_by measure

lv_extract_part :: Part_ID -> [LV a] -> [LV a]
lv_extract_part n = filter ((== n) . part . fst)

lv_extract_measure :: Measure -> [LV a] -> [LV a]
lv_extract_measure n = filter ((== n) . measure . fst)

lm_pitches :: [LM] -> [T.Pitch]
lm_pitches l =
    let m = map snd l
        p = concatMap collect_pitches (filter has_pitch m)
    in (nub . sort) p

lm_pcset :: [LM] -> [T.PitchClass]
lm_pcset = nub . sort . map T.pitch_to_pc . lm_pitches

lm_pitches_per_measure :: [LM] -> [[T.Pitch]]
lm_pitches_per_measure = map lm_pitches . lv_group_measures

lm_pcset_per_measure :: [LM] -> [[T.PitchClass]]
lm_pcset_per_measure = map lm_pcset . lv_group_measures

unlocate_p :: L.Music -> Bool
unlocate_p m =
    case m of
      L.Note _ Nothing _ -> False
      L.Repeat _ _ -> False
      L.Join _ -> False
      L.Polyphony _ _ -> error "unlocate_p: Polyphony"
      _ -> True

normal_mode_p :: Location -> Bool
normal_mode_p l = mode l == LM_Normal

lm_unlocate :: [LM] -> [L.Music]
lm_unlocate = filter unlocate_p . map snd . filter (normal_mode_p . fst)

location_time :: Location -> (Measure,Pulse)
location_time (Location i j _ _) = (i,j)

lv_sort :: [LV a] -> [LV a]
lv_sort = sortBy (compare `on` (location_time . fst))

located_pitches :: [[L.Music]] -> [(Location, [T.Pitch])]
located_pitches xs =
    let xs' = map (lm_discard_tied_notes . locate') xs
        set_vc i (j,x) = (j {part = i },x)
        xs'' = concat (zipWith (\i x -> map (set_vc i) x) [1..] xs')
        xs''' = lv_sort (filter (has_pitch . snd) xs'')
    in kv_map id collect_pitches xs'''

-- * Time-signature structure analysis.

measure_diff :: Location -> Location -> Integer
measure_diff l1 l2 = measure l2 - measure l1

lv_last_measure :: [LV a] -> Measure
lv_last_measure =
    let fn = compare `on` measure
    in measure . maximumBy fn . map fst

time_unpack :: L.Music -> T.Time_Signature
time_unpack m =
    case m of
      L.Time t -> t
      _ -> error "time_unpack"

-- | Time signature structure of music.
ts_structure :: L.Music -> [[(T.Time_Signature, Integer)]]
ts_structure m =
    let m' = locate m
        ts = filter (L.is_time . snd) m'
        ts' = kv_map id time_unpack ts
        ps = lv_group_parts ts'
        e xs = let (l,t) = last xs
               in (t, lv_last_measure m' - measure l)
        f (l1, t1) (l2, _) = (t1, measure_diff l1 l2)
    in map (\ys -> zipWith f ys (tail ys) ++ [e ys]) ps

ts_structure' :: [L.Music] -> [[(T.Time_Signature, Integer)]]
ts_structure' = ts_structure . mconcat

structure_unfold' :: (Integral i) => [(a,i)] -> [a]
structure_unfold' =
    let repl = genericReplicate
    in concatMap (\(x,n) -> repl n x)

structure_unfold :: (Integral i) => [(a,i)] -> [Maybe a]
structure_unfold =
    let repl = genericReplicate
    in concatMap (\(x,n) -> Just x : repl (n - 1) Nothing)

lm_ts_map :: [LM] -> Time_Signature_Map
lm_ts_map xs =
    let xs' = filter (L.is_time . snd) xs
    in map (measure *** time_unpack) xs'

ts_map :: [L.Music] -> Time_Signature_Map
ts_map = lm_ts_map . locate'

-- | Keys are in ascending order, the value retrieved is the that with
--   the greatest key less than or equal to the key requested.
map_lookup :: Ord i => [(i,a)] -> i -> a
map_lookup mp i =
    let fn pr xs =
            case xs of
              ((j,x):xs') -> if j > i then pr else fn x xs'
              [] -> pr
    in case mp of
         [(j,x)] -> if i >= j then x else error "map_lookup"
         _ -> fn (error "map_lookup") mp

ts_lookup :: [(Measure,T.Time_Signature)] -> Measure -> T.Time_Signature
ts_lookup = map_lookup

-- * Tempo

lm_tempo_map :: [LM] -> [(Measure,T.Tempo_Marking)]
lm_tempo_map =
    let f (t,i) = case i of
                    L.Tempo _ d x -> Just (measure t,(d,x))
                    _ -> Nothing
    in mapMaybe f

type Tempo_Map = [(Measure,T.Tempo_Marking)]

tempo_map :: [L.Music] -> Tempo_Map
tempo_map = lm_tempo_map . locate'

tempo_lookup :: Tempo_Map -> Measure -> T.Tempo_Marking
tempo_lookup = map_lookup

-- * Key/value utilties

-- Group keys equal under 'fn'.
kv_group_by :: (Ord c) => (a -> c) -> [(a, b)] -> [[(a, b)]]
kv_group_by fn =
    let mk_f op = (op `on` (fn . fst))
    in groupBy (mk_f (==)) . sortBy (mk_f compare)

kv_collate :: (Ord k) => (a -> k) -> (a -> v) -> [a] -> [(k,[v])]
kv_collate k v =
    map (\xs -> (k (head xs), map v xs)) .
    groupBy ((==) `on` k) .
    sortBy (compare `on` k)

kv_collate' :: (Ord k) => [(k,v)] -> [(k,[v])]
kv_collate' = kv_collate fst snd

-- | Filter with predicates at key and value.
kv_filter :: (k -> Bool) -> (v -> Bool) -> [(k,v)] -> [(k,v)]
kv_filter k v = filter (\x -> k (fst x) && v (snd x))

-- | Apply functions to keys and values.
kv_map :: (k -> k') -> (v -> v') -> [(k,v)] -> [(k',v')]
kv_map f g = map (f *** g)

-- * Measure collation

measure_collate :: (L.Music -> Bool) -> L.Music -> [[(Integer, [L.Music])]]
measure_collate p m =
    let m' = filter (p . snd) (locate m)
    in map (kv_collate' . kv_map measure id) (lv_group_parts m')

collation_unfold :: [(Integer, a)] -> [Maybe a]
collation_unfold =
    let repl = genericReplicate
        go _ [] = []
        go n ((m,x):r) =
            let s = m - n
            in repl s Nothing ++ [Just x] ++ go (n + s + 1) r
    in go 0

-- * Transformation

type ST_r st = (st, L.Music)
type ST_f st = (st -> L.Music -> ST_r st)

transform_st :: ST_f st -> st -> L.Music -> ST_r st
transform_st fn st m =
    let rc = transform_st fn
    in case m of
         L.Chord xs d a ->
             let (st',xs') = mapAccumL rc st xs
             in fn st' (L.Chord xs' d a)
         L.Tremolo (Left x) n ->
             let (st',x') = rc st x
             in fn st' (L.Tremolo (Left x') n)
         L.Tremolo (Right (x0,x1)) n ->
             let (st',x0') = rc st x0
                 (st'',x1') = rc st' x1
             in fn st'' (L.Tremolo (Right (x0', x1')) n)
         L.Repeat n x ->
             let (st',x') = rc st x
             in fn st' (L.Repeat n x')
         L.Tuplet o t x ->
             let (st',x') = rc st x
             in fn st' (L.Tuplet o t x')
         L.Grace x ->
             let (st',x') = rc st x
             in fn st' (L.Grace x')
         L.AfterGrace x0 x1 ->
             let (st',x0') = rc st x0
                 (st'',x1') = rc st' x1
             in fn st'' (L.AfterGrace x0' x1')
         L.Join xs ->
             let (st',xs') = mapAccumL rc st xs
             in fn st' (L.Join xs')
         L.Polyphony x0 x1 ->
             let (st',x0') = rc st x0
                 (st'',x1') = rc st' x1
             in fn st'' (L.Polyphony x0' x1')
         _ -> fn st m

transform :: (L.Music -> L.Music) -> L.Music -> L.Music
transform fn =
    let fn' _ m = ((), fn m)
    in snd . transform_st fn' ()

-- * Repeats

write_out_repeats :: L.Music -> L.Music
write_out_repeats =
    let fn m = case m of
                 L.Repeat n x -> L.Join (genericReplicate n x)
                 _ -> m
    in transform fn

-- * Replace

-- | Replace the pitch of note element n1 with that of n0.
note_replace_pitch :: T.Pitch -> L.Music -> L.Music
note_replace_pitch x n = n { L.note_pitch = x }

note_replace_pitch_m :: L.Music -> L.Music -> L.Music
note_replace_pitch_m n0 n1 = n1 { L.note_pitch = L.note_pitch n0 }

replace_notes_fn :: (a -> T.Pitch) -> [a] -> L.Music -> ([a], L.Music)
replace_notes_fn fn xs m =
    case xs of
      [] -> ([], m)
      (n:ns) -> if L.is_note m
                then let m' = note_replace_pitch (fn n) m
                     in if L.is_tied m
                        then (xs, m')
                        else (ns, m')
                else (xs, m)

-- | Replaces notes with indicated pitches, rhythms and annotations
--   are not replaced.  Tied notes do not use multiple pitches from
--   the input sequence.
replace_notes_p :: [T.Pitch] -> L.Music -> L.Music
replace_notes_p ns = snd . transform_st (replace_notes_fn id) ns

replace_notes :: [L.Music] -> L.Music -> L.Music
replace_notes ns = snd . transform_st (replace_notes_fn L.note_pitch) ns

-- * Insert

insert_after_notes_fn :: [Maybe L.Music] -> L.Music -> ([Maybe L.Music], L.Music)
insert_after_notes_fn xs m =
    case xs of
      [] -> ([], m)
      (y:ys) -> if L.is_note m && not (L.is_tied m)
                then (ys, case y of
                            Nothing -> m
                            Just y' -> L.Join [m,y'])
                else (xs, m)

-- | Inserts a value after each note as indicated.
insert_after_notes :: [Maybe L.Music] -> L.Music -> L.Music
insert_after_notes xs = snd . transform_st insert_after_notes_fn xs

-- * Tied notes

discard_tied_notes_pr :: (a -> Bool) -> (a -> Bool) -> [a] -> [a]
discard_tied_notes_pr i_pr f_pr =
    let fn keep ns =
            case ns of
              (x:xs) -> let i = if keep then [x] else []
                        in i ++ if i_pr x || (not keep && f_pr x)
                                then fn False xs
                                else fn True xs
              _ -> ns
    in fn True

discard_tied_notes :: [L.Music] -> [L.Music]
discard_tied_notes =
    let i_pr = L.is_tied
        f_pr = not . has_pitch
    in discard_tied_notes_pr i_pr f_pr

lm_discard_tied_notes :: [LM] -> [LM]
lm_discard_tied_notes =
    let i_pr = L.is_tied . snd
        f_pr = not . has_pitch . snd
    in discard_tied_notes_pr i_pr f_pr

-- * Spelling

spell_ks :: (T.Octave, T.PitchClass) -> L.Music
spell_ks (o,pc) =
    let (n,a) = T.pc_spell_ks pc
    in L.Note (T.Pitch n a o) Nothing []

spell_sharp :: (T.Octave, T.PitchClass) -> L.Music
spell_sharp (o,pc) =
    let (n,a) = T.pc_spell_sharp pc
    in L.Note (T.Pitch n a o) Nothing []

spell_flat :: (T.Octave, T.PitchClass) -> L.Music
spell_flat (o,pc) =
    let (n,a) = T.pc_spell_flat pc
    in L.Note (T.Pitch n a o) Nothing []

-- * Validation

v_assert :: String -> (L.Music -> Bool) -> L.Music -> Maybe String
v_assert str fn m =
    if fn m then Nothing else Just str

-- | L.Notes in chords must not have duration.
v_chord_note_valid :: L.Music -> Maybe String
v_chord_note_valid =
    let fn (L.Note _ Nothing _) = True
        fn _ = False
    in v_assert "v_chord_note_valid" fn

validate :: L.Music -> [String]
validate m =
    case m of
      L.Chord [] _ _ -> ["empty chord"]
      L.Chord xs _ _ -> mapMaybe v_chord_note_valid xs
      _ -> []