-- | SDIF matrix functions.
module Sound.SDIF.Matrix where

import qualified Data.ByteString.Lazy as B
import Sound.SDIF.Byte.Matrix
import Sound.SDIF.Type

-- | SDIF matrix data store.
data Matrix = Matrix { matrix_b :: B.ByteString
                     , matrix_type :: String
                     , matrix_data_type :: Int
                     , matrix_rows :: Int
                     , matrix_columns :: Int
                     , matrix_elements :: Int
                     , matrix_data_size :: Int
                     , matrix_storage_size :: Int
                     , matrix_v :: [Datum] }
              deriving (Eq, Show)

-- | Decode 'Matrix'.
decode_matrix :: B.ByteString -> Matrix
decode_matrix mtx =
    let m = Matrix { matrix_b = mtx
                   , matrix_type = matrix_b_type mtx
                   , matrix_data_type = matrix_b_data_type mtx
                   , matrix_rows = matrix_b_rows mtx
                   , matrix_columns = matrix_b_columns mtx
                   , matrix_elements = matrix_b_elements mtx
                   , matrix_data_size = matrix_b_data_size mtx
                   , matrix_storage_size = matrix_b_storage_size mtx
                   , matrix_v = matrix_b_to_matrix_v mtx }
    in if is_matrix_b mtx
       then m
       else error "decode_matrix: illegal data"

-- | Section of list from /i/th to /j/th indices.
--
-- > list_section [1..9] 4 6 == [5,6]
list_section :: [a] -> Int -> Int -> [a]
list_section xs i j = take (j - i) (drop i xs)

-- | Extract /n/th row of 'Matrix'.
matrix_row :: Matrix -> Int -> [Datum]
matrix_row m n =
    let r = matrix_rows m
        c = matrix_columns m
        i = n * c
    in if n >= r
       then error "matrix_row: domain error"
       else list_section (matrix_v m) i (i + c)

-- | Extract /n/th column of 'Matrix'.
matrix_column :: Matrix -> Int -> [Datum]
matrix_column m n =
    let nr = matrix_rows m
        nc = matrix_columns m
        v = matrix_v m
        build i xs =
            if i == nr
            then reverse xs
            else build (i + 1) ((v !! (n + (i * nc))) : xs)
    in if n >= nc
       then error "matrix_column: domain error"
       else build 0 []