{-# LANGUAGE TypeOperators, TypeFamilies #-} {-# LANGUAGE CPP #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MagicHash, UnboxedTuples, DataKinds #-} {-# OPTIONS_GHC -fno-warn-orphans #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.Vector.Base.FloatXN -- Copyright : (c) Artem Chirkin -- License : MIT -- -- Maintainer : chirkin@arch.ethz.ch -- -- ----------------------------------------------------------------------------- module Numeric.Vector.Base.FloatXN () where #include "MachDeps.h" #include "HsBaseConfig.h" import GHC.Base (runRW#) import GHC.Prim import GHC.Types import GHC.TypeLits import Numeric.Commons import Numeric.Vector.Class import Numeric.Vector.Family instance KnownNat n => Show (VFloatXN n) where show x = "{" ++ drop 1 ( accumVReverse (\a s -> ", " ++ show (F# a) ++ s) x " }" ) instance KnownNat n => Eq (VFloatXN n) where a == b = accumV2 (\x y r -> r && isTrue# (x `eqFloat#` y)) a b True {-# INLINE (==) #-} a /= b = accumV2 (\x y r -> r || isTrue# (x `neFloat#` y)) a b False {-# INLINE (/=) #-} -- | Implement partial ordering for `>`, `<`, `>=`, `<=` and lexicographical ordering for `compare` instance KnownNat n => Ord (VFloatXN n) where a > b = accumV2 (\x y r -> r && isTrue# (x `gtFloat#` y)) a b True {-# INLINE (>) #-} a < b = accumV2 (\x y r -> r && isTrue# (x `ltFloat#` y)) a b True {-# INLINE (<) #-} a >= b = accumV2 (\x y r -> r && isTrue# (x `geFloat#` y)) a b True {-# INLINE (>=) #-} a <= b = accumV2 (\x y r -> r && isTrue# (x `leFloat#` y)) a b True {-# INLINE (<=) #-} -- | Compare lexicographically compare a b = accumV2 (\x y r -> r `mappend` if isTrue# (x `gtFloat#` y) then GT else if isTrue# (x `ltFloat#` y) then LT else EQ ) a b EQ {-# INLINE compare #-} instance (KnownNat n, 3 <= n) => Num (VFloatXN n) where (+) = zipV plusFloat# {-# INLINE (+) #-} (-) = zipV minusFloat# {-# INLINE (-) #-} (*) = zipV timesFloat# {-# INLINE (*) #-} negate = mapV negateFloat# {-# INLINE negate #-} abs = mapV (\x -> if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x) {-# INLINE abs #-} signum = mapV (\x -> if isTrue# (x `gtFloat#` 0.0#) then 1.0# else if isTrue# (x `ltFloat#` 0.0#) then -1.0# else 0.0#) {-# INLINE signum #-} fromInteger = broadcastVec . fromInteger {-# INLINE fromInteger #-} instance (KnownNat n, 3 <= n) => Fractional (VFloatXN n) where (/) = zipV divideFloat# {-# INLINE (/) #-} recip = mapV (divideFloat# 1.0#) {-# INLINE recip #-} fromRational = broadcastVec . fromRational {-# INLINE fromRational #-} instance (KnownNat n, 3 <= n) => Floating (VFloatXN n) where pi = broadcastVec pi {-# INLINE pi #-} exp = mapV expFloat# {-# INLINE exp #-} log = mapV logFloat# {-# INLINE log #-} sqrt = mapV sqrtFloat# {-# INLINE sqrt #-} sin = mapV sinFloat# {-# INLINE sin #-} cos = mapV cosFloat# {-# INLINE cos #-} tan = mapV tanFloat# {-# INLINE tan #-} asin = mapV asinFloat# {-# INLINE asin #-} acos = mapV acosFloat# {-# INLINE acos #-} atan = mapV atanFloat# {-# INLINE atan #-} sinh = mapV sinFloat# {-# INLINE sinh #-} cosh = mapV coshFloat# {-# INLINE cosh #-} tanh = mapV tanhFloat# {-# INLINE tanh #-} (**) = zipV powerFloat# {-# INLINE (**) #-} logBase = zipV (\x y -> logFloat# y `divideFloat#` logFloat# x) {-# INLINE logBase #-} asinh = mapV (\x -> logFloat# (x `plusFloat#` sqrtFloat# (1.0# `plusFloat#` timesFloat# x x))) {-# INLINE asinh #-} acosh = mapV (\x -> case plusFloat# x 1.0# of y -> logFloat# ( x `plusFloat#` timesFloat# y (sqrtFloat# (minusFloat# x 1.0# `divideFloat#` y))) ) {-# INLINE acosh #-} atanh = mapV (\x -> 0.5# `timesFloat#` logFloat# (plusFloat# 1.0# x `divideFloat#` minusFloat# 1.0# x)) {-# INLINE atanh #-} instance (KnownNat n, 3 <= n) => VectorCalculus Float n (VFloatXN n) where broadcastVec (F# x) = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> case loop# n (\i s' -> writeFloatArray# marr i x s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> VFloatXN r where n = dim# (undefined :: VFloatXN n) bs = n *# 4# {-# INLINE broadcastVec #-} a .*. b = broadcastVec $ dot a b {-# INLINE (.*.) #-} dot a b = F# (accumV2Float (\x y r -> r `plusFloat#` timesFloat# x y) a b 0.0#) {-# INLINE dot #-} indexVec (I# i) _x@(VFloatXN arr) #ifndef UNSAFE_INDICES | isTrue# ( (i ># dim# _x) `orI#` (i <=# 0#) ) = error $ "Bad index " ++ show (I# i) ++ " for " ++ show (dim _x) ++ "D vector" | otherwise #endif = F# (indexFloatArray# arr (i -# 1#)) {-# INLINE indexVec #-} normL1 v = F# (accumVFloat (\x a -> a `plusFloat#` (if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x)) v 0.0#) {-# INLINE normL1 #-} normL2 v = sqrt $ F# (accumVFloat (\x a -> a `plusFloat#` timesFloat# x x) v 0.0#) {-# INLINE normL2 #-} normLPInf v@(VFloatXN arr) = F# (accumVFloat (\x a -> if isTrue# (x `geFloat#` a) then x else a) v (indexFloatArray# arr 0#)) {-# INLINE normLPInf #-} normLNInf v@(VFloatXN arr) = F# (accumVFloat (\x a -> if isTrue# (x `leFloat#` a) then x else a) v (indexFloatArray# arr 0#)) {-# INLINE normLNInf #-} normLP n v = case realToFrac n of F# p -> F# (powerFloat# (divideFloat# 1.0# p) (accumVFloat (\x a -> a `plusFloat#` powerFloat# x p) v 0.0#)) {-# INLINE normLP #-} dim x = I# (dim# x) {-# INLINE dim #-} instance KnownNat n => PrimBytes (VFloatXN n) where toBytes (VFloatXN a) = a {-# INLINE toBytes #-} fromBytes = VFloatXN {-# INLINE fromBytes #-} byteSize x = SIZEOF_HSFLOAT# *# dim# x {-# INLINE byteSize #-} byteAlign _ = ALIGNMENT_HSFLOAT# {-# INLINE byteAlign #-} instance FloatBytes (VFloatXN n) where ixF i (VFloatXN a) = indexFloatArray# a i {-# INLINE ixF #-} ----------------------------------------------------------------------------- -- Helpers ----------------------------------------------------------------------------- dim# :: KnownNat n => VFloatXN n -> Int# dim# x = case fromInteger (natVal x) of I# n -> n {-# INLINE dim# #-} -- | Do something in a loop for int i from 0 to n-1 loop# :: Int# -> (Int# -> State# s -> State# s) -> State# s -> State# s loop# n f = loop' 0# where loop' i s | isTrue# (i ==# n) = s | otherwise = case f i s of s1 -> loop' (i +# 1#) s1 {-# INLINE loop# #-} zipV :: KnownNat n => (Float# -> Float# -> Float#) -> VFloatXN n -> VFloatXN n -> VFloatXN n zipV f x@(VFloatXN a) (VFloatXN b) = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> case loop# n (\i s' -> case f (indexFloatArray# a i) (indexFloatArray# b i) of r -> writeFloatArray# marr i r s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> VFloatXN r where n = dim# x bs = n *# 4# mapV :: KnownNat n => (Float# -> Float#) -> VFloatXN n -> VFloatXN n mapV f x@(VFloatXN a) = case runRW# ( \s0 -> case newByteArray# bs s0 of (# s1, marr #) -> case loop# n (\i s' -> case f (indexFloatArray# a i) of r -> writeFloatArray# marr i r s' ) s1 of s2 -> unsafeFreezeByteArray# marr s2 ) of (# _, r #) -> VFloatXN r where n = dim# x bs = n *# 4# accumVFloat :: KnownNat n => (Float# -> Float# -> Float#) -> VFloatXN n -> Float# -> Float# accumVFloat f x@(VFloatXN a) = loop' 0# where loop' i acc | isTrue# (i ==# n) = acc | otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) acc) n = dim# x accumV2 :: KnownNat n => (Float# -> Float# -> a -> a) -> VFloatXN n -> VFloatXN n -> a -> a accumV2 f x@(VFloatXN a) (VFloatXN b) = loop' 0# where loop' i acc | isTrue# (i ==# n) = acc | otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc) n = dim# x accumV2Float :: KnownNat n => (Float# -> Float# -> Float# -> Float#) -> VFloatXN n -> VFloatXN n -> Float# -> Float# accumV2Float f x@(VFloatXN a) (VFloatXN b) = loop' 0# where loop' i acc | isTrue# (i ==# n) = acc | otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc) n = dim# x accumVReverse :: KnownNat n => (Float# -> a -> a) -> VFloatXN n -> a -> a accumVReverse f x@(VFloatXN a) = loop' (n -# 1#) where loop' i acc | isTrue# (i ==# -1#) = acc | otherwise = loop' (i -# 1#) (f (indexFloatArray# a i) acc) n = dim# x