{-# LANGUAGE MagicHash, UnboxedTuples, DataKinds #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE MultiParamTypeClasses, DataKinds, KindSignatures #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.Matrix.Class -- Copyright : (c) Artem Chirkin -- License : MIT -- -- Maintainer : chirkin@arch.ethz.ch -- -- ----------------------------------------------------------------------------- module Numeric.Matrix.Class ( MatrixCalculus (..) , SquareMatrixCalculus (..) , Matrix2x2 (..) , MatrixProduct (..) , MatrixInverse (..) , prodF, prodD , prodI8, prodI16, prodI32, prodI64 , prodW8, prodW16, prodW32, prodW64 ) where import GHC.Base (runRW#) import GHC.Prim import GHC.Types #include "MachDeps.h" #include "HsBaseConfig.h" import Numeric.Commons import Numeric.Vector.Class import Numeric.Vector.Family (Vector) import Numeric.Matrix.Family (Matrix) class MatrixCalculus t (n :: Nat) (m :: Nat) v | v -> t, v -> n, v -> m where -- | Fill Mat with the same value broadcastMat :: t -> v -- | Get element by its index indexMat :: Int -> Int -> v -> t -- | Transpose Mat transpose :: (MatrixCalculus t m n w, PrimBytes w) => v -> w -- | First dimension size of a matrix dimN :: v -> Int -- | Second dimension size of a matrix dimM :: v -> Int -- | Get vector column by its index indexCol :: (VectorCalculus t n w, PrimBytes w) => Int -> v -> w -- | Get vector row by its index indexRow :: (VectorCalculus t m w, PrimBytes w) => Int -> v -> w class SquareMatrixCalculus t (n :: Nat) v | v -> t, v -> n where -- | Mat with 1 on diagonal and 0 elsewhere eye :: v -- | Put the same value on the Mat diagonal, 0 otherwise diag :: t -> v -- | Determinant of Mat det :: v -> t -- | Sum of diagonal elements trace :: v -> t -- | Get the diagonal elements from Mat into Vec fromDiag :: (VectorCalculus t n w, PrimBytes w) => v -> w -- | Set Vec values into the diagonal elements of Mat toDiag :: (VectorCalculus t n w, PrimBytes w) => w -> v class Matrix2x2 t where -- | Compose a 2x2D matrix mat22 :: Vector t 2 -> Vector t 2 -> Matrix t 2 2 rowsOfM22 :: Matrix t 2 2 -> (Vector t 2, Vector t 2) colsOfM22 :: Matrix t 2 2 -> (Vector t 2, Vector t 2) class MatrixProduct a b c where -- | matrix-matrix or matrix-vector product prod :: a -> b -> c class MatrixInverse a where inverse :: a -> a prodF :: (FloatBytes a, FloatBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodF n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusFloat#` timesFloat# (ixF (i +# n *# l) x) (ixF (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeFloatArray# marr (i +# n *# j) (loop' i j 0# 0.0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k *# SIZEOF_HSFLOAT# {-# INLINE prodF #-} prodD :: (DoubleBytes a, DoubleBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodD n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +## (*##) (ixD (i +# n *# l) x) (ixD (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeDoubleArray# marr (i +# n *# j) (loop' i j 0# 0.0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k *# SIZEOF_HSDOUBLE# {-# INLINE prodD #-} prodI8 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodI8 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x) (ixI (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt8Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodI8 #-} prodI16 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodI16 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x) (ixI (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt16Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodI16 #-} prodI32 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodI32 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x) (ixI (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt32Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodI32 #-} prodI64 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodI64 n m k x y= case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x) (ixI (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeInt64Array# marr (i +# n *# j) (loop' i j 0# 0#) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodI64 #-} prodW8 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodW8 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x) (ixW (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord8Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodW8 #-} prodW16 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodW16 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x) (ixW (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord16Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodW16 #-} prodW32 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodW32 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x) (ixW (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord32Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodW32 #-} prodW64 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c prodW64 n m k x y = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> let loop' i j l r | isTrue# (l ==# m) = r | otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x) (ixW (l +# m *# j) y)) in case loop2# n k (\i j s' -> writeWord64Array# marr (i +# n *# j) (loop' i j 0# 0##) s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> fromBytes r where bs = n *# k {-# INLINE prodW64 #-} -- | Do something in a loop for int i from 0 to n-1 and j from 0 to m-1 loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s) -> State# s -> State# s loop2# n m f = loop' 0# 0# where loop' i j s | isTrue# (j ==# m) = s | isTrue# (i ==# n) = loop' 0# (j +# 1#) s | otherwise = case f i j s of s1 -> loop' (i +# 1#) j s1 {-# INLINE loop2# #-}