{-# LANGUAGE CPP #-} {- | A module capturing common patterns in ArBB VM code emission and thereby making it easier to emit code. -} module Intel.ArbbVM.Convenience ( ifThenElse, while, readScalarOfSize, newConstant, arbbSession, EmitArbb, ConvFunction, if_, while_, readScalar_, funDef_, funDefS_, -- funDefCallable_, op_, opImm_, opDynamic_, opDynamicImm_, map_, call_, mapToHost_, const_, int8_, int16_, int32_, int64_,float32_, float64_, bool_, uint8_, uint16_, uint32_, uint64_, const_storable_, usize_, isize_, incr_int32_, copy_, copyImm_, local_bool_, local_int8_, local_int32_, local_float32_, local_float64_, -- global_nobind_, global_nobind_int32_, -- Compile does not exist in this way anymore --compile_, execute_, serializeFunction_,serializeFunctionWrapper_, finish_, -- getBindingNull_, getScalarType_, variableFromGlobal_, getFunctionType_, createGlobal_, createLocal_, createGlobal_nobind_, createDenseBinding_, getDenseType_, freeBinding_, getNestedType_, withArray_, print_, doarith_, SimpleArith(V), -- module Intel.ArbbVM.SimpleArith, liftIO, liftMs, -- These should probably be internal only: getCtx, getFun, getConvFun, -- low level setNumThreads_, setDecompDegree_ ) where --import qualified Intel.ArbbVM as VM import Intel.ArbbVM as VM -- import Intel.ArbbVM.SimpleArith import Control.Monad import Data.IORef import Data.Word import Data.Int import Data.Serialize import Data.ByteString.Internal import Foreign.Marshal.Array import Foreign.Marshal.Alloc import Foreign.ForeignPtr import Foreign.Storable as Storable import Foreign.Ptr -- import C2HS import qualified Control.Monad.State.Strict as S import Debug.Trace -------------------------------------------------------------------------------- -- The monad for emitting Arbb code. type EmitArbb = S.StateT ArbbEmissionState IO -- We put the context and a stack of function types into the -- background. Note, we need the stack of functions because we allow -- nested function definitions at this convenience layer. (They will -- all have global scope to ArBB however.) type ArbbEmissionState = (Context, [ConvFunction]) -- BJS: ConvFunction in place of Function -- Note: if we also include a counter in here we can do gensyms... -- BJS: Convenient functions are pairs of inconvenient functions data ConvFunction = ConvFunction { executable :: Function, -- just a wrapper callable :: Function } -- "The" function deriving Show instance Show Function where show f = "Function" #define L S.lift$ liftIO :: IO a -> EmitArbb a liftIO = S.lift -- Allow the user to perform IO inside this monad. arbbSession :: EmitArbb a -> IO a arbbSession m = do ctx <- getDefaultContext (a,s) <- S.runStateT m (ctx,[]) return a -- BJS: ConvFunction (as a name) is a bit inconvenient getConvFun msg = do (_,ls) <- S.get case ls of [] -> error$ msg ++" when not inside a function" (h:t) -> return h getFun msg = do (_,ls) <- S.get case ls of [] -> error$ msg ++" when not inside a function" (h:t) -> return (callable h) getCtx = do (ctx,_) <- S.get return ctx ------------------------------------------------------------------------------ -- Map an ArBB array into host addrspace mapToHost_ :: Variable -> [Word64] -> RangeAccessMode -> EmitArbb (Ptr ()) mapToHost_ var pitch mode = do ctx <- getCtx L mapToHost ctx var pitch mode -------------------------------------------------------------------------------- -- Convenience functions for common patterns: opImm_ :: Opcode -> [Variable] -> [Variable] -> EmitArbb() opImm_ code out inp = L opImm code out inp op_ :: Opcode -> [Variable] -> [Variable] -> EmitArbb () op_ code out inp = do fun <- getFun "Convenience.op_ cannot execute an Opcode" L op fun code out inp opDynamicImm_ :: Opcode -> [Variable] -> [Variable] -> EmitArbb() opDynamicImm_ code out inp = L opDynamicImm code out inp opDynamic_ :: Opcode -> [Variable] -> [Variable] -> EmitArbb () opDynamic_ code out inp = do fun <- getFun "Convenience.opDynamic_ cannot execute an Opcode" L opDynamic fun code out inp if_ :: Variable -> EmitArbb a -> EmitArbb a1 -> EmitArbb () if_ c t e = do fun <- getFun "Convenience.if_ cannot execute a conditional" L ifBranch fun c t -- op myfun ArbbOpSub [c] [a,a] L elseBranch fun e -- op myfun ArbbOpDiv [c] [a,a] L endIf fun -- | An ArBB while loop. Must be called inside a function definition. while_ :: (EmitArbb Variable) -> EmitArbb a -> EmitArbb a while_ cond body = do fun <- getFun "Convenience.while_ cannot execute a while loop" L beginLoop fun ArbbLoopWhile L beginLoopBlock fun ArbbLoopBlockCond lc <- cond L loopCondition fun lc L beginLoopBlock fun ArbbLoopBlockBody result <- body L endLoop fun return result const_storable_ :: Storable a => ScalarType -> a -> EmitArbb Variable const_storable_ st n = do ctx <- getCtx L newConstantAlt ctx st n -- This version picks the right in-memory representation. const_ :: Integral a => ScalarType -> a -> EmitArbb Variable const_ sty i = case sty of ArbbI8 -> const_storable_ sty (fromIntegral i :: Int8) ArbbI16 -> const_storable_ sty (fromIntegral i :: Int16) ArbbI32 -> const_storable_ sty (fromIntegral i :: Int32) ArbbI64 -> const_storable_ sty (fromIntegral i :: Int64) ArbbU8 -> const_storable_ sty (fromIntegral i :: Word8) ArbbU16 -> const_storable_ sty (fromIntegral i :: Word16) ArbbU32 -> const_storable_ sty (fromIntegral i :: Word32) ArbbU64 -> const_storable_ sty (fromIntegral i :: Word64) -- This only lets you get at the integral floating point numbers: ArbbF32 -> const_storable_ sty (fromIntegral i :: Float) ArbbF64 -> const_storable_ sty (fromIntegral i :: Double) ArbbUsize -> const_storable_ sty (fromIntegral i :: Word) ArbbIsize -> const_storable_ sty (fromIntegral i :: Int) readScalar_ :: (Num a, Storable a) => Variable -> EmitArbb a readScalar_ v = do ctx <- getCtx let z = 0 size = Storable.sizeOf z x <- L readScalarOfSize size ctx v return (x+z) type FunBody = [Variable] -> [Variable] -> EmitArbb () debug_fundef = False {- BJS: This funDef_ situation might need some improvement. - one option is create the functions both callable and not a Function would need to be a pair of the "callable" function and the "executable" function - This is a part of ArBB that is very likely to change -} {- funDef_ name outty inty userbody = funDefInternal name outty inty userbody 1 funDefCallable_ name outty inty userbody = funDefInternal name outty inty userbody 0 funDefInternal :: String -> [Type] -> [Type] -> FunBody -> Int -> EmitArbb Function funDefInternal name outty inty userbody remote = do ctx <- getCtx fnt <- L getFunctionType ctx outty inty fun <- L beginFunction ctx fnt name remote when debug_fundef$ print_$ "["++name++"] Function begun." invars <- L forM [0 .. length inty - 1] (getParameter fun 0) outvars <- L forM [0 .. length outty - 1] (getParameter fun 1) -- Push on the stack: S.modify (\ (c,ls) -> (c, fun:ls)) -- Now generate body: when debug_fundef$ print_$ "["++name++"] Begin body codgen..." userbody outvars invars when debug_fundef$ print_$ "["++name++"] Done body codgen." -- Pop off the stack: S.modify (\ (c, h:t) -> (c, t)) L endFunction fun -- EXPERIMENTAL! Compile immediately!! when debug_fundef$ print_$ "["++name++"] Function ended. Compiling..." L compile fun when debug_fundef$ print_$ "["++name++"] Done compiling." return fun -} is_callable = 0 is_executable = 1 nullfun = Function nullPtr -- BJS: New funDef (create a pair of funs, actual fun + wrapper) -- BJS: Names seem to be only there as a help when printing in human readable form. funDef_ :: String -> [Type] -> [Type] -> FunBody -> EmitArbb ConvFunction funDef_ name outty inty userbody = do ctx <- getCtx fnt <- L getFunctionType ctx outty inty fun <- L beginFunction ctx fnt name is_callable when debug_fundef$ print_$ "["++name++"] Function begun." invars <- L forM [0 .. length inty - 1] (getParameter fun 0) outvars <- L forM [0 .. length outty - 1] (getParameter fun 1) -- Push on the stack: S.modify (\ (c,ls) -> (c, (ConvFunction nullfun fun):ls)) -- Now generate body: when debug_fundef$ print_$ "["++name++"] Begin body codgen..." userbody outvars invars when debug_fundef$ print_$ "["++name++"] Done body codgen." -- Pop off the stack: S.modify (\ (c, h:t) -> (c, t)) L endFunction fun when debug_fundef$ print_$ "["++name++"] Function ended." -- EXPERIMENTAL! Compile immediately!! --L compile fun --when debug_fundef$ print_$ "["++name++"] Done compiling." -- Also create an executable wrapper wrapper <- L beginFunction ctx fnt (name ++ "W") is_executable when debug_fundef$ print_$ "["++name++ "W" ++"] Wrapper function begun." inputs <- L forM [0 .. length inty - 1] (getParameter wrapper 0) outputs <- L forM [0 .. length outty - 1] (getParameter wrapper 1) when debug_fundef$ print_$ "["++name++"] Begin wrapper body codgen..." L callOp wrapper ArbbOpCall fun outputs inputs when debug_fundef$ print_$ "["++name++"] Done wrapper body codgen." L endFunction wrapper when debug_fundef$ print_$ "["++name++"] Wrapper function ended." return$ ConvFunction wrapper fun -- Umm... what's a good naming convention here? funDefS_ :: String -> [ScalarType] -> [ScalarType] -> FunBody -> EmitArbb ConvFunction funDefS_ name outs ins body = do outs' <- mapM getScalarType_ outs ins' <- mapM getScalarType_ ins funDef_ name outs' ins' body {- call_ :: Function -> [Variable] -> [Variable] -> EmitArbb () call_ fun out inp = do -- At the point of the call the *caller* is on the top of the stack: caller <- getFun "Convenience.call_ cannot call function" when debug_fundef$ print_ "Call_: got caller function, emitting call opcode..." L callOp caller ArbbOpCall fun out inp when debug_fundef$ print_ "Call_: Done emitting call opcode." -} call_ :: ConvFunction -> [Variable] -> [Variable] -> EmitArbb () call_ fun out inp = do -- At the point of the call the *caller* is on the top of the stack: caller <- getFun "Convenience.call_ cannot call function" when debug_fundef$ print_ "Call_: got caller function, emitting call opcode..." L callOp caller ArbbOpCall (callable fun) out inp when debug_fundef$ print_ "Call_: Done emitting call opcode." map_ :: ConvFunction -> [Variable] -> [Variable] -> EmitArbb () map_ fun out inp = do -- At the point of the call the *caller* is on the top of the stack: caller <- getFun "Convenience.map_ cannot call function" when debug_fundef$ print_ "Map_: got caller function, emitting map opcode..." L callOp caller ArbbOpMap (callable fun) out inp when debug_fundef$ print_ "Map_: Done emitting map opcode." -------------------------------------------------------------------------------- -- Iteration Patterns. -- for_range_ :: Variable -> Variable -> (i -> EmitArbb ()) -> EmitArbb () -- This uses C-style [inclusive,exclusive) ranges. -- for_constRange_ :: Int -> Int -> (i -> EmitArbb ()) -> EmitArbb () -- for_range_ start end body = do -- counter <- local_int32_ "counter" -- op_ ArbbOpCopy [counter] [zer] -- while_ (do -- lc <- local_bool_ "loopcond" -- op_ ArbbOpLess [lc] [counter,max] -- return lc) -- (op_ ArbbOpAdd [counter] [counter,one]) -- Lifting higher order ops like this is trickier. withArray_ :: Storable a => [a] -> (Ptr a -> EmitArbb b) -> EmitArbb b withArray_ ls body = do state <- S.get ref <- L newIORef state let body2 ptr = do (a,s2) <- S.runStateT (body ptr) state writeIORef ref s2 return a res <- L withArray ls body2 state2 <- L readIORef ref S.put state2 return res print_ :: String -> EmitArbb () print_ = S.lift . putStrLn -------------------------------------------------------------------------------- -- liftM for lists liftMs :: Monad m => ([a] -> m b) -> [m a] -> m b -- liftMs fn ls = liftM fn (sequence ls) liftMs fn ls = sequence ls >>= fn -- do ls' <- sequence ls -- fn ls' -- These let us lift the slew of ArbbVM functions that expect a Context as a first argument. lift1 :: (Context -> a -> IO b) -> a -> EmitArbb b lift2 :: (Context -> a -> b -> IO c) -> a -> b -> EmitArbb c lift3 :: (Context -> a -> b -> c -> IO d) -> a -> b -> c -> EmitArbb d lift4 :: (Context -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> EmitArbb e lift1 fn a = do ctx <- getCtx; L fn ctx a lift2 fn a b = do ctx <- getCtx; L fn ctx a b lift3 fn a b c = do ctx <- getCtx; L fn ctx a b c lift4 fn a b c d = do ctx <- getCtx; L fn ctx a b c d getScalarType_ = lift1 getScalarType getDenseType_ = lift2 getDenseType getNestedType_ = lift1 getNestedType variableFromGlobal_ = lift1 variableFromGlobal getFunctionType_ = lift2 getFunctionType createGlobal_ = lift3 createGlobal createGlobal_nobind_ = lift2 createGlobalNB createLocal_ :: Type -> String -> EmitArbb Variable createLocal_ ty name = do f <- getFun "Convenience.createLocal_ cannot create local" L createLocal f ty name createDenseBinding_ = lift4 createDenseBinding freeBinding_ = lift1 freeBinding -- setDecompDegree_ = lift1 setDecompDegree setNumThreads_ = lift1 setNumThreads ---------------------------------------- -- These are easy ones, no Context or Function argument: -- compile_ fn = liftIO$ compile fn -- execute_ a b c = liftIO$ execute a b c finish_ = liftIO finish --serializeFunction_ = liftIO . serializeFunction --getBindingNull_ = liftIO getBindingNull --BJS: execute_ nolonger quite as simple execute_ f o i = liftIO$ execute (executable f) o i --BJS: should be a way to serialize the wrapper also! serializeFunction_ f = liftIO$ serializeFunction (callable f) serializeFunctionWrapper_ f = liftIO$ serializeFunction (executable f) -- ... TODO ... Keep going. -------------------------------------------------------------------------------- -- Lazy, lazy, lazy: here are even more shorthands. int8_ :: Integral t => t -> EmitArbb Variable int16_ :: Integral t => t -> EmitArbb Variable int32_ :: Integral t => t -> EmitArbb Variable int64_ :: Integral t => t -> EmitArbb Variable uint8_ :: Integral t => t -> EmitArbb Variable uint16_ :: Integral t => t -> EmitArbb Variable uint32_ :: Integral t => t -> EmitArbb Variable uint64_ :: Integral t => t -> EmitArbb Variable usize_ :: Integral t => t -> EmitArbb Variable isize_ :: Integral t => t -> EmitArbb Variable float32_ :: Float -> EmitArbb Variable float64_ :: Double -> EmitArbb Variable int8_ = const_ ArbbI8 int16_ = const_ ArbbI16 int32_ = const_ ArbbI32 int64_ = const_ ArbbI64 uint8_ = const_ ArbbU8 uint16_ = const_ ArbbU16 uint32_ = const_ ArbbU32 uint64_ = const_ ArbbU64 usize_ = const_ ArbbUsize isize_ = const_ ArbbIsize float32_ = const_storable_ ArbbF32 float64_ = const_storable_ ArbbF64 bool_ :: Bool -> EmitArbb Variable bool_ True = const_storable_ ArbbBoolean (1::Int32) bool_ False = const_storable_ ArbbBoolean (0::Int32) -- TODO... Keep going... incr_int32_ :: Variable -> EmitArbb () incr_int32_ var = do one <- int32_ 1 op_ ArbbOpAdd [var] [var,one] copy_ v1 v2 = op_ ArbbOpCopy [v1] [v2] copyImm_ v1 v2 = opImm_ ArbbOpCopy [v1] [v2] ------------------------------------------------------------ local_bool_ name = do bty <- getScalarType_ ArbbBoolean createLocal_ bty name local_int32_ name = do ity <- getScalarType_ ArbbI32 createLocal_ ity name local_int8_ name = do ity <- getScalarType_ ArbbI8 createLocal_ ity name local_float32_ name = do ity <- getScalarType_ ArbbF32 createLocal_ ity name local_float64_ name = do ty <- getScalarType_ ArbbF64 createLocal_ ty name {- global_nobind_ ty name = do binding <- getBindingNull_ g <- createGlobal_ ty name binding variableFromGlobal_ g global_nobind_int32_ name = do sty <- getScalarType_ ArbbI32 global_nobind_ sty name -} -------------------------------------------------------------------------------- -- Num instance. -------------------------------------------------------------------------------- -- OBSOLETE: These were some helpers that didn't use the EmitArbb monad. ifThenElse :: Function -> Variable -> IO a -> IO a1 -> IO () ifThenElse f c t e = do ifBranch f c t -- op myfun ArbbOpSub [c] [a,a] elseBranch f e -- op myfun ArbbOpDiv [c] [a,a] endIf f -- while loops while :: Function -> (IO Variable) -> IO a1 -> IO () while f cond body = do beginLoop f ArbbLoopWhile beginLoopBlock f ArbbLoopBlockCond lc <- cond loopCondition f lc beginLoopBlock f ArbbLoopBlockBody body endLoop f -- Works not just for arrays but anything serializable: withSerialized :: Serialize a => a -> (Ptr () -> IO b) -> IO b withSerialized x fn = withForeignPtr fptr (fn . castPtr) where (fptr,_,_) = toForeignPtr (encode x) newConstant :: Storable a => Context -> Type -> a -> IO Variable newConstant ctx t n = do -- Could use withSerialized possibly... tmp <- withArray [n] $ \x -> createConstant ctx t (castPtr x) variableFromGlobal ctx tmp newConstantAlt :: Storable a => Context -> ScalarType -> a -> IO Variable newConstantAlt ctx st n = do t <- getScalarType ctx st tmp <- withArray [n] $ \x -> createConstant ctx t (castPtr x) variableFromGlobal ctx tmp -- global/constant shortcuts -- readScalarOfSize :: Storable b => Int -> Context -> Variable -> EmitArbb b readScalarOfSize :: Storable b => Int -> Context -> Variable -> IO b readScalarOfSize n ctx v = allocaBytes n $ \ptr -> do readScalar ctx v ptr peek (castPtr ptr) -- TODO: readScalar of storable should be able to determine size. ---------------------------------------------------------------------------------------------------- -- Numeric instances. ---------------------------------------------------------------------------------------------------- {- It is not clear that this is worth it, yet. This is mainly here because I want to use Data.Complex. -} -- We could use a richer type for Variable in the convenience interface. -- data VarPlus = VarPlus Variable Type -- -- The trick would be to have functions like op_ take [SimpleArith] rather than [Variable] -- But to dealwith anything other than the "V" variant, the desired type would need to be known. -- instance Show Variable where show v = "" instance Eq Variable where a == b = error "equality on Variables doesn't make sense yet" data SimpleArith = V Variable | Const Integer | ConstD Double | Plus SimpleArith SimpleArith | Times SimpleArith SimpleArith | Div SimpleArith SimpleArith | Signum SimpleArith | Abs SimpleArith | Expon SimpleArith | Sqrt SimpleArith | Log SimpleArith | Sin SimpleArith | Cos SimpleArith | ASin SimpleArith | ACos SimpleArith | ATan SimpleArith | SinH SimpleArith | CosH SimpleArith | ASinH SimpleArith | ACosH SimpleArith | ATanH SimpleArith -- | ProperFrac SimpleArith -- UNFINISHED deriving (Show,Eq) instance Num SimpleArith where (+) = Plus (*) = Times signum = Signum abs = Abs fromInteger = Const instance Fractional SimpleArith where (/) = Div fromRational rat = error "fromRational not implemented yet for SimpleArith" instance Ord SimpleArith where a < b = error "< not implemented yet for SimpleArith" instance Real SimpleArith where toRational v = error "toRational not implemented for SimpleArith" instance Floating SimpleArith where pi = ConstD pi exp = Expon sqrt = Sqrt log = Log -- (**) :: a -> a -> a -- logBase :: a -> a -> a sin = Sin -- tan :: a -> a cos = Cos asin = ASin atan = ATan acos = ACos sinh = SinH cosh = CosH asinh = ASinH acosh = ACosH atanh = ATanH -- instance Enum SimpleArith where -- instance Integral SimpleArith where --class (Real a, Fractional a) => RealFrac a where instance RealFrac SimpleArith where -- properFraction :: Integral b => a -> (b, a) properFraction x = error "properFraction not implemented for SimpleArith" #if 0 class (Real a, Enum a) => Integral a where quot :: a -> a -> a rem :: a -> a -> a div :: a -> a -> a mod :: a -> a -> a quotRem :: a -> a -> (a, a) divMod :: a -> a -> (a, a) toInteger :: a -> Integer class Enum a where succ :: a -> a pred :: a -> a toEnum :: Int -> a fromEnum :: a -> Int enumFrom :: a -> [a] enumFromThen :: a -> a -> [a] enumFromTo :: a -> a -> [a] enumFromThenTo :: a -> a -> a -> [a] class (RealFrac a, Floating a) => RealFloat a where floatRadix :: a -> Integer floatDigits :: a -> Int floatRange :: a -> (Int, Int) decodeFloat :: a -> (Integer, Int) encodeFloat :: Integer -> Int -> a exponent :: a -> Int significand :: a -> a scaleFloat :: Int -> a -> a isNaN :: a -> Bool isInfinite :: a -> Bool isDenormalized :: a -> Bool isNegativeZero :: a -> Bool isIEEE :: a -> Bool atan2 :: a -> a -> a #endif -- | This lets one execute simple arithmetic expressions and store the result. -- Returns the name of a new local binding that caries the result. doarith_ :: ScalarType -> SimpleArith -> EmitArbb Variable doarith_ ty_ exp = do ty <- getScalarType_ ty_ let binop op a b = do tmp <- createLocal_ ty "tmp" a' <- loop a b' <- loop b op_ op [tmp] [a',b'] return tmp loop exp = case exp of V v -> return v Const i -> const_ ty_ i Plus a b -> binop ArbbOpAdd a b Times a b -> binop ArbbOpMul a b _ -> error$ "doarith_: not handled yet: "++ show exp loop exp -- data ScalarType = ArbbI8 -- | ArbbI16 -- | ArbbI32 -- | ArbbI64 -- | ArbbU8 -- | ArbbU16 -- | ArbbU32 -- | ArbbU64 -- deriving (Enum,Show,Eq)