-- | /Spear/ @PTPF@ (par-text-partials-format) files.
module Sound.Analysis.Spear.PTPF where

import qualified Data.ByteString.Lazy.Char8 as C {- bytestring -}
import Data.ByteString.Lex.Lazy.Double {- bytestring-lexing -}
import Data.Function {- base -}
import Data.List {- base -}
import Data.List.Split {- split -}

-- * Tuple

-- | Adjacent duples.
--
-- > duples [1..6] == [(1,2),(3,4),(5,6)]
duples :: [t] -> [(t,t)]
duples =
    let f x = case x of
                [i,j] -> (i,j)
                _ -> error "duples"
    in map f . chunksOf 2

-- | Adjacent triples.
--
-- > triples [1..6] == [(1,2,3),(4,5,6)]
triples :: [t] -> [(t,t,t)]
triples =
    let f x = case x of
                [i,j,k] -> (i,j,k)
                _ -> error "triples"
    in map f . chunksOf 3

-- * List

-- | Apply /f/ at last element.
--
-- > at_last negate [1..3] == [1,2,-3]
at_last :: (a -> a) -> [a] -> [a]
at_last f x =
    case x of
      [] -> []
      [e] -> [f e]
      e:x' -> e : at_last f x'

-- | Numerically stable mean
--
-- > map mean [[1..5],[3,5,7],[7,7],[3,9,10,11,12]] == [3,5,7,9]
mean :: Floating a => [a] -> a
mean =
    let f (m,n) x = (m + (x - m) / (n + 1),n + 1)
    in fst . foldl' f (0,0)

-- | 'minimum' & 'maximum'.
--
-- > minmax [0..5] == (0,5)
minmax :: Ord b => [b] -> (b, b)
minmax l =
    let f (p,q) n = (min p n,max q n)
    in case l of
         [] -> error "minmax: empty list"
         e:l' -> foldl f (e,e) l'

-- * Node

type N_Time = Double
type N_Data = Double

-- | Record to hold data for single node of a partial track.
data Node = Node {n_partial_id :: Int -- ^ Partial identifier
                 ,n_time :: Double
                 ,n_frequency :: N_Data
                 ,n_amplitude :: N_Data}
            deriving (Eq,Show)

-- | Set 'n_amplitude' at 'Node' to @0@.
n_zero_amplitude :: Node -> Node
n_zero_amplitude e = e {n_amplitude = 0}

-- | Set 'n_partial_id' at 'Node'.
n_set_partial_id :: Int -> Node -> Node
n_set_partial_id k e = e {n_partial_id = k}

-- | Apply transform /f/ at 'n_time'.
n_temporal_f :: (N_Time -> N_Time) -> Node -> Node
n_temporal_f f e = e {n_time = f (n_time e)}

-- * Seq

-- | A sequence of partial 'Node' data.
data Seq = Seq {s_identifier :: Int -- ^ '==' to 'n_partial_id' at 's_data'.
               ,s_start_time :: N_Time -- ^ 'minimum' 'n_time' at 's_data'.
               ,s_end_time :: N_Time -- ^ 'maximum' 'n_time' at 's_data'.
               ,s_nodes :: Int -- ^ '==' to 'length' 's_data'
               ,s_data :: [Node]}
           deriving (Eq,Show)

