{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Numeric.Rounded.Hardware.Vector.Storable
  ( -- * Conversion between @VS.Vector a@ and @VS.Vector (Rounded r a)@
    coercion
  , fromVectorOfRounded
  , toVectorOfRounded
  , coercionM
  , fromMVectorOfRounded
  , toMVectorOfRounded
    -- * Specialized functions
  , roundedSum
  , zipWith_roundedAdd
  , zipWith_roundedSub
  , zipWith_roundedMul
  , zipWith3_roundedFusedMultiplyAdd
  , zipWith_roundedDiv
  , map_roundedSqrt
  , sum
  , zipWith_add
  , zipWith_sub
  , zipWith_mul
  , zipWith3_fusedMultiplyAdd
  , zipWith_div
  , map_sqrt
  ) where
import           Data.Coerce
import           Data.Proxy
import           Data.Type.Coercion
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import           Foreign.Storable
import           Numeric.Rounded.Hardware.Internal
import           Prelude hiding (sum)
import           Unsafe.Coerce

--
-- Conversion between 'VS.Vector a' and 'VS.Vector (Rounded r a)'
--
-- 'VS.Vector' will be nominally roled after vector-0.13.
-- See:
--     * https://github.com/haskell/vector/issues/223
--     * https://github.com/haskell/vector/pull/235
--
-- But, we know 'Storable (Rounded r a)' is the same as 'Storable a'
--

coercion :: Coercion (VS.Vector a) (VS.Vector (Rounded r a))
coercion :: Coercion (Vector a) (Vector (Rounded r a))
coercion = Coercion (Vector Any) (Vector Any)
-> Coercion (Vector a) (Vector (Rounded r a))
forall a b. a -> b
unsafeCoerce (forall a. Coercion (Vector a) (Vector a)
forall k (a :: k) (b :: k). Coercible a b => Coercion a b
Coercion :: Coercion (VS.Vector a) (VS.Vector a))

fromVectorOfRounded :: VS.Vector (Rounded r a) -> VS.Vector a
fromVectorOfRounded :: Vector (Rounded r a) -> Vector a
fromVectorOfRounded = Vector (Rounded r a) -> Vector a
forall a b. a -> b
unsafeCoerce

toVectorOfRounded :: VS.Vector a -> VS.Vector (Rounded r a)
toVectorOfRounded :: Vector a -> Vector (Rounded r a)
toVectorOfRounded = Vector a -> Vector (Rounded r a)
forall a b. a -> b
unsafeCoerce

coercionM :: Coercion (VSM.MVector s a) (VSM.MVector s (Rounded r a))
coercionM :: Coercion (MVector s a) (MVector s (Rounded r a))
coercionM = Coercion (MVector Any Any) (MVector Any Any)
-> Coercion (MVector s a) (MVector s (Rounded r a))
forall a b. a -> b
unsafeCoerce (forall k (a :: k) (b :: k). Coercible a b => Coercion a b
forall s a. Coercion (MVector s a) (MVector s a)
Coercion :: Coercion (VSM.MVector s a) (VSM.MVector s a))

fromMVectorOfRounded :: VSM.MVector s (Rounded r a) -> VSM.MVector s a
fromMVectorOfRounded :: MVector s (Rounded r a) -> MVector s a
fromMVectorOfRounded = MVector s (Rounded r a) -> MVector s a
forall a b. a -> b
unsafeCoerce

toMVectorOfRounded :: VSM.MVector s a -> VSM.MVector s (Rounded r a)
toMVectorOfRounded :: MVector s a -> MVector s (Rounded r a)
toMVectorOfRounded = MVector s a -> MVector s (Rounded r a)
forall a b. a -> b
unsafeCoerce

--
-- Vector Operations
--

-- | Equivalent to 'VS.sum'
sum :: forall r a. (Rounding r, Storable a, RoundedRing_Vector VS.Vector a) => VS.Vector (Rounded r a) -> Rounded r a
sum :: Vector (Rounded r a) -> Rounded r a
sum Vector (Rounded r a)
v = a -> Rounded r a
coerce (RoundingMode -> Vector a -> a
forall (vector :: * -> *) a.
RoundedRing_Vector vector a =>
RoundingMode -> vector a -> a
roundedSum RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE sum #-}

-- | Equivalent to @'VS.zipWith' (+)@
zipWith_add :: forall r a. (Rounding r, Storable a, RoundedRing_Vector VS.Vector a) => VS.Vector (Rounded r a) -> VS.Vector (Rounded r a) -> VS.Vector (Rounded r a)
zipWith_add :: Vector (Rounded r a)
-> Vector (Rounded r a) -> Vector (Rounded r a)
zipWith_add Vector (Rounded r a)
v1 Vector (Rounded r a)
v2 = Vector a -> Vector (Rounded r a)
forall a (r :: RoundingMode). Vector a -> Vector (Rounded r a)
toVectorOfRounded (RoundingMode -> Vector a -> Vector a -> Vector a
forall (vector :: * -> *) a.
RoundedRing_Vector vector a =>
RoundingMode -> vector a -> vector a -> vector a
zipWith_roundedAdd RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v1) (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v2))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE zipWith_add #-}

-- | Equivalent to @'VS.zipWith' (-)@
zipWith_sub :: forall r a. (Rounding r, Storable a, RoundedRing_Vector VS.Vector a) => VS.Vector (Rounded r a) -> VS.Vector (Rounded r a) -> VS.Vector (Rounded r a)
zipWith_sub :: Vector (Rounded r a)
-> Vector (Rounded r a) -> Vector (Rounded r a)
zipWith_sub Vector (Rounded r a)
v1 Vector (Rounded r a)
v2 = Vector a -> Vector (Rounded r a)
forall a (r :: RoundingMode). Vector a -> Vector (Rounded r a)
toVectorOfRounded (RoundingMode -> Vector a -> Vector a -> Vector a
forall (vector :: * -> *) a.
RoundedRing_Vector vector a =>
RoundingMode -> vector a -> vector a -> vector a
zipWith_roundedSub RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v1) (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v2))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE zipWith_sub #-}

