-- | The @repa-plugin@ rewrites client code to use these primitives. -- -- The plugin will use whatever names are in scope, so if you want to debug -- your code you can import a different implementation of these primitives -- into the module to be vectorized. -- module Data.Array.Repa.Series.Prim ( Primitives (..) , primitives ) where import Data.Array.Repa.Series.Vector as V import Data.Array.Repa.Series.Series as S import Data.Array.Repa.Series.Ref as Ref import GHC.Exts import GHC.Types type World = State# RealWorld -- | Primitives needed by the repa-plugin. data Primitives = Primitives { prim_Series :: forall k a. Series k a , prim_Vector :: forall a. Vector a , prim_Ref :: forall a. Ref a -- Int ------------------------------------------------ -- Arith Int , prim_addInt :: Int# -> Int# -> Int# , prim_subInt :: Int# -> Int# -> Int# , prim_mulInt :: Int# -> Int# -> Int# , prim_divInt :: Int# -> Int# -> Int# , prim_modInt :: Int# -> Int# -> Int# , prim_remInt :: Int# -> Int# -> Int# -- Eq Int , prim_eqInt :: Int# -> Int# -> Bool , prim_neqInt :: Int# -> Int# -> Bool , prim_gtInt :: Int# -> Int# -> Bool , prim_geInt :: Int# -> Int# -> Bool , prim_ltInt :: Int# -> Int# -> Bool , prim_leInt :: Int# -> Int# -> Bool -- Ref Int , prim_newRefInt :: Int# -> World -> (# World, Ref Int #) , prim_readRefInt :: Ref Int -> World -> (# World, Int# #) , prim_writeRefInt :: Ref Int -> Int# -> World -> World -- Ref (Int,Int) , prim_newRefInt_T2 :: (# Int#, Int# #) -> World -> (# World, Ref (Int,Int) #) , prim_readRefInt_T2 :: Ref (Int,Int) -> World -> (# World, (# Int#, Int# #) #) , prim_writeRefInt_T2 :: Ref (Int,Int) -> (# Int#, Int# #) -> World -> World -- Vector Int , prim_newVectorInt :: Int# -> World -> (# World, Vector Int #) , prim_readVectorInt :: Vector Int -> Int# -> World -> (# World, Int# #) , prim_writeVectorInt :: Vector Int -> Int# -> Int# -> World -> World , prim_sliceVectorInt :: Int# -> Vector Int -> World -> (# World, Vector Int #) -- Loop , prim_rateOfSeries :: forall k a. Series k a -> Int# , prim_loop :: Int# -> (Int# -> World -> World) -> World -> World , prim_guard :: Ref Int -> Bool -> (Int# -> World -> World) -> World -> World -- Next , prim_nextInt :: forall k . Series k Int -> Int# -> World -> (# World, Int# #) , prim_nextInt_T2 :: forall k . Series k (Int,Int) -> Int# -> World -> (# World, (# Int#, Int# #) #) } -- | Table of primitives used by the repa-plugin. primitives :: Primitives primitives = Primitives { prim_Series = undefined , prim_Vector = undefined , prim_Ref = undefined -- Arith Int , prim_addInt = (+#) , prim_subInt = (-#) , prim_mulInt = (*#) , prim_divInt = repa_divInt , prim_modInt = repa_modInt , prim_remInt = repa_remInt -- Eq Int , prim_eqInt = repa_eqInt , prim_neqInt = repa_neqInt , prim_gtInt = repa_gtInt , prim_geInt = repa_geInt , prim_ltInt = repa_ltInt , prim_leInt = repa_leInt -- Ref Int , prim_newRefInt = repa_newRefInt , prim_readRefInt = repa_readRefInt , prim_writeRefInt = repa_writeRefInt -- Ref (Int,Int) , prim_newRefInt_T2 = repa_newRefInt_T2 , prim_readRefInt_T2 = repa_readRefInt_T2 , prim_writeRefInt_T2 = repa_writeRefInt_T2 -- Vector Int , prim_newVectorInt = repa_newVectorInt , prim_readVectorInt = repa_readVectorInt , prim_writeVectorInt = repa_writeVectorInt , prim_sliceVectorInt = repa_sliceVectorInt -- Loop , prim_rateOfSeries = repa_rateOfSeries , prim_loop = repa_loop , prim_guard = repa_guard -- Next , prim_nextInt = repa_nextInt , prim_nextInt_T2 = repa_nextInt_T2 } -- Utils ---------------------------------------------------------------------- unwrapIO' :: IO a -> State# RealWorld -> (# State# RealWorld, a #) unwrapIO' (IO f) = f {-# INLINE unwrapIO' #-} unwrapIO_ :: IO a -> State# RealWorld -> State# RealWorld unwrapIO_ (IO f) world = case f world of (# world', _ #) -> world' {-# INLINE unwrapIO_ #-} -- Arith Int ------------------------------------------------------------------ repa_divInt i1 i2 = case div (I# i1) (I# i2) of I# i3 -> i3 {-# INLINE repa_divInt #-} repa_modInt i1 i2 = case mod (I# i1) (I# i2) of I# i3 -> i3 {-# INLINE repa_modInt #-} repa_remInt i1 i2 = case rem (I# i1) (I# i2) of I# i3 -> i3 {-# INLINE repa_remInt #-} -- Eq Int --------------------------------------------------------------------- repa_eqInt i1 i2 = I# i1 == I# i2 {-# INLINE repa_eqInt #-} repa_neqInt i1 i2 = I# i1 /= I# i2 {-# INLINE repa_neqInt #-} repa_gtInt i1 i2 = I# i1 > I# i2 {-# INLINE repa_gtInt #-} repa_geInt i1 i2 = I# i1 >= I# i2 {-# INLINE repa_geInt #-} repa_ltInt i1 i2 = I# i1 < I# i2 {-# INLINE repa_ltInt #-} repa_leInt i1 i2 = I# i1 <= I# i2 {-# INLINE repa_leInt #-} -- Ref Int -------------------------------------------------------------------- repa_newRefInt :: Int# -> World -> (# World, Ref Int #) repa_newRefInt x = unwrapIO' (Ref.new (I# x)) {-# INLINE repa_newRefInt #-} repa_readRefInt :: Ref Int -> World -> (# World, Int# #) repa_readRefInt ref = case Ref.read ref of IO f -> \world -> case f world of (# world', I# i #) -> (# world', i #) {-# INLINE repa_readRefInt #-} repa_writeRefInt :: Ref Int -> Int# -> World -> World repa_writeRefInt ref val = unwrapIO_ (Ref.write ref (I# val)) {-# INLINE repa_writeRefInt #-} -- Ref (Int,Int) -------------------------------------------------------------- repa_newRefInt_T2 :: (# Int#, Int# #) -> World -> (# World, Ref (Int,Int) #) repa_newRefInt_T2 (# x, y #) = unwrapIO' (Ref.new (I# x, I# y)) {-# INLINE repa_newRefInt_T2 #-} repa_readRefInt_T2 :: Ref (Int,Int) -> World -> (# World, (# Int#, Int# #) #) repa_readRefInt_T2 ref = case Ref.read ref of IO f -> \world -> case f world of (# world', (I# i, I# j) #) -> (# world', (# i, j #) #) {-# INLINE repa_readRefInt_T2 #-} repa_writeRefInt_T2 :: Ref (Int,Int) -> (# Int#, Int# #) -> World -> World repa_writeRefInt_T2 ref (# i, j #) = unwrapIO_ (Ref.write ref (I# i, I# j)) {-# INLINE repa_writeRefInt_T2 #-} -- Vector Int ----------------------------------------------------------------- repa_newVectorInt :: Int# -> World -> (# World, Vector Int #) repa_newVectorInt len = unwrapIO' (V.new len) {-# INLINE repa_newVectorInt #-} repa_readVectorInt :: Vector Int -> Int# -> World -> (# World, Int# #) repa_readVectorInt vec ix = case V.read vec ix of IO f -> \world -> case f world of (# world', I# i #) -> (# world', i #) {-# INLINE repa_readVectorInt #-} repa_writeVectorInt :: Vector Int -> Int# -> Int# -> World -> World repa_writeVectorInt vec ix val = unwrapIO_ (V.write vec ix (I# val)) {-# INLINE repa_writeVectorInt #-} repa_sliceVectorInt :: Int# -> Vector Int -> World -> (# World, Vector Int #) repa_sliceVectorInt len vec = unwrapIO' (V.take len vec) {-# INLINE repa_sliceVectorInt #-} -- Loop combinators ----------------------------------------------------------- -- | Primitive stateful loop combinator. repa_loop :: Int# -> (Int# -> World -> World) -> World -> World repa_loop len worker world0 = go 0# world0 where go ix world | ix >=# len = world | world' <- worker ix world = go (ix +# 1#) world' {-# INLINE repa_loop #-} -- | Guard an inner context with a flag. repa_guard :: Ref Int -> Bool -> (Int# -> World -> World) -> World -> World repa_guard ref flag worker world0 | False <- flag = world0 | (# world1, ix #) <- repa_readRefInt ref world0 , world2 <- repa_writeRefInt ref (ix +# 1#) world1 , world3 <- worker ix world2 = world3 {-# INLINE repa_guard #-} -- Series --------------------------------------------------------------------- -- | Get the Rate / Length of a series. repa_rateOfSeries :: Series k a -> Int# repa_rateOfSeries s = seriesLength s {-# INLINE repa_rateOfSeries #-} -- | Get the next element of a series. repa_nextInt :: Series k Int -> Int# -> World -> (# World, Int# #) repa_nextInt s ix world = case S.index s ix of I# i -> (# world, i #) {-# INLINE repa_nextInt #-} -- TODO generalise repa_nextInt_T2 :: Series k (Int,Int) -> Int# -> World -> (# World, (# Int#, Int# #) #) repa_nextInt_T2 s ix world = case S.index s ix of (I# i1, I# i2) -> (# world, (# i1, i2 #) #) {-# INLINE repa_nextInt_T2 #-}