-- | Apply /f/ at 's_data' of 'Seq' and re-calculate temporal bounds.
s_map :: (Node -> Node) -> Seq -> Seq
s_map f (Seq i _ _ n d) =
    let d' = map f d
        (s,e) = minmax (map n_time d')
    in Seq i s e n d'

s_summarise :: ([a] -> b) -> (Node -> a) -> Seq -> b
s_summarise f g = f . map g . s_data

-- | 'maximum' 'n_amplitude' at 's_data'.
s_max_amplitude :: Seq -> N_Data
s_max_amplitude = s_summarise maximum n_amplitude

-- | 'minimum' 'n_amplitude' at 's_data'.
s_min_amplitude :: Seq -> N_Data
s_min_amplitude = s_summarise minimum n_amplitude

-- | 'mean' 'n_amplitude' at 's_data'.
s_mean_amplitude :: Seq -> N_Data
s_mean_amplitude = s_summarise mean n_amplitude

-- | 'mean' 'n_frequency' at 's_data'.
s_mean_frequency :: Seq -> N_Data
s_mean_frequency = s_summarise mean n_frequency

-- | 's_end_time' '-' 's_start_time'.
s_duration :: Seq -> N_Time
s_duration s = s_end_time s - s_start_time s

-- | Set 's_identifier' and associated 'n_partial_id'.
s_set_identifier :: Int -> Seq -> Seq
s_set_identifier k s = s {s_identifier = k
                         ,s_data = map (n_set_partial_id k) (s_data s)}

-- | '==' 'on' 's_identifier'.
s_eq_identifier :: Seq -> Seq -> Bool
s_eq_identifier = (==) `on` s_identifier

-- | 'unionBy' 's_eq_identifier'.
s_union :: [Seq] -> [Seq] -> [Seq]
s_union = unionBy s_eq_identifier

-- | Apply transform /f/ at 'n_time'.
s_temporal_f :: (N_Time -> N_Time) -> Seq -> Seq
s_temporal_f f s =
    let (Seq i st et n d) = s
    in Seq i (f st) (f et) n (map (n_temporal_f f) d)

-- * PTPF

-- | A 'PTPF' is a set of 'Seq'.
data PTPF = PTPF {p_partials :: Int
                 ,p_seq :: [Seq]}
            deriving (Eq,Show)

-- | 'minimum' 's_start_time' at 'p_seq'.
p_start_time :: PTPF -> N_Time
p_start_time = minimum . map s_start_time . p_seq

-- | 'maximum' 's_end_time' at 'p_seq'.
p_end_time :: PTPF -> N_Time
p_end_time = maximum . map s_end_time . p_seq

-- | 'sum' of 's_nodes' of 'p_seq'.
p_nodes :: PTPF -> Int
p_nodes = sum . map s_nodes . p_seq

-- | Generate 'PTPF' from set of 'Seq'.  Re-assigns partial identifiers.
p_from_seq :: [Seq] -> PTPF
p_from_seq s =
    let n = length s
        s' = zipWith s_set_identifier [0..] s
    in PTPF n s'

p_temporal_f :: (N_Time -> N_Time) -> PTPF -> PTPF
p_temporal_f f (PTPF n s) = PTPF n (map (s_temporal_f f) s)

p_map :: (Seq -> Seq) -> PTPF -> PTPF
p_map f (PTPF n s) = PTPF n (map f s)

p_filter :: (Seq -> Bool) -> PTPF -> PTPF
p_filter f (PTPF _ s) =
    let s' = filter f s
    in PTPF (length s') s'

p_node_map :: (Node -> Node) -> PTPF -> PTPF
p_node_map f = p_map (s_map f)

-- * Parser

type STR = C.ByteString

str_int :: C.ByteString -> Int
str_int = maybe 0 fst . C.readInt

str_double :: Floating n => C.ByteString -> n
str_double = maybe 0 (realToFrac . fst) . readDouble

str_words :: C.ByteString -> [C.ByteString]
str_words = C.split ' '

str_lines :: C.ByteString -> [C.ByteString]
str_lines = filter (not . C.null) . C.split '\n'

-- | Parse 'Node'.
ptpf_node :: Int -> (STR,STR,STR) -> Node
ptpf_node n (t,f,a) = Node n (str_double t) (str_double f) (str_double a)

-- | Parse 'Seq' from pair of input lines.
ptpf_seq :: (STR,STR) -> Seq
ptpf_seq (i,j) =
    let [ix,n,st,et] = str_words i
        ix' = str_int ix
        n' = str_int n
        p = map (ptpf_node ix') (triples (str_words j))
    in if n' /= length p
       then error "ptpf_seq"
       else Seq ix' (str_double st) (str_double et) n' (at_last n_zero_amplitude p)

-- | Parse header section, result is number of partials.
ptpf_header :: [STR] -> Maybe Int
ptpf_header h =
    let mk = C.pack
        r0 = mk "par-text-partials-format"
        r1 = mk "point-type time frequency amplitude"
        r2 = mk "partials-count "
        r3 = mk "partials-data"
    in case h of
         [h0,h1,h2,h3] -> if h0 == r0 && h1 == r1 && h3 == r3
                          then Just (str_int (C.drop (C.length r2) h2))
                          else Nothing
         _ -> Nothing

-- | Parse 'PTPF' at 'STR'.
parse_ptpf :: STR -> Either String PTPF
parse_ptpf s =
    let l = str_lines s
        (h,d) = splitAt 4 l
    in case ptpf_header h of
         Just np -> let p = map ptpf_seq (duples d)
                    in if length p /= np
                       then Left ("parse_ptpf: partial count: " ++ show (np,length p))
                       else Right (PTPF np p)
         _ -> Left "parse_ptpf: illegal header"

-- * Operations

-- | All 'Node's grouped into sets with equal start times.
ptpf_time_asc :: PTPF -> [(N_Time,[Node])]
ptpf_time_asc =
    let f x = (n_time (head x),x)
    in map f .
       groupBy ((==) `on` n_time) .
       sortBy (compare `on` n_time) .
       concatMap s_data .
       p_seq