{-# LANGUAGE ConstraintKinds, RankNTypes, ScopedTypeVariables #-}
module Hpp.Expr (Expr(..), readLitInt, parseExpr, renderExpr, evalExpr) where
import Control.Applicative
import Control.Monad ((>=>))
import Data.Bits (complement, (.&.), (.|.), xor, shiftL, shiftR)
import Data.List (foldl')
import Text.Read (readMaybe)
import Hpp.Tokens
import Data.Char (digitToInt, toLower)
import Data.Proxy (Proxy(..))
import Data.Bifunctor (bimap)
import Data.Bits (Bits)
data BinOp = Add | Sub | Mul | Div | Mod
| BitAnd | BitOr | BitXor | ShiftL | ShiftR
| LessThan | GreaterThan | EqualTo | NotEqualTo
| GreaterOrEqualTo | LessOrEqualTo
| And | Or
deriving (Eq, Ord, Show)
data UnaryOp = Neg | BitNot | Not | Defined deriving (Eq,Ord,Show)
data Lit = LitInt Int | LitUInt Word | LitStr String | LitChar Char | LitID String
deriving (Eq,Ord,Show)
data Parsed = PBinOp BinOp | PUnaryOp UnaryOp | PLit Lit deriving (Eq,Ord,Show)
newtype CppInt = CppInt { getCppInt :: Either Word Int }
onCommonType :: forall p c. (c Word, c Int)
=> p c
-> (forall a. c a => a -> a -> a)
-> CppInt -> CppInt -> CppInt
onCommonType _ f (CppInt (Right x)) (CppInt (Right y)) = CppInt (Right (f x y))
onCommonType _ f (CppInt (Left x)) (CppInt (Left y)) = CppInt (Left (f x y))
onCommonType _ f (CppInt (Right x)) (CppInt (Left y)) =
CppInt (Left $ f (fromIntegral x) y)
onCommonType _ f (CppInt (Left x)) (CppInt (Right y)) =
CppInt (Left $ f x (fromIntegral y))
instance Eq CppInt where
CppInt (Left x) == CppInt (Right y) = x == fromIntegral y
CppInt (Right x) == CppInt (Left y) = fromIntegral x == y
CppInt x == CppInt y = x == y
onCommonTypeB :: forall p c. (c Word, c Int)
=> p c
-> (forall a. c a => a -> a -> Bool)
-> CppInt -> CppInt -> Bool
onCommonTypeB _ f (CppInt (Right x)) (CppInt (Right y)) = f x y
onCommonTypeB _ f (CppInt (Left x)) (CppInt (Left y)) = f x y
onCommonTypeB _ f (CppInt (Right x)) (CppInt (Left y)) = f (fromIntegral x) y
onCommonTypeB _ f (CppInt (Left x)) (CppInt (Right y)) = f x (fromIntegral y)
instance Ord CppInt where
x < y = onCommonTypeB (Proxy::Proxy Ord) (<) x y
x > y = onCommonTypeB (Proxy::Proxy Ord) (>) x y
x <= y = onCommonTypeB (Proxy::Proxy Ord) (<=) x y
x >= y = onCommonTypeB (Proxy::Proxy Ord) (>=) x y
max x y = onCommonType (Proxy::Proxy Ord) max x y
min x y = onCommonType (Proxy::Proxy Ord) min x y
instance Num CppInt where
x + y = onCommonType (Proxy::Proxy Num) (+) x y
x - y = onCommonType (Proxy::Proxy Num) (-) x y
x * y = onCommonType (Proxy::Proxy Num) (*) x y
negate (CppInt x) = CppInt (Right $ either (negate . fromIntegral) negate x)
abs (CppInt x) = CppInt (bimap abs abs x)
signum (CppInt x) = CppInt (bimap signum signum x)
fromInteger = CppInt . Right . fromInteger
integralCpp :: (forall a. Integral a => a -> a -> a) -> CppInt -> CppInt -> CppInt
integralCpp = onCommonType (Proxy::Proxy Integral)
bitsCpp :: (forall a. Bits a => a -> a -> a) -> CppInt -> CppInt -> CppInt
bitsCpp = onCommonType (Proxy::Proxy Bits)
cppShiftL,cppShiftR :: CppInt -> Int -> CppInt
cppShiftL (CppInt x) s = CppInt $ bimap (`shiftL` s) (`shiftL` s) x
cppShiftR (CppInt x) s = CppInt $ bimap (`shiftR` s) (`shiftR` s) x
cppComplement :: CppInt -> CppInt
cppComplement (CppInt x) = CppInt $ bimap complement complement x
data Assoc = RightLeft | LeftRight deriving (Eq, Ord, Show)
associativity :: Either BinOp UnaryOp -> Assoc
associativity = either (const LeftRight) (const RightLeft)
precedence :: Either BinOp UnaryOp -> Int
precedence (Right _) = 10
precedence (Left x) = precedenceBin x
precedenceBin :: BinOp -> Int
precedenceBin Or = 0
precedenceBin And = 1
precedenceBin BitOr = 2
precedenceBin BitXor = 3
precedenceBin BitAnd = 4
precedenceBin EqualTo = 5
precedenceBin NotEqualTo = 5
precedenceBin LessThan = 6
precedenceBin GreaterThan = 6
precedenceBin GreaterOrEqualTo = 6
precedenceBin LessOrEqualTo = 6
precedenceBin ShiftL = 7
precedenceBin ShiftR = 7
precedenceBin Add = 8
precedenceBin Sub = 8
precedenceBin Mul = 9
precedenceBin Div = 9
precedenceBin Mod = 9
fixStringLits :: [Token String] -> Maybe [String]
fixStringLits [] = Just []
fixStringLits (Important h@('"':_):xs) =
let (hs,ys) = break ((== '"') . last . detok) xs
in case ys of
[] -> Nothing
(y:ys') -> fmap ((h ++ detokenize hs ++ detok y) :)
(fixStringLits ys')
fixStringLits (Important x:xs) = fmap (x :) (fixStringLits xs)
fixStringLits (Other _ : xs) = fixStringLits xs
onFirstImportant :: (String -> String)
-> ([Token String] -> [Token String])
-> [Token String]
-> [Token String]
onFirstImportant f k = go
where go [] = k []
go (Important s : ts') = Important (f s) : k ts'
go (o@(Other _) : ts') = o : go ts'
fixUnaryOps :: [Token String] -> [Token String]
fixUnaryOps [] = []
fixUnaryOps (t0:ts0) =
case t0 of
Important "+" -> onFirstImportant ('+':) (go False) ts0
Important "-" -> onFirstImportant ('-':) (go False) ts0
_ -> go False (t0:ts0)
where go _ [] = []
go _ (Important "(":ts) = Important "(" : go True ts
go True (Important "+":ts) = onFirstImportant ('+':) (go False) ts
go True (Important "-":ts) = onFirstImportant ('-':) (go False) ts
go _ (Important t:ts) = let isOp = maybe False (const True) $ parseOp t
in Important t : go isOp ts
go _ (o@(Other _):ts) = o : go False ts
fixBinaryOps :: [Token String] -> [Token String]
fixBinaryOps [] = []
fixBinaryOps (h@(Important t1) : ts@(Important t2 : ts')) =
case parseBinOp (t1++t2) of
Just o -> Important (renderBinOp o) : fixBinaryOps ts'
Nothing -> h : fixBinaryOps ts
fixBinaryOps (t:ts) = t : fixBinaryOps ts
renderBinOp :: BinOp -> String
renderBinOp Add = "+"
renderBinOp Sub = "-"
renderBinOp Mul = "*"
renderBinOp Div = "/"
renderBinOp Mod = "%"
renderBinOp BitAnd = "&"
renderBinOp BitOr = "|"
renderBinOp BitXor = "^"
renderBinOp ShiftL = "<<"
renderBinOp ShiftR = ">>"
renderBinOp LessThan = "<"
renderBinOp GreaterThan = ">"
renderBinOp EqualTo = "=="
renderBinOp NotEqualTo = "!="
renderBinOp GreaterOrEqualTo = ">="
renderBinOp LessOrEqualTo = "<="
renderBinOp And = "&&"
renderBinOp Or = "||"
renderUnaryOp :: UnaryOp -> String
renderUnaryOp Neg = "-"
renderUnaryOp BitNot = "~"
renderUnaryOp Not = "!"
renderUnaryOp Defined = "defined "
lexExpr :: [Token String] -> Maybe [String]
lexExpr = fixStringLits . fixUnaryOps . fixBinaryOps
parseUnaryOp :: String -> Maybe UnaryOp
parseUnaryOp "-" = Just Neg
parseUnaryOp "~" = Just BitNot
parseUnaryOp "!" = Just Not
parseUnaryOp "defined" = Just Defined
parseUnaryOp _ = Nothing
parseBinOp :: String -> Maybe BinOp
parseBinOp "+" = Just Add
parseBinOp "-" = Just Sub
parseBinOp "*" = Just Mul
parseBinOp "/" = Just Div
parseBinOp "%" = Just Mod
parseBinOp "&" = Just BitAnd
parseBinOp "|" = Just BitOr
parseBinOp "^" = Just BitXor
parseBinOp "<<" = Just ShiftL
parseBinOp ">>" = Just ShiftR
parseBinOp "<" = Just LessThan
parseBinOp ">" = Just GreaterThan
parseBinOp "==" = Just EqualTo
parseBinOp "!=" = Just NotEqualTo
parseBinOp ">=" = Just GreaterOrEqualTo
parseBinOp "<=" = Just LessOrEqualTo
parseBinOp "&&" = Just And
parseBinOp "||" = Just Or
parseBinOp _ = Nothing
parseOp :: String -> Maybe (Either BinOp UnaryOp)
parseOp s = Left <$> parseBinOp s <|> Right <$> parseUnaryOp s
readWideChar :: String -> Maybe Char
readWideChar ('L':'\'':cs0) = go 0 cs0
where go n ['\''] = Just $ toEnum n
go n (c:cs) = go (n*256 + fromEnum c) cs
go _ [] = Nothing
readWideChar _ = Nothing
readNarrowChar :: String -> Maybe Char
readNarrowChar ['\'',c,'\''] = Just c
readNarrowChar _ = Nothing
parseLit :: String -> Maybe Lit
parseLit s = case readLitInt s of
Just (CppInt i) -> Just (either LitUInt LitInt i)
Nothing -> case readNarrowChar s <|> readWideChar s of
Just c -> Just (LitChar c)
Nothing -> case readMaybe s of
Just str -> Just (LitStr str)
Nothing -> Nothing
digitsFromBase :: Word -> [Word] -> Word
digitsFromBase base = foldl' aux 0
where aux acc d = base*acc + d
readLitInt' :: String -> Maybe Word
readLitInt' ('0':x:hexDigits)
| x == 'x' || x == 'X' = fmap (digitsFromBase 16)
(mapM (fmap fromIntegral . hexDigit) hexDigits)
where hexDigit c
| c >= '0' && c <= '9' = Just $ digitToInt c
| c >= 'a' && c <= 'f' = Just $ (fromEnum c - fromEnum 'a') + 10
| c >= 'A' && c <= 'F' = Just $ (fromEnum c - fromEnum 'A') + 10
| otherwise = Nothing
readLitInt' ('0':octalDigits) = fmap (digitsFromBase 8)
(mapM octalDigit octalDigits)
where octalDigit c
| c >= '0' && c < '8' = Just . fromIntegral $ digitToInt c
| otherwise = Nothing
readLitInt' s = readMaybe s
readLitInt :: String -> Maybe CppInt
readLitInt s = case map toLower . take 2 $ reverse s of
"lu" -> fmap (CppInt . Left) (readLitInt' (init (init s)))
'l':_ -> readLitInt (init s)
'u':_ -> fmap (CppInt . Left . asWord) (readLitInt (init s))
_ -> fmap (CppInt . Right . fromIntegral) (readLitInt' s)
where asWord = either id fromIntegral . getCppInt
data FunLike = FunBin BinOp | FunUnary UnaryOp | FunParen deriving (Eq,Ord,Show)
finishShunting :: [Parsed] -> [FunLike] -> Maybe [Parsed]
finishShunting q [] = Just (reverse q)
finishShunting _ (FunParen:_) = Nothing
finishShunting q (FunBin op:ops) = finishShunting (PBinOp op : q) ops
finishShunting q (FunUnary op:ops) = finishShunting (PUnaryOp op : q) ops
opParsed :: Either BinOp UnaryOp -> Parsed
opParsed (Left x) = PBinOp x
opParsed (Right x) = PUnaryOp x
opFun :: Either BinOp UnaryOp -> FunLike
opFun = either FunBin FunUnary
juggleBinOps :: Either BinOp UnaryOp -> [Parsed] -> [FunLike]
-> ([Parsed] -> [FunLike] -> r)
-> r
juggleBinOps o q [] k = k q [opFun o]
juggleBinOps o1 q s@(o2:ss) k =
case o2 of
FunBin o -> aux $ Left o
FunUnary o -> aux $ Right o
FunParen -> done
where a1 = associativity o1
p1 = precedence o1
aux o2'
| a1 == LeftRight && p1 <= p2 = juggleBinOps o1 (opParsed o2':q) ss k
| a1 == RightLeft && p1 < p2 = juggleBinOps o1 (opParsed o2':q) ss k
| otherwise = done
where p2 = precedence o2'
done = k q (opFun o1 : s)
shuntRPN :: [Parsed] -> [FunLike] -> [String] -> Maybe [Parsed]
shuntRPN q s [] = finishShunting q s
shuntRPN q s ("(":es) = shuntRPN q (FunParen:s) es
shuntRPN q s (")":es) = let go _ [] = Nothing
go q' (FunParen:s') = shuntRPN q' s' es
go q' (FunBin op:s') = go (PBinOp op : q') s'
go q' (FunUnary op:s') = go (PUnaryOp op : q') s'
in go q s
shuntRPN q s (('+':e@(_:_)):es) = shuntRPN q s (e:es)
shuntRPN q s (('-':e@(_:_)):es) = shuntRPN q (FunUnary Neg : s) (e:es)
shuntRPN q s ("!":es) = shuntRPN q (FunUnary Not : s) es
shuntRPN q s ("~":es) = shuntRPN q (FunUnary BitNot : s) es
shuntRPN q s ("defined":es) = shuntRPN q (FunUnary Defined : s) es
shuntRPN q s (e:es) =
case parseLit e of
Just l -> shuntRPN (PLit l : q) s es
Nothing -> case parseOp e of
Just o -> juggleBinOps o q s $
\q' s' -> shuntRPN q' s' es
Nothing -> shuntRPN (PLit (LitID e) : q) s es
data Expr = ELit Lit | EBinOp BinOp Expr Expr | EUnaryOp UnaryOp Expr
deriving (Eq, Ord, Show)
rpnToExpr :: [Expr] -> [Parsed] -> Maybe Expr
rpnToExpr [e] [] = Just e
rpnToExpr _ [] = Nothing
rpnToExpr s (PLit e:es) = rpnToExpr (ELit e:s) es
rpnToExpr (s:ss) (PUnaryOp o : es) = rpnToExpr (EUnaryOp o s:ss) es
rpnToExpr (s1:s2:ss) (PBinOp o : es) = rpnToExpr (EBinOp o s2 s1 : ss) es
rpnToExpr _ _ = Nothing
parseExpr :: [Token String] -> Maybe Expr
parseExpr = lexExpr >=> shuntRPN [] [] >=> rpnToExpr []
renderExpr :: Expr -> String
renderExpr (ELit (LitInt e)) = show e
renderExpr (ELit (LitUInt e)) = show e ++ "U"
renderExpr (ELit (LitStr e)) = '"':e++"\""
renderExpr (ELit (LitChar c)) = [c]
renderExpr (ELit (LitID e)) = e
renderExpr (EBinOp o e1 e2) = concat [ "(", renderExpr e1," ",renderBinOp o
, " ", renderExpr e2, ")" ]
renderExpr (EUnaryOp o e1) = renderUnaryOp o ++ renderExpr e1
evalExpr :: Expr -> Int
evalExpr = either fromIntegral id . getCppInt . evalExpr'
evalExpr' :: Expr -> CppInt
evalExpr' = go
where int = CppInt . Right
word = CppInt . Left
asInt = either fromIntegral id . getCppInt
go (ELit (LitInt e)) = int e
go (ELit (LitUInt e)) = word e
go (ELit (LitStr _)) = int 1
go (ELit (LitID _)) = int 0
go (ELit (LitChar c)) = int $ fromEnum c
go (EBinOp Add e2 e3) = go e2 + go e3
go (EBinOp Sub e2 e3) = go e2 - go e3
go (EBinOp Mul e2 e3) = go e2 * go e3
go (EBinOp Div e2 e3) = integralCpp div (go e2) (go e3)
go (EBinOp Mod e2 e3) = integralCpp mod (go e2) (go e3)
go (EBinOp BitAnd e2 e3) = bitsCpp (.&.) (go e2) (go e3)
go (EBinOp BitOr e2 e3) = bitsCpp (.|.) (go e2) (go e3)
go (EBinOp BitXor e2 e3) = bitsCpp xor (go e2) (go e3)
go (EBinOp ShiftL e2 e3) = go e2 `cppShiftL` asInt (go e3)
go (EBinOp ShiftR e2 e3) = go e2 `cppShiftR` asInt (go e3)
go (EBinOp LessThan e2 e3) = int . fromEnum $ go e2 < go e3
go (EBinOp GreaterThan e2 e3) = int . fromEnum $ go e2 > go e3
go (EBinOp EqualTo e2 e3) = int . fromEnum $ go e2 == go e3
go (EBinOp NotEqualTo e2 e3) = int . fromEnum $ go e2 /= go e3
go (EBinOp GreaterOrEqualTo e2 e3) = int . fromEnum $ go e2 >= go e3
go (EBinOp LessOrEqualTo e2 e3) = int . fromEnum $ go e2 <= go e3
go (EBinOp And e2 e3) = int . fromEnum $ go e2 /= 0 && go e3 /= 0
go (EBinOp Or e2 e3) = int . fromEnum $ go e2 /= 0 || go e3 /= 0
go (EUnaryOp Neg e2) = negate (go e2)
go (EUnaryOp BitNot e2) = cppComplement $ go e2
go (EUnaryOp Not e2) = int . fromEnum $ go e2 == 0
go (EUnaryOp Defined e2) =
case e2 of
ELit (LitInt 1) -> 1
_ -> 0