{- |
Module      : Language.Egison.Primitives.Utils
Licence     : MIT
-}

module Language.Egison.Primitives.Utils
  ( noArg
  , oneArg
  , oneArg'
  , twoArgs
  , twoArgs'
  , threeArgs'
  , lazyOneArg
  , unaryOp
  , binaryOp
  ) where

import qualified Data.Vector            as V

import           Language.Egison.Data
import           Language.Egison.Tensor

{-# INLINE noArg #-}
noArg :: EvalM EgisonValue -> String -> PrimitiveFunc
noArg :: EvalM EgisonValue -> String -> PrimitiveFunc
noArg EvalM EgisonValue
f String
name [EgisonValue]
args =
  case [EgisonValue]
args of
    [] -> EvalM EgisonValue
f
    [Tuple []] -> EvalM EgisonValue
f
    [EgisonValue]
_ ->
      (CallStack -> EgisonError) -> EvalM EgisonValue
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
1 ([EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [EgisonValue]
args))

{-# INLINE oneArg #-}
oneArg :: (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg :: (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg EgisonValue -> EvalM EgisonValue
f String
name [EgisonValue]
args =
  case [EgisonValue]
args of
    [TensorData (Tensor Shape
ns Vector EgisonValue
ds [Index EgisonValue]
js)] -> do
      Vector EgisonValue
ds' <- (EgisonValue -> EvalM EgisonValue)
-> Vector EgisonValue
-> StateT
     EvalState (ExceptT EgisonError RuntimeM) (Vector EgisonValue)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM EgisonValue -> EvalM EgisonValue
f Vector EgisonValue
ds
      EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector EgisonValue
ds' [Index EgisonValue]
js)
    [EgisonValue
arg] -> EgisonValue -> EvalM EgisonValue
f EgisonValue
arg
    [EgisonValue]
_ ->
      (CallStack -> EgisonError) -> EvalM EgisonValue
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
1 ([EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [EgisonValue]
args))

{-# INLINE oneArg' #-}
oneArg' :: (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg' :: (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg' EgisonValue -> EvalM EgisonValue
f String
name [EgisonValue]
args =
  case [EgisonValue]
args of
    [EgisonValue
arg] -> EgisonValue -> EvalM EgisonValue
f EgisonValue
arg
    [EgisonValue]
_     ->
      (CallStack -> EgisonError) -> EvalM EgisonValue
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
1 ([EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [EgisonValue]
args))

{-# INLINE twoArgs #-}
twoArgs :: (EgisonValue -> EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
twoArgs :: (EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> String -> PrimitiveFunc
twoArgs EgisonValue -> EgisonValue -> EvalM EgisonValue
f String
name [EgisonValue]
args =
  case [EgisonValue]
args of
    [TensorData t1 :: Tensor EgisonValue
t1@Tensor{}, TensorData t2 :: Tensor EgisonValue
t2@Tensor{}] ->
      (EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> Tensor EgisonValue
-> Tensor EgisonValue
-> EvalM (Tensor EgisonValue)
forall a b c.
(a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct EgisonValue -> EgisonValue -> EvalM EgisonValue
f Tensor EgisonValue
t1 Tensor EgisonValue
t2 EvalM (Tensor EgisonValue)
-> (Tensor EgisonValue -> EvalM EgisonValue) -> EvalM EgisonValue
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor EgisonValue -> EvalM EgisonValue
forall a b. TensorComponent a b => Tensor b -> EvalM a
fromTensor
    [TensorData(Tensor Shape
ns Vector EgisonValue
ds [Index EgisonValue]
js), EgisonValue
val] -> do
      Vector EgisonValue
ds' <- (EgisonValue -> EvalM EgisonValue)
-> Vector EgisonValue
-> StateT
     EvalState (ExceptT EgisonError RuntimeM) (Vector EgisonValue)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (EgisonValue -> EgisonValue -> EvalM EgisonValue
`f` EgisonValue
val) Vector EgisonValue
ds
      EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector EgisonValue
ds' [Index EgisonValue]
js)
    [EgisonValue
val, TensorData (Tensor Shape
ns Vector EgisonValue
ds [Index EgisonValue]
js)] -> do
      Vector EgisonValue
ds' <- (EgisonValue -> EvalM EgisonValue)
-> Vector EgisonValue
-> StateT
     EvalState (ExceptT EgisonError RuntimeM) (Vector EgisonValue)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val) Vector EgisonValue
ds
      EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector EgisonValue
ds' [Index EgisonValue]
js)
    [EgisonValue
val, EgisonValue
val'] -> EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val EgisonValue
val'
    [EgisonValue
val] -> EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> (PrimitiveFunc -> EgisonValue)
-> PrimitiveFunc
-> EvalM EgisonValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimitiveFunc -> EgisonValue
PrimitiveFunc (PrimitiveFunc -> EvalM EgisonValue)
-> PrimitiveFunc -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg (EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val) String
name
    [EgisonValue]
_ -> (CallStack -> EgisonError) -> EvalM EgisonValue
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
2 ([EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [EgisonValue]
args))

{-# INLINE twoArgs' #-}
twoArgs' :: (EgisonValue -> EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
twoArgs' :: (EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> String -> PrimitiveFunc
twoArgs' EgisonValue -> EgisonValue -> EvalM EgisonValue
f String
name [EgisonValue]
args =
  case [EgisonValue]
args of
    [EgisonValue
val, EgisonValue
val'] -> EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val EgisonValue
val'
    [EgisonValue
val]       -> EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> (PrimitiveFunc -> EgisonValue)
-> PrimitiveFunc
-> EvalM EgisonValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimitiveFunc -> EgisonValue
PrimitiveFunc (PrimitiveFunc -> EvalM EgisonValue)
-> PrimitiveFunc -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg' (EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val) String
name
    [EgisonValue]
_           -> (CallStack -> EgisonError) -> EvalM EgisonValue
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
2 ([EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [EgisonValue]
args))

{-# INLINE threeArgs' #-}
threeArgs' :: (EgisonValue -> EgisonValue -> EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
threeArgs' :: (EgisonValue -> EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> String -> PrimitiveFunc
threeArgs' EgisonValue -> EgisonValue -> EgisonValue -> EvalM EgisonValue
f String
name [EgisonValue]
args =
  case [EgisonValue]
args of
    [EgisonValue
val, EgisonValue
val', EgisonValue
val''] -> EgisonValue -> EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val EgisonValue
val' EgisonValue
val''
    [EgisonValue
val, EgisonValue
val']        -> EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> (PrimitiveFunc -> EgisonValue)
-> PrimitiveFunc
-> EvalM EgisonValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimitiveFunc -> EgisonValue
PrimitiveFunc (PrimitiveFunc -> EvalM EgisonValue)
-> PrimitiveFunc -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg' (EgisonValue -> EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val EgisonValue
val') String
name
    [EgisonValue
val]              -> EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> (PrimitiveFunc -> EgisonValue)
-> PrimitiveFunc
-> EvalM EgisonValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimitiveFunc -> EgisonValue
PrimitiveFunc (PrimitiveFunc -> EvalM EgisonValue)
-> PrimitiveFunc -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ (EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> String -> PrimitiveFunc
twoArgs' (EgisonValue -> EgisonValue -> EgisonValue -> EvalM EgisonValue
f EgisonValue
val) String
name
    [EgisonValue]
_                  -> (CallStack -> EgisonError) -> EvalM EgisonValue
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
3 ([EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [EgisonValue]
args))

lazyOneArg :: (WHNFData -> EvalM WHNFData) -> String -> LazyPrimitiveFunc
lazyOneArg :: (WHNFData -> EvalM WHNFData) -> String -> LazyPrimitiveFunc
lazyOneArg WHNFData -> EvalM WHNFData
f String
name [WHNFData]
args =
  case [WHNFData]
args of
    [WHNFData
arg] -> WHNFData -> EvalM WHNFData
f WHNFData
arg
    [WHNFData]
_     -> (CallStack -> EgisonError) -> EvalM WHNFData
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> Int -> Int -> CallStack -> EgisonError
ArgumentsNumPrimitive String
name Int
1 ([WHNFData] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WHNFData]
args))

unaryOp :: (EgisonData a, EgisonData b) => (a -> b) -> String -> PrimitiveFunc
unaryOp :: (a -> b) -> String -> PrimitiveFunc
unaryOp a -> b
op = (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
oneArg ((EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc)
-> (EgisonValue -> EvalM EgisonValue) -> String -> PrimitiveFunc
forall a b. (a -> b) -> a -> b
$ \EgisonValue
val -> do
  a
v <- EgisonValue -> EvalM a
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
val
  EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ b -> EgisonValue
forall a. EgisonData a => a -> EgisonValue
toEgison (a -> b
op a
v)

binaryOp :: (EgisonData a, EgisonData b) => (a -> a -> b) -> String -> PrimitiveFunc
binaryOp :: (a -> a -> b) -> String -> PrimitiveFunc
binaryOp a -> a -> b
op = (EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> String -> PrimitiveFunc
twoArgs ((EgisonValue -> EgisonValue -> EvalM EgisonValue)
 -> String -> PrimitiveFunc)
-> (EgisonValue -> EgisonValue -> EvalM EgisonValue)
-> String
-> PrimitiveFunc
forall a b. (a -> b) -> a -> b
$ \EgisonValue
val EgisonValue
val' -> do
  a
i <- EgisonValue -> EvalM a
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
val
  a
i' <- EgisonValue -> EvalM a
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
val'
  EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ b -> EgisonValue
forall a. EgisonData a => a -> EgisonValue
toEgison (a -> a -> b
op a
i a
i')