// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2015 Eugene Brevdo // Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H namespace Eigen { namespace internal { /** \class TensorIndexTuple * \ingroup CXX11_Tensor_Module * * \brief Tensor + Index Tuple class. * * */ template struct traits > : public traits { typedef traits XprTraits; typedef typename XprTraits::StorageKind StorageKind; typedef typename XprTraits::Index Index; typedef Tuple Scalar; typedef typename XprType::Nested Nested; typedef typename remove_reference::type _Nested; static const int NumDimensions = XprTraits::NumDimensions; static const int Layout = XprTraits::Layout; }; template struct eval, Eigen::Dense> { typedef const TensorIndexTupleOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorIndexTupleOp type; }; } // end namespace internal template class TensorIndexTupleOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; typedef Tuple CoeffReturnType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr) : m_xpr(expr) {} EIGEN_DEVICE_FUNC const typename internal::remove_all::type& expression() const { return m_xpr; } protected: typename XprType::Nested m_xpr; }; // Eval as rvalue template struct TensorEvaluator, Device> { typedef TensorIndexTupleOp XprType; typedef typename XprType::Index Index; typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename TensorEvaluator::Dimensions Dimensions; static const int NumDims = internal::array_size::value; enum { IsAligned = /*TensorEvaluator::IsAligned*/ false, PacketAccess = /*TensorEvaluator::PacketAccess*/ false, BlockAccess = false, Layout = TensorEvaluator::Layout, CoordAccess = false, // to be implemented RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_impl(op.expression(), device) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { m_impl.evalSubExprsIfNeeded(NULL); return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return CoeffReturnType(index, m_impl.coeff(index)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1); } EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } protected: TensorEvaluator m_impl; }; namespace internal { /** \class TensorTupleIndex * \ingroup CXX11_Tensor_Module * * \brief Converts to Tensor > and reduces to Tensor. * */ template struct traits > : public traits { typedef traits XprTraits; typedef typename XprTraits::StorageKind StorageKind; typedef typename XprTraits::Index Index; typedef Index Scalar; typedef typename XprType::Nested Nested; typedef typename remove_reference::type _Nested; static const int NumDimensions = XprTraits::NumDimensions - array_size::value; static const int Layout = XprTraits::Layout; }; template struct eval, Eigen::Dense> { typedef const TensorTupleReducerOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorTupleReducerOp type; }; } // end namespace internal template class TensorTupleReducerOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; typedef Index CoeffReturnType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr, const ReduceOp& reduce_op, const int return_dim, const Dims& reduce_dims) : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {} EIGEN_DEVICE_FUNC const typename internal::remove_all::type& expression() const { return m_xpr; } EIGEN_DEVICE_FUNC const ReduceOp& reduce_op() const { return m_reduce_op; } EIGEN_DEVICE_FUNC const Dims& reduce_dims() const { return m_reduce_dims; } EIGEN_DEVICE_FUNC int return_dim() const { return m_return_dim; } protected: typename XprType::Nested m_xpr; const ReduceOp m_reduce_op; const int m_return_dim; const Dims m_reduce_dims; }; // Eval as rvalue template struct TensorEvaluator, Device> { typedef TensorTupleReducerOp XprType; typedef typename XprType::Index Index; typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename TensorIndexTupleOp::CoeffReturnType TupleType; typedef typename TensorEvaluator >, Device>::Dimensions Dimensions; typedef typename TensorEvaluator , Device>::Dimensions InputDimensions; static const int NumDims = internal::array_size::value; typedef array StrideDims; enum { IsAligned = /*TensorEvaluator::IsAligned*/ false, PacketAccess = /*TensorEvaluator::PacketAccess*/ false, BlockAccess = false, Layout = TensorEvaluator >, Device>::Layout, CoordAccess = false, // to be implemented RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_orig_impl(op.expression(), device), m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), m_return_dim(op.return_dim()) { gen_strides(m_orig_impl.dimensions(), m_strides); if (Layout == static_cast(ColMajor)) { const Index total_size = internal::array_prod(m_orig_impl.dimensions()); m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size; } else { const Index total_size = internal::array_prod(m_orig_impl.dimensions()); m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size; } m_stride_div = m_strides[m_return_dim]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { m_impl.evalSubExprsIfNeeded(NULL); return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { const TupleType v = m_impl.coeff(index); return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div; } EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { const double compute_cost = 1.0 + (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost() + TensorOpCost::DivCost())); return m_orig_impl.costPerCoeff(vectorized) + m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost); } private: EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) { if (m_return_dim < 0) { return; // Won't be using the strides. } eigen_assert(m_return_dim < NumDims && "Asking to convert index to a dimension outside of the rank"); // Calculate m_stride_div and m_stride_mod, which are used to // calculate the value of an index w.r.t. the m_return_dim. if (Layout == static_cast(ColMajor)) { strides[0] = 1; for (int i = 1; i < NumDims; ++i) { strides[i] = strides[i-1] * dims[i-1]; } } else { strides[NumDims-1] = 1; for (int i = NumDims - 2; i >= 0; --i) { strides[i] = strides[i+1] * dims[i+1]; } } } protected: TensorEvaluator, Device> m_orig_impl; TensorEvaluator >, Device> m_impl; const int m_return_dim; StrideDims m_strides; Index m_stride_mod; Index m_stride_div; }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H