--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Geometry.Matrix
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--
-- type-indexed matrices.
--
--------------------------------------------------------------------------------
module Data.Geometry.Matrix(
    Matrix(Matrix)
  , identityMatrix

  , multM
  , mult

  , Invertible(..)
  , HasDeterminant(..)
  ) where

import           Control.Lens                           (imap)
import           Data.Coerce
import           Data.Geometry.Matrix.Internal          (mkRow)
import           Data.Geometry.Vector
import           Data.Geometry.Vector.VectorFamilyPeano
import           Linear.Matrix                          (M22, M33, M44, (!*!), (!*))
import qualified Linear.Matrix                          as Lin

--------------------------------------------------------------------------------
-- * Matrices

-- | A matrix of n rows, each of m columns, storing values of type r.
newtype Matrix n m r = Matrix (Vector n (Vector m r))

deriving instance (Show r, Arity n, Arity m) => Show (Matrix n m r)
deriving instance (Eq r, Arity n, Arity m)   => Eq (Matrix n m r)
deriving instance (Ord r, Arity n, Arity m)  => Ord (Matrix n m r)
deriving instance (Arity n, Arity m)         => Functor (Matrix n m)
deriving instance (Arity n, Arity m)         => Foldable (Matrix n m)
deriving instance (Arity n, Arity m)         => Traversable (Matrix n m)

