module Numeric.Matrix.Base.FloatXNM () where
#include "MachDeps.h"
#include "HsBaseConfig.h"
import GHC.Base (runRW#)
import GHC.Prim
import GHC.Types
import GHC.TypeLits
import Numeric.Commons
import Numeric.Matrix.Class
import Numeric.Matrix.Family
instance (KnownNat n, KnownNat m) => Show (MFloatXNM n m) where
show x@(MFloatXNM arr) = "{" ++ drop 2 (loop' (n -# 1#) (m -# 1#) " }")
where
loop' i j acc | isTrue# (i ==# 1#) = acc
| isTrue# (j ==# 1#) = loop' (i -# 1#) (m -# 1#) ('\n':acc)
| otherwise = loop' i (j -# 1#) (", " ++ show (F# (indexFloatArray# arr (i +# n *# j))) ++ acc)
n = dimN# x
m = dimM# x
instance (KnownNat n, KnownNat m) => Eq (MFloatXNM n m) where
a == b = accumV2 (\x y r -> r && isTrue# (x `eqFloat#` y)) a b True
a /= b = accumV2 (\x y r -> r || isTrue# (x `neFloat#` y)) a b False
instance (KnownNat n, KnownNat m) => Ord (MFloatXNM n m) where
a > b = accumV2 (\x y r -> r && isTrue# (x `gtFloat#` y)) a b True
a < b = accumV2 (\x y r -> r && isTrue# (x `ltFloat#` y)) a b True
a >= b = accumV2 (\x y r -> r && isTrue# (x `geFloat#` y)) a b True
a <= b = accumV2 (\x y r -> r && isTrue# (x `leFloat#` y)) a b True
compare a b = accumV2 (\x y r -> r `mappend`
if isTrue# (x `gtFloat#` y)
then GT
else if isTrue# (x `ltFloat#` y)
then LT
else EQ
) a b EQ
instance (KnownNat n, KnownNat m) => Num (MFloatXNM n m) where
(+) = zipV plusFloat#
() = zipV minusFloat#
(*) = zipV timesFloat#
negate = mapV negateFloat#
abs = mapV (\x -> if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x)
signum = mapV (\x -> if isTrue# (x `gtFloat#` 0.0#) then 1.0# else if isTrue# (x `ltFloat#` 0.0#) then 1.0# else 0.0#)
fromInteger = broadcastMat . fromInteger
instance (KnownNat n, KnownNat m) => Fractional (MFloatXNM n m) where
(/) = zipV divideFloat#
recip = mapV (divideFloat# 1.0#)
fromRational = broadcastMat . fromRational
instance (KnownNat n, KnownNat m) => Floating (MFloatXNM n m) where
pi = broadcastMat pi
exp = mapV expFloat#
log = mapV logFloat#
sqrt = mapV sqrtFloat#
sin = mapV sinFloat#
cos = mapV cosFloat#
tan = mapV tanFloat#
asin = mapV asinFloat#
acos = mapV acosFloat#
atan = mapV atanFloat#
sinh = mapV sinFloat#
cosh = mapV coshFloat#
tanh = mapV tanhFloat#
(**) = zipV powerFloat#
logBase = zipV (\x y -> logFloat# y `divideFloat#` logFloat# x)
asinh = mapV (\x -> logFloat# (x `plusFloat#` sqrtFloat# (1.0# `plusFloat#` timesFloat# x x)))
acosh = mapV (\x -> case plusFloat# x 1.0# of
y -> logFloat# ( x `plusFloat#` timesFloat# y (sqrtFloat# (minusFloat# x 1.0# `divideFloat#` y)))
)
atanh = mapV (\x -> 0.5# `timesFloat#` logFloat# (plusFloat# 1.0# x `divideFloat#` minusFloat# 1.0# x))
instance (KnownNat n, KnownNat m) => MatrixCalculus Float n m (MFloatXNM n m) where
broadcastMat (F# x) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\i s' -> writeFloatArray# marr i x s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> MFloatXNM r
where
n = dimN# (undefined :: MFloatXNM n m) *# dimM# (undefined :: MFloatXNM n m)
bs = n *# SIZEOF_HSFLOAT#
indexMat (I# i) (I# j) x@(MFloatXNM arr)
#ifndef UNSAFE_INDICES
| isTrue# ( (i ># n)
`orI#` (i <=# 0#)
`orI#` (j ># _m)
`orI#` (j <=# 0#)
) = error $ "Bad index (" ++ show (I# i) ++ ", " ++ show (I# j) ++ ") for "
++ show (I# n) ++ "x" ++ show (I# _m) ++ "D matrix"
| otherwise
#endif
= F# (indexFloatArray# arr (i -# 1# +# n *# (j -# 1#)))
where
n = dimN# x
_m = dimM# x
transpose x@(MFloatXNM arr) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop2# n m
(\i j s' -> writeFloatArray# marr (i +# n *# j) (indexFloatArray# arr (i *# m +# j)) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
n = dimN# x
m = dimM# x
bs = n *# m *# SIZEOF_HSFLOAT#
dimN x = I# (dimN# x)
dimM x = I# (dimM# x)
indexCol (I# j) x@(MFloatXNM arr)
#ifndef UNSAFE_INDICES
| isTrue# ( (j ># dimM# x)
`orI#` (j <=# 0#)
) = error $ "Bad column index " ++ show (I# j) ++ " for "
++ show (I# n) ++ "x" ++ show (dimM x) ++ "D matrix"
| otherwise
#endif
= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case copyByteArray# arr (bs *# (j -# 1#)) marr 0# bs s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
n = dimN# x
bs = n *# SIZEOF_HSFLOAT#
indexRow (I# i) x@(MFloatXNM arr)
#ifndef UNSAFE_INDICES
| isTrue# ( (i ># n)
`orI#` (i <=# 0#)
) = error $ "Bad row index " ++ show (I# i) ++ " for "
++ show (I# n) ++ "x" ++ show (I# m) ++ "D matrix"
| otherwise
#endif
= case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# m
(\j s' -> writeFloatArray# marr j (indexFloatArray# arr (i -# 1# +# n *# j)) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
n = dimN# x
m = dimM# x
bs = m *# SIZEOF_HSFLOAT#
instance (KnownNat n, KnownNat m) => PrimBytes (MFloatXNM n m) where
toBytes (MFloatXNM a) = a
fromBytes = MFloatXNM
byteSize x = SIZEOF_HSFLOAT# *# dimN# x *# dimM# x
byteAlign _ = ALIGNMENT_HSFLOAT#
instance FloatBytes (MFloatXNM n m) where
ixF i (MFloatXNM a) = indexFloatArray# a i
instance KnownNat n => SquareMatrixCalculus Float n (MFloatXNM n n) where
eye = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\j s' -> writeFloatArray# marr (j *# n1) 1.0# s'
) (setByteArray# marr 0# bs 0# s1) of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
n1 = n +# 1#
n = dimN# (undefined :: MFloatXNM n n)
bs = n *# n *# SIZEOF_HSFLOAT#
diag (F# v) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\j s' -> writeFloatArray# marr (j *# n1) v s'
) (setByteArray# marr 0# bs 0# s1) of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
n1 = n +# 1#
n = dimN# (undefined :: MFloatXNM n n)
bs = n *# n *# SIZEOF_HSFLOAT#
det v@(MFloatXNM arr) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, mat #) -> case newByteArray# (n *# SIZEOF_HSFLOAT#) (copyByteArray# arr 0# mat 0# bs s1) of
(# s2, vec #) ->
let f i x s | isTrue# (i >=# n) = (# s, x #)
| otherwise =
let (# s' , j #) = maxInRowRem# n n i mat s
(# s'', x' #) = if isTrue# (i /=# j) then (# swapCols# n i j vec mat s', negateFloat# x #)
else (# s', x #)
(# s''', y #) = clearRowEnd# n n i mat s''
in if isTrue# (eqFloat# 0.0# y) then (# s''', 0.0# #)
else f (i +# 1#) (timesFloat# x' y) s'''
in f 0# 1.0# s2
) of (# _, r #) -> F# r
where
n = dimN# v
bs = n *# n *# SIZEOF_HSFLOAT#
trace x@(MFloatXNM a) = F# (loop' 0# 0.0#)
where
n1 = n +# 1#
n = dimN# x
nn = n *# n
loop' i acc | isTrue# (i ># nn) = acc
| otherwise = loop' (i +# n1) (indexFloatArray# a i `plusFloat#` acc)
fromDiag x@(MFloatXNM a) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\j s' -> writeFloatArray# marr j (indexFloatArray# a (j *# n1)) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
n1 = n +# 1#
n = dimN# x
bs = n *# SIZEOF_HSFLOAT#
toDiag x = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\j s' -> writeFloatArray# marr (j *# n1) (indexFloatArray# a j) s'
) (setByteArray# marr 0# bs 0# s1) of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes r
where
a = toBytes x
n1 = n +# 1#
n = dimN# (undefined :: MFloatXNM n n)
bs = n *# n *# SIZEOF_HSFLOAT#
instance KnownNat n => MatrixInverse (MFloatXNM n n) where
inverse v@(MFloatXNM arr) = case runRW#
( \s0 -> case newByteArray# (bs *# 2#) s0 of
(# s1, mat #) -> case newByteArray# (n *# SIZEOF_HSFLOAT#)
(loop# n (\i s -> writeFloatArray# mat (i *# nn +# i +# n) 1.0# (copyByteArray# arr (i *# vs) mat (2# *# i *# vs) vs s))
(setByteArray# mat 0# (bs *# 2#) 0# s1)
) of
(# s2, vec #) ->
let f i s | isTrue# (i >=# n) = s
| otherwise =
let (# s' , j #) = maxInRowRem# nn n i mat s
s'' = if isTrue# (i /=# j) then swapCols# nn i j vec mat s'
else s'
(# s''', _ #) = clearRowAll# nn n i mat s''
in f (i +# 1#) s'''
in unsafeFreezeByteArray# mat
( shrinkMutableByteArray# mat bs
(
loop# n (\i s -> copyMutableByteArray# mat (2# *# i *# vs +# vs) mat (i *# vs) vs s)
(f 0# s2)
)
)
) of (# _, r #) -> MFloatXNM r
where
nn = 2# *# n
n = dimN# v
vs = n *# SIZEOF_HSFLOAT#
bs = n *# n *# SIZEOF_HSFLOAT#
proxyN# :: MFloatXNM n m -> Proxy# n
proxyN# _ = proxy#
dimN# :: KnownNat n => MFloatXNM n m -> Int#
dimN# x = case fromInteger (natVal' (proxyN# x)) of I# n -> n
dimM# :: KnownNat m => MFloatXNM n m -> Int#
dimM# x = case fromInteger (natVal x) of I# n -> n
loop# :: Int# -> (Int# -> State# s -> State# s) -> State# s -> State# s
loop# n f = loop' 0#
where
loop' i s | isTrue# (i ==# n) = s
| otherwise = case f i s of s1 -> loop' (i +# 1#) s1
loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s) -> State# s -> State# s
loop2# n m f = loop' 0# 0#
where
loop' i j s | isTrue# (j ==# m) = s
| isTrue# (i ==# n) = loop' 0# (j +# 1#) s
| otherwise = case f i j s of s1 -> loop' (i +# 1#) j s1
zipV :: (KnownNat n, KnownNat m) => (Float# -> Float# -> Float#) -> MFloatXNM n m -> MFloatXNM n m -> MFloatXNM n m
zipV f x@(MFloatXNM a) (MFloatXNM b) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\i s' -> case f (indexFloatArray# a i) (indexFloatArray# b i) of
r -> writeFloatArray# marr i r s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> MFloatXNM r
where
n = dimN# x *# dimM# x
bs = n *# SIZEOF_HSFLOAT#
mapV :: (KnownNat n, KnownNat m) => (Float# -> Float#) -> MFloatXNM n m -> MFloatXNM n m
mapV f x@(MFloatXNM a) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\i s' -> case f (indexFloatArray# a i) of
r -> writeFloatArray# marr i r s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> MFloatXNM r
where
n = dimN# x *# dimM# x
bs = n *# SIZEOF_HSFLOAT#
accumV2 :: (KnownNat n, KnownNat m) => (Float# -> Float# -> a -> a) -> MFloatXNM n m -> MFloatXNM n m -> a -> a
accumV2 f x@(MFloatXNM a) (MFloatXNM b) = loop' 0#
where
loop' i acc | isTrue# (i ==# n) = acc
| otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc)
n = dimN# x *# dimM# x
swapCols# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> MutableByteArray# s
-> State# s
-> State# s
swapCols# n i j vec mat s0 =
case copyMutableByteArray# mat (i *# bs) vec 0# bs s0 of
s1 -> case copyMutableByteArray# mat (j *# bs) mat (i *# bs) bs s1 of
s2 -> copyMutableByteArray# vec 0# mat (j *# bs) bs s2
where
bs = n *# SIZEOF_HSFLOAT#
clearRowEnd# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> State# s
-> (# State# s, Float# #)
clearRowEnd# n m i mat s0 = (# loop' (i +# 1#) s1, y' #)
where
y0 = (n +# 1#) *# i +# 1#
(# s1, y' #) = readFloatArray# mat ((n +# 1#) *# i) s0
yrc = 1.0# `divideFloat#` y'
n' = n -# i -# 1#
loop' k s | isTrue# (k >=# m) = s
| otherwise = loop' (k +# 1#)
( let x0 = k *# n +# i
(# s', a' #) = readFloatArray# mat x0 s
s'' = writeFloatArray# mat x0 0.0# s'
a = a' `timesFloat#` yrc
in multNRem# n' (x0 +# 1#) y0 a mat s''
)
clearRowAll# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> State# s
-> (# State# s, Float# #)
clearRowAll# n m i mat s0 = (# divLoop (i +# 1#) (writeFloatArray# mat ((n +# 1#) *# i) 1.0# (loop' 0# i (loop' (i +# 1#) m s1))), y' #)
where
y0 = (n +# 1#) *# i +# 1#
(# s1, y' #) = readFloatArray# mat ((n +# 1#) *# i) s0
yrc = 1.0# `divideFloat#` y'
n' = n -# i -# 1#
loop' k km s | isTrue# (k >=# km) = s
| otherwise = loop' (k +# 1#) km
( let x0 = k *# n +# i
(# s', a' #) = readFloatArray# mat x0 s
s'' = writeFloatArray# mat x0 0.0# s'
a = a' `timesFloat#` yrc
in multNRem# n' (x0 +# 1#) y0 a mat s''
)
divLoop k s | isTrue# (k >=# n) = s
| otherwise = divLoop (k +# 1#)
( let x0 = n *# i +# k
(# s', x #) = readFloatArray# mat x0 s
in writeFloatArray# mat x0 (timesFloat# x yrc) s'
)
multNRem# :: Int#
-> Int#
-> Int#
-> Float#
-> MutableByteArray# s
-> State# s
-> State# s
multNRem# 0# _ _ _ _ s = s
multNRem# n x0 y0 a mat s = multNRem# (n -# 1#) (x0 +# 1#) (y0 +# 1#) a mat
( case readFloatArray# mat y0 s of
(# s1, y #) -> case readFloatArray# mat x0 s1 of
(# s2, x #) -> writeFloatArray# mat x0 (x `minusFloat#` timesFloat# y a) s2
)
maxInRowRem# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> State# s
-> (# State# s, Int# #)
maxInRowRem# n m i mat s0 = loop' i (abs# v) i s1
where
(# s1, v #) = readFloatArray# mat ((n +# 1#) *# i) s0
abs# x = if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x
loop' ok ov k s | isTrue# (k >=# m) = (# s, ok #)
| otherwise = case readFloatArray# mat (n *# k +# i) s of
(# s', v' #) -> if isTrue# (abs# v' `gtFloat#` ov)
then loop' k (abs# v') (k +# 1#) s'
else loop' ok ov (k +# 1#) s'