{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module Data.Tensor.Matrix where
import Data.List (foldl')
import Data.Proxy
import Data.Ratio
import Data.Tensor.Tensor
import Data.Tensor.Type
import GHC.TypeLits
type SimpleMatrix a n = Matrix a a n
dotM :: (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
dotM :: SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
dotM = SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
forall (s :: [Nat]) (s' :: [Nat]) (r :: [Nat]) n.
(Last s ~ Head s', r ~ DotTensor s s', HasShape s, HasShape s',
HasShape r, Num n, Eq n) =>
Tensor s n -> Tensor s' n -> Tensor r n
dot
diag :: SimpleMatrix a n -> Vector a n
diag :: SimpleMatrix a n -> Vector a n
diag (Tensor Shape -> Shape -> n
t) = (Shape -> Shape -> n) -> Vector a n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Vector a n)
-> (Shape -> Shape -> n) -> Vector a n
forall a b. (a -> b) -> a -> b
$ \[Int
s] [Int
i] -> Shape -> Shape -> n
t [Int
s,Int
s] [Int
i,Int
i]
trace :: (KnownNat a, Num n) => SimpleMatrix a n -> n
trace :: SimpleMatrix a n -> n
trace SimpleMatrix a n
t = let (Tensor Shape -> Shape -> n
f) = (Proxy 0, Proxy 1) -> SimpleMatrix a n -> Tensor '[] n
forall (x :: Nat) (y :: Nat) (s :: [Nat]) (s' :: [Nat]) n.
(CheckIndices x y s, s' ~ Contraction s x y,
Dimension s x ~ Dimension s y, KnownNat x, KnownNat y, HasShape s,
HasShape s', KnownNat (Dimension s x), Num n) =>
(Proxy x, Proxy y) -> Tensor s n -> Tensor s' n
contraction (Proxy 0
i0,Proxy 1
i1) SimpleMatrix a n
t in Shape -> Shape -> n
f [] []
lu :: forall a n . (KnownNat a, Integral n)
=> SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n), SimpleMatrix a n)
lu :: SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n)
lu SimpleMatrix a n
t =
let a :: Int
a = Proxy a -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a)
(SimpleMatrix a (Ratio n)
l,SimpleMatrix a (Ratio n)
u,SimpleMatrix a n
p,Shape
_) = ((SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n, Shape)
-> Int
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n, Shape))
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n, Shape)
-> Shape
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n, Shape)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n, Shape)
-> Int
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n, Shape)
forall a (a :: Nat) (a :: Nat) (s :: [Nat]) n.
(Integral a, KnownNat a, KnownNat a) =>
(Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
Shape)
-> Int
-> (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
Shape)
go (SimpleMatrix a (Ratio n)
forall (s :: [Nat]) n. (HasShape s, Num n) => Tensor s n
identity, (n -> Ratio n) -> SimpleMatrix a n -> SimpleMatrix a (Ratio n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (n -> n -> Ratio n
forall a. Integral a => a -> a -> Ratio a
% n
1) SimpleMatrix a n
t, SimpleMatrix a n
forall (s :: [Nat]) n. (HasShape s, Num n) => Tensor s n
identity, [Int
a,Int
a]) ([Int
0..Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] :: [Int])
in (SimpleMatrix a (Ratio n)
l, SimpleMatrix a (Ratio n)
u, SimpleMatrix a n
p)
where
{-# INLINE go #-}
go :: (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
Shape)
-> Int
-> (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
Shape)
go (Tensor '[a, a] (Ratio a)
l,u :: Tensor '[a, a] (Ratio a)
u@(Tensor Shape -> Shape -> Ratio a
f),p :: Tensor s n
p@(Tensor Shape -> Shape -> n
fp),Shape
s) Int
i =
let ii :: Ratio a
ii = Shape -> Shape -> Ratio a
f Shape
s [Int
i,Int
i]
in if Ratio a
ii Ratio a -> Ratio a -> Bool
forall a. Eq a => a -> a -> Bool
== Ratio a
0 then
let is :: Shape
is = (Int -> Bool) -> Shape -> Shape
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
j -> Shape -> Shape -> Ratio a
f Shape
s [Int
i,Int
j] Ratio a -> Ratio a -> Bool
forall a. Eq a => a -> a -> Bool
/= Ratio a
0) [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Shape -> Int
forall a. [a] -> a
head Shape
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
in if Shape -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Shape
is then (Tensor '[a, a] (Ratio a)
l,Tensor '[a, a] (Ratio a)
u,Tensor s n
p,Shape
s)
else let j :: Int
j = Shape -> Int
forall a. [a] -> a
head Shape
is
u' :: Tensor '[a, a] (Ratio a)
u' = Int -> Int -> (Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall n (s :: [Nat]). Int -> Int -> (Shape -> n) -> Tensor s n
swap Int
i Int
j (Shape -> Shape -> Ratio a
f Shape
s)
p' :: Tensor s n
p' = Int -> Int -> (Shape -> n) -> Tensor s n
forall n (s :: [Nat]). Int -> Int -> (Shape -> n) -> Tensor s n
swap Int
i Int
j (Shape -> Shape -> n
fp Shape
s)
in (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
Shape)
-> Int
-> (Tensor '[a, a] (Ratio a), Tensor '[a, a] (Ratio a), Tensor s n,
Shape)
go (Tensor '[a, a] (Ratio a)
l,Tensor '[a, a] (Ratio a)
u',Tensor s n
p',Shape
s) Int
i
else
let ij :: Ratio a
ij = Ratio a -> a
forall a. Ratio a -> a
denominator Ratio a
ii a -> a -> Ratio a
forall a. Integral a => a -> a -> Ratio a
% Ratio a -> a
forall a. Ratio a -> a
numerator Ratio a
ii
li :: Tensor '[a, a] (Ratio a)
li = (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a))
-> (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
i' -> Ratio a
ij Ratio a -> Ratio a -> Ratio a
forall a. Num a => a -> a -> a
* Int -> (Shape -> Ratio a) -> Shape -> Ratio a
forall a p. (Ord a, Num p) => a -> ([a] -> p) -> [a] -> p
gi Int
i (Shape -> Shape -> Ratio a
f Shape
s) Shape
i'
lj :: Tensor '[a, a] (Ratio a)
lj = (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a))
-> (Shape -> Shape -> Ratio a) -> Tensor '[a, a] (Ratio a)
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
i' -> Ratio a
ij Ratio a -> Ratio a -> Ratio a
forall a. Num a => a -> a -> a
* Int -> (Shape -> Ratio a) -> Shape -> Ratio a
forall a p. (Ord a, Num p) => a -> ([a] -> p) -> [a] -> p
gj Int
i (Shape -> Shape -> Ratio a
f Shape
s) Shape
i'
in (Tensor '[a, a] (Ratio a)
l Tensor '[a, a] (Ratio a)
-> Tensor '[a, a] (Ratio a) -> Tensor '[a, a] (Ratio a)
forall (a :: Nat) n.
(KnownNat a, Num n, Eq n) =>
SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
`dotM` Tensor '[a, a] (Ratio a)
lj, Tensor '[a, a] (Ratio a)
li Tensor '[a, a] (Ratio a)
-> Tensor '[a, a] (Ratio a) -> Tensor '[a, a] (Ratio a)
forall (a :: Nat) n.
(KnownNat a, Num n, Eq n) =>
SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
`dotM` Tensor '[a, a] (Ratio a)
u, Tensor s n
p, Shape
s)
{-# INLINE gi #-}
gi :: a -> ([a] -> p) -> [a] -> p
gi a
a [a] -> p
fs [a
x,a
y]
| a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
a Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a = - ([a] -> p
fs [a
x,a
y])
| a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y = [a] -> p
fs [a
a,a
a]
| Bool
otherwise = p
0
{-# INLINE gj #-}
gj :: a -> ([a] -> p) -> [a] -> p
gj a
a [a] -> p
fs [a
x,a
y]
| a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
a Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a = [a] -> p
fs [a
x,a
y]
| a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y = [a] -> p
fs [a
a,a
a]
| Bool
otherwise = p
0
{-# INLINE swap #-}
swap :: Int -> Int -> (Shape -> n) -> Tensor s n
swap Int
a Int
b Shape -> n
g = (Shape -> Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Tensor s n)
-> (Shape -> Shape -> n) -> Tensor s n
forall a b. (a -> b) -> a -> b
$ \Shape
_ [Int
x,Int
y] -> if Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a then Shape -> n
g [Int
x,Int
b] else if Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
b then Shape -> n
g [Int
x,Int
a] else Shape -> n
g [Int
x,Int
y]
det' :: forall a n . (KnownNat a, Integral n) => SimpleMatrix a n -> n
det' :: SimpleMatrix a n -> n
det' SimpleMatrix a n
t =
let (SimpleMatrix a (Ratio n)
l,SimpleMatrix a (Ratio n)
u,SimpleMatrix a n
p) = SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n)
forall (a :: Nat) n.
(KnownNat a, Integral n) =>
SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n),
SimpleMatrix a n)
lu SimpleMatrix a n
t
s :: Shape
s = SimpleMatrix a n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape SimpleMatrix a n
t
r :: Int
r = Shape -> Int
forall a. [a] -> a
head Shape
s
v :: n
v = Shape -> Int -> SimpleMatrix a (Ratio n) -> n
forall a (s :: [Nat]).
Integral a =>
Shape -> Int -> Tensor s (Ratio a) -> a
go Shape
s Int
r SimpleMatrix a (Ratio n)
u
w :: n
w = SimpleMatrix a n -> n
forall (a :: Nat) n.
(KnownNat a, Num n, Eq n) =>
SimpleMatrix a n -> n
det SimpleMatrix a n
p
in if n
v n -> n -> Bool
forall a. Eq a => a -> a -> Bool
== n
0 then n
0 else Shape -> Int -> SimpleMatrix a (Ratio n) -> n
forall a (s :: [Nat]).
Integral a =>
Shape -> Int -> Tensor s (Ratio a) -> a
go Shape
s Int
r SimpleMatrix a (Ratio n)
l n -> n -> n
forall a. Num a => a -> a -> a
* n
v n -> n -> n
forall a. Num a => a -> a -> a
* n
w
where
{-# INLINE go #-}
go :: Shape -> Int -> Tensor s (Ratio a) -> a
go Shape
s' Int
r' (Tensor Shape -> Shape -> Ratio a
f) = let fs :: Shape -> Ratio a
fs = Shape -> Shape -> Ratio a
f Shape
s' in Ratio a -> a
forall a. Ratio a -> a
numerator (Ratio a -> a) -> Ratio a -> a
forall a b. (a -> b) -> a -> b
$ [Ratio a] -> Ratio a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Ratio a] -> Ratio a) -> [Ratio a] -> Ratio a
forall a b. (a -> b) -> a -> b
$ (Int -> Ratio a) -> Shape -> [Ratio a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> Shape -> Ratio a
fs [Int
i,Int
i]) ([Int
0..Int
r' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] :: [Int])
det :: forall a n. (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> n
det :: SimpleMatrix a n -> n
det = let n :: Int
n = Proxy a -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a) in Int -> (Shape -> n) -> n
forall a. (Num a, Eq a) => Int -> (Shape -> a) -> a
go Int
n ((Shape -> n) -> n)
-> (SimpleMatrix a n -> Shape -> n) -> SimpleMatrix a n -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleMatrix a n -> Shape -> n
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape -> n
runTensor
where
{-# INLINE go #-}
go :: Int -> (Shape -> a) -> a
go Int
1 Shape -> a
f = Shape -> a
f [Int
0,Int
0]
go Int
n Shape -> a
f = [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (Int -> a) -> Shape -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Shape -> a) -> Int -> Int -> a
g2 Shape -> a
f Int
n) ([Int
0.. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] :: [Int])
{-# INLINE g2 #-}
g2 :: (Shape -> a) -> Int -> Int -> a
g2 Shape -> a
f Int
n Int
i =
let f' :: Shape -> a
f' [Int
x,Int
y] = if Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i then Shape -> a
f [Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1] else Shape -> a
f [Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
y]
in Shape -> a
f [Int
0,Int
i] a -> a -> a
forall a. (Eq a, Num a) => a -> a -> a
`mult` (if Int -> Bool
forall a. Integral a => a -> Bool
even Int
i then Int -> (Shape -> a) -> a
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Shape -> a
f' else - (Int -> (Shape -> a) -> a
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Shape -> a
f'))