module Language.Sh.Arithmetic ( runMathParser ) where
import Text.ParserCombinators.Parsec
import Text.ParserCombinators.Parsec.Language
import Text.ParserCombinators.Parsec.Expr
import qualified Text.ParserCombinators.Parsec.Token as P
import Data.Bits ( shiftL, shiftR, complement, xor, (.&.), (.|.) )
import Data.List ( unionBy )
import Data.Maybe ( fromMaybe )
import Debug.Trace ( trace )
import Language.Sh.Compat ( on )
type SS = [(String,String)]
type SI = [(String,Int)]
type AP a = CharParser SS a
data Term = Literal SI Int | Variable String | Error String
deriving ( Show )
runMathParser :: SS -> String -> Either String (Int,SI)
runMathParser subs s = case runParser exprSubs (subs) "" s of
Left err -> Left $ show err
Right x -> Right x
joinS :: Eq a => [(a,b)] -> [(a,b)] -> [(a,b)]
joinS = unionBy ((==) `on` fst)
mapS :: (b -> c) -> [(a,b)] -> [(a,c)]
mapS f = map $ \(a,b)->(a,f b)
exprSubs :: AP (Int,SI)
exprSubs = do e <- expr
eof
case e of
Literal subs i -> return (i,subs)
Variable s -> do ss <- getState
let val = fromMaybe "0" $ lookup s ss
case runMathParser ss val of
Left err -> fail err
Right (i,si) -> return (i,si)
Error err -> fail err
lexer :: P.TokenParser st
lexer = P.makeTokenParser $
emptyDef {identLetter = alphaNum <|> char '_'
, opStart = oneOf []
, opLetter = oneOf []
, reservedOpNames= ["++","+","--","-","*","/","%","^"
,"|","||","&","&&","<<",">>"
,"<","<=",">",">=","==","=","!=","!","~"
,"?",":"
,"+=","-=","*=","/=","%=","|=","&="
,"^=","<<=",">>="]
}
parens = P.parens lexer
whiteSpace = P.whiteSpace lexer
hexadecimal = P.hexadecimal lexer
decimal = P.decimal lexer
reservedOp = P.reservedOp lexer
identifier = P.identifier lexer
natural = do n <- octal <|> decimal <|> hexadecimal
whiteSpace
return n
where octal = do char '0'
bo 0
bo n = do d <- oneOf "01234567" <?> "octal digit"
return $ 8*n + read [d]
<|> return n
mapT :: (Int -> Int) -> Term -> Term
mapT _ (Error err) = Error err
mapT _ (Variable v) = Error $ "impossible: unexpanded variable: "++v
mapT f (Literal s i) = Literal s $ f i
mapT2 :: (Int -> Int -> Int) -> Term -> Term -> Term
mapT2 _ (Error err) _ = Error err
mapT2 _ (Variable v) _ = Error $ "impossible: unexpanded variable: "++v
mapT2 _ _ (Error err) = Error err
mapT2 _ _ (Variable v) = Error $ "impossible: unexpanded variable: "++v
mapT2 f (Literal s1 i1) (Literal s2 i2) = Literal (s1 `joinS` s2) $ f i1 i2
expr1 :: AP Term
expr1 = buildExpressionParser table1 term
expr2 :: AP Term
expr2 = try (do eIf <- expr1
reservedOp "?"
eThen <- expr1
reservedOp ":"
eElse <- expr1
ss <- getState
case expand ss eIf of
Error err -> return $ Error err
Literal si i -> return $ if (i/=0) then expandWith ss si eThen
else expandWith ss si eElse
) <|> expr1
where expandWith ss si t = case expand (mapS show si `joinS` ss) t of
Error err -> Error err
Literal si' i -> Literal (si `joinS` si') i
expr :: AP Term
expr = buildExpressionParser table2 expr2
term :: AP Term
term = parens expr
<|> fmap (Literal [] . fromIntegral) natural
<|> fmap Variable identifier
<?> "simple expression"
table1 :: OperatorTable Char SS Term
#ifdef HAVE_PARSEC_POSTFIX
table1 = [ [postfix "++" $ postinc (+1), postfix "--" $ postinc (+(1))]
, [prefix "+" $ e1 id, prefix "-" $ e1 negate]
, [prefix "++" $ preinc (+1), prefix "--" $ preinc (+(1))]
#else
table1 = [ [prefix "+" $ e1 id, prefix "-" $ e1 negate]
#endif
, [prefix "~" $ e1 complement,prefix "!" $ e1 $ b2i . not . i2b]
, [binary "*" $ e2 (*), binary "/" $ e2 div, binary "%" $ e2 mod]
, [binary "+" $ e2 (+), binary "-" $ e2 ()]
, [binary "<<" $ e2 shiftL, binary ">>" $ e2 shiftR]
, [binary "<" $ e2 $ b2i .: (<), binary "<=" $ e2 $ b2i .: (<=)
,binary ">" $ e2 $ b2i .: (>), binary ">=" $ e2 $ b2i .: (>=)
,binary "==" $ e2 $ b2i .: (==), binary "!=" $ e2 $ b2i .: (/=)]
, [binary "&" $ e2 (.&.)]
, [binary "^" $ e2 xor]
, [binary "|" $ e2 (.|.)]
, [binary "&&" $ e2 $ b2 (&&)]
, [binary "||" $ e2 $ b2 (||)] ]
where e1 :: (Int -> Int) -> AP (Term -> Term)
e1 f = do ss <- getState
return $ mapT f . expand ss
e2 :: (Int -> Int -> Int) -> AP (Term -> Term -> Term)
e2 f = do ss <- getState
return $ \t1 t2 -> mapT2 f (expand ss t1) (expand ss t2)
b2 :: (Bool -> Bool -> Bool) -> Int -> Int -> Int
b2 f i j = b2i $ f (i2b i) (i2b j)
i2b i = if i==0 then False else True
b2i b = if b then 1 else 0
(.:) f g a b = f $ g a b
ro name = try (reservedOp name >> notFollowedBy (char '='))
binary name fun = Infix (ro name >> fun) AssocLeft
prefix name fun = Prefix (reservedOp name >> fun)
#ifdef HAVE_PARSEC_POSTFIX
postfix name fun = Postfix (ro name >> fun)
#endif
expand :: SS -> Term -> Term
expand _ (Error err) = Error err
expand _ (Literal s i) = Literal s i
expand subs (Variable name) =
case lookup name subs of
Nothing -> Literal [] 0
Just s -> case runMathParser subs s of
Left err -> Error err
Right (i,si) -> Literal si i
postinc,preinc :: (Int -> Int) -> AP (Term -> Term)
postinc f = assignReturn $ \i -> (f i,i)
preinc f = assignReturn $ \i -> (f i,f i)
assignReturn' :: SS -> SI -> (Int -> (Int,Int)) -> (Term -> Term)
assignReturn' ss si f = ar
where ar (Error err) = Error err
ar (Literal _ i) = Error $ "assignment to non-variable: "++show i
ar (Variable v) = let val = fromMaybe "0" $ lookup v ss
in case runMathParser ss val of
Left err -> Error err
Right (i,si') ->
let (ass,ret) = f i
si'' = [(v,ass)] `joinS` si' `joinS` si
in Literal si'' ret
assignReturn :: (Int -> (Int,Int)) -> AP (Term -> Term)
assignReturn f = do ss <- getState
return $ assignReturn' ss [] f
assignReturn2 :: (Int -> Int -> Int) -> AP (Term -> Term -> Term)
assignReturn2 f = ar `fmap` getState
where ar ss t t' = let t'' = expand ss t'
in case t'' of
Error err -> Error err
Literal si j ->
assignReturn' ss si (\i -> (f i j,f i j)) t
table2 = [ [op "=" $ flip const, op "*=" (*), op "/=" div
,op "%=" mod, op "+=" (+), op "-=" ()]
, [op "<<=" shiftL, op ">>=" shiftR
,op "&=" (.&.), op "^=" xor, op "|=" (.|.)] ]
where a2 :: (Int -> Int -> Int) -> AP (Term -> Term -> Term)
a2 = assignReturn2
op name fun = Infix (reservedOp name >> a2 fun) AssocLeft