{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies     #-}

module Data.Array.Accelerate.Test ( normalize
                                  , randDist
                                  , the
                                  ) where

import qualified Data.Array.Accelerate                   as A
import           Data.Array.Accelerate.System.Random.MWC (Variate, randomArray,
                                                          uniformR)

-- | Doesn't actually check the list has one element
the :: A.Scalar e -> e
the = head . A.toList

-- | Doesn't check for negative values
normalize :: A.Floating e => A.Acc (A.Vector e) -> A.Acc (A.Vector e)
normalize xs =
    let tot = A.the $ A.sum xs
    in A.map (/tot) xs

-- | Make a distribution of a given size
randDist :: (A.Shape sh, Fractional e, Variate e, A.Floating e, sh ~ A.DIM1) => sh -> IO (A.Acc (A.Vector e))
randDist = fmap (normalize . A.use) . randomArray (uniformR (0.0, 1.0))