-- | Matrix product.
multM :: (Arity r, Arity c, Arity c', Num a) => Matrix r c a -> Matrix c c' a -> Matrix r c' a
(Matrix Vector r (Vector c a)
a) multM :: Matrix r c a -> Matrix c c' a -> Matrix r c' a
`multM` (Matrix Vector c (Vector c' a)
b) = Vector r (Vector c' a) -> Matrix r c' a
forall (n :: Nat) (m :: Nat) r.
Vector n (Vector m r) -> Matrix n m r
Matrix (Vector r (Vector c' a) -> Matrix r c' a)
-> Vector r (Vector c' a) -> Matrix r c' a
forall a b. (a -> b) -> a -> b
$ Vector r (Vector c a)
a Vector r (Vector c a)
-> Vector c (Vector c' a) -> Vector r (Vector c' a)
forall (m :: * -> *) (t :: * -> *) (n :: * -> *) a.
(Functor m, Foldable t, Additive t, Additive n, Num a) =>
m (t a) -> t (n a) -> m (n a)
!*! Vector c (Vector c' a)
b

-- | Matrix * column vector.
mult :: (Arity m, Arity n, Num r) => Matrix n m r -> Vector m r -> Vector n r
(Matrix Vector n (Vector m r)
m) mult :: Matrix n m r -> Vector m r -> Vector n r
`mult` Vector m r
v = Vector n (Vector m r)
m Vector n (Vector m r) -> Vector m r -> Vector n r
forall (m :: * -> *) (r :: * -> *) a.
(Functor m, Foldable r, Additive r, Num a) =>
m (r a) -> r a -> m a
!* Vector m r
v

-- | Produces the Identity Matrix.
identityMatrix :: (Arity d, Num r) => Matrix d d r
identityMatrix :: Matrix d d r
identityMatrix = Vector d (Vector d r) -> Matrix d d r
forall (n :: Nat) (m :: Nat) r.
Vector n (Vector m r) -> Matrix n m r
Matrix (Vector d (Vector d r) -> Matrix d d r)
-> Vector d (Vector d r) -> Matrix d d r
forall a b. (a -> b) -> a -> b
$ (Int -> r -> Vector d r) -> Vector d r -> Vector d (Vector d r)
forall i (f :: * -> *) a b.
FunctorWithIndex i f =>
(i -> a -> b) -> f a -> f b
imap Int -> r -> Vector d r
forall (d :: Nat) r. (Arity d, Num r) => Int -> r -> Vector d r
mkRow (r -> Vector d r
forall (f :: * -> *) a. Applicative f => a -> f a
pure r
1)

-- | Class of matrices that are invertible.
class Invertible n r where
  inverse' :: Matrix n n r -> Matrix n n r

instance Fractional r => Invertible 2 r where
  -- >>> inverse' $ Matrix $ Vector2 (Vector2 1 2) (Vector2 3 4.0)
  -- Matrix Vector2 [Vector2 [-2.0,1.0],Vector2 [1.5,-0.5]]
  inverse' :: Matrix 2 2 r -> Matrix 2 2 r
inverse' = (M22 r -> M22 r) -> Matrix 2 2 r -> Matrix 2 2 r
forall a b. (M22 a -> M22 b) -> Matrix 2 2 a -> Matrix 2 2 b
withM22 M22 r -> M22 r
forall a. Fractional a => M22 a -> M22 a
Lin.inv22

instance Fractional r => Invertible 3 r where
  -- >>> inverse' $ Matrix $ Vector3 (Vector3 1 2 4) (Vector3 4 2 2) (Vector3 1 1 1.0)
  -- Matrix Vector3 [Vector3 [0.0,0.5,-1.0],Vector3 [-0.5,-0.75,3.5],Vector3 [0.5,0.25,-1.5]]
  inverse' :: Matrix 3 3 r -> Matrix 3 3 r
inverse' = (M33 r -> M33 r) -> Matrix 3 3 r -> Matrix 3 3 r
forall a b. (M33 a -> M33 b) -> Matrix 3 3 a -> Matrix 3 3 b
withM33 M33 r -> M33 r
forall a. Fractional a => M33 a -> M33 a
Lin.inv33

instance Fractional r => Invertible 4 r where
  inverse' :: Matrix 4 4 r -> Matrix 4 4 r
inverse' = (M44 r -> M44 r) -> Matrix 4 4 r -> Matrix 4 4 r
forall a b. (M44 a -> M44 b) -> Matrix 4 4 a -> Matrix 4 4 b
withM44 M44 r -> M44 r
forall a. Fractional a => M44 a -> M44 a
Lin.inv44

-- | Class of matrices that have a determinant.
class Arity d => HasDeterminant d where
  det :: Num r => Matrix d d r -> r

instance HasDeterminant 1 where
  det :: Matrix 1 1 r -> r
det (Matrix (Vector1 (Vector1 r
x))) = r
x
instance HasDeterminant 2 where
  det :: Matrix 2 2 r -> r
det = M22 r -> r
forall a. Num a => M22 a -> a
Lin.det22 (M22 r -> r) -> (Matrix 2 2 r -> M22 r) -> Matrix 2 2 r -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix 2 2 r -> M22 r
coerce
instance HasDeterminant 3 where
  det :: Matrix 3 3 r -> r
det = M33 r -> r
forall a. Num a => M33 a -> a
Lin.det33 (M33 r -> r) -> (Matrix 3 3 r -> M33 r) -> Matrix 3 3 r -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix 3 3 r -> M33 r
coerce
instance HasDeterminant 4 where
  det :: Matrix 4 4 r -> r
det = M44 r -> r
forall a. Num a => M44 a -> a
Lin.det44 (M44 r -> r) -> (Matrix 4 4 r -> M44 r) -> Matrix 4 4 r -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix 4 4 r -> M44 r
coerce

--------------------------------------------------------------------------------
-- Boilerplate code for converting between Matrix and M22/M33/M44.

withM22 :: (M22 a -> M22 b) -> Matrix 2 2 a -> Matrix 2 2 b
withM22 :: (M22 a -> M22 b) -> Matrix 2 2 a -> Matrix 2 2 b
withM22 M22 a -> M22 b
f = M22 b -> Matrix 2 2 b
coerce (M22 b -> Matrix 2 2 b)
-> (Matrix 2 2 a -> M22 b) -> Matrix 2 2 a -> Matrix 2 2 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M22 a -> M22 b
f (M22 a -> M22 b)
-> (Matrix 2 2 a -> M22 a) -> Matrix 2 2 a -> M22 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix 2 2 a -> M22 a
coerce

withM33 :: (M33 a -> M33 b) -> Matrix 3 3 a -> Matrix 3 3 b
withM33 :: (M33 a -> M33 b) -> Matrix 3 3 a -> Matrix 3 3 b
withM33 M33 a -> M33 b
f = M33 b -> Matrix 3 3 b
coerce (M33 b -> Matrix 3 3 b)
-> (Matrix 3 3 a -> M33 b) -> Matrix 3 3 a -> Matrix 3 3 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M33 a -> M33 b
f (M33 a -> M33 b)
-> (Matrix 3 3 a -> M33 a) -> Matrix 3 3 a -> M33 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix 3 3 a -> M33 a
coerce

withM44 :: (M44 a -> M44 b) -> Matrix 4 4 a -> Matrix 4 4 b
withM44 :: (M44 a -> M44 b) -> Matrix 4 4 a -> Matrix 4 4 b
withM44 M44 a -> M44 b
f = M44 b -> Matrix 4 4 b
coerce (M44 b -> Matrix 4 4 b)
-> (Matrix 4 4 a -> M44 b) -> Matrix 4 4 a -> Matrix 4 4 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M44 a -> M44 b
f (M44 a -> M44 b)
-> (Matrix 4 4 a -> M44 a) -> Matrix 4 4 a -> M44 b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix 4 4 a -> M44 a
coerce