module Language.Egison.Primitives (primitiveEnv, primitiveEnvNoIO) where
import Control.Arrow
import Control.Applicative
import Control.Monad.Error
import Data.IORef
import qualified Data.Array as A
import Data.Ratio
import System.IO
import System.Random
import qualified Data.Sequence as Sq
import System.IO.Unsafe
import qualified Database.MySQL.Base as MySQL
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BC
import Control.Monad
import Language.Egison.Types
import Language.Egison.Core
primitiveEnv :: IO Env
primitiveEnv = do
let ops = map (second PrimitiveFunc) (primitives ++ ioPrimitives)
bindings <- forM (constants ++ ops) $ \(name, op) -> do
ref <- newIORef . WHNF $ Value op
return (name, ref)
return $ extendEnv nullEnv bindings
primitiveEnvNoIO :: IO Env
primitiveEnvNoIO = do
let ops = map (second PrimitiveFunc) primitives
bindings <- forM (constants ++ ops) $ \(name, op) -> do
ref <- newIORef . WHNF $ Value op
return (name, ref)
return $ extendEnv nullEnv bindings
noArg :: (MonadError EgisonError m) =>
m EgisonValue ->
[WHNFData] -> m EgisonValue
noArg f = \vals -> case vals of
[] -> f
_ -> throwError $ ArgumentsNum 0 $ length vals
oneArg :: (MonadError EgisonError m) =>
(WHNFData -> m EgisonValue) ->
[WHNFData] -> m EgisonValue
oneArg f = \vals -> case vals of
[val] -> f val
[] -> f $ Value $ Tuple []
_ -> throwError $ ArgumentsNum 1 $ length vals
twoArgs :: (MonadError EgisonError m) =>
(WHNFData -> WHNFData -> m EgisonValue) ->
[WHNFData] -> m EgisonValue
twoArgs f = \vals -> case vals of
[val, val'] -> f val val'
_ -> throwError $ ArgumentsNum 2 $ length vals
threeArgs :: (MonadError EgisonError m) =>
(WHNFData -> WHNFData -> WHNFData -> m EgisonValue) ->
[WHNFData] -> m EgisonValue
threeArgs f = \vals -> case vals of
[val, val', val''] -> f val val' val''
_ -> throwError $ ArgumentsNum 3 $ length vals
constants :: [(String, EgisonValue)]
constants = [ ("pi", Float 3.141592653589793) ]
primitives :: [(String, PrimitiveFunc)]
primitives = [ ("+i", integerBinaryOp (+))
, ("-i", integerBinaryOp ())
, ("*i", integerBinaryOp (*))
, ("modulo", integerBinaryOp mod)
, ("quotient", integerBinaryOp quot)
, ("remainder", integerBinaryOp rem)
, ("eq-i?", integerBinaryPred (==))
, ("lt-i?", integerBinaryPred (<))
, ("lte-i?", integerBinaryPred (<=))
, ("gt-i?", integerBinaryPred (>))
, ("gte-i?", integerBinaryPred (>=))
, ("+f", floatBinaryOp (+))
, ("-f", floatBinaryOp ())
, ("*f", floatBinaryOp (*))
, ("/f", floatBinaryOp (/))
, ("eq-f?", floatBinaryPred (==))
, ("lt-f?", floatBinaryPred (<))
, ("lte-f?", floatBinaryPred (<=))
, ("gt-f?", floatBinaryPred (>))
, ("gte-f?", floatBinaryPred (>=))
, ("neg", integerUnaryOp negate)
, ("abs", integerUnaryOp abs)
, ("sqrt", floatUnaryOp sqrt)
, ("exp", floatUnaryOp exp)
, ("log", floatUnaryOp log)
, ("sin", floatUnaryOp sin)
, ("cos", floatUnaryOp cos)
, ("tan", floatUnaryOp tan)
, ("asin", floatUnaryOp asin)
, ("acos", floatUnaryOp acos)
, ("atan", floatUnaryOp atan)
, ("sinh", floatUnaryOp sinh)
, ("cosh", floatUnaryOp cosh)
, ("tanh", floatUnaryOp tanh)
, ("asinh", floatUnaryOp asinh)
, ("acosh", floatUnaryOp acosh)
, ("atanh", floatUnaryOp atanh)
, ("round", floatToIntegerOp round)
, ("floor", floatToIntegerOp floor)
, ("ceiling", floatToIntegerOp ceiling)
, ("truncate", floatToIntegerOp truncate)
, ("itof", integerToFloat)
, ("rtof", rationalToFloat)
, ("itos", integerToString)
, ("stoi", stringToInteger)
, ("eq?", eq)
, ("lt?", lt)
, ("lte?", lte)
, ("gt?", gt)
, ("gte?", gte)
, ("+", plus)
, ("-", minus)
, ("*", multiply)
, ("/", divide)
, ("/-inverse", divideInverse)
, ("assert", assert)
, ("assert-equal", assertEqual)
, ("pure-mysql", pureMySQL)
]
integerUnaryOp :: (Integer -> Integer) -> PrimitiveFunc
integerUnaryOp op = (liftError .) $ oneArg $ \val ->
Integer . op <$> fromIntegerValue val
integerBinaryOp :: (Integer -> Integer -> Integer) -> PrimitiveFunc
integerBinaryOp op = (liftError .) $ twoArgs $ \val val' ->
(Integer .) . op <$> fromIntegerValue val
<*> fromIntegerValue val'
integerBinaryPred :: (Integer -> Integer -> Bool) -> PrimitiveFunc
integerBinaryPred pred = (liftError .) $ twoArgs $ \val val' ->
(Bool .) . pred <$> fromIntegerValue val
<*> fromIntegerValue val'
floatUnaryOp :: (Double -> Double) -> PrimitiveFunc
floatUnaryOp op = (liftError .) $ oneArg $ \val ->
Float . op <$> fromFloatValue val
floatBinaryOp :: (Double -> Double -> Double) -> PrimitiveFunc
floatBinaryOp op = (liftError .) $ twoArgs $ \val val' ->
(Float .) . op <$> fromFloatValue val
<*> fromFloatValue val'
floatBinaryPred :: (Double -> Double -> Bool) -> PrimitiveFunc
floatBinaryPred pred = (liftError .) $ twoArgs $ \val val' ->
(Bool .) . pred <$> fromFloatValue val
<*> fromFloatValue val'
floatToIntegerOp :: (Double -> Integer) -> PrimitiveFunc
floatToIntegerOp op = (liftError .) $ oneArg $ \val ->
Integer . op <$> fromFloatValue val
integerToFloat :: PrimitiveFunc
integerToFloat = (liftError .) $ oneArg $ \val ->
Float . fromInteger <$> fromIntegerValue val
rationalToFloat :: PrimitiveFunc
rationalToFloat = (liftError .) $ oneArg $ \val ->
Float . fromRational <$> fromRationalValue val
integerToString :: PrimitiveFunc
integerToString = (liftError .) $ oneArg $ \val ->
makeStringValue . show <$> fromIntegerValue val
stringToInteger :: PrimitiveFunc
stringToInteger = (liftError .) $ oneArg $ \val -> do
numStr <- fromStringValue val
return $ Integer (read numStr :: Integer)
eq :: PrimitiveFunc
eq = (liftError .) $ twoArgs $ \val val' ->
(Bool .) . (==) <$> fromPrimitiveValue val
<*> fromPrimitiveValue val'
lt :: PrimitiveFunc
lt = (liftError .) $ twoArgs lt'
where
lt' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i < i'
lt' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i < f
lt' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f < fromInteger i
lt' (Value (Float f)) (Value (Float f')) = return $ Bool $ f < f'
lt' (Value (Integer _)) val = throwError $ TypeMismatch "number" val
lt' (Value (Float _)) val = throwError $ TypeMismatch "number" val
lt' val _ = throwError $ TypeMismatch "number" val
lte :: PrimitiveFunc
lte = (liftError .) $ twoArgs lte'
where
lte' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i <= i'
lte' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i <= f
lte' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f <= fromInteger i
lte' (Value (Float f)) (Value (Float f')) = return $ Bool $ f <= f'
lte' (Value (Integer _)) val = throwError $ TypeMismatch "number" val
lte' (Value (Float _)) val = throwError $ TypeMismatch "number" val
lte' val _ = throwError $ TypeMismatch "number" val
gt :: PrimitiveFunc
gt = (liftError .) $ twoArgs gt'
where
gt' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i > i'
gt' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i > f
gt' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f > fromInteger i
gt' (Value (Float f)) (Value (Float f')) = return $ Bool $ f > f'
gt' (Value (Integer _)) val = throwError $ TypeMismatch "number" val
gt' (Value (Float _)) val = throwError $ TypeMismatch "number" val
gt' val _ = throwError $ TypeMismatch "number" val
gte :: PrimitiveFunc
gte = (liftError .) $ twoArgs gte'
where
gte' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i >= i'
gte' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i >= f
gte' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f >= fromInteger i
gte' (Value (Float f)) (Value (Float f')) = return $ Bool $ f >= f'
gte' (Value (Integer _)) val = throwError $ TypeMismatch "number" val
gte' (Value (Float _)) val = throwError $ TypeMismatch "number" val
gte' val _ = throwError $ TypeMismatch "number" val
plus :: PrimitiveFunc
plus = (liftError .) $ twoArgs plus'
where
plus' (Value (Integer x)) (Value (Integer x')) = return $ Integer $ (x + x')
plus' (Value (Integer i)) val = plus' (Value (Rational (i % 1))) val
plus' val (Value (Integer i)) = plus' val (Value (Rational (i % 1)))
plus' (Value (Rational x)) (Value (Rational x')) = let y = (x + x') in
if denominator y == 1
then return $ Integer $ numerator y
else return $ Rational y
plus' (Value (Float f)) (Value (Float f')) = return $ Float $ f + f'
plus' (Value (Rational i)) (Value (Float f)) = return $ Float $ (fromRational i) + f
plus' (Value (Float f)) (Value (Rational i)) = return $ Float $ f + (fromRational i)
plus' (Value (Rational _)) val = throwError $ TypeMismatch "number" val
plus' (Value (Float _)) val = throwError $ TypeMismatch "number" val
plus' val _ = throwError $ TypeMismatch "number" val
minus :: PrimitiveFunc
minus = (liftError .) $ twoArgs minus'
where
minus' (Value (Integer x)) (Value (Integer x')) = return $ Integer $ (x x')
minus' (Value (Integer i)) val = minus' (Value (Rational (i % 1))) val
minus' val (Value (Integer i)) = minus' val (Value (Rational (i % 1)))
minus' (Value (Rational x)) (Value (Rational x')) = let y = (x x') in
if denominator y == 1
then return $ Integer $ numerator y
else return $ Rational y
minus' (Value (Float f)) (Value (Float f')) = return $ Float $ f f'
minus' (Value (Rational i)) (Value (Float f)) = return $ Float $ (fromRational i) f
minus' (Value (Float f)) (Value (Rational i)) = return $ Float $ f (fromRational i)
minus' (Value (Rational _)) val = throwError $ TypeMismatch "number" val
minus' (Value (Float _)) val = throwError $ TypeMismatch "number" val
minus' val _ = throwError $ TypeMismatch "number" val
multiply :: PrimitiveFunc
multiply = (liftError .) $ twoArgs multiply'
where
multiply' (Value (Integer x)) (Value (Integer x')) = return $ Integer $ (x * x')
multiply' (Value (Integer i)) val = multiply' (Value (Rational (i % 1))) val
multiply' val (Value (Integer i)) = multiply' val (Value (Rational (i % 1)))
multiply' (Value (Rational x)) (Value (Rational x')) = let y = (x * x') in
if denominator y == 1
then return $ Integer $ numerator y
else return $ Rational y
multiply' (Value (Float f)) (Value (Float f')) = return $ Float $ f * f'
multiply' (Value (Rational i)) (Value (Float f)) = return $ Float $ (fromRational i) * f
multiply' (Value (Float f)) (Value (Rational i)) = return $ Float $ f * (fromRational i)
multiply' (Value (Rational _)) val = throwError $ TypeMismatch "number" val
multiply' (Value (Float _)) val = throwError $ TypeMismatch "number" val
multiply' val _ = throwError $ TypeMismatch "number" val
divide :: PrimitiveFunc
divide = (liftError .) $ twoArgs divide'
where
divide' (Value (Integer x)) (Value (Integer x')) = return $ Rational $ (x % x')
divide' (Value (Integer i)) val = divide' (Value (Rational (i % 1))) val
divide' val (Value (Integer i)) = divide' val (Value (Rational (i % 1)))
divide' (Value (Rational x)) (Value (Rational x')) =
let m = numerator x' in
let n = denominator x' in
let y = (x * (n % m)) in
if denominator y == 1
then return $ Integer $ numerator y
else return $ Rational y
divide' (Value (Rational x)) (Value (Float f)) = return $ Float $ (fromRational x) / f
divide' (Value (Float f)) (Value (Rational x)) = return $ Float $ f / (fromRational x)
divide' (Value (Rational _)) val = throwError $ TypeMismatch "number" val
divide' (Value (Float f)) (Value (Float f')) = return $ Float $ f / f'
divide' (Value (Float _)) val = throwError $ TypeMismatch "number" val
divide' val _ = throwError $ TypeMismatch "number" val
divideInverse :: PrimitiveFunc
divideInverse = (liftError .) $ oneArg $ divideInverse'
where
divideInverse' (Value (Rational rat)) = do
return $ Tuple [Integer (numerator rat), Integer (denominator rat)]
divideInverse' (Value (Integer x)) = do
return $ Tuple [Integer x, Integer 1]
divideInverse' val = throwError $ TypeMismatch "rational" val
assert :: PrimitiveFunc
assert = (liftError .) $ twoArgs $ \label test -> do
test <- fromBoolValue test
if test
then return $ Bool True
else throwError $ Assertion $ show label
assertEqual :: PrimitiveFunc
assertEqual = threeArgs $ \label actual expected -> do
actual <- evalDeep actual
expected <- evalDeep expected
if actual == expected
then return $ Bool True
else throwError $ Assertion $ show label ++ "\n expected: " ++ show expected ++
"\n but found: " ++ show actual
pureMySQL :: PrimitiveFunc
pureMySQL = (liftError .) $ twoArgs $ \val val' -> do
dbName <- fromStringValue val
qStr <- fromStringValue val'
let ret = unsafePerformIO $ query' dbName $ BC.pack qStr
return $ Collection $ Sq.fromList $ map (\r -> Tuple (map makeStringValue r)) ret
where
query' :: String -> ByteString -> IO [[String]]
query' dbName q = do
conn <- MySQL.connect MySQL.defaultConnectInfo { MySQL.connectDatabase = dbName }
MySQL.query conn q
ret <- MySQL.storeResult conn
fetchAllRows ret
fetchAllRows :: MySQL.Result -> IO [[String]]
fetchAllRows ret = do
row <- MySQL.fetchRow ret
case row of
[] -> return []
_ -> do row' <- forM row (\mcol -> case mcol of
Just col -> return $ BC.unpack col
Nothing -> return "null")
rows' <- fetchAllRows ret
return $ row':rows'
ioPrimitives :: [(String, PrimitiveFunc)]
ioPrimitives = [ ("return", return')
, ("open-input-file", makePort ReadMode)
, ("open-output-file", makePort WriteMode)
, ("close-input-port", closePort)
, ("close-output-port", closePort)
, ("read-char", readChar)
, ("read-line", readLine)
, ("write-char", writeChar)
, ("write-string", writeString)
, ("write", write)
, ("eof?", isEOFStdin)
, ("flush", flushStdout)
, ("read-char-from-port", readCharFromPort)
, ("read-line-from-port", readLineFromPort)
, ("write-char-to-port", writeCharToPort)
, ("write-string-to-port", writeStringToPort)
, ("write-to-port", writeToPort)
, ("eof-port?", isEOFPort)
, ("flush-port", flushPort)
, ("rand", randRange) ]
makeIO :: EgisonM EgisonValue -> EgisonValue
makeIO m = IOFunc $ liftM (Value . Tuple . (World :) . (:[])) m
makeIO' :: EgisonM () -> EgisonValue
makeIO' m = IOFunc $ m >> return (Value $ Tuple [World, Tuple []])
return' :: PrimitiveFunc
return' = oneArg $ return . makeIO . evalDeep
makePort :: IOMode -> PrimitiveFunc
makePort mode = (liftError .) $ oneArg $ \val -> do
filename <- fromStringValue val
return . makeIO . liftIO $ Port <$> openFile filename mode
closePort :: PrimitiveFunc
closePort = (liftError .) $ oneArg $ \val ->
makeIO' . liftIO . hClose <$> fromPortValue val
writeChar :: PrimitiveFunc
writeChar = (liftError .) $ oneArg $ \val ->
makeIO' . liftIO . putChar <$> fromCharValue val
writeString :: PrimitiveFunc
writeString = (liftError .) $ oneArg $ \val ->
makeIO' . liftIO . putStr <$> fromStringValue val
write :: PrimitiveFunc
write = oneArg $ \val ->
makeIO' . liftIO . putStr . show <$> evalDeep val
readChar :: PrimitiveFunc
readChar = noArg $ return $ makeIO $ liftIO $ liftM Char getChar
readLine :: PrimitiveFunc
readLine = noArg $ return $ makeIO $ liftIO $ liftM makeStringValue getLine
flushStdout :: PrimitiveFunc
flushStdout = noArg $ return $ makeIO' $ liftIO $ hFlush stdout
isEOFStdin :: PrimitiveFunc
isEOFStdin = noArg $ return $ makeIO $ liftIO $ liftM Bool isEOF
writeCharToPort :: PrimitiveFunc
writeCharToPort = (liftError .) $ twoArgs $ \val val' ->
((makeIO' . liftIO) .) . hPutChar <$> fromPortValue val <*> fromCharValue val'
writeStringToPort :: PrimitiveFunc
writeStringToPort = (liftError .) $ twoArgs $ \val val' ->
((makeIO' . liftIO) .) . hPutStr <$> fromPortValue val <*> fromStringValue val'
writeToPort :: PrimitiveFunc
writeToPort = twoArgs $ \val val' -> do
((makeIO' . liftIO) .) . hPutStr <$> liftError (fromPortValue val)
<*> (show <$> evalDeep val')
readCharFromPort :: PrimitiveFunc
readCharFromPort = (liftError .) $ oneArg $ \val ->
makeIO . liftIO . liftM Char . hGetChar <$> fromPortValue val
readLineFromPort :: PrimitiveFunc
readLineFromPort = (liftError .) $ oneArg $ \val ->
makeIO . liftIO . liftM makeStringValue . hGetLine <$> fromPortValue val
flushPort :: PrimitiveFunc
flushPort = (liftError .) $ oneArg $ \val ->
makeIO' . liftIO . hFlush <$> fromPortValue val
isEOFPort :: PrimitiveFunc
isEOFPort = (liftError .) $ oneArg $ \val ->
makeIO . liftIO . liftM Bool . hIsEOF <$> fromPortValue val
randRange :: PrimitiveFunc
randRange = (liftError .) $ twoArgs $ \val val' ->
return . makeIO . liftIO . liftM Integer . getStdRandom . randomR =<< liftM2 (,) (fromIntegerValue val) (fromIntegerValue val')