{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module Data.Tensor.Matrix where

import           Data.List          (foldl')
import           Data.Proxy
import           Data.Ratio
import           Data.Tensor.Tensor
import           Data.Tensor.Type
import           GHC.TypeLits

type SimpleMatrix a n = Matrix a a n

-- | <https://en.wikipedia.org/wiki/Matrix_multiplication Matrix multiplication>
dotM :: (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
dotM :: SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
dotM = SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
forall (s :: [Nat]) (s' :: [Nat]) (r :: [Nat]) n.
(Last s ~ Head s', r ~ DotTensor s s', HasShape s, HasShape s',
 HasShape r, Num n, Eq n) =>
Tensor s n -> Tensor s' n -> Tensor r n
dot

diag :: SimpleMatrix a n -> Vector a n
diag :: SimpleMatrix a n -> Vector a n
diag (Tensor Shape -> Shape -> n
t) = (Shape -> Shape -> n) -> Vector a n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Vector a n)
-> (Shape -> Shape -> n) -> Vector a n
forall a b. (a -> b) -> a -> b
$ \[Int
s] [Int
i] -> Shape -> Shape -> n
t [Int
s,Int
s] [Int
i,Int
i]

-- | <https://en.wikipedia.org/wiki/Trace_(linear_algebra) Matrix trace>
trace :: (KnownNat a, Num n) => SimpleMatrix a n -> n
trace :: SimpleMatrix a n -> n
trace SimpleMatrix a n
t = let (Tensor Shape -> Shape -> n
f) = (Proxy 0, Proxy 1) -> SimpleMatrix a n -> Tensor '[] n
forall (x :: Nat) (y :: Nat) (s :: [Nat]) (s' :: [Nat]) n.
(CheckIndices x y s, s' ~ Contraction s x y,
 Dimension s x ~ Dimension s y, KnownNat x, KnownNat y, HasShape s,
 HasShape s', KnownNat (Dimension s x), Num n) =>
(Proxy x, Proxy y) -> Tensor s n -> Tensor s' n
contraction (Proxy 0
i0,Proxy 1
i1) SimpleMatrix a n
t in Shape -> Shape -> n
f [] []

-- | <https://en.wikipedia.org/wiki/LU_decomposition LU decomposition> of n x n matrix
--
-- > λ> a = [1,2,3,2,5,7,3,5,3]:: Tensor '[3,3] Int
-- > λ> (l,u,p) = lu a
-- > λ> l
-- > [[1 % 1,0 % 1,0 % 1],
-- > [2 % 1,1 % 1,0 % 1],
-- > [3 % 1,(-1) % 1,1 % 1]]
-- > λ> u
-- > [[1 % 1,2 % 1,3 % 1],
-- > [0 % 1,1 % 1,1 % 1],
-- > [0 % 1,0 % 1,(-5) % 1]]
-- > λ> p
-- > [[1,0,0],
-- > [0,1,0],
-- > [0,0,1]]
lu :: forall a n . (KnownNat a, Integral n)
   => SimpleMatrix a n
   -> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n), SimpleMatrix a n)
lu :: SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
    SimpleMatrix a n)
lu SimpleMatrix a n
t =
  let a :: Int
a  = Proxy a -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a)
      (SimpleMatrix a (Ratio n)
l,SimpleMatrix a (Ratio n)
u,SimpleMatrix a n
p,Shape
_) = ((SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
  SimpleMatrix a n, Shape)
 -> Int
 -> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
     SimpleMatrix a n, Shape))
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
    SimpleMatrix a n, Shape)
-> Shape
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
    SimpleMatrix a n, Shape)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
 SimpleMatrix a n, Shape)
-> Int
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
    SimpleMatrix a n, Shape)
forall a (a :: Nat) (a :: Nat) (s :: [Nat]) n.
(Integral a, KnownNat a, KnownNat a) =>
(Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
 Shape)
-> Int
-> (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
    Shape)
