module Language.Atom.Expressions
  (
  
    E     (..)
  , V     (..)
  , UE    (..)
  , UV    (..)
  , UVLocality (..)
  , A     (..)
  , UA    (..)
  , Expr  (..)
  , Expression (..)
  , Variable   (..)
  , Type  (..)
  , Const (..)
  , Width (..)
  , TypeOf (..)
  , bytes
  , ue
  , uv
  , ueUpstream
  , nearestUVs
  , arrayIndices
  , NumE
  , IntegralE
  , FloatingE
  , EqE
  , OrdE
  
  , true
  , false
  
  , value
  
  , not_
  , (&&.)
  , (||.)
  , and_
  , or_
  , any_
  , all_
  , imply
  
  , (==.)
  , (/=.)
  , (<.)
  , (<=.)
  , (>.)
  , (>=.)
  , min_
  , minimum_
  , max_
  , maximum_
  , limit
  
  , div_
  , mod_
  
  , mux
  
  , (!)
  , (!.)
  
  , ubool
  , unot
  , uand
  , uor
  , ueq
  , umux
  ) where
import Data.Bits
import Data.Function (on)
import Data.Int
import Data.List
import Data.Ratio
import Data.Word
infixl 9 !, !.
infix  4 ==., /=., <., <=., >., >=.
infixl 3 &&. 
infixl 2 ||. 
data Type
  = Bool
  | Int8
  | Int16
  | Int32
  | Int64
  | Word8
  | Word16
  | Word32
  | Word64
  | Float
  | Double
  deriving (Show, Read, Eq, Ord)
data Const
  = CBool   Bool
  | CInt8   Int8
  | CInt16  Int16
  | CInt32  Int32
  | CInt64  Int64
  | CWord8  Word8
  | CWord16 Word16
  | CWord32 Word32
  | CWord64 Word64
  | CFloat  Float
  | CDouble Double
  deriving (Eq, Ord)
instance Show Const where
  show c = case c of
    CBool   True  -> "1"
    CBool   False -> "0"
    CInt8   c -> show c
    CInt16  c -> show c
    CInt32  c -> show c
    CInt64  c -> show c
    CWord8  c -> show c
    CWord16 c -> show c
    CWord32 c -> show c
    CWord64 c -> show c
    CFloat  c -> show c
    CDouble c -> show c
data Expression
  = EBool   (E Bool)
  | EInt8   (E Int8)
  | EInt16  (E Int16)
  | EInt32  (E Int32)
  | EInt64  (E Int64)
  | EWord8  (E Word8)
  | EWord16 (E Word16)
  | EWord32 (E Word32)
  | EWord64 (E Word64)
  | EFloat  (E Float)
  | EDouble (E Double)
data Variable
  = VBool   (V Bool)
  | VInt8   (V Int8)
  | VInt16  (V Int16)
  | VInt32  (V Int32)
  | VInt64  (V Int64)
  | VWord8  (V Word8)
  | VWord16 (V Word16)
  | VWord32 (V Word32)
  | VWord64 (V Word64)
  | VFloat  (V Float)
  | VDouble (V Double) deriving Eq
data V a = V UV deriving Eq
data UV = UV UVLocality deriving (Show, Eq, Ord)
data UVLocality = Array UA UE | External String Type deriving (Show, Eq, Ord)
data A a = A UA deriving Eq
data UA = UA Int String [Const] deriving (Show, Eq, Ord)
data E a where
  VRef    :: V a -> E a
  Const   :: a -> E a
  Cast    :: (NumE a, NumE b) => E a -> E b
  Add     :: NumE a => E a -> E a -> E a
  Sub     :: NumE a => E a -> E a -> E a
  Mul     :: NumE a => E a -> E a -> E a
  Div     :: NumE a => E a -> E a -> E a
  Mod     :: IntegralE a => E a -> E a -> E a
  Not     :: E Bool -> E Bool
  And     :: E Bool -> E Bool -> E Bool
  BWNot   :: IntegralE a => E a -> E a
  BWAnd   :: IntegralE a => E a -> E a -> E a
  BWOr    :: IntegralE a => E a -> E a -> E a
  Shift   :: IntegralE a => E a -> Int -> E a
  Eq      :: (EqE a, OrdE a) => E a -> E a -> E Bool
  Lt      :: OrdE a => E a -> E a -> E Bool
  Mux     :: E Bool -> E a -> E a -> E a
  F2B     :: E Float  -> E Word32
  D2B     :: E Double -> E Word64
  B2F     :: E Word32 -> E Float
  B2D     :: E Word64 -> E Double
