{-# LANGUAGE GADTs, ExistentialQuantification, FlexibleInstances #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {- CodeGen.Program. Joel Svensson 2012..2015 Notes: 2013-03-17: Codegeneration is changing -} module Obsidian.CodeGen.Program where import Obsidian.Exp import Obsidian.Globs import Obsidian.Types import Obsidian.Atomic import qualified Obsidian.Program as P import Data.Word import Data.Supply import Data.List import System.IO.Unsafe import Control.Monad.State import Control.Applicative --------------------------------------------------------------------------- -- New Intermediate representation --------------------------------------------------------------------------- type IMList a = [(Statement a,a)] type IM = IMList () out :: a -> [(a,())] out a = [(a,())] -- Atomic operations data AtOp = AtInc | AtAdd IExp | AtSub IExp | AtExch IExp data HLevel = Thread | Warp | Block | Grid -- Statements data Statement t = SAssign IExp [IExp] IExp | SAtomicOp IExp IExp AtOp | SCond IExp (IMList t) | SSeqFor String IExp (IMList t) | SBreak | SSeqWhile IExp (IMList t) -- Iters Body | SForAll HLevel IExp (IMList t) | SDistrPar HLevel IExp (IMList t) -- Memory Allocation.. | SAllocate Name Word32 Type | SDeclare Name Type -- Synchronisation | SSynchronize --------------------------------------------------------------------------- -- Collect and pass around data during first step compilation data Context = Context { ctxNWarps :: Maybe Word32, ctxNThreads :: Maybe Word32, ctxGLBUsesTid :: Bool, ctxGLBUsesWid :: Bool} newtype CM a = CM (State Context a) deriving (Monad, MonadState Context, Functor, Applicative) runCM :: CM a -> Context -> a --runCM (CM cm) ctx = evalState cm ctx runCM (CM cm) = evalState cm evalCM :: CM a -> Context -> (a, Context) evalCM (CM cm) = runState cm setUsesTid :: CM () setUsesTid = modify $ \ctx -> ctx { ctxGLBUsesTid = True } setUsesWid :: CM () setUsesWid = modify $ \ctx -> ctx { ctxGLBUsesWid = True } enterWarp :: Word32 -> CM () enterWarp n = modify $ \ctx -> ctx { ctxNWarps = Just n } enterThread :: Word32 -> CM () enterThread n = modify $ \ctx -> ctx {ctxNThreads = Just n} clearWarp :: CM () clearWarp = modify $ \ctx -> ctx {ctxNWarps = Nothing} getNWarps :: CM (Maybe Word32) getNWarps = do ctx <- get return $ ctxNWarps ctx getNThreads :: CM (Maybe Word32) getNThreads = do ctx <- get return $ ctxNThreads ctx emptyCtx :: Context emptyCtx = Context Nothing Nothing False False --------------------------------------------------------------------------- -- Sort these out and improve! usesWarps :: IMList t -> Bool usesWarps = any (go . fst) where go (SDistrPar _ _ im) = usesWarps im go (SForAll Warp _ _) = True go _ = False usesTid :: IMList t -> Bool usesTid = any (go . fst) where go (SDistrPar _ _ im) = usesTid im go (SForAll Block _ _) = True go (SSeqFor _ _ im) = usesTid im go _ = False usesBid :: IMList t -> Bool usesBid = any (go . fst) where go (SDistrPar Block _ _) = True -- usesBid im -- go (SForAll Block _ _) = True go _ = False usesGid :: IMList t -> Bool usesGid = any (go . fst) where go (SForAll Grid _ _) = True go _ = False --------------------------------------------------------------------------- -- COmpilation of Program to IM --------------------------------------------------------------------------- compileStep1 :: Compile t => P.Program t a -> IM compileStep1 p = snd $ runCM (compile ns p) emptyCtx where ns = unsafePerformIO$ newEnumSupply class Compile t where compile :: Supply Int -> P.Program t a -> CM (a,IM) -- Compile Thread program instance Compile P.Thread where -- Can add cases for P.ForAll here. -- Turn into sequential loop. Could be important to make push -- operate uniformly across entire hierarchy. compile s (P.ForAll n f) = do let (i1,i2) = split2 s nom = "i" ++ show (supplyValue i1) v = variable nom p = f v (a,im) <- compile i2 p return ((),out $ SSeqFor nom (expToIExp n) im) compile s (P.Allocate nom n t) = do (Just nt) <- getNThreads -- must be a Just at this point nw' <- getNWarps let nw = case nw' of Nothing -> 1 Just i -> i return ((),out $ SAllocate nom (nt*nw*n) t) compile _ (P.Sync) = return ((),[]) compile s p = cs s p -- Compile Warp program instance Compile P.Warp where compile s (P.DistrPar n f) = error "Currently not supported to distribute over the threads, use ForAll instead!" compile s (P.ForAll n@(Literal n') f) = do -- setup context to know number of threads -- executing enterThread n' let p = f (variable "warpIx") (a,im) <- compile s p return (a, out $ SForAll Warp (expToIExp n) im) --undefined -- compile a warp program that iterates over a space n large compile s (P.Allocate nom n t) = do (Just nw) <- getNWarps -- Must be a Just here, or something is wrong! return ((),out $ SAllocate nom (nw*n) t) compile s (P.Bind p f) = do let (s1,s2) = split2 s (a,im1) <- compile s1 p (b,im2) <- compile s2 (f a) return (b,(im1 ++ im2)) compile s (P.Return a) = return (a,[]) compile s (P.Identifier) = return (supplyValue s, []) compile s (P.Sync) = return ((),[]) -- Why no fallthrough here ? -- Adding (must have been a horrible mistake!) compile s p = cs s p -- Compile Block program instance Compile P.Block where compile s (P.ForAll n@(Literal n') f) = do -- Set up the context to know the number of -- concurrent thread programs that are executing. enterThread n' setUsesTid let nom = "tid" v = variable nom p = f v (a,im) <- compile s p -- in this case a could be () (since it is guaranteed to be anyway). a return (a,out (SForAll Block (expToIExp n) im)) compile s (P.DistrPar n'@(Literal n) f) = do {- Distribute work over warps! -} -- Set up the context for the compilation -- of the Warp code. -- BUG: Something like this is needed for distribution -- over threads too! -- FIXED: Bug mentioned above should be (at least) partially fixed. enterWarp n -- Number of active warps are stored in the context. (a,im) <- compile s (f (variable "warpID")) return (a, out (SDistrPar Warp (expToIExp n') im)) compile s (P.Allocate id n t) = return ((),out (SAllocate id n t)) compile s (P.Sync) = return ((),out (SSynchronize)) compile s p = cs s p -- Compile a Grid Program instance Compile P.Grid where {- Distribute over blocks -} compile s (P.DistrPar n f) = do -- Need to generate IM here that the backend can read desired number of blocks from let p = f (variable "bid") -- "blockIdx.x") -- (BlockIdx X) (a, im) <- compile s p -- (f (BlockIdx X)) return (a, out (SDistrPar Block (expToIExp n) im)) compile s (P.Allocate _ _ _) = error "Allocate at level Grid" compile s p = cs s p {- ForAll cannot happen here! -} --------------------------------------------------------------------------- -- General compilation --------------------------------------------------------------------------- cs :: forall t a . Compile t => Supply Int -> P.Program t a -> CM (a,IM) cs i P.Identifier = return $ (supplyValue i, []) cs i (P.Assign name ix e) = return $ ((),out (SAssign (IVar name (typeOf e)) (map expToIExp ix) (expToIExp e))) cs i (P.AtomicOp name ix atom) = case atom of AtomicInc -> return $ ((),out (SAtomicOp (IVar name Word32) (expToIExp ix) AtInc)) AtomicAdd e -> error $ "CodeGen.Program: AtomicAdd is not implemented" AtomicSub e -> error $ "CodeGen.Program: AtomicSub is not implemented" AtomicExch e -> error $ "CodeGen.Program: AtomicExch is not implemented" cs i (P.Cond bexp p) = do ((),im) <- compile i p return ((),out (SCond (expToIExp bexp) im)) cs i (P.SeqFor n f) = do let (i1,i2) = split2 i nom = "i" ++ show (supplyValue i1) v = variable nom p = f v (a,im) <- compile i2 p return (a,out (SSeqFor nom (expToIExp n) im)) cs i (P.SeqWhile b p) = do (a,im) <- compile i p return (a, out (SSeqWhile (expToIExp b) im)) cs i (P.Break) = return ((), out SBreak) cs i (P.Allocate id n t) = error $ "CodeGen.Program: Allocate without a hierarchy designation" cs i (P.Declare id t) = return ((),out (SDeclare id t)) cs i (P.Bind p f) = do let (s1,s2) = split2 i (a,im1) <- compile s1 p (b,im2) <- compile s2 (f a) return (b,im1 ++ im2) cs i (P.Return a) = return (a,[]) -- Unhandled cases cs i p = error $ "CodeGen.Program: unhandled in cs: " ++ P.printPrg p -- compile i p --------------------------------------------------------------------------- -- Turning IM to strings (outdated and broken) --------------------------------------------------------------------------- printIM :: Show a => IMList a -> String printIM im = concatMap printStm im -- Print a Statement with metadata printStm :: Show a => (Statement a,a) -> String printStm (SAssign name [] e,m) = show name ++ " = " ++ show e ++ ";" ++ meta m printStm (SAssign name ix e,m) = show name ++ "[" ++ concat (intersperse "," (map show ix)) ++ "]" ++ " = " ++ show e ++ ";" ++ meta m --printStm (SAtomicOp res arr ix op,m) = -- res ++ " = " ++ -- printAtomic op ++ "(" ++ arr ++ "[" ++ show ix ++ "]);" ++ meta m printStm (SAllocate name n t,m) = name ++ " = malloc(" ++ show n ++ ");" ++ meta m printStm (SDeclare name t,m) = show t ++ " " ++ name ++ ";" ++ meta m printStm (SCond bexp im,m) = "if " ++ show bexp ++ "{\n" ++ concatMap printStm im ++ "\n};" ++ meta m printStm (SSynchronize,m) = "sync();" ++ meta m printStm (SSeqFor name n im,m) = "for " ++ name ++ " in [0.." ++ show n ++"] do" ++ meta m ++ concatMap printStm im ++ "\ndone;\n" printStm (SForAll Warp n im,m) = "forAll wid" ++ " in [0.." ++ show n ++"] do" ++ meta m ++ concatMap printStm im ++ "\ndone;\n" printStm (SForAll Block n im,m) = "forAll tid" ++ " in [0.." ++ show n ++"] do" ++ meta m ++ concatMap printStm im ++ "\ndone;\n" printStm (SForAll Grid n im,m) = "forAll gid in [0.." ++ show n ++"] do" ++ meta m ++ concatMap printStm im ++ "\ndone;\n" printStm (SDistrPar lvl n im,m) = "forAll gid in [0.." ++ show n ++"] do" ++ meta m ++ concatMap printStm im ++ "\ndone;\n" meta :: Show a => a -> String meta m = "\t//" ++ show m ++ "\n"