{-# Language FlexibleContexts #-} module Language.Egison.Primitives (primitiveEnv, primitiveEnvNoIO) where import Control.Arrow import Control.Applicative import Control.Monad.Error import Data.IORef import qualified Data.Array as A import Data.Ratio import System.IO import System.Random import qualified Data.Sequence as Sq import System.IO.Unsafe import qualified Database.MySQL.Base as MySQL import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BC import Control.Monad import Language.Egison.Types import Language.Egison.Core primitiveEnv :: IO Env primitiveEnv = do let ops = map (second PrimitiveFunc) (primitives ++ ioPrimitives) bindings <- forM (constants ++ ops) $ \(name, op) -> do ref <- newIORef . WHNF $ Value op return (name, ref) return $ extendEnv nullEnv bindings primitiveEnvNoIO :: IO Env primitiveEnvNoIO = do let ops = map (second PrimitiveFunc) primitives bindings <- forM (constants ++ ops) $ \(name, op) -> do ref <- newIORef . WHNF $ Value op return (name, ref) return $ extendEnv nullEnv bindings {-# INLINE noArg #-} noArg :: (MonadError EgisonError m) => m EgisonValue -> [WHNFData] -> m EgisonValue noArg f = \vals -> case vals of [] -> f _ -> throwError $ ArgumentsNum 0 $ length vals {-# INLINE oneArg #-} oneArg :: (MonadError EgisonError m) => (WHNFData -> m EgisonValue) -> [WHNFData] -> m EgisonValue oneArg f = \vals -> case vals of [val] -> f val [] -> f $ Value $ Tuple [] _ -> throwError $ ArgumentsNum 1 $ length vals {-# INLINE twoArgs #-} twoArgs :: (MonadError EgisonError m) => (WHNFData -> WHNFData -> m EgisonValue) -> [WHNFData] -> m EgisonValue twoArgs f = \vals -> case vals of [val, val'] -> f val val' _ -> throwError $ ArgumentsNum 2 $ length vals {-# INLINE threeArgs #-} threeArgs :: (MonadError EgisonError m) => (WHNFData -> WHNFData -> WHNFData -> m EgisonValue) -> [WHNFData] -> m EgisonValue threeArgs f = \vals -> case vals of [val, val', val''] -> f val val' val'' _ -> throwError $ ArgumentsNum 3 $ length vals -- -- Constants -- constants :: [(String, EgisonValue)] constants = [ ("pi", Float 3.141592653589793) ] -- -- Primitives -- primitives :: [(String, PrimitiveFunc)] primitives = [ ("+i", integerBinaryOp (+)) , ("-i", integerBinaryOp (-)) , ("*i", integerBinaryOp (*)) , ("modulo", integerBinaryOp mod) , ("quotient", integerBinaryOp quot) , ("remainder", integerBinaryOp rem) , ("eq-i?", integerBinaryPred (==)) , ("lt-i?", integerBinaryPred (<)) , ("lte-i?", integerBinaryPred (<=)) , ("gt-i?", integerBinaryPred (>)) , ("gte-i?", integerBinaryPred (>=)) , ("+f", floatBinaryOp (+)) , ("-f", floatBinaryOp (-)) , ("*f", floatBinaryOp (*)) , ("/f", floatBinaryOp (/)) , ("eq-f?", floatBinaryPred (==)) , ("lt-f?", floatBinaryPred (<)) , ("lte-f?", floatBinaryPred (<=)) , ("gt-f?", floatBinaryPred (>)) , ("gte-f?", floatBinaryPred (>=)) , ("neg", integerUnaryOp negate) , ("abs", integerUnaryOp abs) , ("sqrt", floatUnaryOp sqrt) , ("exp", floatUnaryOp exp) , ("log", floatUnaryOp log) , ("sin", floatUnaryOp sin) , ("cos", floatUnaryOp cos) , ("tan", floatUnaryOp tan) , ("asin", floatUnaryOp asin) , ("acos", floatUnaryOp acos) , ("atan", floatUnaryOp atan) , ("sinh", floatUnaryOp sinh) , ("cosh", floatUnaryOp cosh) , ("tanh", floatUnaryOp tanh) , ("asinh", floatUnaryOp asinh) , ("acosh", floatUnaryOp acosh) , ("atanh", floatUnaryOp atanh) , ("round", floatToIntegerOp round) , ("floor", floatToIntegerOp floor) , ("ceiling", floatToIntegerOp ceiling) , ("truncate", floatToIntegerOp truncate) , ("itof", integerToFloat) , ("rtof", rationalToFloat) , ("itos", integerToString) , ("stoi", stringToInteger) , ("eq?", eq) , ("lt?", lt) , ("lte?", lte) , ("gt?", gt) , ("gte?", gte) , ("+", plus) , ("-", minus) , ("*", multiply) , ("/", divide) , ("/-inverse", divideInverse) , ("assert", assert) , ("assert-equal", assertEqual) , ("pure-mysql", pureMySQL) ] integerUnaryOp :: (Integer -> Integer) -> PrimitiveFunc integerUnaryOp op = (liftError .) $ oneArg $ \val -> Integer . op <$> fromIntegerValue val integerBinaryOp :: (Integer -> Integer -> Integer) -> PrimitiveFunc integerBinaryOp op = (liftError .) $ twoArgs $ \val val' -> (Integer .) . op <$> fromIntegerValue val <*> fromIntegerValue val' integerBinaryPred :: (Integer -> Integer -> Bool) -> PrimitiveFunc integerBinaryPred pred = (liftError .) $ twoArgs $ \val val' -> (Bool .) . pred <$> fromIntegerValue val <*> fromIntegerValue val' floatUnaryOp :: (Double -> Double) -> PrimitiveFunc floatUnaryOp op = (liftError .) $ oneArg $ \val -> Float . op <$> fromFloatValue val floatBinaryOp :: (Double -> Double -> Double) -> PrimitiveFunc floatBinaryOp op = (liftError .) $ twoArgs $ \val val' -> (Float .) . op <$> fromFloatValue val <*> fromFloatValue val' floatBinaryPred :: (Double -> Double -> Bool) -> PrimitiveFunc floatBinaryPred pred = (liftError .) $ twoArgs $ \val val' -> (Bool .) . pred <$> fromFloatValue val <*> fromFloatValue val' floatToIntegerOp :: (Double -> Integer) -> PrimitiveFunc floatToIntegerOp op = (liftError .) $ oneArg $ \val -> Integer . op <$> fromFloatValue val integerToFloat :: PrimitiveFunc integerToFloat = (liftError .) $ oneArg $ \val -> Float . fromInteger <$> fromIntegerValue val rationalToFloat :: PrimitiveFunc rationalToFloat = (liftError .) $ oneArg $ \val -> Float . fromRational <$> fromRationalValue val integerToString :: PrimitiveFunc integerToString = (liftError .) $ oneArg $ \val -> makeStringValue . show <$> fromIntegerValue val stringToInteger :: PrimitiveFunc stringToInteger = (liftError .) $ oneArg $ \val -> do numStr <- fromStringValue val return $ Integer (read numStr :: Integer) eq :: PrimitiveFunc eq = (liftError .) $ twoArgs $ \val val' -> (Bool .) . (==) <$> fromPrimitiveValue val <*> fromPrimitiveValue val' lt :: PrimitiveFunc lt = (liftError .) $ twoArgs lt' where lt' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i < i' lt' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i < f lt' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f < fromInteger i lt' (Value (Float f)) (Value (Float f')) = return $ Bool $ f < f' lt' (Value (Integer _)) val = throwError $ TypeMismatch "number" val lt' (Value (Float _)) val = throwError $ TypeMismatch "number" val lt' val _ = throwError $ TypeMismatch "number" val lte :: PrimitiveFunc lte = (liftError .) $ twoArgs lte' where lte' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i <= i' lte' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i <= f lte' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f <= fromInteger i lte' (Value (Float f)) (Value (Float f')) = return $ Bool $ f <= f' lte' (Value (Integer _)) val = throwError $ TypeMismatch "number" val lte' (Value (Float _)) val = throwError $ TypeMismatch "number" val lte' val _ = throwError $ TypeMismatch "number" val gt :: PrimitiveFunc gt = (liftError .) $ twoArgs gt' where gt' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i > i' gt' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i > f gt' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f > fromInteger i gt' (Value (Float f)) (Value (Float f')) = return $ Bool $ f > f' gt' (Value (Integer _)) val = throwError $ TypeMismatch "number" val gt' (Value (Float _)) val = throwError $ TypeMismatch "number" val gt' val _ = throwError $ TypeMismatch "number" val gte :: PrimitiveFunc gte = (liftError .) $ twoArgs gte' where gte' (Value (Integer i)) (Value (Integer i')) = return $ Bool $ i >= i' gte' (Value (Integer i)) (Value (Float f)) = return $ Bool $ fromInteger i >= f gte' (Value (Float f)) (Value (Integer i)) = return $ Bool $ f >= fromInteger i gte' (Value (Float f)) (Value (Float f')) = return $ Bool $ f >= f' gte' (Value (Integer _)) val = throwError $ TypeMismatch "number" val gte' (Value (Float _)) val = throwError $ TypeMismatch "number" val gte' val _ = throwError $ TypeMismatch "number" val plus :: PrimitiveFunc plus = (liftError .) $ twoArgs plus' where plus' (Value (Integer x)) (Value (Integer x')) = return $ Integer $ (x + x') plus' (Value (Integer i)) val = plus' (Value (Rational (i % 1))) val plus' val (Value (Integer i)) = plus' val (Value (Rational (i % 1))) plus' (Value (Rational x)) (Value (Rational x')) = let y = (x + x') in if denominator y == 1 then return $ Integer $ numerator y else return $ Rational y plus' (Value (Float f)) (Value (Float f')) = return $ Float $ f + f' plus' (Value (Rational i)) (Value (Float f)) = return $ Float $ (fromRational i) + f plus' (Value (Float f)) (Value (Rational i)) = return $ Float $ f + (fromRational i) plus' (Value (Rational _)) val = throwError $ TypeMismatch "number" val plus' (Value (Float _)) val = throwError $ TypeMismatch "number" val plus' val _ = throwError $ TypeMismatch "number" val minus :: PrimitiveFunc minus = (liftError .) $ twoArgs minus' where minus' (Value (Integer x)) (Value (Integer x')) = return $ Integer $ (x - x') minus' (Value (Integer i)) val = minus' (Value (Rational (i % 1))) val minus' val (Value (Integer i)) = minus' val (Value (Rational (i % 1))) minus' (Value (Rational x)) (Value (Rational x')) = let y = (x - x') in if denominator y == 1 then return $ Integer $ numerator y else return $ Rational y minus' (Value (Float f)) (Value (Float f')) = return $ Float $ f - f' minus' (Value (Rational i)) (Value (Float f)) = return $ Float $ (fromRational i) - f minus' (Value (Float f)) (Value (Rational i)) = return $ Float $ f - (fromRational i) minus' (Value (Rational _)) val = throwError $ TypeMismatch "number" val minus' (Value (Float _)) val = throwError $ TypeMismatch "number" val minus' val _ = throwError $ TypeMismatch "number" val multiply :: PrimitiveFunc multiply = (liftError .) $ twoArgs multiply' where multiply' (Value (Integer x)) (Value (Integer x')) = return $ Integer $ (x * x') multiply' (Value (Integer i)) val = multiply' (Value (Rational (i % 1))) val multiply' val (Value (Integer i)) = multiply' val (Value (Rational (i % 1))) multiply' (Value (Rational x)) (Value (Rational x')) = let y = (x * x') in if denominator y == 1 then return $ Integer $ numerator y else return $ Rational y multiply' (Value (Float f)) (Value (Float f')) = return $ Float $ f * f' multiply' (Value (Rational i)) (Value (Float f)) = return $ Float $ (fromRational i) * f multiply' (Value (Float f)) (Value (Rational i)) = return $ Float $ f * (fromRational i) multiply' (Value (Rational _)) val = throwError $ TypeMismatch "number" val multiply' (Value (Float _)) val = throwError $ TypeMismatch "number" val multiply' val _ = throwError $ TypeMismatch "number" val divide :: PrimitiveFunc divide = (liftError .) $ twoArgs divide' where divide' (Value (Integer x)) (Value (Integer x')) = return $ Rational $ (x % x') divide' (Value (Integer i)) val = divide' (Value (Rational (i % 1))) val divide' val (Value (Integer i)) = divide' val (Value (Rational (i % 1))) divide' (Value (Rational x)) (Value (Rational x')) = let m = numerator x' in let n = denominator x' in let y = (x * (n % m)) in if denominator y == 1 then return $ Integer $ numerator y else return $ Rational y divide' (Value (Rational x)) (Value (Float f)) = return $ Float $ (fromRational x) / f divide' (Value (Float f)) (Value (Rational x)) = return $ Float $ f / (fromRational x) divide' (Value (Rational _)) val = throwError $ TypeMismatch "number" val divide' (Value (Float f)) (Value (Float f')) = return $ Float $ f / f' divide' (Value (Float _)) val = throwError $ TypeMismatch "number" val divide' val _ = throwError $ TypeMismatch "number" val divideInverse :: PrimitiveFunc divideInverse = (liftError .) $ oneArg $ divideInverse' where divideInverse' (Value (Rational rat)) = do return $ Tuple [Integer (numerator rat), Integer (denominator rat)] divideInverse' (Value (Integer x)) = do return $ Tuple [Integer x, Integer 1] divideInverse' val = throwError $ TypeMismatch "rational" val assert :: PrimitiveFunc assert = (liftError .) $ twoArgs $ \label test -> do test <- fromBoolValue test if test then return $ Bool True else throwError $ Assertion $ show label assertEqual :: PrimitiveFunc assertEqual = threeArgs $ \label actual expected -> do actual <- evalDeep actual expected <- evalDeep expected if actual == expected then return $ Bool True else throwError $ Assertion $ show label ++ "\n expected: " ++ show expected ++ "\n but found: " ++ show actual pureMySQL :: PrimitiveFunc pureMySQL = (liftError .) $ twoArgs $ \val val' -> do dbName <- fromStringValue val qStr <- fromStringValue val' let ret = unsafePerformIO $ query' dbName $ BC.pack qStr return $ Collection $ Sq.fromList $ map (\r -> Tuple (map makeStringValue r)) ret where query' :: String -> ByteString -> IO [[String]] query' dbName q = do conn <- MySQL.connect MySQL.defaultConnectInfo { MySQL.connectDatabase = dbName } MySQL.query conn q ret <- MySQL.storeResult conn fetchAllRows ret fetchAllRows :: MySQL.Result -> IO [[String]] fetchAllRows ret = do row <- MySQL.fetchRow ret case row of [] -> return [] _ -> do row' <- forM row (\mcol -> case mcol of Just col -> return $ BC.unpack col Nothing -> return "null") rows' <- fetchAllRows ret return $ row':rows' -- -- IO Primitives -- ioPrimitives :: [(String, PrimitiveFunc)] ioPrimitives = [ ("return", return') , ("open-input-file", makePort ReadMode) , ("open-output-file", makePort WriteMode) , ("close-input-port", closePort) , ("close-output-port", closePort) , ("read-char", readChar) , ("read-line", readLine) -- , ("read", readFromStdin) , ("write-char", writeChar) , ("write-string", writeString) , ("write", write) -- , ("print", writeStringLine) , ("eof?", isEOFStdin) , ("flush", flushStdout) , ("read-char-from-port", readCharFromPort) , ("read-line-from-port", readLineFromPort) -- , ("read-from-port", readFromPort) , ("write-char-to-port", writeCharToPort) , ("write-string-to-port", writeStringToPort) , ("write-to-port", writeToPort) , ("eof-port?", isEOFPort) -- , ("print-to-port", writeStringLineToPort) , ("flush-port", flushPort) , ("rand", randRange) ] -- , ("get-lib-dir-name", getLibDirName) ] makeIO :: EgisonM EgisonValue -> EgisonValue makeIO m = IOFunc $ liftM (Value . Tuple . (World :) . (:[])) m makeIO' :: EgisonM () -> EgisonValue makeIO' m = IOFunc $ m >> return (Value $ Tuple [World, Tuple []]) return' :: PrimitiveFunc return' = oneArg $ return . makeIO . evalDeep makePort :: IOMode -> PrimitiveFunc makePort mode = (liftError .) $ oneArg $ \val -> do filename <- fromStringValue val return . makeIO . liftIO $ Port <$> openFile filename mode closePort :: PrimitiveFunc closePort = (liftError .) $ oneArg $ \val -> makeIO' . liftIO . hClose <$> fromPortValue val writeChar :: PrimitiveFunc writeChar = (liftError .) $ oneArg $ \val -> makeIO' . liftIO . putChar <$> fromCharValue val writeString :: PrimitiveFunc writeString = (liftError .) $ oneArg $ \val -> makeIO' . liftIO . putStr <$> fromStringValue val write :: PrimitiveFunc write = oneArg $ \val -> makeIO' . liftIO . putStr . show <$> evalDeep val readChar :: PrimitiveFunc readChar = noArg $ return $ makeIO $ liftIO $ liftM Char getChar readLine :: PrimitiveFunc readLine = noArg $ return $ makeIO $ liftIO $ liftM makeStringValue getLine flushStdout :: PrimitiveFunc flushStdout = noArg $ return $ makeIO' $ liftIO $ hFlush stdout isEOFStdin :: PrimitiveFunc isEOFStdin = noArg $ return $ makeIO $ liftIO $ liftM Bool isEOF writeCharToPort :: PrimitiveFunc writeCharToPort = (liftError .) $ twoArgs $ \val val' -> ((makeIO' . liftIO) .) . hPutChar <$> fromPortValue val <*> fromCharValue val' writeStringToPort :: PrimitiveFunc writeStringToPort = (liftError .) $ twoArgs $ \val val' -> ((makeIO' . liftIO) .) . hPutStr <$> fromPortValue val <*> fromStringValue val' writeToPort :: PrimitiveFunc writeToPort = twoArgs $ \val val' -> do ((makeIO' . liftIO) .) . hPutStr <$> liftError (fromPortValue val) <*> (show <$> evalDeep val') readCharFromPort :: PrimitiveFunc readCharFromPort = (liftError .) $ oneArg $ \val -> makeIO . liftIO . liftM Char . hGetChar <$> fromPortValue val readLineFromPort :: PrimitiveFunc readLineFromPort = (liftError .) $ oneArg $ \val -> makeIO . liftIO . liftM makeStringValue . hGetLine <$> fromPortValue val flushPort :: PrimitiveFunc flushPort = (liftError .) $ oneArg $ \val -> makeIO' . liftIO . hFlush <$> fromPortValue val isEOFPort :: PrimitiveFunc isEOFPort = (liftError .) $ oneArg $ \val -> makeIO . liftIO . liftM Bool . hIsEOF <$> fromPortValue val --rand :: PrimitiveFunc --rand = noArg $ return $ makeIO $ liftIO $ liftM Integer $ getStdRandom random randRange :: PrimitiveFunc randRange = (liftError .) $ twoArgs $ \val val' -> return . makeIO . liftIO . liftM Integer . getStdRandom . randomR =<< liftM2 (,) (fromIntegerValue val) (fromIntegerValue val')