instance Show (E a) where
  show _ = error "Show (E a) not implemented"
instance Expr a => Eq (E a) where
  (==) = (==) `on` ue
data UE
  = UVRef UV
  | UConst Const
  | UCast  Type UE
  | UAdd   UE UE
  | USub   UE UE
  | UMul   UE UE
  | UDiv   UE UE
  | UMod   UE UE
  | UNot   UE
  | UAnd   [UE]
  | UBWNot UE
  | UBWAnd UE UE
  | UBWOr  UE UE
  | UShift UE Int
  | UEq    UE UE
  | ULt    UE UE
  | UMux   UE UE UE
  | UF2B   UE
  | UD2B   UE
  | UB2F   UE
  | UB2D   UE
  deriving (Show, Eq, Ord)
class Width a where
  width :: a -> Int
bytes :: Width a => a -> Int
bytes a = div (width a) 8 + if mod (width a) 8 == 0 then 0 else 1
instance Width Type where
  width t = case t of
    Bool   -> 1
    Int8   -> 8
    Int16  -> 16
    Int32  -> 32
    Int64  -> 64
    Word8  -> 8
    Word16 -> 16
    Word32 -> 32
    Word64 -> 64
    Float  -> 32
    Double -> 64
instance Width Const           where width = width . typeOf
instance Expr a => Width (E a) where width = width . typeOf
instance Expr a => Width (V a) where width = width . typeOf
instance Width UE              where width = width . typeOf
instance Width UV              where width = width . typeOf
class TypeOf a where typeOf :: a -> Type
instance TypeOf Const where
  typeOf a = case a of
    CBool   _ -> Bool
    CInt8   _ -> Int8
    CInt16  _ -> Int16
    CInt32  _ -> Int32
    CInt64  _ -> Int64
    CWord8  _ -> Word8
    CWord16 _ -> Word16
    CWord32 _ -> Word32
    CWord64 _ -> Word64
    CFloat  _ -> Float
    CDouble _ -> Double
instance TypeOf UV where
  typeOf a = case a of
    UV (Array a _)  -> typeOf a
    UV (External _ t) -> t
instance TypeOf (V a) where
  typeOf (V uv) = typeOf uv
instance TypeOf UA where
  typeOf (UA _ _ c) = typeOf $ head c
instance TypeOf (A a) where
  typeOf (A ua) = typeOf ua
instance TypeOf UE where
  typeOf t = case t of
    UVRef uvar -> typeOf uvar
    UCast t _  -> t
    UConst c   -> typeOf c
    UAdd a _   -> typeOf a
    USub a _   -> typeOf a
    UMul a _   -> typeOf a
    UDiv a _   -> typeOf a
    UMod a _   -> typeOf a
    UNot _     -> Bool
    UAnd _     -> Bool
    UBWNot a   -> typeOf a
    UBWAnd a _ -> typeOf a
    UBWOr  a _ -> typeOf a
    UShift a _ -> typeOf a
    UEq  _ _   -> Bool
    ULt  _ _   -> Bool
    UMux _ a _ -> typeOf a
    UF2B _     -> Word32
    UD2B _     -> Word64
    UB2F _     -> Float
    UB2D _     -> Double
instance Expr a => TypeOf (E a) where
  typeOf = eType
class Eq a => Expr a where
  eType      :: E a -> Type
  constant   :: a -> Const
  expression :: E a -> Expression
  variable   :: V a -> Variable
  rawBits    :: E a -> E Word64
instance Expr Bool   where
  eType _    = Bool
  constant   = CBool
  expression = EBool
  variable   = VBool
  rawBits a  = mux a 1 0
instance Expr Int8   where
  eType _    = Int8
  constant   = CInt8
  expression = EInt8
  variable   = VInt8
  rawBits    = Cast
instance Expr Int16  where
  eType _    = Int16
  constant   = CInt16
  expression = EInt16
  variable   = VInt16
  rawBits    = Cast
instance Expr Int32  where
  eType _    = Int32
  constant   = CInt32
  expression = EInt32
  variable   = VInt32
  rawBits     = Cast
instance Expr Int64  where
  eType _    = Int64
  constant   = CInt64
  expression = EInt64
  variable   = VInt64
  rawBits    = Cast
instance Expr Word8  where
  eType _    = Word8
  constant   = CWord8
  expression = EWord8
  variable   = VWord8
  rawBits    = Cast
