module Data.Tensor.Statistics where

import           Data.Tensor.Tensor
import           Data.Tensor.Type

-- | Average of tensor
--
-- > λ> average (identity :: Tensor '[3,3] Float)
-- > 0.33333334
average :: forall s n. (HasShape s, Fractional n) => Tensor s n -> n
average :: Tensor s n -> n
average Tensor s n
t =
  let v :: n
v = Tensor s n -> n
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Tensor s n
t
      s :: n
s = Integer -> n
forall a. Num a => Integer -> a
fromInteger (Integer -> n) -> Integer -> n
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ Tensor s n -> [Int]
forall (s :: [Nat]) n. HasShape s => Tensor s n -> [Int]
shape Tensor s n
t
  in n
v n -> n -> n
forall a. Fractional a => a -> a -> a
/ n
s

-- | Variance of tensor
--
-- > λ> var ([1,2,3,4] :: Vector 4 Double )
-- > 1.25
var :: forall s n. (HasShape s, Fractional n) => Tensor s n -> n
var :: Tensor s n -> n
var Tensor s n
t =
  let m :: n
m = Tensor s n -> n
forall (s :: [Nat]) n.
(HasShape s, Fractional n) =>
Tensor s n -> n
average Tensor s n
t
      r :: Tensor s n
r = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\n
v -> let x :: n
x = n
v n -> n -> n
forall a. Num a => a -> a -> a
- n
m in n
x n -> n -> n
forall a. Num a => a -> a -> a
* n
x) Tensor s n
t
  in Tensor s n -> n
forall (s :: [Nat]) n.
(HasShape s, Fractional n) =>
Tensor s n -> n
average Tensor s n
r

-- | Standard Deviation of tensor
--
-- > λ> std ([1,2,3,4] :: Vector 4 Double )
-- > 1.118033988749895
std :: forall s n. (HasShape s, Floating n) => Tensor s n -> n
std :: Tensor s n -> n
std = n -> n
forall a. Floating a => a -> a
sqrt (n -> n) -> (Tensor s n -> n) -> Tensor s n -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor s n -> n
forall (s :: [Nat]) n.
(HasShape s, Fractional n) =>
Tensor s n -> n
var