{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
module Data.Semiring.Matrix (
type M22
, type M23
, type M24
, type M32
, type M33
, type M34
, type M42
, type M43
, type M44
, m22
, m23
, m24
, m32
, m33
, m34
, m42
, m43
, m44
, row
, col
, (.>)
, (<.)
, (#>)
, (<#)
, (<#>)
, scale
, identity
, transpose
, trace
, diag
, bdet2
, det2
, det2d
, inv2d
, bdet3
, det3
, det3d
, inv3d
, bdet4
, det4
, det4d
, inv4d
) where
import Data.Distributive
import Data.Foldable as Foldable (fold, foldl')
import Data.Functor.Compose
import Data.Functor.Rep
import Data.Group
import Data.Prd
import Data.Ring
import Data.Semigroup.Foldable as Foldable1
import Data.Semiring
import Data.Semiring.Module
import Data.Semiring.V2
import Data.Semiring.V3
import Data.Semiring.V4
import Data.Tuple
import Data.Double.Instance ()
import Prelude hiding (sum, negate)
type M22 a = V2 (V2 a)
type M23 a = V2 (V3 a)
type M24 a = V2 (V4 a)
type M32 a = V3 (V2 a)
type M33 a = V3 (V3 a)
type M34 a = V3 (V4 a)
type M42 a = V4 (V2 a)
type M43 a = V4 (V3 a)
type M44 a = V4 (V4 a)
m22 :: a -> a -> a -> a -> M22 a
m22 a b c d = V2 (V2 a b) (V2 c d)
{-# INLINE m22 #-}
m23 :: a -> a -> a -> a -> a -> a -> M23 a
m23 a b c d e f = V2 (V3 a b c) (V3 d e f)
{-# INLINE m23 #-}
m24 :: a -> a -> a -> a -> a -> a -> a -> a -> M24 a
m24 a b c d e f g h = V2 (V4 a b c d) (V4 e f g h)
{-# INLINE m24 #-}
m32 :: a -> a -> a -> a -> a -> a -> M32 a
m32 a b c d e f = V3 (V2 a b) (V2 c d) (V2 e f)
{-# INLINE m32 #-}
m33 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> M33 a
m33 a b c d e f g h i = V3 (V3 a b c) (V3 d e f) (V3 g h i)
{-# INLINE m33 #-}
m34 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> M34 a
m34 a b c d e f g h i j k l = V3 (V4 a b c d) (V4 e f g h) (V4 i j k l)
{-# INLINE m34 #-}
m42 :: a -> a -> a -> a -> a -> a -> a -> a -> M42 a
m42 a b c d e f g h = V4 (V2 a b) (V2 c d) (V2 e f) (V2 g h)
{-# INLINE m42 #-}
m43 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> M43 a
m43 a b c d e f g h i j k l = V4 (V3 a b c) (V3 d e f) (V3 g h i) (V3 j k l)
{-# INLINE m43 #-}
m44 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> M44 a
m44 a b c d e f g h i j k l m n o p = V4 (V4 a b c d) (V4 e f g h) (V4 i j k l) (V4 m n o p)
{-# INLINE m44 #-}
row :: Representable f => Rep f -> f c -> c
row i = flip index i
{-# INLINE row #-}
col :: Functor f => Representable g => Rep g -> f (g a) -> f a
col j m = flip index j $ distribute m
{-# INLINE col #-}
infixl 7 <.
(<.) :: Semiring a => Functor f => Functor g => f (g a) -> a -> f (g a)
f <. a = fmap (fmap (>< a)) f
{-# INLINE (<.) #-}
infixr 7 .>
(.>) :: Semiring a => Functor f => Functor g => a -> f (g a) -> f (g a)
(.>) a = fmap (fmap (a ><))
{-# INLINE (.>) #-}
infixl 7 <#
(<#) :: (Semiring a, Free f, Free g) => f a -> f (g a) -> g a
x <# y = tabulate (\j -> x <.> col j y)
{-# INLINE (<#) #-}
infixr 7 #>, <#>
(#>) :: (Semiring a, Free f, Free g) => f (g a) -> g a -> f a
x #> y = tabulate (\i -> row i x <.> y)
{-# INLINE (#>) #-}
(<#>) :: (Semiring a, Free f, Free g, Free h) => f (g a) -> g (h a) -> f (h a)
(<#>) x y = getCompose $ tabulate (\(i,j) -> row i x <.> col j y)
{-# INLINE (<#>) #-}
scale :: Monoid a => Free f => f a -> f (f a)
scale f = flip imapRep f $ \i x -> flip imapRep f (\j _ -> if i == j then x else mempty)
{-# INLINE scale #-}
identity :: Unital a => Free f => f (f a)
identity = scale $ pureRep sunit
{-# INLINE identity #-}
transpose :: Functor f => Distributive g => f (g a) -> g (f a)
transpose = distribute
{-# INLINE transpose #-}
trace :: Semigroup a => Free f => f (f a) -> a
trace = Foldable1.fold1 . diag
{-# INLINE trace #-}
diag :: Representable f => f (f a) -> f a
diag = flip bindRep id
{-# INLINE diag #-}
bdet2 :: Semiring a => M22 a -> (a, a)
bdet2 (V2 (V2 a b) (V2 c d)) = (a >< d, b >< c)
{-# INLINE bdet2 #-}
det2 :: Ring a => M22 a -> a
det2 = uncurry (<<) . bdet2
{-# INLINE det2 #-}
det2d :: M22 Double -> Double
det2d (V2 (V2 a b) (V2 c d)) = a * d - b * c
{-# INLINE det2d #-}
inv2d :: M22 Double -> M22 Double
inv2d m@(V2 (V2 a b) (V2 c d)) = (1 / det2d m) .> m22 d (-b) (-c) a
{-# INLINE inv2d #-}
bdet3 :: Semiring a => M33 a -> (a, a)
bdet3 (V3 (V3 a b c) (V3 d e f) (V3 g h i)) = (evens, odds)
where
evens = a><e><i <> g><b><f <> d><h><c
odds = a><h><f <> d><b><i <> g><e><c
{-# INLINE bdet3 #-}
det3 :: Ring a => M33 a -> a
det3 = uncurry (<<) . bdet3
{-# INLINE det3 #-}
det3d :: M33 Double -> Double
det3d (V3 (V3 a b c)
(V3 d e f)
(V3 g h i)) = a * (e*i-f*h) - d * (b*i-c*h) + g * (b*f-c*e)
{-# INLINE det3d #-}
inv3d :: M33 Double -> M33 Double
inv3d m@(V3 (V3 a b c)
(V3 d e f)
(V3 g h i)) =
let a' = cofactor (e,f,h,i)
b' = cofactor (c,b,i,h)
c' = cofactor (b,c,e,f)
d' = cofactor (f,d,i,g)
e' = cofactor (a,c,g,i)
f' = cofactor (c,a,f,d)
g' = cofactor (d,e,g,h)
h' = cofactor (b,a,h,g)
i' = cofactor (a,b,d,e)
cofactor (q,r,s,t) = det2d $ m22 q r s t
det = det3d m
in (1 / det) .> m33 a' b' c' d' e' f' g' h' i'
{-# INLINE inv3d #-}
bdet4 :: Semiring a => M44 a -> (a, a)
bdet4 (V4 (V4 a b c d) (V4 e f g h) (V4 i j k l) (V4 m n o p)) = (evens, odds)
where
evens = a >< (f><k><p <> g><l><n <> h><j><o) <>
b >< (g><i><p <> e><l><o <> h><k><m) <>
c >< (e><j><p <> f><l><m <> h><i><n) <>
d >< (f><i><o <> e><k><n <> g><j><m)
odds = a >< (g><j><p <> f><l><o <> h><k><n) <>
b >< (e><k><p <> g><l><m <> h><i><o) <>
c >< (f><i><p <> e><l><n <> h><j><m) <>
d >< (e><j><o <> f><k><m <> g><i><n)
{-# INLINE bdet4 #-}
det4 :: Ring a => M44 a -> a
det4 = uncurry (<<) . bdet4
{-# INLINE det4 #-}
det4d :: M44 Double -> Double
det4d (V4 (V4 i00 i01 i02 i03)
(V4 i10 i11 i12 i13)
(V4 i20 i21 i22 i23)
(V4 i30 i31 i32 i33)) =
let
s0 = i00 * i11 - i10 * i01
s1 = i00 * i12 - i10 * i02
s2 = i00 * i13 - i10 * i03
s3 = i01 * i12 - i11 * i02
s4 = i01 * i13 - i11 * i03
s5 = i02 * i13 - i12 * i03
c5 = i22 * i33 - i32 * i23
c4 = i21 * i33 - i31 * i23
c3 = i21 * i32 - i31 * i22
c2 = i20 * i33 - i30 * i23
c1 = i20 * i32 - i30 * i22
c0 = i20 * i31 - i30 * i21
in s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0
{-# INLINE det4d #-}
inv4d :: M44 Double -> M44 Double
inv4d (V4 (V4 i00 i01 i02 i03)
(V4 i10 i11 i12 i13)
(V4 i20 i21 i22 i23)
(V4 i30 i31 i32 i33)) =
let s0 = i00 * i11 - i10 * i01
s1 = i00 * i12 - i10 * i02
s2 = i00 * i13 - i10 * i03
s3 = i01 * i12 - i11 * i02
s4 = i01 * i13 - i11 * i03
s5 = i02 * i13 - i12 * i03
c5 = i22 * i33 - i32 * i23
c4 = i21 * i33 - i31 * i23
c3 = i21 * i32 - i31 * i22
c2 = i20 * i33 - i30 * i23
c1 = i20 * i32 - i30 * i22
c0 = i20 * i31 - i30 * i21
det = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0
invDet = recip det
in invDet .> V4 (V4 (i11 * c5 - i12 * c4 + i13 * c3)
(-i01 * c5 + i02 * c4 - i03 * c3)
(i31 * s5 - i32 * s4 + i33 * s3)
(-i21 * s5 + i22 * s4 - i23 * s3))
(V4 (-i10 * c5 + i12 * c2 - i13 * c1)
(i00 * c5 - i02 * c2 + i03 * c1)
(-i30 * s5 + i32 * s2 - i33 * s1)
(i20 * s5 - i22 * s2 + i23 * s1))
(V4 (i10 * c4 - i11 * c2 + i13 * c0)
(-i00 * c4 + i01 * c2 - i03 * c0)
(i30 * s4 - i31 * s2 + i33 * s0)
(-i20 * s4 + i21 * s2 - i23 * s0))
(V4 (-i10 * c3 + i11 * c1 - i12 * c0)
(i00 * c3 - i01 * c1 + i02 * c0)
(-i30 * s3 + i31 * s1 - i32 * s0)
(i20 * s3 - i21 * s1 + i22 * s0))
{-# INLINE inv4d #-}