module Numeric.Vector.Base.FloatXN () where
#include "MachDeps.h"
#include "HsBaseConfig.h"
import GHC.Base (runRW#)
import GHC.Prim
import GHC.Types
import GHC.TypeLits
import Numeric.Commons
import Numeric.Vector.Class
import Numeric.Vector.Family
instance KnownNat n => Show (VFloatXN n) where
show x = "{" ++ drop 1
( accumVReverse (\a s -> ", " ++ show (F# a) ++ s) x " }"
)
instance KnownNat n => Eq (VFloatXN n) where
a == b = accumV2 (\x y r -> r && isTrue# (x `eqFloat#` y)) a b True
a /= b = accumV2 (\x y r -> r || isTrue# (x `neFloat#` y)) a b False
instance KnownNat n => Ord (VFloatXN n) where
a > b = accumV2 (\x y r -> r && isTrue# (x `gtFloat#` y)) a b True
a < b = accumV2 (\x y r -> r && isTrue# (x `ltFloat#` y)) a b True
a >= b = accumV2 (\x y r -> r && isTrue# (x `geFloat#` y)) a b True
a <= b = accumV2 (\x y r -> r && isTrue# (x `leFloat#` y)) a b True
compare a b = accumV2 (\x y r -> r `mappend`
if isTrue# (x `gtFloat#` y)
then GT
else if isTrue# (x `ltFloat#` y)
then LT
else EQ
) a b EQ
instance (KnownNat n, 3 <= n) => Num (VFloatXN n) where
(+) = zipV plusFloat#
() = zipV minusFloat#
(*) = zipV timesFloat#
negate = mapV negateFloat#
abs = mapV (\x -> if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x)
signum = mapV (\x -> if isTrue# (x `gtFloat#` 0.0#) then 1.0# else if isTrue# (x `ltFloat#` 0.0#) then 1.0# else 0.0#)
fromInteger = broadcastVec . fromInteger
instance (KnownNat n, 3 <= n) => Fractional (VFloatXN n) where
(/) = zipV divideFloat#
recip = mapV (divideFloat# 1.0#)
fromRational = broadcastVec . fromRational
instance (KnownNat n, 3 <= n) => Floating (VFloatXN n) where
pi = broadcastVec pi
exp = mapV expFloat#
log = mapV logFloat#
sqrt = mapV sqrtFloat#
sin = mapV sinFloat#
cos = mapV cosFloat#
tan = mapV tanFloat#
asin = mapV asinFloat#
acos = mapV acosFloat#
atan = mapV atanFloat#
sinh = mapV sinFloat#
cosh = mapV coshFloat#
tanh = mapV tanhFloat#
(**) = zipV powerFloat#
logBase = zipV (\x y -> logFloat# y `divideFloat#` logFloat# x)
asinh = mapV (\x -> logFloat# (x `plusFloat#` sqrtFloat# (1.0# `plusFloat#` timesFloat# x x)))
acosh = mapV (\x -> case plusFloat# x 1.0# of
y -> logFloat# ( x `plusFloat#` timesFloat# y (sqrtFloat# (minusFloat# x 1.0# `divideFloat#` y)))
)
atanh = mapV (\x -> 0.5# `timesFloat#` logFloat# (plusFloat# 1.0# x `divideFloat#` minusFloat# 1.0# x))
instance (KnownNat n, 3 <= n) => VectorCalculus Float n (VFloatXN n) where
broadcastVec (F# x) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\i s' -> writeFloatArray# marr i x s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> VFloatXN r
where
n = dim# (undefined :: VFloatXN n)
bs = n *# 4#
a .*. b = broadcastVec $ dot a b
dot a b = F# (accumV2Float (\x y r -> r `plusFloat#` timesFloat# x y) a b 0.0#)
indexVec (I# i) _x@(VFloatXN arr)
#ifndef UNSAFE_INDICES
| isTrue# ( (i ># dim# _x)
`orI#` (i <=# 0#)
) = error $ "Bad index " ++ show (I# i) ++ " for " ++ show (dim _x) ++ "D vector"
| otherwise
#endif
= F# (indexFloatArray# arr (i -# 1#))
normL1 v = F# (accumVFloat (\x a -> a `plusFloat#` (if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x)) v 0.0#)
normL2 v = sqrt $ F# (accumVFloat (\x a -> a `plusFloat#` timesFloat# x x) v 0.0#)
normLPInf v@(VFloatXN arr) = F# (accumVFloat (\x a -> if isTrue# (x `geFloat#` a) then x else a) v (indexFloatArray# arr 0#))
normLNInf v@(VFloatXN arr) = F# (accumVFloat (\x a -> if isTrue# (x `leFloat#` a) then x else a) v (indexFloatArray# arr 0#))
normLP n v = case realToFrac n of
F# p -> F# (powerFloat# (divideFloat# 1.0# p) (accumVFloat (\x a -> a `plusFloat#` powerFloat# x p) v 0.0#))
dim x = I# (dim# x)
instance KnownNat n => PrimBytes (VFloatXN n) where
toBytes (VFloatXN a) = a
fromBytes = VFloatXN
byteSize x = SIZEOF_HSFLOAT# *# dim# x
byteAlign _ = ALIGNMENT_HSFLOAT#
instance FloatBytes (VFloatXN n) where
ixF i (VFloatXN a) = indexFloatArray# a i
dim# :: KnownNat n => VFloatXN n -> Int#
dim# x = case fromInteger (natVal x) of I# n -> n
loop# :: Int# -> (Int# -> State# s -> State# s) -> State# s -> State# s
loop# n f = loop' 0#
where
loop' i s | isTrue# (i ==# n) = s
| otherwise = case f i s of s1 -> loop' (i +# 1#) s1
zipV :: KnownNat n => (Float# -> Float# -> Float#) -> VFloatXN n -> VFloatXN n -> VFloatXN n
zipV f x@(VFloatXN a) (VFloatXN b) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\i s' -> case f (indexFloatArray# a i) (indexFloatArray# b i) of
r -> writeFloatArray# marr i r s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> VFloatXN r
where
n = dim# x
bs = n *# 4#
mapV :: KnownNat n => (Float# -> Float#) -> VFloatXN n -> VFloatXN n
mapV f x@(VFloatXN a) = case runRW#
( \s0 -> case newByteArray# bs s0 of
(# s1, marr #) -> case loop# n
(\i s' -> case f (indexFloatArray# a i) of
r -> writeFloatArray# marr i r s'
) s1 of
s2 -> unsafeFreezeByteArray# marr s2
) of (# _, r #) -> VFloatXN r
where
n = dim# x
bs = n *# 4#
accumVFloat :: KnownNat n => (Float# -> Float# -> Float#) -> VFloatXN n -> Float# -> Float#
accumVFloat f x@(VFloatXN a) = loop' 0#
where
loop' i acc | isTrue# (i ==# n) = acc
| otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) acc)
n = dim# x
accumV2 :: KnownNat n => (Float# -> Float# -> a -> a) -> VFloatXN n -> VFloatXN n -> a -> a
accumV2 f x@(VFloatXN a) (VFloatXN b) = loop' 0#
where
loop' i acc | isTrue# (i ==# n) = acc
| otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc)
n = dim# x
accumV2Float :: KnownNat n => (Float# -> Float# -> Float# -> Float#) -> VFloatXN n -> VFloatXN n -> Float# -> Float#
accumV2Float f x@(VFloatXN a) (VFloatXN b) = loop' 0#
where
loop' i acc | isTrue# (i ==# n) = acc
| otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc)
n = dim# x
accumVReverse :: KnownNat n => (Float# -> a -> a) -> VFloatXN n -> a -> a
accumVReverse f x@(VFloatXN a) = loop' (n -# 1#)
where
loop' i acc | isTrue# (i ==# 1#) = acc
| otherwise = loop' (i -# 1#) (f (indexFloatArray# a i) acc)
n = dim# x