-- Copyright (c) 2009, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-- | Operations on matrices (nested parallel vectors). All operations in this
-- module assume rectangular matrices.

module Feldspar.Matrix where



import qualified Prelude as P

import Types.Data.Ord

import Feldspar.Prelude
import Feldspar.Utils
import Feldspar.Core.Types
import Feldspar.Core
import Feldspar.Vector



type Matrix m n a = Par m :>> Par n :>> Data a



-- | Converts a matrix to a core array.
freezeMatrix :: (NaturalT m, NaturalT n, Storable a) =>
    Matrix m n a -> Data (m :> n :> a)

freezeMatrix = freezeVector . map freezeVector



-- | Converts a core array to a matrix.
unfreezeMatrix :: (NaturalT m, NaturalT n, Storable a) =>
    Data Int -> Data Int -> Data (m :> n :> a) -> Matrix m n a

unfreezeMatrix y x = map (unfreezeVector x) . (unfreezeVector y)



-- | Constructs a matrix.
matrix :: (NaturalT m, NaturalT n, Storable a, ListBased a ~ a) =>
    [[a]] -> Matrix m n a

matrix as
    | allEqual xs = unfreezeMatrix y x $ array as
    | otherwise   = error "matrix: Not rectangular"
  where
    y  = value $ P.length as
    xs = P.map P.length as
    x  = value $ P.head (xs P.++ [0])



-- | Transpose of a matrix
transpose :: Matrix m n a -> Matrix n m a
transpose a = Indexed (length $ head a) ixf
  where
    ixf y = Indexed (length a) (\x -> a ! x ! y)

-- | Matrix multiplication
mul :: (Primitive a, Num a) => Matrix m n a -> Matrix n p a -> Matrix m p a
mul a b = map (\aRow -> map (scalarProd aRow) b') a
  where
    b' = transpose b



-- | Concatenates the rows of a matrix.
flatten :: Matrix m n a -> VectorP (m :* n) a
flatten matr = Indexed (m*n) ixf
  where
    m = length matr
    n = (m==0) ? (0, length (head matr))

    ixf i = matr ! y ! x
      where
        y = i `div` m
        x = i `mod` m



-- | The diagonal vector of a square matrix. It happens to work if the number of
-- rows is less than the number of columns, but not the other way around (this
-- would require some overhead).
diagonal :: Matrix n n a -> VectorP n a
diagonal m = map (uncurry (!)) $ zip m $ enumFromTo 0 (length m - 1)