instance Expr Word16 where
  eType _    = Word16
  constant   = CWord16
  expression = EWord16
  variable   = VWord16
  rawBits    = Cast
instance Expr Word32 where
  eType _    = Word32
  constant   = CWord32
  expression = EWord32
  variable   = VWord32
  rawBits    = Cast
instance Expr Word64 where
  eType _    = Word64
  constant   = CWord64
  expression = EWord64
  variable   = VWord64
  rawBits    = id
instance Expr Float  where
  eType _    = Float
  constant   = CFloat
  expression = EFloat
  variable   = VFloat
  rawBits    = Cast . F2B
instance Expr Double where
  eType _    = Double
  constant   = CDouble
  expression = EDouble
  variable   = VDouble
  rawBits    = D2B
class (Num a, Expr a, EqE a, OrdE a) => NumE a
instance NumE Int8
instance NumE Int16
instance NumE Int32
instance NumE Int64
instance NumE Word8
instance NumE Word16
instance NumE Word32
instance NumE Word64
instance NumE Float
instance NumE Double
class (NumE a, Integral a) => IntegralE a where signed :: E a -> Bool
instance IntegralE Int8   where signed _ = True
instance IntegralE Int16  where signed _ = True
instance IntegralE Int32  where signed _ = True
instance IntegralE Int64  where signed _ = True
instance IntegralE Word8  where signed _ = False
instance IntegralE Word16 where signed _ = False
instance IntegralE Word32 where signed _ = False
instance IntegralE Word64 where signed _ = False
class (Eq a, Expr a) => EqE a
instance EqE Bool
instance EqE Int8
instance EqE Int16
instance EqE Int32
instance EqE Int64
instance EqE Word8
instance EqE Word16
instance EqE Word32
instance EqE Word64
instance EqE Float
instance EqE Double
class (Eq a, Ord a, EqE a) => OrdE a
instance OrdE Int8
instance OrdE Int16
instance OrdE Int32
instance OrdE Int64
instance OrdE Word8
instance OrdE Word16
instance OrdE Word32
instance OrdE Word64
instance OrdE Float
instance OrdE Double
class (RealFloat a, NumE a, OrdE a) => FloatingE a
instance FloatingE Float
instance FloatingE Double
instance (Num a, NumE a, OrdE a) => Num (E a) where
  (Const a) + (Const b) = Const $ a + b
  a + b = Add a b
  (Const a)  (Const b) = Const $ a  b
  a  b = Sub a b
  (Const a) * (Const b) = Const $ a * b
  a * b = Mul a b
  negate a = 0  a
  abs a = mux (a <. 0) (negate a) a
  signum a = mux (a ==. 0) 0 $ mux (a <. 0) (1) 1
  fromInteger = Const . fromInteger
instance (OrdE a, NumE a, Num a, Fractional a) => Fractional (E a) where
  (Const a) / (Const b) = Const $ a / b
  a / b = Div a b
  recip a = 1 / a
  fromRational r = Const $ fromInteger (numerator r) / fromInteger (denominator r)
instance (Expr a, OrdE a, EqE a, IntegralE a, Bits a) => Bits (E a) where
  (Const a) .&. (Const b) = Const $ a .&. b
  a .&. b = BWAnd a b
  complement (Const a) = Const $ complement a
  complement a = BWNot a
  (Const a) .|. (Const b) = Const $ a .|. b
  a .|. b = BWOr a b
  xor a b = (a .&. complement b) .|. (complement a .&. b)
  shift (Const a) n = Const $ shift a n
  shift a n = Shift a n
  rotate = error "E rotate not supported."
  bitSize = width
  isSigned = signed
