{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module Data.Tensor.Matrix where
import Data.List (foldl')
import Data.Proxy
import Data.Ratio
import Data.Tensor.Tensor
import Data.Tensor.Type
import GHC.TypeLits
type SimpleMatrix a n = Matrix a a n
dotM :: (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
dotM = dot
diag :: SimpleMatrix a n -> Vector a n
diag (Tensor t) = Tensor $ \[s] [i] -> t [s,s] [i,i]
trace :: (KnownNat a, Num n) => SimpleMatrix a n -> n
trace t = let (Tensor f) = contraction (i0,i1) t in f [] []
lu :: forall a n . (KnownNat a, Integral n)
=> SimpleMatrix a n
-> (SimpleMatrix a (Ratio n), SimpleMatrix a (Ratio n), SimpleMatrix a n)
lu t =
let a = toNat (Proxy :: Proxy a)
(l,u,p,_) = foldl' go (identity, fmap (% 1) t, identity, [a,a]) ([0..a-1] :: [Int])
in (l, u, p)
where
{-# INLINE go #-}
go (l,u@(Tensor f),p@(Tensor fp),s) i =
let ii = f s [i,i]
in if ii == 0 then
let is = filter (\j -> f s [i,j] /= 0) [i+1..head s-1]
in if null is then (l,u,p,s)
else let j = head is
u' = swap i j (f s)
p' = swap i j (fp s)
in go (l,u',p',s) i
else
let ij = denominator ii % numerator ii
li = Tensor $ \_ i' -> ij * gi i (f s) i'
lj = Tensor $ \_ i' -> ij * gj i (f s) i'
in (l `dotM` lj, li `dotM` u, p, s)
{-# INLINE gi #-}
gi a fs [x,y]
| x > a && y == a = - (fs [x,y])
| x == y = fs [a,a]
| otherwise = 0
{-# INLINE gj #-}
gj a fs [x,y]
| x > a && y == a = fs [x,y]
| x == y = fs [a,a]
| otherwise = 0
{-# INLINE swap #-}
swap a b g = Tensor $ \_ [x,y] -> if y == a then g [x,b] else if y == b then g [x,a] else g [x,y]
det' :: forall a n . (KnownNat a, Integral n) => SimpleMatrix a n -> n
det' t =
let (l,u,p) = lu t
s = shape t
r = head s
v = go s r u
w = det p
in if v == 0 then 0 else go s r l * v * w
where
{-# INLINE go #-}
go s' r' (Tensor f) = let fs = f s' in numerator $ product $ fmap (\i -> fs [i,i]) ([0..r' - 1] :: [Int])
det :: forall a n. (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> n
det = let n = toNat (Proxy :: Proxy a) in go n . runTensor
where
{-# INLINE go #-}
go 1 f = f [0,0]
go n f = sum $ fmap (g2 f n) ([0.. n-1] :: [Int])
{-# INLINE g2 #-}
g2 f n i =
let f' [x,y] = if y >= i then f [x+1,y +1] else f [x+1,y]
in f [0,i] `mult` (if even i then go (n-1) f' else - (go (n-1) f'))