module Goal.Core
    ( -- * Module Exports
      module Goal.Core.Plot
    , module Data.Function
    , module Data.Ord
    , module Data.Monoid
    , module Data.List
    , module Data.Maybe
    , module Data.Either
    , module Data.Default.Class
    , module Control.Applicative
    , module Control.Monad
    , module Control.Monad.ST
    , module Control.Arrow
    , module Control.Lens.Type
    , module Control.Lens.Getter
    , module Control.Lens.Setter
    , module Control.Lens.TH
    , module Control.Concurrent
    , module Numeric
    , module Debug.Trace
    -- * Lists
    , takeEvery
    , breakEvery
    -- * Low-Level
    , traceGiven
    -- * Numeric
    , roundSD
    , toPi
    -- ** Functions
    , logistic
    , logit
    -- ** Lists
    , mean
    , range
    , discretizeFunction
    ) where


--- Imports ---


-- Re-exports --

import Goal.Core.Plot hiding (empty,over)

import Data.Ord
import Data.Function
import Data.Monoid hiding (Dual)
import Data.List hiding (sum)
import Data.Maybe
import Data.Either

import Control.Applicative
import Control.Arrow hiding ((<+>))
import Control.Monad
import Control.Monad.ST
import Control.Lens.Type
import Control.Lens.Getter
import Control.Lens.Setter hiding (Identity)
import Control.Lens.TH
import Control.Concurrent

import Debug.Trace
import Data.Default.Class
import Numeric


--- General Functions ---


takeEvery :: Int -> [x] -> [x]
-- | Takes every nth element, starting with the head of the list.
takeEvery m = map snd . filter (\(x,_) -> mod x m == 0) . zip [0..]

breakEvery :: Int -> [x] -> [[x]]
-- | Break the list up into lists of length n.
breakEvery _ [] = []
breakEvery n xs = take n xs : breakEvery n (drop n xs)

traceGiven :: Show a => a -> a
-- | Runs traceShow on the given element.
traceGiven a = traceShow a a


--- Numeric ---


roundSD :: (Floating x, RealFrac x) => Int -> x -> x
-- | Roundest the number to the specified significant digit.
roundSD n x = (/10^n) . fromIntegral . round $ 10^n * x

toPi :: (Floating x, RealFrac x) => x -> x
-- | Modulo pi thingy.
toPi x =
    let xpi = x / pi
        n = floor xpi
        f = xpi - fromIntegral n
    in if even n then pi * f else -(pi * (1 - f))

logistic :: Floating x => x -> x
-- | A standard sigmoid function.
logistic x = 1 / (1 + exp(negate x))

logit :: Floating x => x -> x
-- | The inverse of the logistic.
logit x = log $ x / (1 - x)

-- Lists --

mean :: Fractional x => [x] -> x
-- | Average value of a list of numbers.
mean = uncurry (/) . foldr (\e (s,c) -> (e+s,c+1)) (0,0)

range :: Double -> Double -> Int -> [Double]
-- | Returns n  numbers from mn to mx.
range _ _ 0 = []
range mn mx 1 = [(mn + mx) / 2]
range mn mx n =
    [ x * mx + (1 - x) * mn | x <- (/ (fromIntegral n - 1)) . fromIntegral <$> [0 .. n-1] ]

discretizeFunction :: Double -> Double -> Int -> (Double -> Double) -> [(Double,Double)]
-- | Takes range information in the form of a minimum, maximum, and sample count,
-- a function to sample, and returns a list of pairs (x,f(x)) over the specified
-- range.
discretizeFunction mn mx n f =
    let rng = range mn mx n
    in zip rng $ f <$> rng

-- Graveyard --

{-
parMapWH :: (a -> b) -> [a] -> [b]
-- | ParMap using rseq. WH stands for Weak Head (normal form).
parMapWH = parMap rseq

parMapDS :: NFData b => (a -> b) -> [a] -> [b]
parMapDS = parMap rdeepseq

gridSearch
    :: (a -> Double) -- ^ The error function on the model
    -> (x -> a) -- ^ A constructor from the parameter being tested to the model
    -> [x] -- ^ The list of parameter values to test
    -> (x,[(x,Double)]) -- ^ The best parameter and model, and accompanying statistics

-- | A general implementation of a grid search. Returns a triple where the first
-- element is the best parameter, the second is the best model, and the third is a list
-- of the errors calculated at each parameter value.
gridSearch errorfun constructor xs =
    let as = constructor <$> xs
        errs = parMapDS errorfun as
        x = fst . minimumBy (comparing snd) $ zip xs errs
    in (x,zip xs errs)

iterativeOptimization
    :: (a -> a -> Double) -- ^ Difference measure
    -> Double -- ^ Differential error threshold
    -> Int -- ^ Maximum number of iterations (< 1 is interpreted as infinity)
    -> (a -> a) -- ^ Iterator
    -> a -- ^ Initial Value
    -> (a,[(a,Double)]) -- ^ The final value and associated descent
-- | Iterates a value, stopping when a stopping criterion is satisfied, or the maximum
-- number of iterations is reached. The algorithm returns the resulting value as well as
-- the descent preceding it.
--
-- The error accompanying each element of the descent corresponds to the error between
-- that element and the following element. The last error in the descent will therefore
-- correspond to the measured difference between the last element of the descent, and the
-- returned singleton in the pair.
--
-- If the last error in the returned descent is less then the given threshold, then the
-- threshold was reached. Otherwise, n will have been reached as the terminal condition.
iterativeOptimization difference thrsh n iterator a =
    let itrs = iterate iterator a
        itrs' = if n > 0 then take (n+1) itrs else itrs
        zps = zip itrs' $ tail itrs'
        zps' = case break ((< thrsh) . snd) . zip zps $ uncurry difference <$> zps of
                   (zps0,[]) -> zps0
                   (zps0,rst) -> zps0 ++ [head rst]
    in (snd . fst . last $ zps', first fst <$> zps')
    -}