-- | Determinants. -- -- TODO: specialized prime fields; fast C implementation; pivoting for Bareiss -- {-# LANGUAGE ScopedTypeVariables, TypeFamilies, BangPatterns, FlexibleInstances, TypeSynonymInstances, ForeignFunctionInterface #-} module Math.Algebra.Determinant where -------------------------------------------------------------------------------- import Control.Monad import Control.Monad.ST import Data.Array.Base import Data.Array.IArray import Data.Array.MArray import Data.Array.Unsafe import Data.Array.ST import Data.List import Data.Ratio import Data.STRef import Data.Bits import Data.Word import Data.Int import Foreign.C import Foreign.Ptr import Foreign.Marshal import System.IO.Unsafe as Unsafe import System.Random import Debug.Trace import GHC.IO ( unsafeIOToST ) import Math.Algebra.ModP -------------------------------------------------------------------------------- -- * matrices type Matrix a = Array (Int,Int) a printMatrix :: Show a => Matrix a -> IO () printMatrix = putStrLn . showMatrix showMatrix :: Show a => Matrix a -> String showMatrix = unlines . showMatrix' showMatrix' :: Show a => Matrix a -> [String] showMatrix' mat = map mkRow (transpose cols) where ((1,1),(n,m)) = bounds mat cols = map extend [ [ show (mat!(i,j)) | i<-[1..n] ] | j<-[1..m] ] mkRow strs = "[ " ++ intercalate " " strs ++ " ]" extend :: [String] -> [String] extend xs = map f xs where n = maximum (map length xs) f s = replicate (n - length s) ' ' ++ s -------------------------------------------------------------------------------- -- * a type class for determinants class (Eq a, Num a, Show a) => Determinant a where determinant :: Matrix a -> a instance Determinant Integer where determinant = bareissDeterminantFullRank instance Determinant Int where determinant = bareissDeterminantFullRank instance Determinant Rational where determinant = gaussElimDeterminant instance Determinant Zp where determinant = gaussElimDeterminantInt64 -------------------------------------------------------------------------------- -- * C implementation of determinant in a prime field (gaussian elimination, fitting into 64 bit) foreign import ccall "c_det.h inv_modp" c_inv_modp :: Int64 -> Int64 -> Int64 foreign import ccall "c_det.h det_modp" c_det_modp :: Int64 -> CInt -> Ptr Int64 -> IO Int64 fastDetModP :: Int64 -> Matrix Int64 -> Int64 fastDetModP p mat = Unsafe.unsafePerformIO $ ioFastDetModP p mat ioFastDetModP :: Int64 -> Matrix Int64 -> IO Int64 ioFastDetModP p mat = do let ((1,1),(n,_)) = bounds mat withArray (elems mat) $ \ptr -> c_det_modp p (fromIntegral n :: CInt) ptr gaussElimDeterminantInt64 :: Matrix Zp -> Zp gaussElimDeterminantInt64 mat = Unsafe.unsafePerformIO $ do let pp = fromIntegral p :: Int64 let ((1,1),(n,_)) = bounds mat xs = map (fromIntegral . fromZp) (elems mat) :: [Int64] d <- withArray xs $ \ptr -> c_det_modp pp (fromIntegral n :: CInt) ptr return $ Zp $ fromIntegral d -------------------------------------------------------------------------------- -- * Bareiss determinant algorithm type STMatrix s a = STArray s (Int,Int) a -- | Works only if the top-left minors all have nonzero determinants {-# SPECIALIZE bareissDeterminantFullRank :: Matrix Integer -> Integer #-} {-# SPECIALIZE bareissDeterminantFullRank :: Matrix Int -> Int #-} bareissDeterminantFullRank :: forall a . Integral a => Matrix a -> a bareissDeterminantFullRank mat = if n>0 then runST $ do ar1 <- thaw mat :: ST s (STMatrix s a) ar2 <- newArray_ siz :: ST s (STMatrix s a) last <- newSTRef 1 :: ST s (STRef s a) (ar,_) <- foldM (worker last) (ar1,ar2) [1..n-1] readArray ar (n,n) else 1 -- determinant of the empty matrix is 1 where siz@((1,1),(n,_)) = bounds mat unsafeReadArray :: STMatrix s a -> (Int,Int) -> ST s a unsafeReadArray ar ij = unsafeRead ar (index siz ij) unsafeWriteArray :: STMatrix s a -> (Int,Int) -> a -> ST s () unsafeWriteArray ar ij x = unsafeWrite ar (index siz ij) x worker :: STRef s a -> (STMatrix s a, STMatrix s a) -> Int -> ST s (STMatrix s a, STMatrix s a) worker last (ar1,ar2) !k = do q <- readSTRef last when (q==0) $ unsafeIOToST $ do putStrLn "divison by zero while computing the determinant..." forM_ [k+1..n] $ \(!i) -> forM_ [k+1..n] $ \(!j) -> do a <- unsafeReadArray ar1 (k,k) b <- unsafeReadArray ar1 (i,k) c <- unsafeReadArray ar1 (k,j) d <- unsafeReadArray ar1 (i,j) unsafeWriteArray ar2 (i,j) $ (a*d - b*c) `div` q unsafeReadArray ar1 (k,k) >>= writeSTRef last return (ar2,ar1) -------------------------------------------------------------------------------- -- * Gaussian elimination {-# SPECIALIZE gaussElimDeterminant :: Matrix Rational -> Rational #-} {-# SPECIALIZE gaussElimDeterminant :: Matrix Zp -> Zp #-} gaussElimDeterminant :: forall a. (Eq a, Show a, Fractional a) => Matrix a -> a gaussElimDeterminant mat = if n <= 0 then 1 -- determinant of the empty matrix is 1 else runST $ do -- unsafeIOToST (printMatrix mat >> putStrLn "") neg <- newSTRef False arr <- thaw mat :: ST s (STMatrix s a) worker neg arr 1 where siz@((1,1),(n,_)) = bounds mat unsafeReadArray :: STMatrix s a -> (Int,Int) -> ST s a unsafeReadArray !ar !ij = unsafeRead ar (index siz ij) unsafeWriteArray :: STMatrix s a -> (Int,Int) -> a -> ST s () unsafeWriteArray !ar !ij !x = unsafeWrite ar (index siz ij) x finish :: STRef s Bool -> STMatrix s a -> ST s a finish !neg !arr = do diag <- sequence [ unsafeReadArray arr (i,i) | i<-[1..n] ] b <- readSTRef neg return $ if b then negate $ product diag else product diag worker :: STRef s Bool -> STMatrix s a -> Int -> ST s a worker !neg !arr !i = if i >= n then finish neg arr else do ps <- sequence [ unsafeReadArray arr (i,j) | j<-[i..n] ] case findIndex (/=0) ps of Nothing -> return 0 -- no pivot -> line is full zero -> determinant is zero Just pivot -> cont neg arr i (i+pivot) cont :: STRef s Bool -> STMatrix s a -> Int -> Int -> ST s a cont !neg !arr !i !pivot = do -- printST (i,pivot) when (pivot > i) $ xchg neg arr i pivot p <- unsafeReadArray arr (i,i) forM_ [i+1..n] $ \k -> do q <- unsafeReadArray arr (k,i) unsafeWriteArray arr (k,i) 0 let z = q / p forM_ [i+1..n] $ \j -> do a <- unsafeReadArray arr (i,j) b <- unsafeReadArray arr (k,j) unsafeWriteArray arr (k,j) (b - a*z) worker neg arr (i+1) xchg :: STRef s Bool -> STMatrix s a -> Int -> Int -> ST s () xchg !neg !arr !i !j = do modifySTRef neg not -- exchanging two rows flip the sign of the determinant forM_ [i..n] $ \k -> do a <- unsafeReadArray arr (k,i) b <- unsafeReadArray arr (k,j) unsafeWriteArray arr (k,j) a unsafeWriteArray arr (k,i) b -------------------------------------------------------------------------------- -- * naive determinant algorithm (for testing purposes) naiveDeterminant :: forall a. (Num a) => Matrix a -> a naiveDeterminant mat | n <= 0 = 1 | n == 1 = mat!(1,1) | n == 2 = mat!(1,1) * mat!(2,2) - mat!(1,2) * mat!(2,1) | otherwise = worker [1..n] [1..n] where siz@((1,1),(n,_)) = bounds mat signs = cycle [True,False] worker [] [] = 1 worker [a] [b] = mat!(a,b) worker [a,b] [p,q] = mat!(a,p) * mat!(b,q) - mat!(a,q) * mat!(b,p) worker (i:is) js = foldl' (+) 0 (zipWith f signs js) where f b j = if b then mat!(i,j) * worker is (js\\[j]) else negate $ mat!(i,j) * worker is (js\\[j]) -------------------------------------------------------------------------------- -- * random matrices mkSquareMatrix :: (Int -> Int -> a) -> Int -> Matrix a mkSquareMatrix f n = array ((1,1),(n,n)) [ ((i,j) , f i j ) | i<-[1..n] , j<-[1..n] ] testMatrix :: Num a => Int -> Matrix a testMatrix n = mkSquareMatrix f n where f i j = fromIntegral $ 3 + i*i*i - j*j + (4*i*j + 3*i + 5*j + 7) + xor (13+i) (17+j) where randomMatrix :: (Random a, Num a) => Int -> IO (Matrix a) randomMatrix = randomMatrix' 10 randomMatrix' :: (Random a, Num a) => a -> Int -> IO (Matrix a) randomMatrix' bnd n = do xs <- replicateM (n*n) (randomRIO (-bnd,bnd)) return $ listArray ((1,1),(n,n)) xs printST :: Show a => a -> ST s () printST x = unsafeIOToST (print x) -------------------------------------------------------------------------------- -- * testing test = do forM_ [1..10] $ \n -> do putStrLn $ "testing matrices of size " ++ show n ++ " x " ++ show n ++ "..." replicateM_ 100 $ do imat <- randomMatrix n :: IO (Matrix Integer) let mat = fmap fromInteger imat :: Matrix Rational let a = naiveDeterminant mat b = gaussElimDeterminant mat let ia = naiveDeterminant imat :: Integer amodp = mkZp $ fromIntegral (mod ia (fromIntegral p)) let c = gaussElimDeterminant (fmap mkZp imat) d0 = fastDetModP (fromIntegral p) (fmap (\a -> fromIntegral (mod a (fromIntegral p))) imat) d = fromIntegral d0 :: Zp when (a/=b) $ do putStrLn "\nERROR!" print (a,b) print imat when (c/=d || d/=amodp) $ do putStrLn "\nC ERROR!" print (c,d,amodp) print imat