{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Tensor.Matrix where
import Data.List (foldl')
import Data.Proxy
import Data.Tensor.Tensor
import Data.Tensor.Type
import GHC.TypeLits
type SimpleMatrix a n = Matrix a a n
dotM :: (KnownNat a, Num n) => SimpleMatrix a n -> SimpleMatrix a n -> SimpleMatrix a n
dotM = dot
trace :: (KnownNat a, Num n) => SimpleMatrix a n -> n
trace t = let (Tensor f) = contraction t (i0,i1) in f [] []
lu :: forall a n . (KnownNat a, Integral n) => SimpleMatrix a n -> (SimpleMatrix a n, SimpleMatrix a n, n)
lu t =
let a = toNat (Proxy :: Proxy a)
(l,u,_,m) = foldl' go (identity, t, [a,a], 1) ([0..a-1] :: [Int])
p = minimum (fmap (abs.(`gcd` m)) l) `min` minimum (fmap (abs.(`gcd` m)) u)
g = (`div` (p * signum m))
in (fmap g l,fmap g u, g m)
where
go :: (SimpleMatrix a n, SimpleMatrix a n, [Int], n) -> Int -> (SimpleMatrix a n, SimpleMatrix a n,[Int],n)
go (l,u@(Tensor f),s,m) i =
let li = Tensor $ \_ -> gi i (f s)
lj = Tensor $ \_ -> gj i (f s)
in (l `dotM` lj, li `dotM` u, s, m * f s [i,i])
gi a fs [x,y]
| x > a && y == a = - (fs [x,y])
| x == y = fs [a,a]
| otherwise = 0
gj a fs [x,y]
| x > a && y == a = fs [x,y]
| x == y = fs [a,a]
| otherwise = 0
det' :: forall a n . (KnownNat a, Integral n) => SimpleMatrix a n -> n
det' t =
let (l,u,m) = lu t
s = shape t
r = length s
in (go s r l * go s r u) `div` (m ^ (r+1))
where
go s' r' (Tensor f) = let fs = f s' in product $ fmap (\i -> fs [i,i]) ([0..r' - 1] :: [Int])
det :: forall a n. (KnownNat a, Num n, Eq n) => SimpleMatrix a n -> n
det = let n = toNat (Proxy :: Proxy a) in go n . runTensor
where
{-# INLINE go #-}
go :: Int -> ([Int] -> n) -> n
go 1 f = f [0,0]
go n f = sum $ zipWith (g2 f n) ([0.. n-1] :: [Int]) (cycle [1, -1])
{-# INLINE g2 #-}
g2 f n i sign = case f [0,i] of
0 -> 0
v -> let f' [x,y] = if y >= i then f [x+1,y +1] else f [x+1,y] in sign * v * go (n-1) f'