module Feldspar.Primitive.Representation where
import Data.Array
import Data.Bits
import Data.Complex
import Data.Int
import Data.Typeable
import Data.Word
import Data.Constraint (Dict (..))
import Language.Embedded.Expression
import Language.Embedded.Imperative.CMD (IArr (..))
import Language.Syntactic
import Language.Syntactic.TH
import Language.Syntactic.Functional
type Length = Word32
type Index = Word32
data PrimTypeRep a
where
BoolT :: PrimTypeRep Bool
Int8T :: PrimTypeRep Int8
Int16T :: PrimTypeRep Int16
Int32T :: PrimTypeRep Int32
Int64T :: PrimTypeRep Int64
Word8T :: PrimTypeRep Word8
Word16T :: PrimTypeRep Word16
Word32T :: PrimTypeRep Word32
Word64T :: PrimTypeRep Word64
FloatT :: PrimTypeRep Float
DoubleT :: PrimTypeRep Double
ComplexFloatT :: PrimTypeRep (Complex Float)
ComplexDoubleT :: PrimTypeRep (Complex Double)
data IntTypeRep a
where
Int8Type :: IntTypeRep Int8
Int16Type :: IntTypeRep Int16
Int32Type :: IntTypeRep Int32
Int64Type :: IntTypeRep Int64
data WordTypeRep a
where
Word8Type :: WordTypeRep Word8
Word16Type :: WordTypeRep Word16
Word32Type :: WordTypeRep Word32
Word64Type :: WordTypeRep Word64
data IntWordTypeRep a
where
IntType :: IntTypeRep a -> IntWordTypeRep a
WordType :: WordTypeRep a -> IntWordTypeRep a
data FloatingTypeRep a
where
FloatType :: FloatingTypeRep Float
DoubleType :: FloatingTypeRep Double
data ComplexTypeRep a
where
ComplexFloatType :: ComplexTypeRep (Complex Float)
ComplexDoubleType :: ComplexTypeRep (Complex Double)
data PrimTypeView a
where
PrimTypeBool :: PrimTypeView Bool
PrimTypeIntWord :: IntWordTypeRep a -> PrimTypeView a
PrimTypeFloating :: FloatingTypeRep a -> PrimTypeView a
PrimTypeComplex :: ComplexTypeRep a -> PrimTypeView a
deriving instance Show (PrimTypeRep a)
deriving instance Show (IntTypeRep a)
deriving instance Show (WordTypeRep a)
deriving instance Show (IntWordTypeRep a)
deriving instance Show (FloatingTypeRep a)
deriving instance Show (ComplexTypeRep a)
deriving instance Show (PrimTypeView a)
viewPrimTypeRep :: PrimTypeRep a -> PrimTypeView a
viewPrimTypeRep BoolT = PrimTypeBool
viewPrimTypeRep Int8T = PrimTypeIntWord $ IntType $ Int8Type
viewPrimTypeRep Int16T = PrimTypeIntWord $ IntType $ Int16Type
viewPrimTypeRep Int32T = PrimTypeIntWord $ IntType $ Int32Type
viewPrimTypeRep Int64T = PrimTypeIntWord $ IntType $ Int64Type
viewPrimTypeRep Word8T = PrimTypeIntWord $ WordType $ Word8Type
viewPrimTypeRep Word16T = PrimTypeIntWord $ WordType $ Word16Type
viewPrimTypeRep Word32T = PrimTypeIntWord $ WordType $ Word32Type
viewPrimTypeRep Word64T = PrimTypeIntWord $ WordType $ Word64Type
viewPrimTypeRep FloatT = PrimTypeFloating FloatType
viewPrimTypeRep DoubleT = PrimTypeFloating DoubleType
viewPrimTypeRep ComplexFloatT = PrimTypeComplex ComplexFloatType
viewPrimTypeRep ComplexDoubleT = PrimTypeComplex ComplexDoubleType
unviewPrimTypeRep :: PrimTypeView a -> PrimTypeRep a
unviewPrimTypeRep PrimTypeBool = BoolT
unviewPrimTypeRep (PrimTypeIntWord (IntType Int8Type)) = Int8T
unviewPrimTypeRep (PrimTypeIntWord (IntType Int16Type)) = Int16T
unviewPrimTypeRep (PrimTypeIntWord (IntType Int32Type)) = Int32T
unviewPrimTypeRep (PrimTypeIntWord (IntType Int64Type)) = Int64T
unviewPrimTypeRep (PrimTypeIntWord (WordType Word8Type)) = Word8T
unviewPrimTypeRep (PrimTypeIntWord (WordType Word16Type)) = Word16T
unviewPrimTypeRep (PrimTypeIntWord (WordType Word32Type)) = Word32T
unviewPrimTypeRep (PrimTypeIntWord (WordType Word64Type)) = Word64T
unviewPrimTypeRep (PrimTypeFloating FloatType) = FloatT
unviewPrimTypeRep (PrimTypeFloating DoubleType) = DoubleT
unviewPrimTypeRep (PrimTypeComplex ComplexFloatType) = ComplexFloatT
unviewPrimTypeRep (PrimTypeComplex ComplexDoubleType) = ComplexDoubleT
primTypeIntWidth :: PrimTypeRep a -> Maybe Int
primTypeIntWidth Int8T = Just 8
primTypeIntWidth Int16T = Just 16
primTypeIntWidth Int32T = Just 32
primTypeIntWidth Int64T = Just 64
primTypeIntWidth Word8T = Just 8
primTypeIntWidth Word16T = Just 16
primTypeIntWidth Word32T = Just 32
primTypeIntWidth Word64T = Just 64
primTypeIntWidth _ = Nothing
class (Eq a, Show a, Typeable a) => PrimType' a
where
primTypeRep :: PrimTypeRep a
instance PrimType' Bool where primTypeRep = BoolT
instance PrimType' Int8 where primTypeRep = Int8T
instance PrimType' Int16 where primTypeRep = Int16T
instance PrimType' Int32 where primTypeRep = Int32T
instance PrimType' Int64 where primTypeRep = Int64T
instance PrimType' Word8 where primTypeRep = Word8T
instance PrimType' Word16 where primTypeRep = Word16T
instance PrimType' Word32 where primTypeRep = Word32T
instance PrimType' Word64 where primTypeRep = Word64T
instance PrimType' Float where primTypeRep = FloatT
instance PrimType' Double where primTypeRep = DoubleT
instance PrimType' (Complex Float) where primTypeRep = ComplexFloatT
instance PrimType' (Complex Double) where primTypeRep = ComplexDoubleT
primTypeOf :: PrimType' a => a -> PrimTypeRep a
primTypeOf _ = primTypeRep
primTypeEq :: PrimTypeRep a -> PrimTypeRep b -> Maybe (Dict (a ~ b))
primTypeEq BoolT BoolT = Just Dict
primTypeEq Int8T Int8T = Just Dict
primTypeEq Int16T Int16T = Just Dict
primTypeEq Int32T Int32T = Just Dict
primTypeEq Int64T Int64T = Just Dict
primTypeEq Word8T Word8T = Just Dict
primTypeEq Word16T Word16T = Just Dict
primTypeEq Word32T Word32T = Just Dict
primTypeEq Word64T Word64T = Just Dict
primTypeEq FloatT FloatT = Just Dict
primTypeEq DoubleT DoubleT = Just Dict
primTypeEq ComplexFloatT ComplexFloatT = Just Dict
primTypeEq ComplexDoubleT ComplexDoubleT = Just Dict
primTypeEq _ _ = Nothing
witPrimType :: PrimTypeRep a -> Dict (PrimType' a)
witPrimType BoolT = Dict
witPrimType Int8T = Dict
witPrimType Int16T = Dict
witPrimType Int32T = Dict
witPrimType Int64T = Dict
witPrimType Word8T = Dict
witPrimType Word16T = Dict
witPrimType Word32T = Dict
witPrimType Word64T = Dict
witPrimType FloatT = Dict
witPrimType DoubleT = Dict
witPrimType ComplexFloatT = Dict
witPrimType ComplexDoubleT = Dict
data Primitive sig
where
FreeVar :: PrimType' a => String -> Primitive (Full a)
Lit :: (Eq a, Show a) => a -> Primitive (Full a)
Add :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a)
Sub :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a)
Mul :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a)
Neg :: (Num a, PrimType' a) => Primitive (a :-> Full a)
Abs :: (Num a, PrimType' a) => Primitive (a :-> Full a)
Sign :: (Num a, PrimType' a) => Primitive (a :-> Full a)
Quot :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a)
Rem :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a)
Div :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a)
Mod :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a)
FDiv :: (Fractional a, PrimType' a) => Primitive (a :-> a :-> Full a)
Pi :: (Floating a, PrimType' a) => Primitive (Full a)
Exp :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Log :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Sqrt :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Pow :: (Floating a, PrimType' a) => Primitive (a :-> a :-> Full a)
Sin :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Cos :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Tan :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Asin :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Acos :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Atan :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Sinh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Cosh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Tanh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Asinh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Acosh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Atanh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
Complex :: (Num a, PrimType' a, PrimType' (Complex a)) => Primitive (a :-> a :-> Full (Complex a))
Polar :: (Floating a, PrimType' a, PrimType' (Complex a)) => Primitive (a :-> a :-> Full (Complex a))
Real :: (PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a)
Imag :: (PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a)
Magnitude :: (RealFloat a, PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a)
Phase :: (RealFloat a, PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a)
Conjugate :: (Num a, PrimType' (Complex a)) => Primitive (Complex a :-> Full (Complex a))
I2N :: (Integral a, Num b, PrimType' a, PrimType' b) => Primitive (a :-> Full b)
I2B :: (Integral a, PrimType' a) => Primitive (a :-> Full Bool)
B2I :: (Integral a, PrimType' a) => Primitive (Bool :-> Full a)
Round :: (RealFrac a, Num b, PrimType' a, PrimType' b) => Primitive (a :-> Full b)
Not :: Primitive (Bool :-> Full Bool)
And :: Primitive (Bool :-> Bool :-> Full Bool)
Or :: Primitive (Bool :-> Bool :-> Full Bool)
Eq :: (Eq a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
NEq :: (Eq a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
Lt :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
Gt :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
Le :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
Ge :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
BitAnd :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a)
BitOr :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a)
BitXor :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a)
BitCompl :: (Bits a, PrimType' a) => Primitive (a :-> Full a)
ShiftL :: (Bits a, PrimType' a, Integral b, PrimType' b) => Primitive (a :-> b :-> Full a)
ShiftR :: (Bits a, PrimType' a, Integral b, PrimType' b) => Primitive (a :-> b :-> Full a)
ArrIx :: PrimType' a => IArr Index a -> Primitive (Index :-> Full a)
Cond :: Primitive (Bool :-> a :-> a :-> Full a)
deriving instance Show (Primitive a)
deriveSymbol ''Primitive
instance Render Primitive
where
renderSym (FreeVar v) = v
renderSym (Lit a) = show a
renderSym (ArrIx (IArrComp arr)) = "ArrIx " ++ arr
renderSym (ArrIx _) = "ArrIx ..."
renderSym s = show s
renderArgs = renderArgsSmart
instance StringTree Primitive
instance Eval Primitive
where
evalSym (FreeVar v) = error $ "evaluating free variable " ++ show v
evalSym (Lit a) = a
evalSym Add = (+)
evalSym Sub = ()
evalSym Mul = (*)
evalSym Neg = negate
evalSym Abs = abs
evalSym Sign = signum
evalSym Quot = quot
evalSym Rem = rem
evalSym Div = div
evalSym Mod = mod
evalSym FDiv = (/)
evalSym Pi = pi
evalSym Exp = exp
evalSym Log = log
evalSym Sqrt = sqrt
evalSym Pow = (**)
evalSym Sin = sin
evalSym Cos = cos
evalSym Tan = tan
evalSym Asin = asin
evalSym Acos = acos
evalSym Atan = atan
evalSym Sinh = sinh
evalSym Cosh = cosh
evalSym Tanh = tanh
evalSym Asinh = asinh
evalSym Acosh = acosh
evalSym Atanh = atanh
evalSym Complex = (:+)
evalSym Polar = mkPolar
evalSym Real = realPart
evalSym Imag = imagPart
evalSym Magnitude = magnitude
evalSym Phase = phase
evalSym Conjugate = conjugate
evalSym I2N = fromInteger . toInteger
evalSym I2B = (/=0)
evalSym B2I = \a -> if a then 1 else 0
evalSym Round = fromInteger . round
evalSym Not = not
evalSym And = (&&)
evalSym Or = (||)
evalSym Eq = (==)
evalSym NEq = (/=)
evalSym Lt = (<)
evalSym Gt = (>)
evalSym Le = (<=)
evalSym Ge = (>=)
evalSym BitAnd = (.&.)
evalSym BitOr = (.|.)
evalSym BitXor = xor
evalSym BitCompl = complement
evalSym ShiftL = \a -> shiftL a . fromIntegral
evalSym ShiftR = \a -> shiftR a . fromIntegral
evalSym Cond = \c t f -> if c then t else f
evalSym (ArrIx (IArrRun arr)) = \i ->
if i<l || i>h
then error $ "ArrIx: index "
++ show (toInteger i)
++ " out of bounds "
++ show (toInteger l, toInteger h)
else arr!i
where
(l,h) = bounds arr
evalSym (ArrIx (IArrComp arr)) = error $ "evaluating symbolic array " ++ arr
instance EvalEnv Primitive env
instance Equality Primitive
where
equal s1 s2 = show s1 == show s2
type PrimDomain = Primitive :&: PrimTypeRep
newtype Prim a = Prim { unPrim :: ASTF PrimDomain a }
instance Syntactic (Prim a)
where
type Domain (Prim a) = PrimDomain
type Internal (Prim a) = a
desugar = unPrim
sugar = Prim
evalPrim :: Prim a -> a
evalPrim = go . unPrim
where
go :: AST PrimDomain sig -> Denotation sig
go (Sym (s :&: _)) = evalSym s
go (f :$ a) = go f $ go a
sugarSymPrim
:: ( Signature sig
, fi ~ SmartFun dom sig
, sig ~ SmartSig fi
, dom ~ SmartSym fi
, dom ~ PrimDomain
, SyntacticN f fi
, sub :<: Primitive
, PrimType' (DenResult sig)
)
=> sub sig -> f
sugarSymPrim = sugarSymDecor primTypeRep
instance FreeExp Prim
where
type FreePred Prim = PrimType'
constExp = sugarSymPrim . Lit
varExp = sugarSymPrim . FreeVar
instance EvalExp Prim
where
evalExp = evalPrim
instance (Num a, PrimType' a) => Num (Prim a)
where
fromInteger = constExp . fromInteger
(+) = sugarSymPrim Add
() = sugarSymPrim Sub
(*) = sugarSymPrim Mul
negate = sugarSymPrim Neg
abs = sugarSymPrim Abs
signum = sugarSymPrim Sign