{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE FlexibleContexts      #-}

module LinearAlgebra.TypedSpaces.Matrix
        ( Matrix (..)
        , L.Storable
        , L.Element
        , L.Numeric
        , L.Field
        , rows
        , cols
        , fromRows
        , toRows
        , fromColumns
        , toColumns
        , takeRows
        , takeColumns
        , tr
        , Sparse (..)
        , toDense
        , addToSparse
        , (<.>)
        , (#>)
        , (<#)
        , (#)
        , inv
        , pinv ) where

import qualified Foreign.Storable as L
import qualified Numeric.LinearAlgebra as L
import LinearAlgebra.TypedSpaces.Classes
import LinearAlgebra.TypedSpaces.Vector

newtype Matrix i j a = Matrix { matrix :: L.Matrix a }
  deriving (Show)

instance (Isomorphism Int i, Isomorphism Int j)
      => CIndexed (Matrix i j) (i,j) where

  (Matrix v) ! (m,n) = v `L.atIndex` (bw m, bw n)

instance CFunctor (Matrix i j) where

  type CFun (Matrix i j) a = ( L.Container L.Matrix a
                             , L.Storable a
                             , Num a )

  cmap f (Matrix m) = Matrix (L.cmap f m)

----------------------------------------------------------------------

rows :: Matrix i j a -> Int
rows = L.rows . matrix

cols :: Matrix i j a -> Int
cols = L.cols . matrix

fromRows :: L.Element a => [Vector j a] -> Matrix i j a
fromRows = Matrix . L.fromRows . map vector

toRows :: L.Element a => Matrix i j a -> [Vector j a]
toRows = map Vector . L.toRows . matrix

fromColumns :: L.Element a => [Vector i a] -> Matrix i j a
fromColumns = Matrix . L.fromColumns . map vector

toColumns :: L.Element a => Matrix i j a -> [Vector i a]
toColumns = map Vector . L.toColumns . matrix

takeRows :: L.Element a => Int -> Matrix i j a -> Matrix i j a
takeRows n (Matrix m) = Matrix (L.takeRows n m)

takeColumns :: L.Element a => Int -> Matrix i j a -> Matrix i j a
takeColumns n (Matrix m)= Matrix (L.takeColumns n m)

tr :: (L.Transposable (L.Matrix a) (L.Matrix a))
   => Matrix i j a -> Matrix j i a
tr (Matrix m) = Matrix (L.tr' m)

----------------------------------------------------------------------

newtype Sparse i j a = Sparse { sparse :: [((Int,Int),a)] }
  deriving (Show)

toDense :: Sparse i j Double -> Matrix i j Double
toDense = Matrix . L.toDense . sparse

addToSparse :: (Isomorphism Int i, Isomorphism Int j)
            => ((i,j),a) -> Sparse i j a -> Sparse i j a
addToSparse ((i,j),a) (Sparse l) = Sparse (((i',j'),a):l)
  where i' = bw i
        j' = bw j

----------------------------------------------------------------------

infixr 8 <.>
(<.>) :: (L.Numeric a) => Vector i a -> Vector i a -> a
(Vector v1) <.> (Vector v2) = v1 L.<.> v2

infixr 8 #>
(#>) :: (L.Numeric a) => Matrix i j a -> Vector j a -> Vector i a
(Matrix m) #> (Vector v) = Vector (m L.#> v)

infixl 8 <#
(<#) :: (L.Numeric a) => Vector i a -> Matrix i j a -> Vector j a
(Vector v) <# (Matrix m) = Vector (v L.<# m)

infixr 8 #
(#) :: (L.Numeric a) => Matrix i j a -> Matrix j k a -> Matrix i k a
(Matrix m1) # (Matrix m2) = Matrix (m1 L.<> m2)

inv :: (L.Field a) => Matrix i j a -> Matrix j i a
inv (Matrix m) = Matrix (L.inv m)

pinv :: (L.Field a) => Matrix i j a -> Matrix j i a
pinv (Matrix m) = Matrix (L.pinv m)