```{-# LANGUAGE UndecidableInstances #-}

-- | Operations on matrices (doubly-nested parallel vectors). All operations in
-- this module assume rectangular matrices.

module Feldspar.Matrix where

import qualified Prelude as P
import Data.List (genericLength)
import qualified Data.TypeLevel as TL

import Feldspar.Prelude
import Feldspar.Core
import Feldspar.Vector

type Matrix a = Vector (Vector (Data a))

-- | Converts a matrix to a core array.
freezeMatrix :: Type a => Matrix a -> Data [[a]]
freezeMatrix = freezeVector . map freezeVector

-- | Converts a core array to a matrix.
unfreezeMatrix :: Type a => Data [[a]] -> Matrix a
unfreezeMatrix = map unfreezeVector . unfreezeVector

-- | Converts a core array to a matrix. The first length argument is the number
-- of rows (outer vector), and the second argument is the number of columns
-- (inner argument).
unfreezeMatrix' :: Type a => Length -> Length -> Data [[a]] -> Matrix a
unfreezeMatrix' y x = map (unfreezeVector' x) . (unfreezeVector' y)

-- | Constructs a matrix. The elements are stored in a core array.
matrix :: Type a => [[a]] -> Matrix a
matrix = unfreezeMatrix . value

-- | Constructing a matrix from an index function.
--
-- @indexedMat m n ixf@:
--
--   * @m@ is the number of rows.
--
--   * @n@ is the number of columns.
--
--   * @ifx@ is a function mapping indexes to elements (first argument is row
--     index; second argument is column index).
indexedMat
:: Data Length
-> Data Length
-> (Data Index -> Data Index -> Data a)
-> Matrix a
indexedMat m n idx = indexed m \$ \k -> indexed n \$ \l -> idx k l

-- | Transpose of a matrix
transpose :: Type a => Matrix a -> Matrix a
transpose a = indexedMat (length \$ head a) (length a) \$ \y x -> a ! x ! y
-- TODO This assumes that (head a) can be used even if a is empty.

-- | Concatenates the rows of a matrix.
flatten :: Type a => Matrix a -> Vector (Data a)
flatten matr = Indexed (m*n) ixf Empty
where
m = length matr
n = (m==0) ? (0, length (head matr))

ixf i = matr ! y ! x
where
y = i `div` n
x = i `mod` n
-- TODO Should use linear indexing

-- | The diagonal vector of a square matrix. It happens to work if the number of
-- rows is less than the number of columns, but not the other way around (this
diagonal :: Type a => Matrix a -> Vector (Data a)
diagonal m = zipWith (!) m (0 ... (length m - 1))

distributeL :: (a -> b -> c) -> a -> Vector b -> Vector c
distributeL f = map . f

distributeR :: (a -> b -> c) -> Vector a -> b -> Vector c
distributeR = flip . distributeL . flip

class Mul a b
where
type Prod a b

-- | General multiplication operator
(***) :: a -> b -> Prod a b

instance Numeric a => Mul (Data a) (Data a)
where
type Prod (Data a) (Data a) = Data a
(***) = (*)

instance Numeric a => Mul (Data a) (DVector a)
where
type Prod (Data a) (DVector a) = DVector a
(***) = distributeL (***)

instance Numeric a => Mul (DVector a) (Data a)
where
type Prod (DVector a) (Data a) = DVector a
(***) = distributeR (***)

instance Numeric a => Mul (Data a) (Matrix a)
where
type Prod (Data a) (Matrix a) = Matrix a
(***) = distributeL (***)

instance Numeric a => Mul (Matrix a) (Data a)
where
type Prod (Matrix a) (Data a) = Matrix a
(***) = distributeR (***)

instance Numeric a => Mul (DVector a) (DVector a)
where
type Prod (DVector a) (DVector a) = Data a
(***) = scalarProd

instance Numeric a => Mul (DVector a) (Matrix a)
where
type Prod (DVector a) (Matrix a) = (DVector a)
vec *** mat = distributeL (***) vec (transpose mat)

instance Numeric a => Mul (Matrix a) (DVector a)
where
type Prod (Matrix a) (DVector a) = (DVector a)
(***) = distributeR (***)

instance Numeric a => Mul (Matrix a) (Matrix a)
where
type Prod (Matrix a) (Matrix a) = (Matrix a)
(***) = distributeR (***)

-- | Matrix multiplication
mulMat :: Numeric a => Matrix a -> Matrix a -> Matrix a
mulMat = (***)

class Syntactic a => ElemWise a
where
type Elem a

-- | Operator for general element-wise multiplication
elemWise :: (Elem a -> Elem a -> Elem a) -> a -> a -> a

instance Type a => ElemWise (Data a)
where
type Elem (Data a) = Data a
elemWise = id

instance (ElemWise a, Syntactic (Vector a)) => ElemWise (Vector a)
where
type Elem (Vector a) = Elem a
elemWise = zipWith . elemWise

(.+) :: (ElemWise a, Num (Elem a)) => a -> a -> a
(.+) = elemWise (+)

(.-) :: (ElemWise a, Num (Elem a)) => a -> a -> a
(.-) = elemWise (-)

(.*) :: (ElemWise a, Num (Elem a)) => a -> a -> a
(.*) = elemWise (*)

-- * Wrapping for matrices

instance (Type a) => Wrap (Matrix a) (Data [[a]]) where
wrap = freezeMatrix

instance (Wrap t u, Type a, TL.Nat row, TL.Nat col) => Wrap (Matrix a -> t) (Data' (row,col) [[a]] -> u) where
wrap f = \(Data' d) -> wrap \$ f \$ unfreezeMatrix' row' col' d where
row' = fromInteger \$ toInteger \$ TL.toInt (undefined :: row)
col' = fromInteger \$ toInteger \$ TL.toInt (undefined :: col)
```