// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H namespace Eigen { /** \class TensorExpr * \ingroup CXX11_Tensor_Module * * \brief Tensor expression classes. * * The TensorCwiseNullaryOp class applies a nullary operators to an expression. * This is typically used to generate constants. * * The TensorCwiseUnaryOp class represents an expression where a unary operator * (e.g. cwiseSqrt) is applied to an expression. * * The TensorCwiseBinaryOp class represents an expression where a binary * operator (e.g. addition) is applied to a lhs and a rhs expression. * */ namespace internal { template struct traits > : traits { typedef traits XprTraits; typedef typename XprType::Scalar Scalar; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference::type _XprTypeNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; enum { Flags = 0 }; }; } // end namespace internal template class TensorCwiseNullaryOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef TensorCwiseNullaryOp Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) : m_xpr(xpr), m_functor(func) {} EIGEN_DEVICE_FUNC const typename internal::remove_all::type& nestedExpression() const { return m_xpr; } EIGEN_DEVICE_FUNC const NullaryOp& functor() const { return m_functor; } protected: typename XprType::Nested m_xpr; const NullaryOp m_functor; }; namespace internal { template struct traits > : traits { // TODO(phli): Add InputScalar, InputPacket. Check references to // current Scalar/Packet to see if the intent is Input or Output. typedef typename result_of::type Scalar; typedef traits XprTraits; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference::type _XprTypeNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; }; template struct eval, Eigen::Dense> { typedef const TensorCwiseUnaryOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorCwiseUnaryOp type; }; } // end namespace internal template class TensorCwiseUnaryOp : public TensorBase, ReadOnlyAccessors> { public: // TODO(phli): Add InputScalar, InputPacket. Check references to // current Scalar/Packet to see if the intent is Input or Output. typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef Scalar CoeffReturnType; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp()) : m_xpr(xpr), m_functor(func) {} EIGEN_DEVICE_FUNC const UnaryOp& functor() const { return m_functor; } /** \returns the nested expression */ EIGEN_DEVICE_FUNC const typename internal::remove_all::type& nestedExpression() const { return m_xpr; } protected: typename XprType::Nested m_xpr; const UnaryOp m_functor; }; namespace internal { template struct traits > { // Type promotion to handle the case where the types of the lhs and the rhs // are different. // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to // current Scalar/Packet to see if the intent is Inputs or Output. typedef typename result_of< BinaryOp(typename LhsXprType::Scalar, typename RhsXprType::Scalar)>::type Scalar; typedef traits XprTraits; typedef typename promote_storage_type< typename traits::StorageKind, typename traits::StorageKind>::ret StorageKind; typedef typename promote_index_type< typename traits::Index, typename traits::Index>::type Index; typedef typename LhsXprType::Nested LhsNested; typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; enum { Flags = 0 }; }; template struct eval, Eigen::Dense> { typedef const TensorCwiseBinaryOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorCwiseBinaryOp type; }; } // end namespace internal template class TensorCwiseBinaryOp : public TensorBase, ReadOnlyAccessors> { public: // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to // current Scalar/Packet to see if the intent is Inputs or Output. typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef Scalar CoeffReturnType; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} EIGEN_DEVICE_FUNC const BinaryOp& functor() const { return m_functor; } /** \returns the nested expressions */ EIGEN_DEVICE_FUNC const typename internal::remove_all::type& lhsExpression() const { return m_lhs_xpr; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& rhsExpression() const { return m_rhs_xpr; } protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const BinaryOp m_functor; }; namespace internal { template struct traits > { // Type promotion to handle the case where the types of the args are different. typedef typename result_of< TernaryOp(typename Arg1XprType::Scalar, typename Arg2XprType::Scalar, typename Arg3XprType::Scalar)>::type Scalar; typedef traits XprTraits; typedef typename traits::StorageKind StorageKind; typedef typename traits::Index Index; typedef typename Arg1XprType::Nested Arg1Nested; typedef typename Arg2XprType::Nested Arg2Nested; typedef typename Arg3XprType::Nested Arg3Nested; typedef typename remove_reference::type _Arg1Nested; typedef typename remove_reference::type _Arg2Nested; typedef typename remove_reference::type _Arg3Nested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; enum { Flags = 0 }; }; template struct eval, Eigen::Dense> { typedef const TensorCwiseTernaryOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorCwiseTernaryOp type; }; } // end namespace internal template class TensorCwiseTernaryOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef Scalar CoeffReturnType; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} EIGEN_DEVICE_FUNC const TernaryOp& functor() const { return m_functor; } /** \returns the nested expressions */ EIGEN_DEVICE_FUNC const typename internal::remove_all::type& arg1Expression() const { return m_arg1_xpr; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& arg2Expression() const { return m_arg2_xpr; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& arg3Expression() const { return m_arg3_xpr; } protected: typename Arg1XprType::Nested m_arg1_xpr; typename Arg2XprType::Nested m_arg2_xpr; typename Arg3XprType::Nested m_arg3_xpr; const TernaryOp m_functor; }; namespace internal { template struct traits > : traits { typedef typename traits::Scalar Scalar; typedef traits XprTraits; typedef typename promote_storage_type::StorageKind, typename traits::StorageKind>::ret StorageKind; typedef typename promote_index_type::Index, typename traits::Index>::type Index; typedef typename IfXprType::Nested IfNested; typedef typename ThenXprType::Nested ThenNested; typedef typename ElseXprType::Nested ElseNested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; }; template struct eval, Eigen::Dense> { typedef const TensorSelectOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorSelectOp type; }; } // end namespace internal template class TensorSelectOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename internal::promote_storage_type::ret CoeffReturnType; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC TensorSelectOp(const IfXprType& a_condition, const ThenXprType& a_then, const ElseXprType& a_else) : m_condition(a_condition), m_then(a_then), m_else(a_else) { } EIGEN_DEVICE_FUNC const IfXprType& ifExpression() const { return m_condition; } EIGEN_DEVICE_FUNC const ThenXprType& thenExpression() const { return m_then; } EIGEN_DEVICE_FUNC const ElseXprType& elseExpression() const { return m_else; } protected: typename IfXprType::Nested m_condition; typename ThenXprType::Nested m_then; typename ElseXprType::Nested m_else; }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H