{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fwarn-incomplete-patterns #-} -- | Primitive Feldspar expressions module Feldspar.Primitive.Representation where import Data.Array import Data.Bits import Data.Complex import Data.Int import Data.Typeable import Data.Word import Data.Constraint (Dict (..)) import Language.Embedded.Expression import Language.Embedded.Imperative.CMD (IArr (..)) import Language.Syntactic import Language.Syntactic.TH import Language.Syntactic.Functional -------------------------------------------------------------------------------- -- * Types -------------------------------------------------------------------------------- type Length = Word32 type Index = Word32 -- | Representation of primitive supported types data PrimTypeRep a where BoolT :: PrimTypeRep Bool Int8T :: PrimTypeRep Int8 Int16T :: PrimTypeRep Int16 Int32T :: PrimTypeRep Int32 Int64T :: PrimTypeRep Int64 Word8T :: PrimTypeRep Word8 Word16T :: PrimTypeRep Word16 Word32T :: PrimTypeRep Word32 Word64T :: PrimTypeRep Word64 FloatT :: PrimTypeRep Float DoubleT :: PrimTypeRep Double ComplexFloatT :: PrimTypeRep (Complex Float) ComplexDoubleT :: PrimTypeRep (Complex Double) data IntTypeRep a where Int8Type :: IntTypeRep Int8 Int16Type :: IntTypeRep Int16 Int32Type :: IntTypeRep Int32 Int64Type :: IntTypeRep Int64 data WordTypeRep a where Word8Type :: WordTypeRep Word8 Word16Type :: WordTypeRep Word16 Word32Type :: WordTypeRep Word32 Word64Type :: WordTypeRep Word64 data IntWordTypeRep a where IntType :: IntTypeRep a -> IntWordTypeRep a WordType :: WordTypeRep a -> IntWordTypeRep a data FloatingTypeRep a where FloatType :: FloatingTypeRep Float DoubleType :: FloatingTypeRep Double data ComplexTypeRep a where ComplexFloatType :: ComplexTypeRep (Complex Float) ComplexDoubleType :: ComplexTypeRep (Complex Double) -- | A different view of 'PrimTypeRep' that allows matching on similar types data PrimTypeView a where PrimTypeBool :: PrimTypeView Bool PrimTypeIntWord :: IntWordTypeRep a -> PrimTypeView a PrimTypeFloating :: FloatingTypeRep a -> PrimTypeView a PrimTypeComplex :: ComplexTypeRep a -> PrimTypeView a deriving instance Show (PrimTypeRep a) deriving instance Show (IntTypeRep a) deriving instance Show (WordTypeRep a) deriving instance Show (IntWordTypeRep a) deriving instance Show (FloatingTypeRep a) deriving instance Show (ComplexTypeRep a) deriving instance Show (PrimTypeView a) viewPrimTypeRep :: PrimTypeRep a -> PrimTypeView a viewPrimTypeRep BoolT = PrimTypeBool viewPrimTypeRep Int8T = PrimTypeIntWord $ IntType $ Int8Type viewPrimTypeRep Int16T = PrimTypeIntWord $ IntType $ Int16Type viewPrimTypeRep Int32T = PrimTypeIntWord $ IntType $ Int32Type viewPrimTypeRep Int64T = PrimTypeIntWord $ IntType $ Int64Type viewPrimTypeRep Word8T = PrimTypeIntWord $ WordType $ Word8Type viewPrimTypeRep Word16T = PrimTypeIntWord $ WordType $ Word16Type viewPrimTypeRep Word32T = PrimTypeIntWord $ WordType $ Word32Type viewPrimTypeRep Word64T = PrimTypeIntWord $ WordType $ Word64Type viewPrimTypeRep FloatT = PrimTypeFloating FloatType viewPrimTypeRep DoubleT = PrimTypeFloating DoubleType viewPrimTypeRep ComplexFloatT = PrimTypeComplex ComplexFloatType viewPrimTypeRep ComplexDoubleT = PrimTypeComplex ComplexDoubleType unviewPrimTypeRep :: PrimTypeView a -> PrimTypeRep a unviewPrimTypeRep PrimTypeBool = BoolT unviewPrimTypeRep (PrimTypeIntWord (IntType Int8Type)) = Int8T unviewPrimTypeRep (PrimTypeIntWord (IntType Int16Type)) = Int16T unviewPrimTypeRep (PrimTypeIntWord (IntType Int32Type)) = Int32T unviewPrimTypeRep (PrimTypeIntWord (IntType Int64Type)) = Int64T unviewPrimTypeRep (PrimTypeIntWord (WordType Word8Type)) = Word8T unviewPrimTypeRep (PrimTypeIntWord (WordType Word16Type)) = Word16T unviewPrimTypeRep (PrimTypeIntWord (WordType Word32Type)) = Word32T unviewPrimTypeRep (PrimTypeIntWord (WordType Word64Type)) = Word64T unviewPrimTypeRep (PrimTypeFloating FloatType) = FloatT unviewPrimTypeRep (PrimTypeFloating DoubleType) = DoubleT unviewPrimTypeRep (PrimTypeComplex ComplexFloatType) = ComplexFloatT unviewPrimTypeRep (PrimTypeComplex ComplexDoubleType) = ComplexDoubleT primTypeIntWidth :: PrimTypeRep a -> Maybe Int primTypeIntWidth Int8T = Just 8 primTypeIntWidth Int16T = Just 16 primTypeIntWidth Int32T = Just 32 primTypeIntWidth Int64T = Just 64 primTypeIntWidth Word8T = Just 8 primTypeIntWidth Word16T = Just 16 primTypeIntWidth Word32T = Just 32 primTypeIntWidth Word64T = Just 64 primTypeIntWidth _ = Nothing -- | Primitive supported types class (Eq a, Show a, Typeable a) => PrimType' a where -- | Reify a primitive type primTypeRep :: PrimTypeRep a instance PrimType' Bool where primTypeRep = BoolT instance PrimType' Int8 where primTypeRep = Int8T instance PrimType' Int16 where primTypeRep = Int16T instance PrimType' Int32 where primTypeRep = Int32T instance PrimType' Int64 where primTypeRep = Int64T instance PrimType' Word8 where primTypeRep = Word8T instance PrimType' Word16 where primTypeRep = Word16T instance PrimType' Word32 where primTypeRep = Word32T instance PrimType' Word64 where primTypeRep = Word64T instance PrimType' Float where primTypeRep = FloatT instance PrimType' Double where primTypeRep = DoubleT instance PrimType' (Complex Float) where primTypeRep = ComplexFloatT instance PrimType' (Complex Double) where primTypeRep = ComplexDoubleT -- | Convenience function; like 'primTypeRep' but with an extra argument to -- constrain the type parameter. The extra argument is ignored. primTypeOf :: PrimType' a => a -> PrimTypeRep a primTypeOf _ = primTypeRep -- | Check whether two type representations are equal primTypeEq :: PrimTypeRep a -> PrimTypeRep b -> Maybe (Dict (a ~ b)) primTypeEq BoolT BoolT = Just Dict primTypeEq Int8T Int8T = Just Dict primTypeEq Int16T Int16T = Just Dict primTypeEq Int32T Int32T = Just Dict primTypeEq Int64T Int64T = Just Dict primTypeEq Word8T Word8T = Just Dict primTypeEq Word16T Word16T = Just Dict primTypeEq Word32T Word32T = Just Dict primTypeEq Word64T Word64T = Just Dict primTypeEq FloatT FloatT = Just Dict primTypeEq DoubleT DoubleT = Just Dict primTypeEq ComplexFloatT ComplexFloatT = Just Dict primTypeEq ComplexDoubleT ComplexDoubleT = Just Dict primTypeEq _ _ = Nothing -- | Reflect a 'PrimTypeRep' to a 'PrimType'' constraint witPrimType :: PrimTypeRep a -> Dict (PrimType' a) witPrimType BoolT = Dict witPrimType Int8T = Dict witPrimType Int16T = Dict witPrimType Int32T = Dict witPrimType Int64T = Dict witPrimType Word8T = Dict witPrimType Word16T = Dict witPrimType Word32T = Dict witPrimType Word64T = Dict witPrimType FloatT = Dict witPrimType DoubleT = Dict witPrimType ComplexFloatT = Dict witPrimType ComplexDoubleT = Dict -------------------------------------------------------------------------------- -- * Expressions -------------------------------------------------------------------------------- -- | Primitive operations data Primitive sig where FreeVar :: PrimType' a => String -> Primitive (Full a) Lit :: (Eq a, Show a) => a -> Primitive (Full a) Add :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a) Sub :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a) Mul :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a) Neg :: (Num a, PrimType' a) => Primitive (a :-> Full a) Abs :: (Num a, PrimType' a) => Primitive (a :-> Full a) Sign :: (Num a, PrimType' a) => Primitive (a :-> Full a) Quot :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a) Rem :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a) Div :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a) Mod :: (Integral a, PrimType' a) => Primitive (a :-> a :-> Full a) FDiv :: (Fractional a, PrimType' a) => Primitive (a :-> a :-> Full a) Pi :: (Floating a, PrimType' a) => Primitive (Full a) Exp :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Log :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Sqrt :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Pow :: (Floating a, PrimType' a) => Primitive (a :-> a :-> Full a) Sin :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Cos :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Tan :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Asin :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Acos :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Atan :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Sinh :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Cosh :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Tanh :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Asinh :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Acosh :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Atanh :: (Floating a, PrimType' a) => Primitive (a :-> Full a) Complex :: (Num a, PrimType' a, PrimType' (Complex a)) => Primitive (a :-> a :-> Full (Complex a)) Polar :: (Floating a, PrimType' a, PrimType' (Complex a)) => Primitive (a :-> a :-> Full (Complex a)) Real :: (PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a) Imag :: (PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a) Magnitude :: (RealFloat a, PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a) Phase :: (RealFloat a, PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a) Conjugate :: (Num a, PrimType' (Complex a)) => Primitive (Complex a :-> Full (Complex a)) I2N :: (Integral a, Num b, PrimType' a, PrimType' b) => Primitive (a :-> Full b) I2B :: (Integral a, PrimType' a) => Primitive (a :-> Full Bool) B2I :: (Integral a, PrimType' a) => Primitive (Bool :-> Full a) Round :: (RealFrac a, Num b, PrimType' a, PrimType' b) => Primitive (a :-> Full b) Not :: Primitive (Bool :-> Full Bool) And :: Primitive (Bool :-> Bool :-> Full Bool) Or :: Primitive (Bool :-> Bool :-> Full Bool) Eq :: (Eq a, PrimType' a) => Primitive (a :-> a :-> Full Bool) NEq :: (Eq a, PrimType' a) => Primitive (a :-> a :-> Full Bool) Lt :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool) Gt :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool) Le :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool) Ge :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool) BitAnd :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a) BitOr :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a) BitXor :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a) BitCompl :: (Bits a, PrimType' a) => Primitive (a :-> Full a) ShiftL :: (Bits a, PrimType' a, Integral b, PrimType' b) => Primitive (a :-> b :-> Full a) ShiftR :: (Bits a, PrimType' a, Integral b, PrimType' b) => Primitive (a :-> b :-> Full a) ArrIx :: PrimType' a => IArr Index a -> Primitive (Index :-> Full a) Cond :: Primitive (Bool :-> a :-> a :-> Full a) deriving instance Show (Primitive a) -- The `PrimType'` constraints on certain symbols require an explanation: The -- constraints are actually not needed for anything in the modules in -- `Feldspar.Primitive.*`, but they are needed by `Feldspar.Run.Compile`. They -- guarantee to the compiler that these symbols don't operate on tuples. -- -- It would seem more consistent to have a `PrimType'` constraint on all -- polymorphic symbols. However, this would prevent using some symbols for -- non-primitive types in `Feldspar.Representation`. For example, `Lit` and -- `Cond` are used `Feldspar.Representation`, and there they can also be used -- for tuple types. The current design was chosen because it "just works". deriveSymbol ''Primitive instance Render Primitive where renderSym (FreeVar v) = v renderSym (Lit a) = show a renderSym (ArrIx (IArrComp arr)) = "ArrIx " ++ arr renderSym (ArrIx _) = "ArrIx ..." renderSym s = show s renderArgs = renderArgsSmart instance StringTree Primitive instance Eval Primitive where evalSym (FreeVar v) = error $ "evaluating free variable " ++ show v evalSym (Lit a) = a evalSym Add = (+) evalSym Sub = (-) evalSym Mul = (*) evalSym Neg = negate evalSym Abs = abs evalSym Sign = signum evalSym Quot = quot evalSym Rem = rem evalSym Div = div evalSym Mod = mod evalSym FDiv = (/) evalSym Pi = pi evalSym Exp = exp evalSym Log = log evalSym Sqrt = sqrt evalSym Pow = (**) evalSym Sin = sin evalSym Cos = cos evalSym Tan = tan evalSym Asin = asin evalSym Acos = acos evalSym Atan = atan evalSym Sinh = sinh evalSym Cosh = cosh evalSym Tanh = tanh evalSym Asinh = asinh evalSym Acosh = acosh evalSym Atanh = atanh evalSym Complex = (:+) evalSym Polar = mkPolar evalSym Real = realPart evalSym Imag = imagPart evalSym Magnitude = magnitude evalSym Phase = phase evalSym Conjugate = conjugate evalSym I2N = fromInteger . toInteger evalSym I2B = (/=0) evalSym B2I = \a -> if a then 1 else 0 evalSym Round = fromInteger . round evalSym Not = not evalSym And = (&&) evalSym Or = (||) evalSym Eq = (==) evalSym NEq = (/=) evalSym Lt = (<) evalSym Gt = (>) evalSym Le = (<=) evalSym Ge = (>=) evalSym BitAnd = (.&.) evalSym BitOr = (.|.) evalSym BitXor = xor evalSym BitCompl = complement evalSym ShiftL = \a -> shiftL a . fromIntegral evalSym ShiftR = \a -> shiftR a . fromIntegral evalSym Cond = \c t f -> if c then t else f evalSym (ArrIx (IArrRun arr)) = \i -> if ih then error $ "ArrIx: index " ++ show (toInteger i) ++ " out of bounds " ++ show (toInteger l, toInteger h) else arr!i where (l,h) = bounds arr evalSym (ArrIx (IArrComp arr)) = error $ "evaluating symbolic array " ++ arr -- | Assumes no occurrences of 'FreeVar' and concrete representation of arrays instance EvalEnv Primitive env instance Equality Primitive where equal s1 s2 = show s1 == show s2 -- NOTE: It is very important not to use `renderSym` here, because it will -- render all concrete arrays equal. -- This method uses string comparison. It is probably slightly more -- efficient to pattern match directly on the constructors. Unfortunately -- `deriveEquality ''Primitive` doesn't work, so it gets quite tedious to -- write it with pattern matching. type PrimDomain = Primitive :&: PrimTypeRep -- | Primitive expressions newtype Prim a = Prim { unPrim :: ASTF PrimDomain a } instance Syntactic (Prim a) where type Domain (Prim a) = PrimDomain type Internal (Prim a) = a desugar = unPrim sugar = Prim -- | Evaluate a closed expression evalPrim :: Prim a -> a evalPrim = go . unPrim where go :: AST PrimDomain sig -> Denotation sig go (Sym (s :&: _)) = evalSym s go (f :$ a) = go f $ go a sugarSymPrim :: ( Signature sig , fi ~ SmartFun dom sig , sig ~ SmartSig fi , dom ~ SmartSym fi , dom ~ PrimDomain , SyntacticN f fi , sub :<: Primitive , PrimType' (DenResult sig) ) => sub sig -> f sugarSymPrim = sugarSymDecor primTypeRep instance FreeExp Prim where type FreePred Prim = PrimType' constExp = sugarSymPrim . Lit varExp = sugarSymPrim . FreeVar instance EvalExp Prim where evalExp = evalPrim -------------------------------------------------------------------------------- -- * Interface -------------------------------------------------------------------------------- instance (Num a, PrimType' a) => Num (Prim a) where fromInteger = constExp . fromInteger (+) = sugarSymPrim Add (-) = sugarSymPrim Sub (*) = sugarSymPrim Mul negate = sugarSymPrim Neg abs = sugarSymPrim Abs signum = sugarSymPrim Sign