{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
module Futhark.CodeGen.ImpCode
( Functions (..)
, Function
, FunctionT (..)
, ValueDesc (..)
, Signedness (..)
, ExternalValue (..)
, Param (..)
, paramName
, Size (..)
, MemSize
, DimSize
, Type (..)
, Space (..)
, SpaceId
, Code (..)
, PrimValue (..)
, ExpLeaf (..)
, Exp
, Volatility (..)
, Arg (..)
, var
, vi32
, index
, ErrorMsg(..)
, ErrorMsgPart(..)
, ArrayContents(..)
, Bytes
, Elements
, elements
, bytes
, withElemType
, sizeToExp
, dimSizeToExp
, memSizeToExp
, module Language.Futhark.Core
, module Futhark.Representation.Primitive
, module Futhark.Analysis.PrimExp
, module Futhark.Representation.Kernels.Sizes
)
where
import Data.Monoid ((<>))
import Data.List
import Data.Loc
import Data.Traversable
import Language.Futhark.Core
import Futhark.Representation.Primitive
import Futhark.Representation.AST.Syntax
(Space(..), SpaceId, ErrorMsg(..), ErrorMsgPart(..))
import Futhark.Representation.AST.Attributes.Names
import Futhark.Representation.AST.Pretty ()
import Futhark.Analysis.PrimExp
import Futhark.Util.Pretty hiding (space)
import Futhark.Representation.Kernels.Sizes (Count(..))
data Size = ConstSize Int64
| VarSize VName
deriving (Eq, Show)
type MemSize = Size
type DimSize = Size
data Type = Scalar PrimType | Mem MemSize Space
data Param = MemParam VName Space
| ScalarParam VName PrimType
deriving (Show)
paramName :: Param -> VName
paramName (MemParam name _) = name
paramName (ScalarParam name _) = name
newtype Functions a = Functions [(Name, Function a)]
instance Semigroup (Functions a) where
Functions x <> Functions y = Functions $ x ++ y
instance Monoid (Functions a) where
mempty = Functions []
data Signedness = TypeUnsigned
| TypeDirect
deriving (Eq, Show)
data ValueDesc = ArrayValue VName Space PrimType Signedness [DimSize]
| ScalarValue PrimType Signedness VName
deriving (Eq, Show)
data ExternalValue = OpaqueValue String [ValueDesc]
| TransparentValue ValueDesc
deriving (Show)
data FunctionT a = Function { functionEntry :: Bool
, functionOutput :: [Param]
, functionInput :: [Param]
, functionbBody :: Code a
, functionResult :: [ExternalValue]
, functionArgs :: [ExternalValue]
}
deriving (Show)
type Function = FunctionT
data ArrayContents = ArrayValues [PrimValue]
| ArrayZeros Int
deriving (Show)
data Code a = Skip
| Code a :>>: Code a
| For VName IntType Exp (Code a)
| While Exp (Code a)
| DeclareMem VName Space
| DeclareScalar VName PrimType
| DeclareArray VName Space PrimType ArrayContents
| Allocate VName (Count Bytes Exp) Space
| Free VName Space
| Copy VName (Count Bytes Exp) Space VName (Count Bytes Exp) Space (Count Bytes Exp)
| Write VName (Count Elements Exp) PrimType Space Volatility Exp
| SetScalar VName Exp
| SetMem VName VName Space
| Call [VName] Name [Arg]
| If Exp (Code a) (Code a)
| Assert Exp (ErrorMsg Exp) (SrcLoc, [SrcLoc])
| Comment String (Code a)
| DebugPrint String (Maybe (PrimType, Exp))
| Op a
deriving (Show)
data Volatility = Volatile | Nonvolatile
deriving (Eq, Ord, Show)
instance Semigroup (Code a) where
Skip <> y = y
x <> Skip = x
x <> y = x :>>: y
instance Monoid (Code a) where
mempty = Skip
data ExpLeaf = ScalarVar VName
| SizeOf PrimType
| Index VName (Count Elements Exp) PrimType Space Volatility
deriving (Eq, Show)
type Exp = PrimExp ExpLeaf
data Arg = ExpArg Exp
| MemArg VName
deriving (Show)
data Elements
data Bytes
elements :: Exp -> Count Elements Exp
elements = Count
bytes :: Exp -> Count Bytes Exp
bytes = Count
withElemType :: Count Elements Exp -> PrimType -> Count Bytes Exp
withElemType (Count e) t = bytes $ e * LeafExp (SizeOf t) (IntType Int32)
dimSizeToExp :: DimSize -> Count Elements Exp
dimSizeToExp = elements . sizeToExp
memSizeToExp :: MemSize -> Count Bytes Exp
memSizeToExp = bytes . sizeToExp
sizeToExp :: Size -> Exp
sizeToExp (VarSize v) = LeafExp (ScalarVar v) (IntType Int32)
sizeToExp (ConstSize x) = ValueExp $ IntValue $ Int32Value $ fromIntegral x
var :: VName -> PrimType -> Exp
var = LeafExp . ScalarVar
vi32 :: VName -> Exp
vi32 = flip var $ IntType Int32
index :: VName -> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
index arr i t s vol = LeafExp (Index arr i t s vol) t
instance Pretty op => Pretty (Functions op) where
ppr (Functions funs) = stack $ intersperse mempty $ map ppFun funs
where ppFun (name, fun) =
text "Function " <> ppr name <> colon </> indent 2 (ppr fun)
instance Pretty op => Pretty (FunctionT op) where
ppr (Function _ outs ins body results args) =
text "Inputs:" </> block ins </>
text "Outputs:" </> block outs </>
text "Arguments:" </> block args </>
text "Result:" </> block results </>
text "Body:" </> indent 2 (ppr body)
where block :: Pretty a => [a] -> Doc
block = indent 2 . stack . map ppr
instance Pretty Param where
ppr (ScalarParam name ptype) =
ppr ptype <+> ppr name
ppr (MemParam name space) =
text "mem" <> space' <+> ppr name
where space' = case space of Space s -> text "@" <> text s
DefaultSpace -> mempty
instance Pretty ValueDesc where
ppr (ScalarValue t ept name) =
ppr t <+> ppr name <> ept'
where ept' = case ept of TypeUnsigned -> text " (unsigned)"
TypeDirect -> mempty
ppr (ArrayValue mem space et ept shape) =
foldr f (ppr et) shape <+> text "at" <+> ppr mem <> space' <+> ept'
where f e s = brackets $ s <> comma <> ppr e
ept' = case ept of TypeUnsigned -> text " (unsigned)"
TypeDirect -> mempty
space' = case space of Space s -> text "@" <> text s
DefaultSpace -> mempty
instance Pretty ExternalValue where
ppr (TransparentValue v) = ppr v
ppr (OpaqueValue desc vs) =
text "opaque" <+> text desc <+>
nestedBlock "{" "}" (stack $ map ppr vs)
instance Pretty Size where
ppr (ConstSize x) = ppr x
ppr (VarSize v) = ppr v
instance Pretty ArrayContents where
ppr (ArrayValues vs) = braces (commasep $ map ppr vs)
ppr (ArrayZeros n) = braces (text "0") <+> text "*" <+> ppr n
instance Pretty op => Pretty (Code op) where
ppr (Op op) = ppr op
ppr Skip = text "skip"
ppr (c1 :>>: c2) = ppr c1 </> ppr c2
ppr (For i it limit body) =
text "for" <+> ppr i <> text ":" <> ppr it <+> langle <+> ppr limit <+> text "{" </>
indent 2 (ppr body) </>
text "}"
ppr (While cond body) =
text "while" <+> ppr cond <+> text "{" </>
indent 2 (ppr body) </>
text "}"
ppr (DeclareMem name space) =
text "var" <+> ppr name <> text ": mem" <> parens (ppr space)
ppr (DeclareScalar name t) =
text "var" <+> ppr name <> text ":" <+> ppr t
ppr (DeclareArray name space t vs) =
text "array" <+> ppr name <> text "@" <> ppr space <+> text ":" <+> ppr t <+>
equals <+> ppr vs
ppr (Allocate name e space) =
ppr name <+> text "<-" <+> text "malloc" <> parens (ppr e) <> ppr space
ppr (Free name space) =
text "free" <> parens (ppr name) <> ppr space
ppr (Write name i bt space vol val) =
ppr name <> langle <> vol' <> ppr bt <> ppr space <> rangle <> brackets (ppr i) <+>
text "<-" <+> ppr val
where vol' = case vol of Volatile -> text "volatile "
Nonvolatile -> mempty
ppr (SetScalar name val) =
ppr name <+> text "<-" <+> ppr val
ppr (SetMem dest from space) =
ppr dest <+> text "<-" <+> ppr from <+> text "@" <> ppr space
ppr (Assert e msg _) =
text "assert" <> parens (commasep [ppr msg, ppr e])
ppr (Copy dest destoffset destspace src srcoffset srcspace size) =
text "memcpy" <>
parens (ppMemLoc dest destoffset <> ppr destspace <> comma </>
ppMemLoc src srcoffset <> ppr srcspace <> comma </>
ppr size)
where ppMemLoc base offset =
ppr base <+> text "+" <+> ppr offset
ppr (If cond tbranch fbranch) =
text "if" <+> ppr cond <+> text "then {" </>
indent 2 (ppr tbranch) </>
text "} else {" </>
indent 2 (ppr fbranch) </>
text "}"
ppr (Call dests fname args) =
commasep (map ppr dests) <+> text "<-" <+>
ppr fname <> parens (commasep $ map ppr args)
ppr (Comment s code) =
text "--" <+> text s </> ppr code
ppr (DebugPrint desc (Just (pt, e))) =
text "debug" <+> parens (commasep [text (show desc), ppr pt, ppr e])
ppr (DebugPrint desc Nothing) =
text "debug" <+> parens (text (show desc))
instance Pretty Arg where
ppr (MemArg m) = ppr m
ppr (ExpArg e) = ppr e
instance Pretty ExpLeaf where
ppr (ScalarVar v) =
ppr v
ppr (Index v is bt space vol) =
ppr v <> langle <> vol' <> ppr bt <> space' <> rangle <> brackets (ppr is)
where space' = case space of DefaultSpace -> mempty
Space s -> text "@" <> text s
vol' = case vol of Volatile -> text "volatile "
Nonvolatile -> mempty
ppr (SizeOf t) =
text "sizeof" <> parens (ppr t)
instance Functor Functions where
fmap = fmapDefault
instance Foldable Functions where
foldMap = foldMapDefault
instance Traversable Functions where
traverse f (Functions funs) =
Functions <$> traverse f' funs
where f' (name, fun) = (name,) <$> traverse f fun
instance Functor FunctionT where
fmap = fmapDefault
instance Foldable FunctionT where
foldMap = foldMapDefault
instance Traversable FunctionT where
traverse f (Function entry outs ins body results args) =
Function entry outs ins <$> traverse f body <*> pure results <*> pure args
instance Functor Code where
fmap = fmapDefault
instance Foldable Code where
foldMap = foldMapDefault
instance Traversable Code where
traverse f (x :>>: y) =
(:>>:) <$> traverse f x <*> traverse f y
traverse f (For i it bound code) =
For i it bound <$> traverse f code
traverse f (While cond code) =
While cond <$> traverse f code
traverse f (If cond x y) =
If cond <$> traverse f x <*> traverse f y
traverse f (Op kernel) =
Op <$> f kernel
traverse _ Skip =
pure Skip
traverse _ (DeclareMem name space) =
pure $ DeclareMem name space
traverse _ (DeclareScalar name bt) =
pure $ DeclareScalar name bt
traverse _ (DeclareArray name space t vs) =
pure $ DeclareArray name space t vs
traverse _ (Allocate name size s) =
pure $ Allocate name size s
traverse _ (Free name space) =
pure $ Free name space
traverse _ (Copy dest destoffset destspace src srcoffset srcspace size) =
pure $ Copy dest destoffset destspace src srcoffset srcspace size
traverse _ (Write name i bt val space vol) =
pure $ Write name i bt val space vol
traverse _ (SetScalar name val) =
pure $ SetScalar name val
traverse _ (SetMem dest from space) =
pure $ SetMem dest from space
traverse _ (Assert e msg loc) =
pure $ Assert e msg loc
traverse _ (Call dests fname args) =
pure $ Call dests fname args
traverse f (Comment s code) =
Comment s <$> traverse f code
traverse _ (DebugPrint s v) =
pure $ DebugPrint s v
declaredIn :: Code a -> Names
declaredIn (DeclareMem name _) = oneName name
declaredIn (DeclareScalar name _) = oneName name
declaredIn (DeclareArray name _ _ _) = oneName name
declaredIn (If _ t f) = declaredIn t <> declaredIn f
declaredIn (x :>>: y) = declaredIn x <> declaredIn y
declaredIn (For i _ _ body) = oneName i <> declaredIn body
declaredIn (While _ body) = declaredIn body
declaredIn (Comment _ body) = declaredIn body
declaredIn _ = mempty
instance FreeIn a => FreeIn (Code a) where
freeIn' (x :>>: y) =
fvBind (declaredIn x) $ freeIn' x <> freeIn' y
freeIn' Skip =
mempty
freeIn' (For i _ bound body) =
fvBind (oneName i) $ freeIn' bound <> freeIn' body
freeIn' (While cond body) =
freeIn' cond <> freeIn' body
freeIn' DeclareMem{} =
mempty
freeIn' DeclareScalar{} =
mempty
freeIn' DeclareArray{} =
mempty
freeIn' (Allocate name size _) =
freeIn' name <> freeIn' size
freeIn' (Free name _) =
freeIn' name
freeIn' (Copy dest x _ src y _ n) =
freeIn' dest <> freeIn' x <> freeIn' src <> freeIn' y <> freeIn' n
freeIn' (SetMem x y _) =
freeIn' x <> freeIn' y
freeIn' (Write v i _ _ _ e) =
freeIn' v <> freeIn' i <> freeIn' e
freeIn' (SetScalar x y) =
freeIn' x <> freeIn' y
freeIn' (Call dests _ args) =
freeIn' dests <> freeIn' args
freeIn' (If cond t f) =
freeIn' cond <> freeIn' t <> freeIn' f
freeIn' (Assert e _ _) =
freeIn' e
freeIn' (Op op) =
freeIn' op
freeIn' (Comment _ code) =
freeIn' code
freeIn' (DebugPrint _ v) =
maybe mempty (freeIn' . snd) v
instance FreeIn ExpLeaf where
freeIn' (Index v e _ _ _) = freeIn' v <> freeIn' e
freeIn' (ScalarVar v) = freeIn' v
freeIn' (SizeOf _) = mempty
instance FreeIn Arg where
freeIn' (MemArg m) = freeIn' m
freeIn' (ExpArg e) = freeIn' e
instance FreeIn Size where
freeIn' (VarSize name) = fvName name
freeIn' (ConstSize _) = mempty