true :: E Bool
true = Const True
false :: E Bool
false = Const False
not_ :: E Bool -> E Bool
not_ = Not
(&&.) :: E Bool -> E Bool -> E Bool
(&&.) = And
(||.) :: E Bool -> E Bool -> E Bool
(||.) a b = not_ $ not_ a &&. not_ b
and_ :: [E Bool] -> E Bool
and_ = foldl (&&.) true
or_ :: [E Bool] -> E Bool
or_ = foldl (||.) false
all_ :: (a -> E Bool) -> [a] -> E Bool
all_ f a = and_ $ map f a
any_ :: (a -> E Bool) -> [a] -> E Bool
any_ f a = or_ $ map f a
imply :: E Bool -> E Bool -> E Bool 
imply a b = not_ a ||. b
(==.) :: (EqE a, OrdE a) => E a -> E a -> E Bool
(==.) = Eq
(/=.) :: (EqE a, OrdE a) => E a -> E a -> E Bool
a /=. b = not_ (a ==. b)
(<.) :: OrdE a => E a -> E a -> E Bool
(<.) = Lt
(>.) :: OrdE a => E a -> E a -> E Bool
a >. b = b <. a
(<=.) :: OrdE a => E a -> E a -> E Bool
a <=. b =  not_ (a >. b)
(>=.) :: OrdE a => E a -> E a -> E Bool
a >=. b = not_ (a <. b)
min_ :: OrdE a => E a -> E a -> E a
min_ a b = mux (a <=. b) a b
minimum_ :: OrdE a => [E a] -> E a
minimum_ = foldl1 min_
max_ :: OrdE a => E a -> E a -> E a
max_ a b = mux (a >=. b) a b
maximum_ :: OrdE a => [E a] -> E a
maximum_ = foldl1 max_
limit :: OrdE a => E a -> E a -> E a -> E a
limit a b i = max_ min $ min_ max i
  where
  min = min_ a b
  max = max_ a b
div_ :: IntegralE a => E a -> E a -> E a
div_ (Const a) (Const b) = Const $ a `div` b
div_ a b = Div a b
mod_ :: IntegralE a => E a -> E a -> E a
mod_ (Const a) (Const b) = Const $ a `mod` b
mod_ a b = Mod a b
value :: V a -> E a
value = VRef
mux :: Expr a => E Bool -> E a -> E a -> E a
mux = Mux
(!) :: (Expr a, IntegralE b) => A a -> E b -> V a
(!) (A ua) = V . UV . Array ua . ue
(!.) :: (Expr a, IntegralE b) => A a -> E b -> E a
a !. i = value $ a ! i
ueUpstream :: UE -> [UE]
ueUpstream t = case t of
  UVRef (UV (Array _ ue))   -> [ue]
  UVRef (UV (External _ _)) -> []
  UCast _ a   -> [a]
  UConst _    -> []
  UAdd a b    -> [a, b]
  USub a b    -> [a, b]
  UMul a b    -> [a, b]
  UDiv a b    -> [a, b]
  UMod a b    -> [a, b]
  UNot a      -> [a]
  UAnd a      -> a
  UBWNot a    -> [a]
  UBWAnd a b  -> [a, b]
  UBWOr  a b  -> [a, b]
  UShift a _  -> [a]
  UEq  a b    -> [a, b]
  ULt  a b    -> [a, b]
  UMux a b c  -> [a, b, c]
  UF2B a      -> [a]
  UD2B a      -> [a]
  UB2F a      -> [a]
  UB2D a      -> [a]
nearestUVs :: UE -> [UV]
nearestUVs = nub . f
  where
  f :: UE -> [UV]
  f (UVRef uv) = [uv]
  f ue = concatMap f $ ueUpstream ue
arrayIndices :: UE -> [(UA, UE)]
arrayIndices = nub . f
  where
  f :: UE -> [(UA, UE)]
  f (UVRef (UV (Array ua ue))) = (ua, ue) : f ue
  f ue = concatMap f $ ueUpstream ue
ue :: Expr a => E a -> UE
ue t = case t of
  VRef (V v) -> UVRef v
  Const  a     -> UConst $ constant a
  Cast   a     -> UCast    tt (ue a)
  Add    a b   -> UAdd     (ue a) (ue b)
  Sub    a b   -> USub     (ue a) (ue b)
  Mul    a b   -> UMul     (ue a) (ue b)
  Div    a b   -> UDiv     (ue a) (ue b)
  Mod    a b   -> UMod     (ue a) (ue b)
  Not    a     -> unot     (ue a)
  And    a b   -> uand     (ue a) (ue b)
  BWNot  a     -> UBWNot   (ue a)
  BWAnd  a b   -> UBWAnd   (ue a) (ue b)
  BWOr   a b   -> UBWOr    (ue a) (ue b)
  Shift  a b   -> UShift   (ue a) b
  Eq     a b   -> ueq      (ue a) (ue b)
  Lt     a b   -> ult      (ue a) (ue b)
  Mux    a b c -> umux     (ue a) (ue b) (ue c)
  F2B    a     -> UF2B     (ue a)
  D2B    a     -> UD2B     (ue a)
  B2F    a     -> UB2F     (ue a)
  B2D    a     -> UB2D     (ue a)
  where
  tt = eType t