-- | Equivalent to @'VS.zipWith' (*)@
zipWith_mul :: forall r a. (Rounding r, Storable a, RoundedRing_Vector VS.Vector a) => VS.Vector (Rounded r a) -> VS.Vector (Rounded r a) -> VS.Vector (Rounded r a)
zipWith_mul :: Vector (Rounded r a)
-> Vector (Rounded r a) -> Vector (Rounded r a)
zipWith_mul Vector (Rounded r a)
v1 Vector (Rounded r a)
v2 = Vector a -> Vector (Rounded r a)
forall a (r :: RoundingMode). Vector a -> Vector (Rounded r a)
toVectorOfRounded (RoundingMode -> Vector a -> Vector a -> Vector a
forall (vector :: * -> *) a.
RoundedRing_Vector vector a =>
RoundingMode -> vector a -> vector a -> vector a
zipWith_roundedMul RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v1) (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v2))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE zipWith_mul #-}

-- | Equivalent to @'VS.zipWith3' fusedMultiplyAdd@
zipWith3_fusedMultiplyAdd :: forall r a. (Rounding r, Storable a, RoundedRing_Vector VS.Vector a) => VS.Vector (Rounded r a) -> VS.Vector (Rounded r a) -> VS.Vector (Rounded r a) -> VS.Vector (Rounded r a)
zipWith3_fusedMultiplyAdd :: Vector (Rounded r a)
-> Vector (Rounded r a)
-> Vector (Rounded r a)
-> Vector (Rounded r a)
zipWith3_fusedMultiplyAdd Vector (Rounded r a)
v1 Vector (Rounded r a)
v2 Vector (Rounded r a)
v3 = Vector a -> Vector (Rounded r a)
forall a (r :: RoundingMode). Vector a -> Vector (Rounded r a)
toVectorOfRounded (RoundingMode -> Vector a -> Vector a -> Vector a -> Vector a
forall (vector :: * -> *) a.
RoundedRing_Vector vector a =>
RoundingMode -> vector a -> vector a -> vector a -> vector a
zipWith3_roundedFusedMultiplyAdd RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v1) (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v2) (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v3))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE zipWith3_fusedMultiplyAdd #-}

-- | Equivalent to @'VS.zipWith' (/)@
zipWith_div :: forall r a. (Rounding r, Storable a, RoundedFractional_Vector VS.Vector a) => VS.Vector (Rounded r a) -> VS.Vector (Rounded r a) -> VS.Vector (Rounded r a)
zipWith_div :: Vector (Rounded r a)
-> Vector (Rounded r a) -> Vector (Rounded r a)
zipWith_div Vector (Rounded r a)
v1 Vector (Rounded r a)
v2 = Vector a -> Vector (Rounded r a)
forall a (r :: RoundingMode). Vector a -> Vector (Rounded r a)
toVectorOfRounded (RoundingMode -> Vector a -> Vector a -> Vector a
forall (vector :: * -> *) a.
RoundedFractional_Vector vector a =>
RoundingMode -> vector a -> vector a -> vector a
zipWith_roundedDiv RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v1) (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v2))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE zipWith_div #-}

-- | Equivalent to @'VS.map' sqrt@
map_sqrt :: forall r a. (Rounding r, Storable a, RoundedSqrt_Vector VS.Vector a) => VS.Vector (Rounded r a) -> VS.Vector (Rounded r a)
map_sqrt :: Vector (Rounded r a) -> Vector (Rounded r a)
map_sqrt Vector (Rounded r a)
v = Vector a -> Vector (Rounded r a)
forall a (r :: RoundingMode). Vector a -> Vector (Rounded r a)
toVectorOfRounded (RoundingMode -> Vector a -> Vector a
forall (vector :: * -> *) a.
RoundedSqrt_Vector vector a =>
RoundingMode -> vector a -> vector a
map_roundedSqrt RoundingMode
r (Vector (Rounded r a) -> Vector a
forall (r :: RoundingMode) a. Vector (Rounded r a) -> Vector a
fromVectorOfRounded Vector (Rounded r a)
v))
  where r :: RoundingMode
r = Proxy r -> RoundingMode
forall (r :: RoundingMode) (proxy :: RoundingMode -> *).
Rounding r =>
proxy r -> RoundingMode
rounding (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
{-# INLINE map_sqrt #-}