go (SimpleMatrix a (Ratio n)
forall (s :: [Nat]) n. (HasShape s, Num n) => Tensor s n
identity, (n -> Ratio n) -> SimpleMatrix a n -> SimpleMatrix a (Ratio n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (n -> n -> Ratio n
forall a. Integral a => a -> a -> Ratio a
% n
1) SimpleMatrix a n
t, SimpleMatrix a n
forall (s :: [Nat]) n. (HasShape s, Num n) => Tensor s n
identity, [Int
a,Int
a]) ([Int
0..Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] :: [Int])
  in (SimpleMatrix a (Ratio n)
l, SimpleMatrix a (Ratio n)
u, SimpleMatrix a n
p)
  where
    {-# INLINE go #-}
    go :: (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
 Shape)
-> Int
-> (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
    Shape)
go (Tensor '[a, a] (Ratio a)
l,u :: Tensor '[a, a] (Ratio a)
u@(Tensor Shape -> Shape -> Ratio a
f),p :: Tensor s n
p@(Tensor Shape -> Shape -> n
fp),Shape
s) Int
i =
      let ii :: Ratio a
ii = Shape -> Shape -> Ratio a
f Shape
s [Int
i,Int
i]
      in if Ratio a
ii Ratio a -> Ratio a -> Bool
forall a. Eq a => a -> a -> Bool
== Ratio a
0 then
          let is :: Shape
is = (Int -> Bool) -> Shape -> Shape
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
j -> Shape -> Shape -> Ratio a
f Shape
s [Int
i,Int
j] Ratio a -> Ratio a -> Bool
forall a. Eq a => a -> a -> Bool
/= Ratio a
0) [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Shape -> Int
forall a. [a] -> a
head Shape
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
          in if Shape -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Shape
is then (Tensor '[a, a] (Ratio a)
l,Tensor '[a, a] (Ratio a)
u,Tensor s n
p,Shape
s)
            else let j :: Int
j  = Shape -> Int
forall a. [a] -> a
head Shape
is
                     u' :: Tensor '[a, a] (Ratio a)
u' = Int -> Int -> (Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall n (s :: [Nat]). Int -> Int -> (Shape -> n) -> Tensor s n
swap Int
i Int
j (Shape -> Shape -> Ratio a
f  Shape
s)
                     p' :: Tensor s n
p' = Int -> Int -> (Shape -> n) -> Tensor s n
forall n (s :: [Nat]). Int -> Int -> (Shape -> n) -> Tensor s n
swap Int
i Int
j (Shape -> Shape -> n
fp Shape
s)
                 in (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
 Shape)
-> Int
-> (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
    Shape)
go (Tensor '[a, a] (Ratio a)
l,Tensor '[a, a] (Ratio a)
u',Tensor s n
p',Shape
s) Int
i
        else
          let ij :: Ratio a
ij = Ratio a -> a
forall a. Ratio a -> a
denominator Ratio a
ii a -> a -> Ratio a
forall a. Integral a => a -> a -> Ratio a
% Ratio a -> a
forall a. Ratio a -> a
numerator Ratio a
ii
              li :: Tensor '[a, a] (Ratio a)
li = (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a))
-> (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
i' -> Ratio a
ij Ratio a -> Ratio a -> Ratio a
forall a. Num a => a -> a -> a
* Int -> (Shape -> Ratio a) -> Shape -> Ratio a
forall a p. (Ord a, Num p) => a -> ([a] -> p) -> [a] -> p
gi Int
i (Shape -> Shape -> Ratio a
f Shape
s) Shape
i'
              lj :: Tensor '[a, a] (Ratio a)
lj = (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a))
-> (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
i' -> Ratio a
ij Ratio a -> Ratio a -> Ratio a
forall a. Num a => a -> a -> a
* Int -> (Shape -> Ratio a) -> Shape -> Ratio a
forall a p. (Ord a, Num p) => a -> ([a] -> p) -> [a] -> p
gj Int
i (Shape -> Shape -> Ratio a
f Shape
s) Shape
i'
          in (Tensor '[a, a] (Ratio a)
l Tensor '[a, a] (Ratio a)
-> Tensor '[a, a] (Ratio a) -> Tensor '[a, a] (Ratio a)
forall (a :: Nat) n.
(KnownNat a, Num n, Eq n) =>
SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
`dotM` Tensor '[a, a] (Ratio a)
lj, Tensor '[a, a] (Ratio a)
li Tensor '[a, a] (Ratio a)
-> Tensor '[a, a] (Ratio a) -> Tensor '[a, a] (Ratio a)
forall (a :: Nat) n.
(KnownNat a, Num n, Eq n) =>
SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
`dotM` Tensor '[a, a] (Ratio a)
u, Tensor s n
p, Shape
s)
    {-# INLINE gi #-}
    gi :: a -> ([a] -> p) -> [a] -> p
gi a
a [a] -> p
fs [a
x,a
y]
      | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
a Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a = - ([a] -> p
fs [a
x,a
y])
      | a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y = [a] -> p
fs [a
a,a
a]
      | Bool
otherwise = p
0
    {-# INLINE gj #-}
    gj :: a -> ([a] -> p) -> [a] -> p
gj a
a [a] -> p
fs [a
x,a
y]
      | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
a Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a = [a] -> p
fs [a
x,a
y]
      | a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y = [a] -> p
fs [a
a,a
a]
      | Bool
otherwise = p
0
    {-# INLINE swap #-}
    swap :: Int -> Int -> (Shape -> n) -> Tensor s n
swap Int
a Int
b Shape -> n
g = (Shape -> Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Tensor s n)
-> (Shape -> Shape -> n) -> Tensor s n
forall a b. (a -> b) -> a -> b
$ \Shape
_ [Int
x,Int
y] -> if Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a then Shape -> n
g [Int
x,Int
b] else if Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b then Shape -> n
g [Int
x,Int
a] else Shape -> n
g [Int
x,Int
y]

-- | determiant using `lu` decomposition
det' :: forall a n . (KnownNat a, Integral n) => SimpleMatrix a n -> n
det' :: SimpleMatrix a n -> n
det' SimpleMatrix a n
t =
  let (SimpleMatrix a (Ratio n)
l,SimpleMatrix a (Ratio n)
u,SimpleMatrix a n
p) = SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
    SimpleMatrix a n)
forall (a :: Nat) n.
(KnownNat a, Integral n) =>
SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
    SimpleMatrix a n)
lu SimpleMatrix a n
t
      s :: Shape
s = SimpleMatrix a n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape SimpleMatrix a n
t
      r :: Int
r = Shape -> Int
forall a. [a] -> a
head Shape
s
      v :: n
v = Shape -> Int -> SimpleMatrix a (Ratio n) -> n
forall a (s :: [Nat]).
Integral a =>
Shape -> Int -> Tensor s (Ratio a) -> a
go Shape
s Int
r SimpleMatrix a (Ratio n)
u
      w :: n
w = SimpleMatrix a n -> n
forall (a :: Nat) n.
(KnownNat a, Num n, Eq n) =>
SimpleMatrix a n -> n
det SimpleMatrix a n
p
  in if n
v n -> n -> Bool
forall a. Eq a => a -> a -> Bool
== n
0 then n
0 else Shape -> Int -> SimpleMatrix a (Ratio n) -> n
forall a (s :: [Nat]).
Integral a =>
Shape -> Int -> Tensor s (Ratio a) -> a
go Shape
s Int
r SimpleMatrix a (Ratio n)
l n -> n -> n
forall a. Num a => a -> a -> a
* n
v n -> n -> n
forall a. Num a => a -> a -> a
* n
w
  where
    {-# INLINE go #-}
    go :: Shape -> Int -> Tensor s (Ratio a) -> a
go Shape
s' Int
r' (Tensor Shape -> Shape -> Ratio a
f) = let fs :: Shape -> Ratio a
fs = Shape -> Shape -> Ratio a
f Shape
s' in Ratio a -> a
forall a. Ratio a -> a
numerator (Ratio a -> a) -> Ratio a -> a
forall a b. (a -> b) -> a -> b
$ [Ratio a] -> Ratio a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Ratio a] -> Ratio a) -> [Ratio a] -> Ratio a
forall a b. (a -> b) -> a -> b
$ (Int -> Ratio a) -> Shape -> [Ratio a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> Shape -> Ratio a
fs [Int
i,Int
i]) ([Int
0..Int
r' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] :: [Int])

-- | <https://en.wikipedia.org/wiki/Determinant Determinant> of n x n matrix
--
-- > λ> a = [2,0,1,3,0,0,2,2,1,2,1,34,3,2,34,4] :: Tensor '[4,4] Int
-- > λ> a
-- > [[2,0,1,3],
-- > [0,0,2,2],
-- > [1,2,1,34],
-- > [3,2,34,4]]
-- > λ> det a
-- > 520
--
-- This implementation is not so fast, it can calculate 8 x 8 in 1 second with all the num none zero on my computer.
-- It should be faster if more zero in the matrix.
det :: forall a n. (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> n
det :: SimpleMatrix a n -> n
det = let n :: Int
n = Proxy a -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a) in Int -> (Shape -> n) -> n
forall a. (Num a, Eq a) => Int -> (Shape -> a) -> a
go Int
n ((Shape -> n) -> n)
-> (SimpleMatrix a n -> Shape -> n) -> SimpleMatrix a n -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleMatrix a n -> Shape -> n
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape -> n
runTensor
  where
    {-# INLINE go #-}
    go :: Int -> (Shape -> a) -> a
go Int
1 Shape -> a
f = Shape -> a
f [Int
0,Int
0]
    go Int
n Shape -> a
f = [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (Int -> a) -> Shape -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Shape -> a) -> Int -> Int -> a
g2 Shape -> a
f Int
n) ([Int
0.. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] :: [Int])
    {-# INLINE g2 #-}
    g2 :: (Shape -> a) -> Int -> Int -> a
g2 Shape -> a
f Int
n Int
i =
      let f' :: Shape -> a
f' [Int
x,Int
y] = if Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i then Shape -> a
f [Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1] else Shape -> a
f [Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
y]
      in Shape -> a
f [Int
0,Int
i] a -> a -> a
forall a. (Eq a, Num a) => a -> a -> a
`mult` (if Int -> Bool
forall a. Integral a => a -> Bool
even Int
i then Int -> (Shape -> a) -> a
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Shape -> a
f' else - (Int -> (Shape -> a) -> a
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Shape -> a
f'))