{-# LANGUAGE CPP                        #-}
{-# LANGUAGE Safe                       #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DefaultSignatures          #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE RebindableSyntax           #-}
{-# LANGUAGE RankNTypes                 #-}

module Data.Semimodule.Matrix where
{-(
    type M11
  , type M12
  , type M13
  , type M14
  , type M21
  , type M31
  , type M41
  , type M22
  , type M23
  , type M24
  , type M32
  , type M33
  , type M34
  , type M42
  , type M43
  , type M44
  , lensRep
  , grateRep
  , tran
  , row
  , rows
  , col
  , cols
  , (.#)
  , (.*)
  , (#.)
  , (*.)
  , (.#.)
  , (`dot`)
  , outer
  , diag
  , dirac
  , identity
  , transpose
  , trace
  , diagonal
  , bdet2
  , det2
  , inv1
  , inv2
  , bdet3
  , det3
  , inv3
  , bdet4
  , det4
  , inv4
  , m11
  , m12
  , m13
  , m14
  , m21
  , m31
  , m41
  , m22
  , m23
  , m24
  , m32
  , m33
  , m34
  , m42
  , m43
  , m44
  ) where
-}

import safe Data.Bool
import safe Data.Distributive
import safe Data.Functor.Compose
import safe Data.Functor.Rep
import safe Data.Semifield
import safe Data.Semigroup.Additive
import safe Data.Semigroup.Multiplicative
import safe Data.Semimodule
import safe Data.Semimodule.Transform
import safe Data.Semimodule.Vector
import safe Data.Semiring
import safe Data.Tuple
import safe Prelude hiding (Num(..), Fractional(..), sum, negate)

-- All matrices use row-major representation.

-- | A 1x1 matrix.
type M11 a = V1 (V1 a)

-- | A 1x2 matrix.
type M12 a = V1 (V2 a)

-- | A 1x3 matrix.
type M13 a = V1 (V3 a)

-- | A 1x4 matrix.
type M14 a = V1 (V4 a)

-- | A 2x1 matrix.
type M21 a = V2 (V1 a)

-- | A 3x1 matrix.
type M31 a = V3 (V1 a)

-- | A 4x1 matrix.
type M41 a = V4 (V1 a)

-- | A 2x2 matrix.
type M22 a = V2 (V2 a)

-- | A 2x3 matrix.
type M23 a = V2 (V3 a)

-- | A 2x4 matrix.
type M24 a = V2 (V4 a)

-- | A 3x2 matrix.
type M32 a = V3 (V2 a)

-- | A 3x3 matrix.
type M33 a = V3 (V3 a)

-- | A 3x4 matrix.
type M34 a = V3 (V4 a)

-- | A 4x2 matrix.
type M42 a = V4 (V2 a)

-- | A 4x3 matrix.
type M43 a = V4 (V3 a)

-- | A 4x4 matrix.
type M44 a = V4 (V4 a)


lensRep :: Eq (Rep f) => Representable f => Rep f -> forall g. Functor g => (a -> g a) -> f a -> g (f a)
lensRep i f s = setter s <$> f (getter s)
  where getter = flip index i
        setter s' b = tabulate $ \j -> bool (index s' j) b (i == j)
{-# INLINE lensRep #-}

grateRep :: Representable f => forall g. Functor g => (Rep f -> g a -> b) -> g (f a) -> f b
grateRep iab s = tabulate $ \i -> iab i (fmap (`index` i) s)
{-# INLINE grateRep #-}


-- | Create a unit vector at an index.
--
-- >>> idx I21 :: V2 Int
-- V2 1 0
--
-- >>> idx I42 :: V4 Int
-- V4 0 1 0 0
--
idx :: Semiring a => Free f => Rep f -> f a
idx i = tabulate $ dirac i
{-# INLINE idx #-}
infix 6 `dot`

-- | Dot product.
--
-- >>> V3 1 2 3 `dot` V3 1 2 3
-- 14
-- 
dot :: Semiring a => Free f => Foldable f => f a -> f a -> a
dot x y = sum $ liftR2 (*) x y
{-# INLINE dot #-}

-- | Squared /l2/ norm of a vector.
--
quadrance :: Semiring a => Free f => Foldable f => f a -> a
quadrance f = f `dot` f
{-# INLINE quadrance #-}

-- | Squared /l2/ norm of the difference between two vectors.
--
qd :: FreeModule a f => Foldable f => f a -> f a -> a
qd f g = quadrance $ f - g
{-# INLINE qd #-}

-- @ ('.#') = 'app' . 'tran' @
tran :: Semiring a => Basis b f => Basis c g => Foldable g => f (g a) -> Tran a b c
tran m = Tran $ \f -> index $ m .# (tabulate f)

-- | Retrieve a row of a row-major matrix or element of a row vector.
--
-- >>> row I21 (V2 1 2)
-- 1
--
row :: Representable f => Rep f -> f a -> a
row = flip index
{-# INLINE row #-}

-- | Retrieve a column of a row-major matrix.
--
-- >>> row I22 . col I31 $ V2 (V3 1 2 3) (V3 4 5 6)
-- 4
--
col :: Functor f => Representable g => Rep g -> f (g a) -> f a
col j = flip index j . distribute
{-# INLINE col #-}

infixl 7 #.

-- | Multiply a matrix on the left by a row vector.
--
-- >>> V2 1 2 #. m23 3 4 5 6 7 8
-- V3 15 18 21
--
-- >>> V2 1 2 #. m23 3 4 5 6 7 8 #. m32 1 0 0 0 0 0
-- V2 15 0
--
(#.) :: (Semiring a, Free f, Foldable f, Free g) => f a -> f (g a) -> g a
x #. y = tabulate (\j -> x `dot` col j y)
{-# INLINE (#.) #-}

infixr 7 .#, .#.

-- | Multiply a matrix on the right by a column vector.
--
-- @ ('.#') = 'app' . 'tran' @
--
-- >>> app (tran $ m23 1 2 3 4 5 6) (V3 7 8 9) :: V2 Int
-- V2 50 122
-- >>> m23 1 2 3 4 5 6 .# V3 7 8 9 :: V2 Int
-- V2 50 122
-- >>> m22 1 0 0 0 .# m23 1 2 3 4 5 6 .# V3 7 8 9
-- V2 50 0
--
(.#) :: (Semiring a, Free f, Free g, Foldable g) => f (g a) -> g a -> f a
x .# y = tabulate (\i -> row i x `dot` y)
{-# INLINE (.#) #-}

-- | Multiply two matrices.
--
-- >>> m22 1 2 3 4 .#. m22 1 2 3 4 :: M22 Int
-- V2 (V2 7 10) (V2 15 22)
-- 
-- >>> m23 1 2 3 4 5 6 .#. m32 1 2 3 4 4 5 :: M22 Int
-- V2 (V2 19 25) (V2 43 58)
--
(.#.) :: (Semiring a, Free f, Free g, Free h, Foldable g) => f (g a) -> g (h a) -> f (h a)
(.#.) x y = getCompose $ tabulate (\(i,j) -> row i x `dot` col j y)
{-# INLINE (.#.) #-}

-- | Outer product of two vectors.
--
-- >>> V2 1 1 `outer` V2 1 1
-- V2 (V2 1 1) (V2 1 1)
--
outer :: Semiring a => Functor f => Functor g => f a -> g a -> f (g a)
outer x y = fmap (\z-> fmap (*z) y) x

-- | Obtain a diagonal matrix from a vector.
--
-- >>> diag (V2 2 3)
-- V2 (V2 2 0) (V2 0 3)
--
diag :: (Additive-Monoid) a => Free f => f a -> f (f a)
diag f = flip imapRep f $ \i x -> flip imapRep f (\j _ -> bool zero x $ i == j)
{-# INLINE diag #-}

-- | Identity matrix.
--
-- >>> identity :: M44 Int
-- V4 (V4 1 0 0 0) (V4 0 1 0 0) (V4 0 0 1 0) (V4 0 0 0 1)
--
-- >>> identity :: V3 (V3 Int)
-- V3 (V3 1 0 0) (V3 0 1 0) (V3 0 0 1)
--
identity :: Semiring a => Free f => f (f a)
identity = diag $ pureRep one
{-# INLINE identity #-}

-- | Compute the trace of a matrix.
--
-- >>> trace $ V2 (V2 1.0 2.0) (V2 3.0 4.0)
-- 5.0
--
trace :: Semiring a => Free f => Foldable f => f (f a) -> a
trace = sum . diagonal
{-# INLINE trace #-}

-- | Obtain the diagonal of a matrix as a vector.
--
-- >>> diagonal $ V2 (V2 1.0 2.0) (V2 3.0 4.0)
-- V2 1.0 4.0
--
diagonal :: Representable f => f (f a) -> f a
diagonal = flip bindRep id
{-# INLINE diagonal #-}

ij :: Representable f => Representable g => Rep f -> Rep g -> f (g a) -> a
ij i j = row i . col j

-- | 1x1 matrix inverse over a field.
--
-- >>> inv1 $ m11 4.0 :: M11 Double
-- V1 (V1 0.25)
--
inv1 :: Field a => M11 a -> M11 a
inv1 = transpose . (fmap . fmap) recip

-- | 2x2 matrix bdeterminant over a commutative semiring.
--
-- >>> bdet2 $ m22 1 2 3 4
-- (4,6)
--
bdet2 :: Semiring a => Basis I2 f => Basis I2 g => f (g a) -> (a, a)
bdet2 m = (ij I21 I21 m * ij I22 I22 m, ij I21 I22 m * ij I22 I21 m)
{-# INLINE bdet2 #-}

-- | 2x2 matrix determinant over a commutative ring.
--
-- @
-- 'det2' = 'uncurry' ('-') . 'bdet2'
-- @
--
-- >>> det2 $ m22 1 2 3 4 :: Double
-- -2.0
--
det2 :: Ring a => Basis I2 f => Basis I2 g => f (g a) -> a
det2 = uncurry (-) . bdet2
{-# INLINE det2 #-}

-- | 2x2 matrix inverse over a field.
--
-- >>> inv2 $ m22 1 2 3 4 :: M22 Double
-- V2 (V2 (-2.0) 1.0) (V2 1.5 (-0.5))
--
inv2 :: Field a => M22 a -> M22 a
inv2 m = lscaleDef (recip $ det2 m) <$> m22 d (-b) (-c) a where
  a = ij I21 I21 m
  b = ij I21 I22 m
  c = ij I22 I21 m
  d = ij I22 I22 m
{-# INLINE inv2 #-}

-- | 3x3 matrix bdeterminant over a commutative semiring.
--
-- >>> bdet3 (V3 (V3 1 2 3) (V3 4 5 6) (V3 7 8 9))
-- (225, 225)
--
bdet3 :: Semiring a => Basis I3 f => Basis I3 g => f (g a) -> (a, a)
bdet3 m = (evens, odds) where
  evens = a*e*i + g*b*f + d*h*c
  odds  = a*h*f + d*b*i + g*e*c
  a = ij I31 I31 m
  b = ij I31 I32 m
  c = ij I31 I33 m
  d = ij I32 I31 m
  e = ij I32 I32 m
  f = ij I32 I33 m
  g = ij I33 I31 m
  h = ij I33 I32 m
  i = ij I33 I33 m
{-# INLINE bdet3 #-}

-- | 3x3 double-precision matrix determinant.
--
-- @
-- 'det3' = 'uncurry' ('-') . 'bdet3'
-- @
--
-- Implementation uses a cofactor expansion to avoid loss of precision.
--
-- >>> det3 (V3 (V3 1 2 3) (V3 4 5 6) (V3 7 8 9))
-- 0
--
det3 :: Ring a => Basis I3 f => Basis I3 g => f (g a) -> a
det3 m = a * (e*i-f*h) - d * (b*i-c*h) + g * (b*f-c*e) where
  a = ij I31 I31 m
  b = ij I31 I32 m
  c = ij I31 I33 m
  d = ij I32 I31 m
  e = ij I32 I32 m
  f = ij I32 I33 m
  g = ij I33 I31 m
  h = ij I33 I32 m
  i = ij I33 I33 m
{-# INLINE det3 #-}

-- | 3x3 matrix inverse.
--
-- >>> inv3 $ m33 1 2 4 4 2 2 1 1 1 :: M33 Double
-- V3 (V3 0.0 0.5 (-1.0)) (V3 (-0.5) (-0.75) 3.5) (V3 0.5 0.25 (-1.5))
--
inv3 :: Field a => M33 a -> M33 a
inv3 m = lscaleDef (recip $ det3 m) <$> m33 a' b' c' d' e' f' g' h' i' where
  a = ij I31 I31 m
  b = ij I31 I32 m
  c = ij I31 I33 m
  d = ij I32 I31 m
  e = ij I32 I32 m
  f = ij I32 I33 m
  g = ij I33 I31 m
  h = ij I33 I32 m
  i = ij I33 I33 m
  a' = cofactor (e,f,h,i)
  b' = cofactor (c,b,i,h)
  c' = cofactor (b,c,e,f)
  d' = cofactor (f,d,i,g)
  e' = cofactor (a,c,g,i)
  f' = cofactor (c,a,f,d)
  g' = cofactor (d,e,g,h)
  h' = cofactor (b,a,h,g)
  i' = cofactor (a,b,d,e)
  cofactor (q,r,s,t) = det2 (m22 q r s t)
{-# INLINE inv3 #-}

-- | 4x4 matrix bdeterminant over a commutative semiring.
--
-- >>> bdet4 (V4 (V4 1 2 3 4) (V4 5 6 7 8) (V4 9 10 11 12) (V4 13 14 15 16))
-- (27728,27728)
--
bdet4 :: Semiring a => Basis I4 f => Basis I4 g => f (g a) -> (a, a)
bdet4 x = (evens, odds) where
  evens = a * (f*k*p + g*l*n + h*j*o) +
          b * (g*i*p + e*l*o + h*k*m) +
          c * (e*j*p + f*l*m + h*i*n) +
          d * (f*i*o + e*k*n + g*j*m)
  odds =  a * (g*j*p + f*l*o + h*k*n) +
          b * (e*k*p + g*l*m + h*i*o) +
          c * (f*i*p + e*l*n + h*j*m) +
          d * (e*j*o + f*k*m + g*i*n)
  a = ij I41 I41 x
  b = ij I41 I42 x
  c = ij I41 I43 x
  d = ij I41 I44 x
  e = ij I42 I41 x
  f = ij I42 I42 x
  g = ij I42 I43 x
  h = ij I42 I44 x
  i = ij I43 I41 x
  j = ij I43 I42 x
  k = ij I43 I43 x
  l = ij I43 I44 x
  m = ij I44 I41 x
  n = ij I44 I42 x
  o = ij I44 I43 x
  p = ij I44 I44 x
{-# INLINE bdet4 #-}

-- | 4x4 matrix determinant over a commutative ring.
--
-- @
-- 'det4' = 'uncurry' ('-') . 'bdet4'
-- @
--
-- This implementation uses a cofactor expansion to avoid loss of precision.
--
-- >>> det4 (m44 1 0 3 2 2 0 2 1 0 0 0 1 0 3 4 0 :: M44 Rational)
-- (-12) % 1
--
det4 :: Ring a => Basis I4 f => Basis I4 g => f (g a) -> a
det4 x = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0 where
  s0 = i00 * i11 - i10 * i01
  s1 = i00 * i12 - i10 * i02
  s2 = i00 * i13 - i10 * i03
  s3 = i01 * i12 - i11 * i02
  s4 = i01 * i13 - i11 * i03
  s5 = i02 * i13 - i12 * i03

  c5 = i22 * i33 - i32 * i23
  c4 = i21 * i33 - i31 * i23
  c3 = i21 * i32 - i31 * i22
  c2 = i20 * i33 - i30 * i23
  c1 = i20 * i32 - i30 * i22
  c0 = i20 * i31 - i30 * i21

  i00 = ij I41 I41 x
  i01 = ij I41 I42 x
  i02 = ij I41 I43 x
  i03 = ij I41 I44 x
  i10 = ij I42 I41 x
  i11 = ij I42 I42 x
  i12 = ij I42 I43 x
  i13 = ij I42 I44 x
  i20 = ij I43 I41 x
  i21 = ij I43 I42 x
  i22 = ij I43 I43 x
  i23 = ij I43 I44 x
  i30 = ij I44 I41 x
  i31 = ij I44 I42 x
  i32 = ij I44 I43 x
  i33 = ij I44 I44 x
{-# INLINE det4 #-}

-- | 4x4 matrix inverse.
--
-- >>> row I41 $ inv4 (m44 1 0 3 2 2 0 2 1 0 0 0 1 0 3 4 0 :: M44 Rational)
-- V4 (6 % (-12)) ((-9) % (-12)) ((-3) % (-12)) (0 % (-12))
--
inv4 :: Field a => M44 a -> M44 a
inv4 x = lscaleDef (recip det) <$> x' where
  i00 = ij I41 I41 x
  i01 = ij I41 I42 x
  i02 = ij I41 I43 x
  i03 = ij I41 I44 x
  i10 = ij I42 I41 x
  i11 = ij I42 I42 x
  i12 = ij I42 I43 x
  i13 = ij I42 I44 x
  i20 = ij I43 I41 x
  i21 = ij I43 I42 x
  i22 = ij I43 I43 x
  i23 = ij I43 I44 x
  i30 = ij I44 I41 x
  i31 = ij I44 I42 x
  i32 = ij I44 I43 x
  i33 = ij I44 I44 x

  s0 = i00 * i11 - i10 * i01
  s1 = i00 * i12 - i10 * i02
  s2 = i00 * i13 - i10 * i03
  s3 = i01 * i12 - i11 * i02
  s4 = i01 * i13 - i11 * i03
  s5 = i02 * i13 - i12 * i03
  c5 = i22 * i33 - i32 * i23
  c4 = i21 * i33 - i31 * i23
  c3 = i21 * i32 - i31 * i22
  c2 = i20 * i33 - i30 * i23
  c1 = i20 * i32 - i30 * i22
  c0 = i20 * i31 - i30 * i21

  det = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0

  x' = m44 (i11 * c5 - i12 * c4 + i13 * c3)
           (-i01 * c5 + i02 * c4 - i03 * c3)
           (i31 * s5 - i32 * s4 + i33 * s3)
           (-i21 * s5 + i22 * s4 - i23 * s3)
           (-i10 * c5 + i12 * c2 - i13 * c1)
           (i00 * c5 - i02 * c2 + i03 * c1)
           (-i30 * s5 + i32 * s2 - i33 * s1)
           (i20 * s5 - i22 * s2 + i23 * s1)
           (i10 * c4 - i11 * c2 + i13 * c0)
           (-i00 * c4 + i01 * c2 - i03 * c0)
           (i30 * s4 - i31 * s2 + i33 * s0)
           (-i20 * s4 + i21 * s2 - i23 * s0)
           (-i10 * c3 + i11 * c1 - i12 * c0)
           (i00 * c3 - i01 * c1 + i02 * c0)
           (-i30 * s3 + i31 * s1 - i32 * s0)
           (i20 * s3 - i21 * s1 + i22 * s0)
{-# INLINE inv4 #-}



-- | Construct a 1x1 matrix.
--
-- >>> m11 1 :: M11 Int
-- V1 (V1 1)
--
m11 :: a -> M11 a
m11 a = V1 (V1 a)
{-# INLINE m11 #-}

-- | Construct a 1x2 matrix.
--
-- >>> m12 1 2 :: M12 Int
-- V1 (V2 1 2)
--
m12 :: a -> a -> M12 a
m12 a b = V1 (V2 a b)
{-# INLINE m12 #-}

-- | Construct a 1x3 matrix.
--
-- >>> m13 1 2 3 :: M13 Int
-- V1 (V3 1 2 3)
--
m13 :: a -> a -> a -> M13 a
m13 a b c = V1 (V3 a b c)
{-# INLINE m13 #-}

-- | Construct a 1x4 matrix.
--
-- >>> m14 1 2 3 4 :: M14 Int
-- V1 (V4 1 2 3 4)
--
m14 :: a -> a -> a -> a -> M14 a
m14 a b c d = V1 (V4 a b c d)
{-# INLINE m14 #-}

-- | Construct a 2x1 matrix.
--
-- >>> m21 1 2 :: M21 Int
-- V2 (V1 1) (V1 2)
--
m21 :: a -> a -> M21 a
m21 a b = V2 (V1 a) (V1 b)
{-# INLINE m21 #-}

-- | Construct a 3x1 matrix.
--
-- >>> m31 1 2 3 :: M31 Int
-- V3 (V1 1) (V1 2) (V1 3)
--
m31 :: a -> a -> a -> M31 a
m31 a b c = V3 (V1 a) (V1 b) (V1 c)
{-# INLINE m31 #-}

-- | Construct a 4x1 matrix.
--
-- >>> m41 1 2 3 4 :: M41 Int
-- V4 (V1 1) (V1 2) (V1 3) (V1 4)
--
m41 :: a -> a -> a -> a -> M41 a
m41 a b c d = V4 (V1 a) (V1 b) (V1 c) (V1 d)
{-# INLINE m41 #-}

-- | Construct a 2x2 matrix.
--
-- Arguments are in row-major order.
--
-- >>> m22 1 2 3 4 :: M22 Int
-- V2 (V2 1 2) (V2 3 4)
--
m22 :: a -> a -> a -> a -> M22 a
m22 a b c d = V2 (V2 a b) (V2 c d)
{-# INLINE m22 #-}

-- | Construct a 2x3 matrix.
--
-- Arguments are in row-major order.
--
m23 :: a -> a -> a -> a -> a -> a -> M23 a
m23 a b c d e f = V2 (V3 a b c) (V3 d e f)
{-# INLINE m23 #-}

-- | Construct a 2x4 matrix.
--
-- Arguments are in row-major order.
--
m24 :: a -> a -> a -> a -> a -> a -> a -> a -> M24 a
m24 a b c d e f g h = V2 (V4 a b c d) (V4 e f g h)
{-# INLINE m24 #-}

-- | Construct a 3x2 matrix.
--
-- Arguments are in row-major order.
--
m32 :: a -> a -> a -> a -> a -> a -> M32 a
m32 a b c d e f = V3 (V2 a b) (V2 c d) (V2 e f)
{-# INLINE m32 #-}

-- | Construct a 3x3 matrix.
--
-- Arguments are in row-major order.
--
m33 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> M33 a
m33 a b c d e f g h i = V3 (V3 a b c) (V3 d e f) (V3 g h i)
{-# INLINE m33 #-}

-- | Construct a 3x4 matrix.
--
-- Arguments are in row-major order.
--
m34 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> M34 a
m34 a b c d e f g h i j k l = V3 (V4 a b c d) (V4 e f g h) (V4 i j k l)
{-# INLINE m34 #-}

-- | Construct a 4x2 matrix.
--
-- Arguments are in row-major order.
--
m42 :: a -> a -> a -> a -> a -> a -> a -> a -> M42 a
m42 a b c d e f g h = V4 (V2 a b) (V2 c d) (V2 e f) (V2 g h)
{-# INLINE m42 #-}

-- | Construct a 4x3 matrix.
--
-- Arguments are in row-major order.
--
m43 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> M43 a
m43 a b c d e f g h i j k l = V4 (V3 a b c) (V3 d e f) (V3 g h i) (V3 j k l)
{-# INLINE m43 #-}

-- | Construct a 4x4 matrix.
--
-- Arguments are in row-major order.
--
m44 :: a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> M44 a
m44 a b c d e f g h i j k l m n o p = V4 (V4 a b c d) (V4 e f g h) (V4 i j k l) (V4 m n o p)
{-# INLINE m44 #-}



-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------

--TODO autogenerate and extend to higher dims

instance Semiring a => Semigroup (Multiplicative (M11 a)) where
  (Multiplicative x) <> (Multiplicative y) = Multiplicative $ x .#. y

instance Semiring a => Monoid (Multiplicative (M11 a)) where
  mempty = pure identity

instance Semiring a => Presemiring (M11 a)
instance Semiring a => Semiring (M11 a)
instance Ring a => Ring (M11 a)

instance Field a => Magma (Multiplicative (M11 a)) where
  (Multiplicative x) << (Multiplicative y) = Multiplicative $ x .#. inv1 y

instance Field a => Quasigroup (Multiplicative (M11 a))
instance Field a => Loop (Multiplicative (M11 a))
instance Field a => Group (Multiplicative (M11 a))

instance Semiring a => Semigroup (Multiplicative (M22 a)) where
  (Multiplicative x) <> (Multiplicative y) = Multiplicative $ x .#. y

instance Semiring a => Monoid (Multiplicative (M22 a)) where
  mempty = pure identity

instance Semiring a => Presemiring (M22 a)
instance Semiring a => Semiring (M22 a)
instance Ring a => Ring (M22 a)


instance Semiring a => Semigroup (Multiplicative (M33 a)) where
  (Multiplicative x) <> (Multiplicative y) = Multiplicative $ x .#. y

instance Semiring a => Monoid (Multiplicative (M33 a)) where
  mempty = pure identity

instance Semiring a => Presemiring (M33 a)
instance Semiring a => Semiring (M33 a)
instance Ring a => Ring (M33 a)


instance Semiring a => Semigroup (Multiplicative (M44 a)) where
  (Multiplicative x) <> (Multiplicative y) = Multiplicative $ x .#. y

instance Semiring a => Monoid (Multiplicative (M44 a)) where
  mempty = pure identity

instance Semiring a => Presemiring (M44 a)
instance Semiring a => Semiring (M44 a)
instance Ring a => Ring (M44 a)

{-
-- | A 1x2 matrix.
type M12 a = V1 (V2 a)

-- | A 1x3 matrix.
type M13 a = V1 (V3 a)

-- | A 1x4 matrix.
type M14 a = V1 (V4 a)

-- | A 2x1 matrix.
type M21 a = V2 (V1 a)

-- | A 3x1 matrix.
type M31 a = V3 (V1 a)

-- | A 4x1 matrix.
type M41 a = V4 (V1 a)

-- | A 2x3 matrix.
type M23 a = V2 (V3 a)

-- | A 2x4 matrix.
type M24 a = V2 (V4 a)

-- | A 3x2 matrix.
type M32 a = V3 (V2 a)

-- | A 3x4 matrix.
type M34 a = V3 (V4 a)

-- | A 4x2 matrix.
type M42 a = V4 (V2 a)

-- | A 4x3 matrix.
type M43 a = V4 (V3 a)
-}

instance Semiring a => LeftSemimodule (M22 a) (M21 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M11 a) (M21 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M22 a) (M11 a) (M21 a)

instance Semiring a => LeftSemimodule (M33 a) (M31 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M11 a) (M31 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M33 a) (M11 a) (M31 a)


instance Semiring a => LeftSemimodule (M44 a) (M41 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M11 a) (M41 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M44 a) (M11 a) (M41 a)


instance Semiring a => LeftSemimodule (M11 a) (M12 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M22 a) (M12 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M11 a) (M22 a) (M12 a)

instance Semiring a => LeftSemimodule (M11 a) (M13 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M33 a) (M13 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M11 a) (M33 a) (M13 a)


instance Semiring a => LeftSemimodule (M11 a) (M14 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M44 a) (M14 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M11 a) (M44 a) (M14 a)


-- >>> m22 1 0 0 0 *. m23 1 2 3 4 5 6 :: M23 Integer
-- V2 (V3 1 2 3) (V3 0 0 0)
-- m22 0 0 1 0 *. m22 1 0 0 0 *. m23 1 2 3 4 5 6 :: M23 Integer
-- V2 (V3 0 0 0) (V3 1 2 3)
-- (m22 0 0 1 0 * m22 1 0 0 0) *. m23 1 2 3 4 5 6 :: M23 Integer
-- V2 (V3 0 0 0) (V3 1 2 3)
instance Semiring a => LeftSemimodule (M22 a) (M23 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

-- >>> m23 1 2 3 4 5 6 .* m33 1 0 0 0 0 0 0 0 0 :: M23 Integer
-- V2 (V3 1 0 0) (V3 4 0 0)
instance Semiring a => RightSemimodule (M33 a) (M23 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M22 a) (M33 a) (M23 a)


instance Semiring a => LeftSemimodule (M22 a) (M24 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M44 a) (M24 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M22 a) (M44 a) (M24 a)


instance Semiring a => LeftSemimodule (M33 a) (M32 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M22 a) (M32 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M33 a) (M22 a) (M32 a)


instance Semiring a => LeftSemimodule (M44 a) (M42 a) where
  lscale = (.#.)
  {-# INLINE lscale #-}

instance Semiring a => RightSemimodule (M22 a) (M42 a) where
  rscale = flip (.#.)
  {-# INLINE rscale #-}

instance Semiring a => Bisemimodule (M44 a) (M22 a) (M42 a)