{-# OPTIONS_GHC -Wno-missing-signatures #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -Wno-unused-local-binds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module MXNet.Core.Base.Internal.TH.NDArray where
import Data.Proxy
import MXNet.Core.Base.HMap
import MXNet.Core.Base.Internal
import MXNet.Core.Base.Internal.TH (registerNDArrayOps)
import MXNet.Core.NNVM.Internal (nnGetOpHandle)
import Prelude hiding (sin, sinh, cos, cosh, tan, tanh, min, max, round, floor,
abs, sum, sqrt, log, exp, flip, concat, reverse, repeat)
class NDArrayOpResult a where
toResult :: [NDArrayHandle] -> a
fromResult :: a -> [NDArrayHandle]
instance NDArrayOpResult () where
toResult _ = ()
{-# INLINE toResult #-}
fromResult _ = []
{-# INLINE fromResult #-}
instance NDArrayOpResult NDArrayHandle where
toResult [] = error "This operation didn't return any result."
toResult (x:_) = x
{-# INLINE toResult #-}
fromResult a = [a]
{-# INLINE fromResult #-}
instance NDArrayOpResult [NDArrayHandle] where
toResult = id
{-# INLINE toResult #-}
fromResult = id
{-# INLINE fromResult #-}
$(registerNDArrayOps False)
$(registerNDArrayOps True)