module Language.Atom.Expressions
(
E (..)
, V (..)
, UE (..)
, UV (..)
, UVLocality (..)
, A (..)
, UA (..)
, Expr (..)
, Expression (..)
, Variable (..)
, Type (..)
, Const (..)
, Width (..)
, TypeOf (..)
, bytes
, ue
, uv
, ueUpstream
, nearestUVs
, arrayIndices
, NumE
, IntegralE
, FloatingE
, EqE
, OrdE
, true
, false
, value
, not_
, (&&.)
, (||.)
, and_
, or_
, any_
, all_
, imply
, (==.)
, (/=.)
, (<.)
, (<=.)
, (>.)
, (>=.)
, min_
, minimum_
, max_
, maximum_
, limit
, div_
, mod_
, mux
, (!)
, (!.)
, ubool
, unot
, uand
, uor
, ueq
, umux
) where
import Data.Bits
import Data.Function (on)
import Data.Int
import Data.List
import Data.Ratio
import Data.Word
infixl 9 !, !.
infix 4 ==., /=., <., <=., >., >=.
infixl 3 &&.
infixl 2 ||.
data Type
= Bool
| Int8
| Int16
| Int32
| Int64
| Word8
| Word16
| Word32
| Word64
| Float
| Double
deriving (Show, Read, Eq, Ord)
data Const
= CBool Bool
| CInt8 Int8
| CInt16 Int16
| CInt32 Int32
| CInt64 Int64
| CWord8 Word8
| CWord16 Word16
| CWord32 Word32
| CWord64 Word64
| CFloat Float
| CDouble Double
deriving (Eq, Ord)
instance Show Const where
show c = case c of
CBool True -> "1"
CBool False -> "0"
CInt8 c -> show c
CInt16 c -> show c
CInt32 c -> show c
CInt64 c -> show c
CWord8 c -> show c
CWord16 c -> show c
CWord32 c -> show c
CWord64 c -> show c
CFloat c -> show c
CDouble c -> show c
data Expression
= EBool (E Bool)
| EInt8 (E Int8)
| EInt16 (E Int16)
| EInt32 (E Int32)
| EInt64 (E Int64)
| EWord8 (E Word8)
| EWord16 (E Word16)
| EWord32 (E Word32)
| EWord64 (E Word64)
| EFloat (E Float)
| EDouble (E Double)
data Variable
= VBool (V Bool)
| VInt8 (V Int8)
| VInt16 (V Int16)
| VInt32 (V Int32)
| VInt64 (V Int64)
| VWord8 (V Word8)
| VWord16 (V Word16)
| VWord32 (V Word32)
| VWord64 (V Word64)
| VFloat (V Float)
| VDouble (V Double) deriving Eq
data V a = V UV deriving Eq
data UV = UV UVLocality deriving (Show, Eq, Ord)
data UVLocality = Array UA UE | External String Type deriving (Show, Eq, Ord)
data A a = A UA deriving Eq
data UA = UA Int String [Const] deriving (Show, Eq, Ord)
data E a where
VRef :: V a -> E a
Const :: a -> E a
Cast :: (NumE a, NumE b) => E a -> E b
Add :: NumE a => E a -> E a -> E a
Sub :: NumE a => E a -> E a -> E a
Mul :: NumE a => E a -> E a -> E a
Div :: NumE a => E a -> E a -> E a
Mod :: IntegralE a => E a -> E a -> E a
Not :: E Bool -> E Bool
And :: E Bool -> E Bool -> E Bool
BWNot :: IntegralE a => E a -> E a
BWAnd :: IntegralE a => E a -> E a -> E a
BWOr :: IntegralE a => E a -> E a -> E a
Shift :: IntegralE a => E a -> Int -> E a
Eq :: (EqE a, OrdE a) => E a -> E a -> E Bool
Lt :: OrdE a => E a -> E a -> E Bool
Mux :: E Bool -> E a -> E a -> E a
F2B :: E Float -> E Word32
D2B :: E Double -> E Word64
B2F :: E Word32 -> E Float
B2D :: E Word64 -> E Double
instance Show (E a) where
show _ = error "Show (E a) not implemented"
instance Expr a => Eq (E a) where
(==) = (==) `on` ue
data UE
= UVRef UV
| UConst Const
| UCast Type UE
| UAdd UE UE
| USub UE UE
| UMul UE UE
| UDiv UE UE
| UMod UE UE
| UNot UE
| UAnd [UE]
| UBWNot UE
| UBWAnd UE UE
| UBWOr UE UE
| UShift UE Int
| UEq UE UE
| ULt UE UE
| UMux UE UE UE
| UF2B UE
| UD2B UE
| UB2F UE
| UB2D UE
deriving (Show, Eq, Ord)
class Width a where
width :: a -> Int
bytes :: Width a => a -> Int
bytes a = div (width a) 8 + if mod (width a) 8 == 0 then 0 else 1
instance Width Type where
width t = case t of
Bool -> 1
Int8 -> 8
Int16 -> 16
Int32 -> 32
Int64 -> 64
Word8 -> 8
Word16 -> 16
Word32 -> 32
Word64 -> 64
Float -> 32
Double -> 64
instance Width Const where width = width . typeOf
instance Expr a => Width (E a) where width = width . typeOf
instance Expr a => Width (V a) where width = width . typeOf
instance Width UE where width = width . typeOf
instance Width UV where width = width . typeOf
class TypeOf a where typeOf :: a -> Type
instance TypeOf Const where
typeOf a = case a of
CBool _ -> Bool
CInt8 _ -> Int8
CInt16 _ -> Int16
CInt32 _ -> Int32
CInt64 _ -> Int64
CWord8 _ -> Word8
CWord16 _ -> Word16
CWord32 _ -> Word32
CWord64 _ -> Word64
CFloat _ -> Float
CDouble _ -> Double
instance TypeOf UV where
typeOf a = case a of
UV (Array a _) -> typeOf a
UV (External _ t) -> t
instance TypeOf (V a) where
typeOf (V uv) = typeOf uv
instance TypeOf UA where
typeOf (UA _ _ c) = typeOf $ head c
instance TypeOf (A a) where
typeOf (A ua) = typeOf ua
instance TypeOf UE where
typeOf t = case t of
UVRef uvar -> typeOf uvar
UCast t _ -> t
UConst c -> typeOf c
UAdd a _ -> typeOf a
USub a _ -> typeOf a
UMul a _ -> typeOf a
UDiv a _ -> typeOf a
UMod a _ -> typeOf a
UNot _ -> Bool
UAnd _ -> Bool
UBWNot a -> typeOf a
UBWAnd a _ -> typeOf a
UBWOr a _ -> typeOf a
UShift a _ -> typeOf a
UEq _ _ -> Bool
ULt _ _ -> Bool
UMux _ a _ -> typeOf a
UF2B _ -> Word32
UD2B _ -> Word64
UB2F _ -> Float
UB2D _ -> Double
instance Expr a => TypeOf (E a) where
typeOf = eType
class Eq a => Expr a where
eType :: E a -> Type
constant :: a -> Const
expression :: E a -> Expression
variable :: V a -> Variable
rawBits :: E a -> E Word64
instance Expr Bool where
eType _ = Bool
constant = CBool
expression = EBool
variable = VBool
rawBits a = mux a 1 0
instance Expr Int8 where
eType _ = Int8
constant = CInt8
expression = EInt8
variable = VInt8
rawBits = Cast
instance Expr Int16 where
eType _ = Int16
constant = CInt16
expression = EInt16
variable = VInt16
rawBits = Cast
instance Expr Int32 where
eType _ = Int32
constant = CInt32
expression = EInt32
variable = VInt32
rawBits = Cast
instance Expr Int64 where
eType _ = Int64
constant = CInt64
expression = EInt64
variable = VInt64
rawBits = Cast
instance Expr Word8 where
eType _ = Word8
constant = CWord8
expression = EWord8
variable = VWord8
rawBits = Cast
instance Expr Word16 where
eType _ = Word16
constant = CWord16
expression = EWord16
variable = VWord16
rawBits = Cast
instance Expr Word32 where
eType _ = Word32
constant = CWord32
expression = EWord32
variable = VWord32
rawBits = Cast
instance Expr Word64 where
eType _ = Word64
constant = CWord64
expression = EWord64
variable = VWord64
rawBits = id
instance Expr Float where
eType _ = Float
constant = CFloat
expression = EFloat
variable = VFloat
rawBits = Cast . F2B
instance Expr Double where
eType _ = Double
constant = CDouble
expression = EDouble
variable = VDouble
rawBits = D2B
class (Num a, Expr a, EqE a, OrdE a) => NumE a
instance NumE Int8
instance NumE Int16
instance NumE Int32
instance NumE Int64
instance NumE Word8
instance NumE Word16
instance NumE Word32
instance NumE Word64
instance NumE Float
instance NumE Double
class (NumE a, Integral a) => IntegralE a where signed :: E a -> Bool
instance IntegralE Int8 where signed _ = True
instance IntegralE Int16 where signed _ = True
instance IntegralE Int32 where signed _ = True
instance IntegralE Int64 where signed _ = True
instance IntegralE Word8 where signed _ = False
instance IntegralE Word16 where signed _ = False
instance IntegralE Word32 where signed _ = False
instance IntegralE Word64 where signed _ = False
class (Eq a, Expr a) => EqE a
instance EqE Bool
instance EqE Int8
instance EqE Int16
instance EqE Int32
instance EqE Int64
instance EqE Word8
instance EqE Word16
instance EqE Word32
instance EqE Word64
instance EqE Float
instance EqE Double
class (Eq a, Ord a, EqE a) => OrdE a
instance OrdE Int8
instance OrdE Int16
instance OrdE Int32
instance OrdE Int64
instance OrdE Word8
instance OrdE Word16
instance OrdE Word32
instance OrdE Word64
instance OrdE Float
instance OrdE Double
class (RealFloat a, NumE a, OrdE a) => FloatingE a
instance FloatingE Float
instance FloatingE Double
instance (Num a, NumE a, OrdE a) => Num (E a) where
(Const a) + (Const b) = Const $ a + b
a + b = Add a b
(Const a) (Const b) = Const $ a b
a b = Sub a b
(Const a) * (Const b) = Const $ a * b
a * b = Mul a b
negate a = 0 a
abs a = mux (a <. 0) (negate a) a
signum a = mux (a ==. 0) 0 $ mux (a <. 0) (1) 1
fromInteger = Const . fromInteger
instance (OrdE a, NumE a, Num a, Fractional a) => Fractional (E a) where
(Const a) / (Const b) = Const $ a / b
a / b = Div a b
recip a = 1 / a
fromRational r = Const $ fromInteger (numerator r) / fromInteger (denominator r)
instance (Expr a, OrdE a, EqE a, IntegralE a, Bits a) => Bits (E a) where
(Const a) .&. (Const b) = Const $ a .&. b
a .&. b = BWAnd a b
complement (Const a) = Const $ complement a
complement a = BWNot a
(Const a) .|. (Const b) = Const $ a .|. b
a .|. b = BWOr a b
xor a b = (a .&. complement b) .|. (complement a .&. b)
shift (Const a) n = Const $ shift a n
shift a n = Shift a n
rotate = error "E rotate not supported."
bitSize = width
isSigned = signed
true :: E Bool
true = Const True
false :: E Bool
false = Const False
not_ :: E Bool -> E Bool
not_ = Not
(&&.) :: E Bool -> E Bool -> E Bool
(&&.) = And
(||.) :: E Bool -> E Bool -> E Bool
(||.) a b = not_ $ not_ a &&. not_ b
and_ :: [E Bool] -> E Bool
and_ = foldl (&&.) true
or_ :: [E Bool] -> E Bool
or_ = foldl (||.) false
all_ :: (a -> E Bool) -> [a] -> E Bool
all_ f a = and_ $ map f a
any_ :: (a -> E Bool) -> [a] -> E Bool
any_ f a = or_ $ map f a
imply :: E Bool -> E Bool -> E Bool
imply a b = not_ a ||. b
(==.) :: (EqE a, OrdE a) => E a -> E a -> E Bool
(==.) = Eq
(/=.) :: (EqE a, OrdE a) => E a -> E a -> E Bool
a /=. b = not_ (a ==. b)
(<.) :: OrdE a => E a -> E a -> E Bool
(<.) = Lt
(>.) :: OrdE a => E a -> E a -> E Bool
a >. b = b <. a
(<=.) :: OrdE a => E a -> E a -> E Bool
a <=. b = not_ (a >. b)
(>=.) :: OrdE a => E a -> E a -> E Bool
a >=. b = not_ (a <. b)
min_ :: OrdE a => E a -> E a -> E a
min_ a b = mux (a <=. b) a b
minimum_ :: OrdE a => [E a] -> E a
minimum_ = foldl1 min_
max_ :: OrdE a => E a -> E a -> E a
max_ a b = mux (a >=. b) a b
maximum_ :: OrdE a => [E a] -> E a
maximum_ = foldl1 max_
limit :: OrdE a => E a -> E a -> E a -> E a
limit a b i = max_ min $ min_ max i
where
min = min_ a b
max = max_ a b
div_ :: IntegralE a => E a -> E a -> E a
div_ (Const a) (Const b) = Const $ a `div` b
div_ a b = Div a b
mod_ :: IntegralE a => E a -> E a -> E a
mod_ (Const a) (Const b) = Const $ a `mod` b
mod_ a b = Mod a b
value :: V a -> E a
value = VRef
mux :: Expr a => E Bool -> E a -> E a -> E a
mux = Mux
(!) :: (Expr a, IntegralE b) => A a -> E b -> V a
(!) (A ua) = V . UV . Array ua . ue
(!.) :: (Expr a, IntegralE b) => A a -> E b -> E a
a !. i = value $ a ! i
ueUpstream :: UE -> [UE]
ueUpstream t = case t of
UVRef (UV (Array _ ue)) -> [ue]
UVRef (UV (External _ _)) -> []
UCast _ a -> [a]
UConst _ -> []
UAdd a b -> [a, b]
USub a b -> [a, b]
UMul a b -> [a, b]
UDiv a b -> [a, b]
UMod a b -> [a, b]
UNot a -> [a]
UAnd a -> a
UBWNot a -> [a]
UBWAnd a b -> [a, b]
UBWOr a b -> [a, b]
UShift a _ -> [a]
UEq a b -> [a, b]
ULt a b -> [a, b]
UMux a b c -> [a, b, c]
UF2B a -> [a]
UD2B a -> [a]
UB2F a -> [a]
UB2D a -> [a]
nearestUVs :: UE -> [UV]
nearestUVs = nub . f
where
f :: UE -> [UV]
f (UVRef uv) = [uv]
f ue = concatMap f $ ueUpstream ue
arrayIndices :: UE -> [(UA, UE)]
arrayIndices = nub . f
where
f :: UE -> [(UA, UE)]
f (UVRef (UV (Array ua ue))) = (ua, ue) : f ue
f ue = concatMap f $ ueUpstream ue
ue :: Expr a => E a -> UE
ue t = case t of
VRef (V v) -> UVRef v
Const a -> UConst $ constant a
Cast a -> UCast tt (ue a)
Add a b -> UAdd (ue a) (ue b)
Sub a b -> USub (ue a) (ue b)
Mul a b -> UMul (ue a) (ue b)
Div a b -> UDiv (ue a) (ue b)
Mod a b -> UMod (ue a) (ue b)
Not a -> unot (ue a)
And a b -> uand (ue a) (ue b)
BWNot a -> UBWNot (ue a)
BWAnd a b -> UBWAnd (ue a) (ue b)
BWOr a b -> UBWOr (ue a) (ue b)
Shift a b -> UShift (ue a) b
Eq a b -> ueq (ue a) (ue b)
Lt a b -> ult (ue a) (ue b)
Mux a b c -> umux (ue a) (ue b) (ue c)
F2B a -> UF2B (ue a)
D2B a -> UD2B (ue a)
B2F a -> UB2F (ue a)
B2D a -> UB2D (ue a)
where
tt = eType t
uv :: V a -> UV
uv (V v) = v
ubool :: Bool -> UE
ubool = UConst . CBool
unot :: UE -> UE
unot (UConst (CBool a)) = ubool $ not a
unot (UNot a) = a
unot a = UNot a
uand :: UE -> UE -> UE
uand a b | a == b = a
uand a@(UConst (CBool False)) _ = a
uand _ a@(UConst (CBool False)) = a
uand (UConst (CBool True)) a = a
uand a (UConst (CBool True)) = a
uand (UAnd a) (UAnd b) = reduceAnd $ a ++ b
uand (UAnd a) b = reduceAnd $ b : a
uand a (UAnd b) = reduceAnd $ a : b
uand a b = reduceAnd [a, b]
reduceAnd :: [UE] -> UE
reduceAnd terms | not $ null [ e | e <- terms, e' <- map unot terms, e == e' ] = ubool False
reduceAnd terms | or [ f a b | a <- terms, b <- terms ] = ubool False
where
f :: UE -> UE -> Bool
f (UEq a b) (UEq x y) | a == x = yep $ ueq b y
| a == y = yep $ ueq b x
| b == x = yep $ ueq a y
| b == y = yep $ ueq a x
f _ _ = False
yep :: UE -> Bool
yep (UConst (CBool False)) = True
yep _ = False
reduceAnd terms | not $ null [ e | e <- terms, not $ null $ f e, all (flip elem terms) $ f e ] = ubool False
where
f :: UE -> [UE]
f (UNot (UAnd a)) = a
f _ = []
reduceAnd terms = UAnd $ sort $ nub terms
uor :: UE -> UE -> UE
uor a b = unot (uand (unot a) (unot b))
ueq :: UE -> UE -> UE
ueq a b | a == b = ubool True
ueq (UConst (CBool a)) (UConst (CBool b)) = ubool $ a == b
ueq (UConst (CInt8 a)) (UConst (CInt8 b)) = ubool $ a == b
ueq (UConst (CInt16 a)) (UConst (CInt16 b)) = ubool $ a == b
ueq (UConst (CInt32 a)) (UConst (CInt32 b)) = ubool $ a == b
ueq (UConst (CInt64 a)) (UConst (CInt64 b)) = ubool $ a == b
ueq (UConst (CWord8 a)) (UConst (CWord8 b)) = ubool $ a == b
ueq (UConst (CWord16 a)) (UConst (CWord16 b)) = ubool $ a == b
ueq (UConst (CWord32 a)) (UConst (CWord32 b)) = ubool $ a == b
ueq (UConst (CWord64 a)) (UConst (CWord64 b)) = ubool $ a == b
ueq (UConst (CFloat a)) (UConst (CFloat b)) = ubool $ a == b
ueq (UConst (CDouble a)) (UConst (CDouble b)) = ubool $ a == b
ueq a b = UEq a b
ult :: UE -> UE -> UE
ult a b | a == b = ubool False
ult (UConst (CBool a)) (UConst (CBool b)) = ubool $ a < b
ult (UConst (CInt8 a)) (UConst (CInt8 b)) = ubool $ a < b
ult (UConst (CInt16 a)) (UConst (CInt16 b)) = ubool $ a < b
ult (UConst (CInt32 a)) (UConst (CInt32 b)) = ubool $ a < b
ult (UConst (CInt64 a)) (UConst (CInt64 b)) = ubool $ a < b
ult (UConst (CWord8 a)) (UConst (CWord8 b)) = ubool $ a < b
ult (UConst (CWord16 a)) (UConst (CWord16 b)) = ubool $ a < b
ult (UConst (CWord32 a)) (UConst (CWord32 b)) = ubool $ a < b
ult (UConst (CWord64 a)) (UConst (CWord64 b)) = ubool $ a < b
ult (UConst (CFloat a)) (UConst (CFloat b)) = ubool $ a < b
ult (UConst (CDouble a)) (UConst (CDouble b)) = ubool $ a < b
ult a b = ULt a b
umux :: UE -> UE -> UE -> UE
umux _ t f | t == f = f
umux b t f | typeOf t == Bool = uor (uand b t) (uand (unot b) f)
umux (UConst (CBool b)) t f = if b then t else f
umux (UNot b) t f = umux b f t
umux b1 (UMux b2 t _) f | b1 == b2 = umux b1 t f
umux b1 t (UMux b2 _ f) | b1 == b2 = umux b1 t f
umux b t f = UMux b t f