uv :: V a -> UV
uv (V v) = v
ubool :: Bool -> UE
ubool = UConst . CBool
unot :: UE -> UE
unot (UConst (CBool a)) = ubool $ not a
unot (UNot a) = a
unot a = UNot a
uand :: UE -> UE -> UE
uand a b | a == b                   = a
uand a@(UConst (CBool False)) _     = a
uand _ a@(UConst (CBool False))     = a
uand (UConst (CBool True)) a        = a
uand a (UConst (CBool True))        = a
uand (UAnd a) (UAnd b)              = reduceAnd $ a ++ b
uand (UAnd a) b                     = reduceAnd $ b : a
uand a (UAnd b)                     = reduceAnd $ a : b
uand a b                            = reduceAnd [a, b]
reduceAnd :: [UE] -> UE
reduceAnd terms | not $ null [ e | e <- terms, e' <- map unot terms, e == e' ] = ubool False
reduceAnd terms | or [ f a b | a <- terms, b <- terms ]                        = ubool False
  where
  f :: UE -> UE -> Bool
  f (UEq a b) (UEq x y) | a == x = yep $ ueq b y
                        | a == y = yep $ ueq b x
                        | b == x = yep $ ueq a y
                        | b == y = yep $ ueq a x
  f _ _ = False
  yep :: UE -> Bool
  yep (UConst (CBool False)) = True
  yep _ = False
reduceAnd terms | not $ null [ e | e <- terms, not $ null $ f e, all (flip elem terms) $ f e ] = ubool False
  where
  f :: UE -> [UE]
  f (UNot (UAnd a)) = a
  f _               = []
reduceAnd terms = UAnd $ sort $ nub terms
uor :: UE -> UE -> UE
uor a b = unot (uand (unot a) (unot b))
ueq :: UE -> UE -> UE
ueq a b | a == b = ubool True
ueq (UConst (CBool   a)) (UConst (CBool   b)) = ubool $ a == b
ueq (UConst (CInt8   a)) (UConst (CInt8   b)) = ubool $ a == b
ueq (UConst (CInt16  a)) (UConst (CInt16  b)) = ubool $ a == b
ueq (UConst (CInt32  a)) (UConst (CInt32  b)) = ubool $ a == b
ueq (UConst (CInt64  a)) (UConst (CInt64  b)) = ubool $ a == b
ueq (UConst (CWord8  a)) (UConst (CWord8  b)) = ubool $ a == b
ueq (UConst (CWord16 a)) (UConst (CWord16 b)) = ubool $ a == b
ueq (UConst (CWord32 a)) (UConst (CWord32 b)) = ubool $ a == b
ueq (UConst (CWord64 a)) (UConst (CWord64 b)) = ubool $ a == b
ueq (UConst (CFloat  a)) (UConst (CFloat  b)) = ubool $ a == b
ueq (UConst (CDouble a)) (UConst (CDouble b)) = ubool $ a == b
ueq a b = UEq a b
ult :: UE -> UE -> UE
ult a b | a == b = ubool False
ult (UConst (CBool   a)) (UConst (CBool   b)) = ubool $ a < b
ult (UConst (CInt8   a)) (UConst (CInt8   b)) = ubool $ a < b
ult (UConst (CInt16  a)) (UConst (CInt16  b)) = ubool $ a < b
ult (UConst (CInt32  a)) (UConst (CInt32  b)) = ubool $ a < b
ult (UConst (CInt64  a)) (UConst (CInt64  b)) = ubool $ a < b
ult (UConst (CWord8  a)) (UConst (CWord8  b)) = ubool $ a < b
ult (UConst (CWord16 a)) (UConst (CWord16 b)) = ubool $ a < b
ult (UConst (CWord32 a)) (UConst (CWord32 b)) = ubool $ a < b
ult (UConst (CWord64 a)) (UConst (CWord64 b)) = ubool $ a < b
ult (UConst (CFloat  a)) (UConst (CFloat  b)) = ubool $ a < b
ult (UConst (CDouble a)) (UConst (CDouble b)) = ubool $ a < b
ult a b = ULt a b
umux :: UE -> UE -> UE -> UE
umux _ t f | t == f = f
umux b t f | typeOf t == Bool = uor (uand b t) (uand (unot b) f)
umux (UConst (CBool b)) t f = if b then t else f
umux (UNot b) t f = umux b f t
umux b1 (UMux b2 t _) f | b1 == b2 = umux b1 t f
umux b1 t (UMux b2 _ f) | b1 == b2 = umux b1 t f
umux b t f = UMux b t f