// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H namespace Eigen { /** \class TensorCustomUnaryOp * \ingroup CXX11_Tensor_Module * * \brief Tensor custom class. * * */ namespace internal { template struct traits > { typedef typename XprType::Scalar Scalar; typedef typename XprType::StorageKind StorageKind; typedef typename XprType::Index Index; typedef typename XprType::Nested Nested; typedef typename remove_reference::type _Nested; static const int NumDimensions = traits::NumDimensions; static const int Layout = traits::Layout; }; template struct eval, Eigen::Dense> { typedef const TensorCustomUnaryOp& type; }; template struct nested > { typedef TensorCustomUnaryOp type; }; } // end namespace internal template class TensorCustomUnaryOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename internal::nested::type Nested; typedef typename internal::traits::StorageKind StorageKind; typedef typename internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func) : m_expr(expr), m_func(func) {} EIGEN_DEVICE_FUNC const CustomUnaryFunc& func() const { return m_func; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& expression() const { return m_expr; } protected: typename XprType::Nested m_expr; const CustomUnaryFunc m_func; }; // Eval as rvalue template struct TensorEvaluator, Device> { typedef TensorCustomUnaryOp ArgType; typedef typename internal::traits::Index Index; static const int NumDims = internal::traits::NumDimensions; typedef DSizes Dimensions; typedef typename internal::remove_const::type Scalar; typedef typename internal::remove_const::type CoeffReturnType; typedef typename PacketType::type PacketReturnType; static const int PacketSize = internal::unpacket_traits::size; enum { IsAligned = false, PacketAccess = (internal::packet_traits::size > 1), BlockAccess = false, Layout = TensorEvaluator::Layout, CoordAccess = false, // to be implemented RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device) : m_op(op), m_device(device), m_result(NULL) { m_dimensions = op.func().dimensions(op.expression()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { if (data) { evalTo(data); return false; } else { m_result = static_cast( m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); evalTo(m_result); return true; } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { if (m_result != NULL) { m_device.deallocate(m_result); m_result = NULL; } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_result[index]; } template EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { return internal::ploadt(m_result + index); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); } EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } protected: EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { TensorMap > result( data, m_dimensions); m_op.func().eval(m_op.expression(), result, m_device); } Dimensions m_dimensions; const ArgType m_op; const Device& m_device; CoeffReturnType* m_result; }; /** \class TensorCustomBinaryOp * \ingroup CXX11_Tensor_Module * * \brief Tensor custom class. * * */ namespace internal { template struct traits > { typedef typename internal::promote_storage_type::ret Scalar; typedef typename internal::promote_storage_type::ret CoeffReturnType; typedef typename promote_storage_type::StorageKind, typename traits::StorageKind>::ret StorageKind; typedef typename promote_index_type::Index, typename traits::Index>::type Index; typedef typename LhsXprType::Nested LhsNested; typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; static const int NumDimensions = traits::NumDimensions; static const int Layout = traits::Layout; }; template struct eval, Eigen::Dense> { typedef const TensorCustomBinaryOp& type; }; template struct nested > { typedef TensorCustomBinaryOp type; }; } // end namespace internal template class TensorCustomBinaryOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename internal::traits::CoeffReturnType CoeffReturnType; typedef typename internal::nested::type Nested; typedef typename internal::traits::StorageKind StorageKind; typedef typename internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func) : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {} EIGEN_DEVICE_FUNC const CustomBinaryFunc& func() const { return m_func; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& lhsExpression() const { return m_lhs_xpr; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& rhsExpression() const { return m_rhs_xpr; } protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const CustomBinaryFunc m_func; }; // Eval as rvalue template struct TensorEvaluator, Device> { typedef TensorCustomBinaryOp XprType; typedef typename internal::traits::Index Index; static const int NumDims = internal::traits::NumDimensions; typedef DSizes Dimensions; typedef typename XprType::Scalar Scalar; typedef typename internal::remove_const::type CoeffReturnType; typedef typename PacketType::type PacketReturnType; static const int PacketSize = internal::unpacket_traits::size; enum { IsAligned = false, PacketAccess = (internal::packet_traits::size > 1), BlockAccess = false, Layout = TensorEvaluator::Layout, CoordAccess = false, // to be implemented RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_op(op), m_device(device), m_result(NULL) { m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { if (data) { evalTo(data); return false; } else { m_result = static_cast(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); evalTo(m_result); return true; } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { if (m_result != NULL) { m_device.deallocate(m_result); m_result = NULL; } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_result[index]; } template EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { return internal::ploadt(m_result + index); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); } EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } protected: EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { TensorMap > result(data, m_dimensions); m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); } Dimensions m_dimensions; const XprType m_op; const Device& m_device; CoeffReturnType* m_result; }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H