-----------------------------------------------------------
-- |
-- module:                      MXNet.Core.Base.Internal.TH.NDArray
-- copyright:                   (c) 2016-2017 Tao He
-- license:                     MIT
-- maintainer:                  sighingnow@gmail.com
--
-- Functions about NDArray that generated by template haskell.
--
{-# OPTIONS_GHC -Wno-missing-signatures #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -Wno-unused-local-binds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module MXNet.Core.Base.Internal.TH.NDArray where

import Data.Proxy

import MXNet.Core.Base.HMap
import MXNet.Core.Base.Internal
import MXNet.Core.Base.Internal.TH (registerNDArrayOps)
import MXNet.Core.NNVM.Internal (nnGetOpHandle)
import Prelude hiding (sin, sinh, cos, cosh, tan, tanh, min, max, round, floor,
                       abs, sum, sqrt, log, exp, flip, concat, reverse, repeat)

-- | Result representation for generic NDArray op.
class NDArrayOpResult a where
    toResult :: [NDArrayHandle] -> a
    fromResult :: a -> [NDArrayHandle]

instance NDArrayOpResult () where
    toResult _ = ()
    {-# INLINE toResult #-}
    fromResult _ = []
    {-# INLINE fromResult #-}

instance NDArrayOpResult NDArrayHandle where
    toResult [] = error "This operation didn't return any result."
    toResult (x:_) = x
    {-# INLINE toResult #-}
    fromResult a = [a]
    {-# INLINE fromResult #-}

instance NDArrayOpResult [NDArrayHandle] where
    toResult = id
    {-# INLINE toResult #-}
    fromResult = id
    {-# INLINE fromResult #-}

-- | Register immutable version of ndarray operators.
$(registerNDArrayOps False)

-- | Register mutable version of ndarray operators.
$(registerNDArrayOps True)