module Numeric.Matrix.Class
( MatrixCalculus (..)
, SquareMatrixCalculus (..)
, Matrix2x2 (..)
, MatrixProduct (..)
, MatrixInverse (..)
, prodF, prodD
, prodI8, prodI16, prodI32, prodI64
, prodW8, prodW16, prodW32, prodW64
) where
import GHC.Base (runRW#)
import GHC.Prim
import GHC.Types
#include "MachDeps.h"
#include "HsBaseConfig.h"
import Numeric.Commons
import Numeric.Vector.Class
import Numeric.Vector.Family (Vector)
import Numeric.Matrix.Family (Matrix)
class MatrixCalculus t (n :: Nat) (m :: Nat) v | v -> t, v -> n, v -> m where
broadcastMat :: t -> v
indexMat :: Int -> Int -> v -> t
transpose :: (MatrixCalculus t m n w, PrimBytes w) => v -> w
dimN :: v -> Int
dimM :: v -> Int
indexCol :: (VectorCalculus t n w, PrimBytes w) => Int -> v -> w
indexRow :: (VectorCalculus t m w, PrimBytes w) => Int -> v -> w
class SquareMatrixCalculus t (n :: Nat) v | v -> t, v -> n where
eye :: v
diag :: t -> v
det :: v -> t
trace :: v -> t
fromDiag :: (VectorCalculus t n w, PrimBytes w) => v -> w
toDiag :: (VectorCalculus t n w, PrimBytes w) => w -> v
class Matrix2x2 t where
mat22 :: Vector t 2 -> Vector t 2 -> Matrix t 2 2
rowsOfM22 :: Matrix t 2 2 -> (Vector t 2, Vector t 2)
colsOfM22 :: Matrix t 2 2 -> (Vector t 2, Vector t 2)
class MatrixProduct a b c where
prod :: a -> b -> c
class MatrixInverse a where
inverse :: a -> a
prodF :: (FloatBytes a, FloatBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodF n m k x y = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r `plusFloat#` timesFloat# (ixF (i +# n *# l) x)
(ixF (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeFloatArray# marr (i +# n *# j) (loop' i j 0# 0.0#) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k *# SIZEOF_HSFLOAT#
prodD :: (DoubleBytes a, DoubleBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodD n m k x y= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r +## (*##) (ixD (i +# n *# l) x)
(ixD (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeDoubleArray# marr (i +# n *# j) (loop' i j 0# 0.0##) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k *# SIZEOF_HSDOUBLE#
prodI8 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodI8 n m k x y= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x)
(ixI (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeInt8Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodI16 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodI16 n m k x y= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x)
(ixI (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeInt16Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodI32 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodI32 n m k x y= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x)
(ixI (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeInt32Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodI64 :: (IntBytes a, IntBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodI64 n m k x y= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r +# (*#) (ixI (i +# n *# l) x)
(ixI (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeInt64Array# marr (i +# n *# j) (loop' i j 0# 0#) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodW8 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodW8 n m k x y = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x)
(ixW (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeWord8Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodW16 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodW16 n m k x y = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x)
(ixW (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeWord16Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodW32 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodW32 n m k x y = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x)
(ixW (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeWord32Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
prodW64 :: (WordBytes a, WordBytes b, PrimBytes c) => Int# -> Int# -> Int# -> a -> b -> c
prodW64 n m k x y = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) ->
let loop' i j l r | isTrue# (l ==# m) = r
| otherwise = loop' i j (l +# 1#) (r `plusWord#` timesWord# (ixW (i +# n *# l) x)
(ixW (l +# m *# j) y))
in case loop2# n k
(\i j s' -> writeWord64Array# marr (i +# n *# j) (loop' i j 0# 0##) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
bs = n *# k
loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s) -> State# s -> State# s
loop2# n m f = loop' 0# 0#
where
loop' i j s | isTrue# (j ==# m) = s
| isTrue# (i ==# n) = loop' 0# (j +# 1#) s
| otherwise = case f i j s of s1 -> loop' (i +# 1#) j s1