{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UnboxedSums #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.Matrix -- Copyright : (c) Artem Chirkin -- License : BSD3 -- -- Maintainer : chirkin@arch.ethz.ch -- -- ----------------------------------------------------------------------------- module Numeric.Matrix ( MatrixTranspose (..) , SquareMatrix (..) , MatrixDeterminant (..) , MatrixInverse (..) , MatrixLU (..), LUFact (..) , Matrix , Mat22f, Mat23f, Mat24f , Mat32f, Mat33f, Mat34f , Mat42f, Mat43f, Mat44f , Mat22d, Mat23d, Mat24d , Mat32d, Mat33d, Mat34d , Mat42d, Mat43d, Mat44d , mat22, mat33, mat44 , (%*) , pivotMat, luSolve ) where import Control.Monad (foldM) import Data.Foldable (forM_, foldl') import Data.List (delete) import GHC.Base import Numeric.DataFrame.Contraction ((%*)) import Numeric.DataFrame.Internal.Array.Class import Numeric.DataFrame.Internal.Array.Family as AFam import Numeric.DataFrame.Shape import Numeric.DataFrame.SubSpace import Numeric.DataFrame.Type import Numeric.Dimensions import Numeric.Matrix.Class import Numeric.PrimBytes import Numeric.Scalar import Numeric.Vector import Control.Monad.ST import Numeric.DataFrame.ST -- | Compose a 2x2D matrix mat22 :: ( PrimBytes (Vector (t :: Type) 2) , PrimBytes (Matrix t 2 2) ) => Vector t 2 -> Vector t 2 -> Matrix t 2 2 mat22 = (<::>) -- | Compose a 3x3D matrix mat33 :: ( PrimBytes (t :: Type) , PrimBytes (Vector t 3) , PrimBytes (Matrix t 3 3) ) => Vector t 3 -> Vector t 3 -> Vector t 3 -> Matrix t 3 3 mat33 a b c = runST $ do mmat <- newDataFrame copyDataFrame a (1:*1:*U) mmat copyDataFrame b (1:*2:*U) mmat copyDataFrame c (1:*3:*U) mmat unsafeFreezeDataFrame mmat -- | Compose a 4x4D matrix mat44 :: forall (t :: Type) . ( PrimBytes t , PrimBytes (Vector t (4 :: Nat)) , PrimBytes (Matrix t (4 :: Nat) (4 :: Nat)) ) => Vector t (4 :: Nat) -> Vector t (4 :: Nat) -> Vector t (4 :: Nat) -> Vector t (4 :: Nat) -> Matrix t (4 :: Nat) (4 :: Nat) mat44 a b c d = runST $ do mmat <- newDataFrame copyDataFrame a (1:*1:*U) mmat copyDataFrame b (1:*2:*U) mmat copyDataFrame c (1:*3:*U) mmat copyDataFrame d (1:*4:*U) mmat unsafeFreezeDataFrame mmat instance ( KnownDim n, KnownDim m , PrimArray t (Matrix t n m) , PrimArray t (Matrix t m n) ) => MatrixTranspose t (n :: Nat) (m :: Nat) where transpose df = case elemSize0 df of 0# -> broadcast (ix# 0# df) nm | I# m <- fromIntegral $ dimVal' @m , I# n <- fromIntegral $ dimVal' @n -> let f ( I# j, I# i ) | isTrue# (j ==# m) = f ( 0 , I# (i +# 1#) ) | otherwise = (# ( I# (j +# 1#), I# i ) , ix# (j *# n +# i) df #) in case gen# nm f (0,0) of (# _, r #) -> r instance MatrixTranspose (t :: Type) (xn :: XNat) (xm :: XNat) where transpose (XFrame (df :: DataFrame t ns)) | ((D :: Dim n) :* (D :: Dim m) :* U) <- dims @Nat @ns , E <- AFam.inferPrimElem @t @n @'[m] = XFrame (transpose df :: Matrix t m n) transpose _ = error "MatrixTranspose/transpose: impossible argument" instance (KnownDim n, PrimArray t (Matrix t n n), Num t) => SquareMatrix t n where eye | n@(I# n#) <- fromIntegral $ dimVal' @n = let f 0 = (# n, 1 #) f k = (# k - 1, 0 #) in case gen# (n# *# n#) f 0 of (# _, r #) -> r diag se | n@(I# n#) <- fromIntegral $ dimVal' @n , e <- unScalar se = let f 0 = (# n, e #) f k = (# k - 1, 0 #) in case gen# (n# *# n#) f 0 of (# _, r #) -> r trace df | I# n <- fromIntegral $ dimVal' @n , n1 <- n +# 1# = let f 0# = ix# 0# df f k = ix# k df + f (k -# n1) in scalar $ f (n *# n -# 1#) instance ( KnownDim n, Ord t, Fractional t , PrimBytes t , PrimArray t (Matrix t n n) , PrimArray t (Vector t n) , PrimBytes (Vector t n) , PrimBytes (Matrix t n n) ) => MatrixInverse t n where inverse m = ewmap (luSolve (lu m)) eye instance ( KnownDim n, Ord t, Fractional t , PrimBytes t, PrimArray t (Matrix t n n)) => MatrixDeterminant t n where det m = prodF (luUpper f) * prodF (luLower f) * luPermSign f where f = lu m !(I# n) = fromIntegral $ dimVal' @n n1 = n +# 1# nn1 = n *# n -# 1# prodF a = scalar $ prodF' nn1 a prodF' 0# a = ix# 0# a prodF' k a = ix# k a * prodF' (k -# n1) a instance ( KnownDim n, Ord t, Fractional t , PrimBytes t, PrimArray t (Matrix t n n)) => MatrixLU t n where lu m' = case runRW# go of (# _, (# bu, bl #) #) -> LUFact { luLower = fromElems 0# nn bl , luUpper = fromElems 0# nn bu , luPerm = p , luPermSign = si } where (m, p, si) = pivotMat m' !(I# n) = fromIntegral $ dimVal' @n nn = n *# n tbs = byteSize @t undefined bsize = nn *# tbs ixm i j = ix# (i +# n *# j) m loop :: (Int# -> a -> State# s -> (# State# s, a #)) -> Int# -> Int# -> a -> State# s -> (# State# s, a #) loop f i k x s | isTrue# (i ==# k) = (# s, x #) | otherwise = case f i x s of (# s', y #) -> loop f ( i +# 1# ) k y s' go s0 | (# s1, mbl #) <- newByteArray# bsize s0 , (# s2, mbu #) <- newByteArray# bsize s1 , s3 <- setByteArray# mbl 0# bsize 0# s2 , s4 <- setByteArray# mbu 0# bsize 0# s3 , readL <- \i j -> readArray @t mbl (i +# n *# j) , readU <- \i j -> readArray @t mbu (i +# n *# j) , writeL <- \i j -> writeArray @t mbl (i +# n *# j) , writeU <- \i j -> writeArray @t mbu (i +# n *# j) , computeU <- \i j -> let f k x s | (# s' , ukj #) <- readU k j s , (# s'', lik #) <- readL i k s' = (# s'', x - ukj * lik #) in loop f 0# i (ixm i j) , computeL' <- \i j -> let f k x s | (# s' , ukj #) <- readU k j s , (# s'', lik #) <- readL i k s' = (# s'', x - ukj * lik #) in loop f 0# j (ixm i j) , (# sr, () #) <- loop ( \j _ sj -> case sj of sj0 | sj1 <- writeL j j 1 sj0 , (# sj2, () #) <- loop ( \i _ sij0 -> case computeU i j sij0 of (# sij1, uij #) -> (# writeU i j uij sij1, () #) ) 0# j () sj1 , (# sj3, ujj #) <- computeU j j sj2 , sj4 <- writeU j j ujj sj3 -> case ujj of 0 -> loop ( \i _ sij -> (# writeL i j 0 sij, () #) ) (j +# 1#) n () sj4 x -> loop ( \i _ sij0 -> case computeL' i j sij0 of (# sij1, lij #) -> (# writeL i j (lij / x) sij1, () #) ) (j +# 1#) n () sj4 ) 0# n () s4 , (# sf0, bl #) <- unsafeFreezeByteArray# mbl sr , (# sf1, bu #) <- unsafeFreezeByteArray# mbu sf0 = (# sf1, (# bu, bl #) #) -- | Solve @Ax = b@ problem given LU decomposition of A. luSolve :: forall (t :: Type) (n :: Nat) . ( KnownDim n, Ord t, Fractional t , PrimBytes t, PrimArray t (Matrix t n n), PrimArray t (Vector t n)) => LUFact t n -> Vector t n -> Vector t n luSolve LUFact {..} b = x where -- Pb = LUx pb = luPerm %* b !n@(I# n#) = fromIntegral $ dimVal' @n -- Ly = Pb y :: Vector t n y = runST $ do my <- newDataFrame let ixA (I# i) (I# j) = scalar $ ix# (i +# n# *# j) luLower ixB (I# i) = scalar $ ix# i pb forM_ [0..n-1] $ \i -> do v <- foldM ( \v j -> do dj <- readDataFrameOff my j return $ v - dj * ixA i j ) (ixB i) [0..i-1] writeDataFrameOff my i v unsafeFreezeDataFrame my -- Ux = y x = runST $ do mx <- newDataFrame let ixA (I# i) (I# j) = scalar $ ix# (i +# n# *# j) luUpper ixB (I# i) = scalar $ ix# i y forM_ [n-1, n-2 .. 0] $ \i -> do v <- foldM ( \v j -> do dj <- readDataFrameOff mx j return $ v - dj * ixA i j ) (ixB i) [i+1..n-1] writeDataFrameOff mx i (v / ixA i i) unsafeFreezeDataFrame mx -- | Permute rows that the largest magnitude elements in columns are on diagonals. -- -- Invariants of result matrix: -- * forall j >= i: |M[i,i]| >= M[j,i] -- * if M[i,i] == 0 then forall j >= i: |M[i+1,i+1]| >= M[j,i+1] pivotMat :: forall (t :: Type) (n :: k) . (KnownDim n, PrimArray t (Matrix t n n), Ord t, Num t) => Matrix t n n -> ( Matrix t n n -- permutated matrix , Matrix t n n -- permutation matrix , Scalar t -- sign of permutation matrix ) pivotMat m = ( let f ( j, [] ) = f (j+1, rowOrder) f ( j, i:is ) = (# (j, is), ix i j #) in case gen# nn f (0,rowOrder) of (# _, r #) -> r , let f ( j, [] ) = f (j+1, rowOrder) f ( j, x:xs ) | j == x = (# ( j, xs), 1 #) | otherwise = (# ( j, xs), 0 #) in case gen# nn f (0,rowOrder) of (# _, r #) -> r , if countMisordered rowOrder `rem` 2 == 1 then -1 else 1 ) where -- permuted row ordering rowOrder = uncurry fillPass $ searchPass 0 [0..n-1] -- matrix size !n@(I# n#) = fromIntegral $ dimVal' @n -- sign of permutations countMisordered :: [Int] -> Int countMisordered [] = 0 countMisordered (i:is) = foldl' (\c j -> if i > j then succ c else c) 0 is + countMisordered is nn = n# *# n# ix (I# i) (I# j) = ix# (i +# j *# n#) m findMax :: Int -> [Int] -> (t, Int) findMax j = foldl' (\(ox, oi) i -> let x = abs (ix i j) in if x > ox then (x, i) else (ox, oi) ) (0, 0) -- search maximums, leaving Nothing where all rows are 0 searchPass :: Int -> [Int] -> ([Int], [Maybe Int]) searchPass j is | j == n = (is, []) | otherwise = case findMax j is of (0, _) -> (Nothing:) <$> searchPass (j+1) is (_, i) -> (Just i:) <$> searchPass (j+1) (delete i is) -- replace Nothings with remaining row numbers fillPass :: [Int] -> [Maybe Int] -> [Int] fillPass _ [] = [] fillPass js (Just i : is) = i : fillPass js is fillPass (j:js) (Nothing : is) = j : fillPass js is fillPass [] (Nothing : is) = 0 : fillPass [] is