{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Downcast (
Downcast(..)
) where
import Prelude hiding ( Ordering(..), const )
import Data.Bits
import Foreign.C.Types
import Data.Array.Accelerate.AST ( tupleIdxToInt )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.CodeGen.Type
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Flags
import LLVM.AST.Type.Global
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.Compare
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Metadata
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import LLVM.AST.Type.Terminator
import qualified LLVM.AST.Type.Instruction.RMW as RMW
import qualified LLVM.AST.Attribute as L
import qualified LLVM.AST.AddrSpace as L
import qualified LLVM.AST.CallingConvention as L
import qualified LLVM.AST.Constant as LC
import qualified LLVM.AST.Float as L
import qualified LLVM.AST.FloatingPointPredicate as FP
import qualified LLVM.AST.Global as L
import qualified LLVM.AST.Instruction as L
import qualified LLVM.AST.IntegerPredicate as IP
import qualified LLVM.AST.Name as L
import qualified LLVM.AST.Operand as L
import qualified LLVM.AST.RMWOperation as LA
import qualified LLVM.AST.Type as L
class Downcast typed untyped where
downcast :: typed -> untyped
instance Downcast a a' => Downcast [a] [a'] where
downcast = map downcast
instance Downcast a a' => Downcast (Maybe a) (Maybe a') where
downcast Nothing = Nothing
downcast (Just x) = Just (downcast x)
instance (Downcast a a', Downcast b b') => Downcast (a,b) (a',b') where
downcast (a,b) = (downcast a, downcast b)
instance (Downcast a a', Downcast b b') => Downcast (Either a b) (Either a' b') where
downcast (Left a) = Left (downcast a)
downcast (Right b) = Right (downcast b)
nsw :: Bool
nsw = False
nuw :: Bool
nuw = False
fmf :: FastMathFlags
fmf = UnsafeAlgebra
md :: L.InstructionMetadata
md = []
instance Downcast NUW Bool where
downcast NoUnsignedWrap = True
downcast UnsignedWrap = False
instance Downcast NSW Bool where
downcast NoSignedWrap = True
downcast SignedWrap = False
instance Downcast FastMathFlags L.FastMathFlags where
downcast = id
instance Downcast (Name a) L.Name where
downcast (Name s) = L.Name s
downcast (UnName n) = L.UnName n
tailcall :: Maybe L.TailCallKind
tailcall = Nothing
instance Downcast (Instruction a) L.Instruction where
downcast (Add t x y) =
case t of
IntegralNumType{} -> L.Add nsw nuw (downcast x) (downcast y) md
FloatingNumType{} -> L.FAdd fmf (downcast x) (downcast y) md
downcast (Sub t x y) =
case t of
IntegralNumType{} -> L.Sub nsw nuw (downcast x) (downcast y) md
FloatingNumType{} -> L.FSub fmf (downcast x) (downcast y) md
downcast (Mul t x y) =
case t of
IntegralNumType{} -> L.Mul nsw nuw (downcast x) (downcast y) md
FloatingNumType{} -> L.FMul fmf (downcast x) (downcast y) md
downcast (Quot t x y)
| signed t = L.SDiv False (downcast x) (downcast y) md
| otherwise = L.UDiv False (downcast x) (downcast y) md
downcast (Rem t x y)
| signed t = L.SRem (downcast x) (downcast y) md
| otherwise = L.URem (downcast x) (downcast y) md
downcast (Div _ x y) = L.FDiv fmf (downcast x) (downcast y) md
downcast (ShiftL _ x i) = L.Shl nsw nuw (downcast x) (downcast i) md
downcast (ShiftRL _ x i) = L.LShr False (downcast x) (downcast i) md
downcast (ShiftRA _ x i) = L.AShr False (downcast x) (downcast i) md
downcast (BAnd _ x y) = L.And (downcast x) (downcast y) md
downcast (LAnd x y) = L.And (downcast x) (downcast y) md
downcast (BOr _ x y) = L.Or (downcast x) (downcast y) md
downcast (LOr x y) = L.Or (downcast x) (downcast y) md
downcast (BXor _ x y) = L.Xor (downcast x) (downcast y) md
downcast (LNot x) = L.Xor (downcast x) (downcast (scalar scalarType True)) md
downcast (ExtractValue _ tix tup) = L.ExtractValue (downcast tup) [fromIntegral $ sizeOfTuple - tupleIdxToInt tix - 1] md
where
sizeOfTuple
| PrimType p <- typeOf tup
, TupleType t <- p = go t
| otherwise = $internalError "downcast" "unexpected operand type to ExtractValue"
go :: TupleType t -> Int
go (PairTuple t _) = 1 + go t
go _ = 0
downcast (Load _ v p) = L.Load (downcast v) (downcast p) Nothing 0 md
downcast (Store v p x) = L.Store (downcast v) (downcast p) (downcast x) Nothing 0 md
downcast (GetElementPtr n i) = L.GetElementPtr False (downcast n) (downcast i) md
downcast (Fence a) = L.Fence (downcast a) md
downcast (CmpXchg _ v p x y a m) = L.CmpXchg (downcast v) (downcast p) (downcast x) (downcast y) (downcast a) (downcast m) md
downcast (AtomicRMW t v op p x a) = L.AtomicRMW (downcast v) (downcast (t,op)) (downcast p) (downcast x) (downcast a) md
downcast (Trunc _ t x) = L.Trunc (downcast x) (downcast t) md
downcast (FTrunc _ t x) = L.FPTrunc (downcast x) (downcast t) md
downcast (Ext t t' x)
| signed t = L.SExt (downcast x) (downcast t') md
| otherwise = L.ZExt (downcast x) (downcast t') md
downcast (FExt _ t x) = L.FPExt (downcast x) (downcast t) md
downcast (FPToInt _ t x)
| signed t = L.FPToSI (downcast x) (downcast t) md
| otherwise = L.FPToUI (downcast x) (downcast t) md
downcast (IntToFP t t' x)
| either signed signed t = L.SIToFP (downcast x) (downcast t') md
| otherwise = L.UIToFP (downcast x) (downcast t') md
downcast (BitCast t x) = L.BitCast (downcast x) (downcast t) md
downcast (PtrCast t x) = L.BitCast (downcast x) (downcast t) md
downcast (Phi t incoming) = L.Phi (downcast t) (downcast incoming) md
downcast (Select _ p x y) = L.Select (downcast p) (downcast x) (downcast y) md
downcast (Call f attrs) = L.Call tailcall L.C [] (downcast f) (downcast f) (downcast attrs) md
downcast (Cmp t p x y) =
let
fp EQ = FP.OEQ
fp NE = FP.ONE
fp LT = FP.OLT
fp LE = FP.OLE
fp GT = FP.OGT
fp GE = FP.OGE
si EQ = IP.EQ
si NE = IP.NE
si LT = IP.SLT
si LE = IP.SLE
si GT = IP.SGT
si GE = IP.SGE
ui EQ = IP.EQ
ui NE = IP.NE
ui LT = IP.ULT
ui LE = IP.ULE
ui GT = IP.UGT
ui GE = IP.UGE
in
case t of
NumScalarType FloatingNumType{} -> L.FCmp (fp p) (downcast x) (downcast y) md
_ | signed t -> L.ICmp (si p) (downcast x) (downcast y) md
| otherwise -> L.ICmp (ui p) (downcast x) (downcast y) md
instance Downcast Volatility Bool where
downcast Volatile = True
downcast NonVolatile = False
instance Downcast Synchronisation L.SynchronizationScope where
downcast SingleThread = L.SingleThread
#if MIN_VERSION_llvm_hs_pure(5,0,0)
downcast CrossThread = L.System
#else
downcast CrossThread = L.CrossThread
#endif
instance Downcast MemoryOrdering L.MemoryOrdering where
downcast Unordered = L.Unordered
downcast Monotonic = L.Monotonic
downcast Acquire = L.Acquire
downcast Release = L.Release
downcast AcquireRelease = L.AcquireRelease
downcast SequentiallyConsistent = L.SequentiallyConsistent
instance Downcast (IntegralType t, RMW.RMWOperation) LA.RMWOperation where
downcast (_, RMW.Exchange) = LA.Xchg
downcast (_, RMW.Add) = LA.Add
downcast (_, RMW.Sub) = LA.Sub
downcast (_, RMW.And) = LA.And
downcast (_, RMW.Or) = LA.Or
downcast (_, RMW.Xor) = LA.Xor
downcast (_, RMW.Nand) = LA.Nand
downcast (t, RMW.Min)
| signed t = LA.Min
| otherwise = LA.UMin
downcast (t, RMW.Max)
| signed t = LA.Max
| otherwise = LA.UMax
instance (Downcast (i a) i') => Downcast (Named i a) (L.Named i') where
downcast (x := op) = downcast x L.:= downcast op
downcast (Do op) = L.Do (downcast op)
instance Downcast a b => Downcast (L.Named a) (L.Named b) where
downcast (l L.:= r) = l L.:= downcast r
downcast (L.Do x) = L.Do (downcast x)
instance Downcast (Constant a) LC.Constant where
downcast (ScalarConstant (NumScalarType (IntegralNumType t)) x)
| IntegralDict <- integralDict t
= LC.Int (L.typeBits (downcast t)) (fromIntegral x)
downcast (ScalarConstant (NumScalarType (FloatingNumType t)) x)
= LC.Float
$ case t of
TypeFloat{} -> L.Single x
TypeDouble{} -> L.Double x
TypeCFloat{} -> L.Single $ case x of CFloat x' -> x'
TypeCDouble{} -> L.Double $ case x of CDouble x' -> x'
downcast (ScalarConstant (NonNumScalarType t) x)
= LC.Int (L.typeBits (downcast t))
$ case t of
TypeBool{} -> fromIntegral (fromEnum x)
TypeChar{} -> fromIntegral (fromEnum x)
TypeCChar{} -> fromIntegral (fromEnum x)
TypeCUChar{} -> fromIntegral (fromEnum x)
TypeCSChar{} -> fromIntegral (fromEnum x)
downcast (UndefConstant t)
= LC.Undef (downcast t)
downcast (GlobalReference t n)
= LC.GlobalReference (downcast t) (downcast n)
instance Downcast (Operand a) L.Operand where
downcast (LocalReference t n) = L.LocalReference (downcast t) (downcast n)
downcast (ConstantOperand c) = L.ConstantOperand (downcast c)
instance Downcast Metadata L.Operand where
downcast = L.MetadataOperand . downcast
instance Downcast Metadata L.Metadata where
downcast (MetadataStringOperand s) = L.MDString s
downcast (MetadataConstantOperand o) = L.MDValue (L.ConstantOperand o)
downcast (MetadataNodeOperand n) = L.MDNode (downcast n)
instance Downcast MetadataNode L.MetadataNode where
downcast (MetadataNode n) = L.MetadataNode (downcast n)
downcast (MetadataNodeReference r) = L.MetadataNodeReference r
instance Downcast (Terminator a) L.Terminator where
downcast Ret = L.Ret Nothing md
downcast (RetVal x) = L.Ret (Just (downcast x)) md
downcast (Br l) = L.Br (downcast l) md
downcast (CondBr p t f) = L.CondBr (downcast p) (downcast t) (downcast f) md
downcast (Switch p d a) = L.Switch (downcast p) (downcast d) (downcast a) md
instance Downcast Label L.Name where
downcast (Label l) = L.Name l
instance Downcast (Parameter a) L.Parameter where
downcast (Parameter t x) = L.Parameter (downcast t) (downcast x) attrs
where
attrs | PtrPrimType{} <- t = [L.NoAlias, L.NoCapture]
| otherwise = []
instance Downcast (GlobalFunction args t) L.CallableOperand where
downcast f
= let trav :: GlobalFunction args t -> ([L.Type], L.Type, L.Name)
trav (Body t n) = ([], downcast t, downcast n)
trav (Lam t _ l) = let (t',r, n) = trav l
in (downcast t : t', r, n)
(args, result, name) = trav f
ty = L.PointerType (L.FunctionType result args False) (L.AddrSpace 0)
in
Right (L.ConstantOperand (LC.GlobalReference ty name))
instance Downcast (GlobalFunction args t) [(L.Operand, [L.ParameterAttribute])] where
downcast Body{} = []
downcast (Lam _ x l) = (downcast x, []) : downcast l
instance Downcast (GlobalFunction args t) L.Global where
downcast f
= let trav :: GlobalFunction args t -> ([L.Type], L.Type, L.Name)
trav (Body t n) = ([], downcast t, downcast n)
trav (Lam t _ l) = let (t',r, n) = trav l
in (downcast t : t', r, n)
(args, result, name) = trav f
params = [ L.Parameter t (L.UnName i) [] | t <- args | i <- [0..] ]
in
L.functionDefaults { L.name = name
, L.returnType = result
, L.parameters = (params,False)
}
instance Downcast FunctionAttribute L.FunctionAttribute where
downcast NoReturn = L.NoReturn
downcast NoUnwind = L.NoUnwind
downcast ReadOnly = L.ReadOnly
downcast ReadNone = L.ReadNone
downcast AlwaysInline = L.AlwaysInline
downcast NoDuplicate = L.NoDuplicate
downcast Convergent = L.Convergent
instance Downcast GroupID L.GroupID where
downcast (GroupID n) = L.GroupID n
instance Downcast (Type a) L.Type where
downcast VoidType = L.VoidType
downcast (PrimType t) = downcast t
instance Downcast (PrimType a) L.Type where
downcast (ScalarPrimType t) = downcast t
downcast (PtrPrimType t a) = L.PointerType (downcast t) a
downcast (ArrayType n t) = L.ArrayType n (downcast t)
downcast (TupleType t) = L.StructureType False (go t)
where
go :: TupleType t -> [L.Type]
go UnitTuple = []
go (SingleTuple s) = [downcast s]
go (PairTuple ta tb) = go ta ++ go tb
instance Downcast (ScalarType a) L.Type where
downcast (NumScalarType t) = downcast t
downcast (NonNumScalarType t) = downcast t
instance Downcast (BoundedType t) L.Type where
downcast (IntegralBoundedType t) = downcast t
downcast (NonNumBoundedType t) = downcast t
instance Downcast (NumType a) L.Type where
downcast (IntegralNumType t) = downcast t
downcast (FloatingNumType t) = downcast t
instance Downcast (IntegralType a) L.Type where
downcast TypeInt{} = L.IntegerType $( [| fromIntegral (finiteBitSize (undefined :: Int)) |] )
downcast TypeInt8{} = L.IntegerType 8
downcast TypeInt16{} = L.IntegerType 16
downcast TypeInt32{} = L.IntegerType 32
downcast TypeInt64{} = L.IntegerType 64
downcast TypeWord{} = L.IntegerType $( [| fromIntegral (finiteBitSize (undefined :: Word)) |] )
downcast TypeWord8{} = L.IntegerType 8
downcast TypeWord16{} = L.IntegerType 16
downcast TypeWord32{} = L.IntegerType 32
downcast TypeWord64{} = L.IntegerType 64
downcast TypeCShort{} = L.IntegerType 16
downcast TypeCUShort{} = L.IntegerType 16
downcast TypeCInt{} = L.IntegerType 32
downcast TypeCUInt{} = L.IntegerType 32
downcast TypeCLong{} = L.IntegerType $( [| fromIntegral (finiteBitSize (undefined :: CLong)) |] )
downcast TypeCULong{} = L.IntegerType $( [| fromIntegral (finiteBitSize (undefined :: CULong)) |] )
downcast TypeCLLong{} = L.IntegerType 64
downcast TypeCULLong{} = L.IntegerType 64
instance Downcast (FloatingType a) L.Type where
downcast TypeFloat{} = L.FloatingPointType L.FloatFP
downcast TypeDouble{} = L.FloatingPointType L.DoubleFP
downcast TypeCFloat{} = L.FloatingPointType L.FloatFP
downcast TypeCDouble{} = L.FloatingPointType L.DoubleFP
instance Downcast (NonNumType a) L.Type where
downcast TypeBool{} = L.IntegerType 1
downcast TypeChar{} = L.IntegerType 32
downcast TypeCChar{} = L.IntegerType 8
downcast TypeCSChar{} = L.IntegerType 8
downcast TypeCUChar{} = L.IntegerType 8