module Numeric.Array.Family.ArrayF () where
import GHC.Base (runRW#)
import GHC.Prim
import GHC.Types (Float (..), Int (..),
RuntimeRep (..), isTrue#)
import Numeric.Array.ElementWise
import Numeric.Array.Family
import Numeric.Commons
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Dimensions.Traverse
import Numeric.TypeLits
import Numeric.Matrix.Type
#include "MachDeps.h"
#define ARR_TYPE ArrayF
#define ARR_FROMSCALAR FromScalarF#
#define ARR_CONSTR ArrayF#
#define EL_TYPE_BOXED Float
#define EL_TYPE_PRIM Float#
#define EL_RUNTIME_REP 'FloatRep
#define EL_CONSTR F#
#define EL_SIZE SIZEOF_HSFLOAT#
#define EL_ALIGNMENT ALIGNMENT_HSFLOAT#
#define EL_ZERO 0.0#
#define EL_ONE 1.0#
#define EL_MINUS_ONE -1.0#
#define INDEX_ARRAY indexFloatArray#
#define WRITE_ARRAY writeFloatArray#
#define OP_EQ eqFloat#
#define OP_NE neFloat#
#define OP_GT gtFloat#
#define OP_GE geFloat#
#define OP_LT ltFloat#
#define OP_LE leFloat#
#define OP_PLUS plusFloat#
#define OP_MINUS minusFloat#
#define OP_TIMES timesFloat#
#define OP_NEGATE negateFloat#
#include "Array.h"
instance Num (ArrayF ds) where
(+) = zipV plusFloat#
() = zipV minusFloat#
(*) = zipV timesFloat#
negate = mapV negateFloat#
abs = mapV (\x -> if isTrue# (geFloat# x 0.0#)
then x
else negateFloat# x
)
signum = mapV (\x -> if isTrue# (gtFloat# x 0.0#)
then 1.0#
else if isTrue# (ltFloat# x 0.0#)
then 1.0#
else 0.0#
)
fromInteger = broadcastArray . fromInteger
instance Fractional (ArrayF ds) where
(/) = zipV divideFloat#
recip = mapV (divideFloat# 1.0#)
fromRational = broadcastArray . fromRational
instance Floating (ArrayF ds) where
pi = broadcastArray pi
exp = mapV expFloat#
log = mapV logFloat#
sqrt = mapV sqrtFloat#
sin = mapV sinFloat#
cos = mapV cosFloat#
tan = mapV tanFloat#
asin = mapV asinFloat#
acos = mapV acosFloat#
atan = mapV atanFloat#
sinh = mapV sinFloat#
cosh = mapV coshFloat#
tanh = mapV tanhFloat#
(**) = zipV powerFloat#
logBase = zipV (\x y -> logFloat# y `divideFloat#` logFloat# x)
asinh = mapV (\x -> logFloat# (x `plusFloat#`
sqrtFloat# (1.0# `plusFloat#` timesFloat# x x)))
acosh = mapV (\x -> case plusFloat# x 1.0# of
y -> logFloat# ( x `plusFloat#` timesFloat# y
(sqrtFloat# (minusFloat# x 1.0# `divideFloat#` y))
)
)
atanh = mapV (\x -> 0.5# `timesFloat#`
logFloat# (plusFloat# 1.0# x `divideFloat#` minusFloat# 1.0# x))
instance (KnownDim n, KnownDim m, ArrayF '[n,m] ~ Array Float '[n,m], 2 <= n, 2 <= m)
=> MatrixCalculus Float n m where
transpose (KnownDataFrame (ArrayF# offs nm arr)) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop2# n m
(\i j s' -> writeFloatArray# marr (j +# m *# i)
(indexFloatArray# arr (offs +# j *# n +# i)) s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes (# 0#, nm, r #)
where
n = case dimVal' @n of I# np -> np
m = case dimVal' @m of I# mp -> mp
bs = n *# m *# SIZEOF_HSFLOAT#
transpose (KnownDataFrame (FromScalarF# x)) = unsafeCoerce# $ FromScalarF# x
instance ( KnownDim n, ArrayF '[n,n] ~ Array Float '[n,n] )
=> SquareMatrixCalculus Float n where
eye = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop1# n
(\j s' -> writeFloatArray# marr (j *# n1) 1.0# s'
) (setByteArray# marr 0# bs 0# s1) of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes (# 0#, n *# n, r #)
where
n1 = n +# 1#
n = case dimVal' @n of I# np -> np
bs = n *# n *# SIZEOF_HSFLOAT#
diag (KnownDataFrame (Scalar (F# v))) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop1# n
(\j s' -> writeFloatArray# marr (j *# n1) v s'
) (setByteArray# marr 0# bs 0# s1) of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> fromBytes (# 0#, n *# n, r #)
where
n1 = n +# 1#
n = case dimVal' @n of I# np -> np
bs = n *# n *# SIZEOF_HSFLOAT#
det (KnownDataFrame (ArrayF# off nsqr arr)) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, mat #) -> case newByteArray#
(n *# SIZEOF_HSFLOAT#)
(copyByteArray# arr offb mat 0# bs s1) of
(# s2, vec #) ->
let f i x s | isTrue# (i >=# n) = (# s, x #)
| otherwise =
let !(# s' , j #) = maxInRowRem# n n i mat s
!(# s'', x' #) = if isTrue# (i /=# j)
then (# swapCols# n i j vec mat s'
, negateFloat# x #)
else (# s', x #)
!(# s''', y #) = clearRowEnd# n n i mat s''
in if isTrue# (eqFloat# 0.0# y)
then (# s''', 0.0# #)
else f (i +# 1#) (timesFloat# x' y) s'''
in f 0# 1.0# s2
) of (# _, r #) -> KnownDataFrame (Scalar (F# r))
where
n = case dimVal' @n of I# np -> np
offb = off *# SIZEOF_HSFLOAT#
bs = nsqr *# SIZEOF_HSFLOAT#
det (KnownDataFrame (FromScalarF# _)) = 0
trace (KnownDataFrame (ArrayF# off nsqr a)) = KnownDataFrame (Scalar (F# (loop' 0# 0.0#)))
where
n1 = n +# 1#
n = case dimVal' @n of I# np -> np
loop' i acc | isTrue# (i ># nsqr) = acc
| otherwise = loop' (i +# n1)
(indexFloatArray# a (off +# i) `plusFloat#` acc)
trace (KnownDataFrame (FromScalarF# x)) = KnownDataFrame (Scalar (F# (x `timesFloat#` n)))
where
n = case fromIntegral (dimVal' @n) of F# np -> np
instance (KnownNat n, ArrayF '[n,n] ~ Array Float '[n,n], 2 <= n) => MatrixInverse Float n where
inverse (KnownDataFrame (ArrayF# offs nsqr arr)) = case runRW#
( \s0 -> case newByteArray# (bs *# 2#) s0 of
(# s1, mat #) -> case newByteArray# (vs *# 2#)
(loop1# n (\i s -> writeFloatArray# mat
(i *# nn +# i +# n) 1.0#
(copyByteArray# arr (offb +# i *# vs)
mat (2# *# i *# vs) vs s))
(setByteArray# mat 0# (bs *# 2#) 0# s1)
) of
(# s2, vec #) ->
let f i s | isTrue# (i >=# n) = s
| otherwise =
let !(# s' , j #) = maxInRowRem# nn n i mat s
s'' = if isTrue# (i /=# j) then swapCols# nn i j vec mat s'
else s'
!(# s''', _ #) = clearRowAll# nn n i mat s''
in f (i +# 1#) s'''
in unsafeFreezeByteArray# mat
( shrinkMutableByteArray# mat bs
(
loop1# n (\i s ->
copyMutableByteArray# mat
(2# *# i *# vs +# vs)
mat (i *# vs) vs s)
(f 0# s2)
)
)
) of (# _, r #) -> KnownDataFrame (ArrayF# 0# nsqr r)
where
nn = 2# *# n
n = case dimVal' @n of I# np -> np
vs = n *# SIZEOF_HSFLOAT#
bs = n *# n *# SIZEOF_HSFLOAT#
offb = offs *# SIZEOF_HSFLOAT#
inverse (KnownDataFrame (FromScalarF# _)) = error "Cannot take inverse of a degenerate matrix"
swapCols# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> MutableByteArray# s
-> State# s
-> State# s
swapCols# n i j vec mat s0 =
case copyMutableByteArray# mat (i *# bs) vec 0# bs s0 of
s1 -> case copyMutableByteArray# mat (j *# bs) mat (i *# bs) bs s1 of
s2 -> copyMutableByteArray# vec 0# mat (j *# bs) bs s2
where
bs = n *# SIZEOF_HSFLOAT#
clearRowEnd# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> State# s
-> (# State# s, Float# #)
clearRowEnd# n m i mat s0 = (# loop' (i +# 1#) s1, y' #)
where
y0 = (n +# 1#) *# i +# 1#
!(# s1, y' #) = readFloatArray# mat ((n +# 1#) *# i) s0
yrc = 1.0# `divideFloat#` y'
n' = n -# i -# 1#
loop' k s | isTrue# (k >=# m) = s
| otherwise = loop' (k +# 1#)
( let x0 = k *# n +# i
!(# s', a' #) = readFloatArray# mat x0 s
s'' = writeFloatArray# mat x0 0.0# s'
a = a' `timesFloat#` yrc
in multNRem# n' (x0 +# 1#) y0 a mat s''
)
clearRowAll# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> State# s
-> (# State# s, Float# #)
clearRowAll# n m i mat s0 = (# divLoop (i +# 1#)
(writeFloatArray# mat ((n +# 1#) *# i) 1.0#
(loop' 0# i (loop' (i +# 1#) m s1))), y' #)
where
y0 = (n +# 1#) *# i +# 1#
!(# s1, y' #) = readFloatArray# mat ((n +# 1#) *# i) s0
yrc = 1.0# `divideFloat#` y'
n' = n -# i -# 1#
loop' k km s | isTrue# (k >=# km) = s
| otherwise = loop' (k +# 1#) km
( let x0 = k *# n +# i
!(# s', a' #) = readFloatArray# mat x0 s
s'' = writeFloatArray# mat x0 0.0# s'
a = a' `timesFloat#` yrc
in multNRem# n' (x0 +# 1#) y0 a mat s''
)
divLoop k s | isTrue# (k >=# n) = s
| otherwise = divLoop (k +# 1#)
( let x0 = n *# i +# k
!(# s', x #) = readFloatArray# mat x0 s
in writeFloatArray# mat x0 (timesFloat# x yrc) s'
)
multNRem# :: Int#
-> Int#
-> Int#
-> Float#
-> MutableByteArray# s
-> State# s
-> State# s
multNRem# 0# _ _ _ _ s = s
multNRem# n x0 y0 a mat s = multNRem# (n -# 1#) (x0 +# 1#) (y0 +# 1#) a mat
( case readFloatArray# mat y0 s of
(# s1, y #) -> case readFloatArray# mat x0 s1 of
(# s2, x #) -> writeFloatArray# mat x0 (x `minusFloat#` timesFloat# y a) s2
)
maxInRowRem# :: Int#
-> Int#
-> Int#
-> MutableByteArray# s
-> State# s
-> (# State# s, Int# #)
maxInRowRem# n m i mat s0 = loop' i (abs# v) i s1
where
!(# s1, v #) = readFloatArray# mat ((n +# 1#) *# i) s0
abs# x = if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x
loop' ok ov k s | isTrue# (k >=# m) = (# s, ok #)
| otherwise = case readFloatArray# mat (n *# k +# i) s of
(# s', v' #) -> if isTrue# (abs# v' `gtFloat#` ov)
then loop' k (abs# v') (k +# 1#) s'
else loop' ok ov (k +# 1#) s'
loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s)
-> State# s -> State# s
loop2# n m f = loop0 0# 0#
where
loop0 i j s | isTrue# (j ==# m) = s
| isTrue# (i ==# n) = loop0 0# (j +# 1#) s
| otherwise = case f i j s of s1 -> loop0 (i +# 1#) j s1