{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-- |
-- Module      : Data.Massiv.Core.Operations
-- Copyright   : (c) Alexey Kuleshevich 2018-2019
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
module Data.Massiv.Core.Operations
  ( Numeric(..)
  , NumericFloat(..)
  ) where

import Data.Massiv.Core.Common


class Num e => Numeric r e where

  {-# MINIMAL foldArray, unsafeLiftArray, unsafeLiftArray2 #-}

  -- | Compute sum of all elements in the array
  --
  -- @since 0.5.6
  sumArray :: Index ix => Array r ix e -> e
  sumArray = (e -> e -> e) -> e -> Array r ix e -> e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> e -> Array r ix e -> e
foldArray e -> e -> e
forall a. Num a => a -> a -> a
(+) e
0
  {-# INLINE sumArray #-}

  -- | Compute product of all elements in the array
  --
  -- @since 0.5.6
  productArray :: Index ix => Array r ix e -> e
  productArray = (e -> e -> e) -> e -> Array r ix e -> e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> e -> Array r ix e -> e
foldArray e -> e -> e
forall a. Num a => a -> a -> a
(*) e
1
  {-# INLINE productArray #-}

  -- | Raise each element in the array to some non-negative power and sum the results
  --
  -- @since 0.5.7
  powerSumArray :: Index ix => Array r ix e -> Int -> e
  powerSumArray Array r ix e
arr Int
p = Array r ix e -> e
forall r e ix. (Numeric r e, Index ix) => Array r ix e -> e
sumArray (Array r ix e -> e) -> Array r ix e -> e
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
p) Array r ix e
arr
  {-# INLINE powerSumArray #-}

  -- | Compute dot product without any extraneous checks
  --
  -- @since 0.5.6
  unsafeDotProduct :: Index ix => Array r ix e -> Array r ix e -> e
  unsafeDotProduct Array r ix e
v1 Array r ix e
v2 = Array r ix e -> e
forall r e ix. (Numeric r e, Index ix) => Array r ix e -> e
sumArray (Array r ix e -> e) -> Array r ix e -> e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(*) Array r ix e
v1 Array r ix e
v2
  {-# INLINE unsafeDotProduct #-}


  plusScalar :: Index ix => Array r ix e -> e -> Array r ix e
  plusScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
+ e
e) Array r ix e
arr
  {-# INLINE plusScalar #-}

  minusScalar :: Index ix => Array r ix e -> e -> Array r ix e
  minusScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
subtract e
e) Array r ix e
arr
  {-# INLINE minusScalar #-}

  scalarMinus :: Index ix => e -> Array r ix e -> Array r ix e
  scalarMinus e
e Array r ix e
arr = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e
e e -> e -> e
forall a. Num a => a -> a -> a
-) Array r ix e
arr
  {-# INLINE scalarMinus #-}

  multiplyScalar :: Index ix => Array r ix e -> e -> Array r ix e
  multiplyScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Num a => a -> a -> a
* e
e) Array r ix e
arr
  {-# INLINE multiplyScalar #-}

  absPointwise :: Index ix => Array r ix e -> Array r ix e
  absPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Num a => a -> a
abs
  {-# INLINE absPointwise #-}

  additionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  additionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(+)
  {-# INLINE additionPointwise #-}

  subtractionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  subtractionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 (-)
  {-# INLINE subtractionPointwise #-}

  multiplicationPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  multiplicationPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Num a => a -> a -> a
(*)
  {-# INLINE multiplicationPointwise #-}

  -- TODO:
  --  - rename to powerScalar
  --  - add? powerPointwise :: Array r ix e -> Array r ix Int -> Array r ix e
  -- | Raise each element of the array to the power
  powerPointwise :: Index ix => Array r ix e -> Int -> Array r ix e
  powerPointwise Array r ix e
arr Int
pow = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> Int -> e
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
pow) Array r ix e
arr
  {-# INLINE powerPointwise #-}

  -- | Fold over an array
  --
  -- @since 0.5.6
  foldArray :: Index ix => (e -> e -> e) -> e -> Array r ix e -> e

  unsafeLiftArray :: Index ix => (e -> e) -> Array r ix e -> Array r ix e

  unsafeLiftArray2 :: Index ix => (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e



class (Numeric r e, Floating e) => NumericFloat r e where

  divideScalar :: Index ix => Array r ix e -> e -> Array r ix e
  divideScalar Array r ix e
arr e
e = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
e) Array r ix e
arr
  {-# INLINE divideScalar #-}

  scalarDivide :: Index ix => e -> Array r ix e -> Array r ix e
  scalarDivide e
e Array r ix e
arr = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (e
e e -> e -> e
forall a. Fractional a => a -> a -> a
/) Array r ix e
arr
  {-# INLINE scalarDivide #-}

  divisionPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix e
  divisionPointwise = (e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e -> e) -> Array r ix e -> Array r ix e -> Array r ix e
unsafeLiftArray2 e -> e -> e
forall a. Fractional a => a -> a -> a
(/)
  {-# INLINE divisionPointwise #-}

  recipPointwise :: Index ix => Array r ix e -> Array r ix e
  recipPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Fractional a => a -> a
recip
  {-# INLINE recipPointwise #-}

  sqrtPointwise :: Index ix => Array r ix e -> Array r ix e
  sqrtPointwise = (e -> e) -> Array r ix e -> Array r ix e
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray e -> e
forall a. Floating a => a -> a
sqrt
  {-# INLINE sqrtPointwise #-}

  -- floorPointwise :: (Index ix, Integral a) => Array r ix e -> Array r ix a
  -- floorPointwise = unsafeLiftArray floor
  -- {-# INLINE floorPointwise #-}

  -- ceilingPointwise :: (Index ix, Integral a) => Array r ix e -> Array r ix a
  -- ceilingPointwise = unsafeLiftArray ceiling
  -- {-# INLINE ceilingPointwise #-}


-- class Equality r e where

--   unsafeEq :: Index ix => Array r ix e -> Array r ix e -> Bool

--   unsafeEqPointwise :: Index ix => Array r ix e -> Array r ix e -> Array r ix Bool


-- class Relation r e where

--   unsafePointwiseLT :: Array r ix e -> Array r ix e -> Array r ix Bool
--   unsafePointwiseLTE :: Array r ix e -> Array r ix e -> Array r ix Bool

--   unsafePointwiseGT :: Array r ix e -> Array r ix e -> Array r ix Bool
--   unsafePointwiseGTE :: Array r ix e -> Array r ix e -> Array r ix Bool

--   unsafePointwiseMin :: Array r ix e -> Array r ix e -> Array r ix e
--   unsafePointwiseMax :: Array r ix e -> Array r ix e -> Array r ix e

--   unsafeMinimum :: Array r ix e -> e

--   unsafeMaximum :: Array r ix e -> e