{-# LANGUAGE MultiParamTypeClasses, ScopedTypeVariables, FlexibleContexts #-}
module Math.Root.Finder
    ( RootFinder(..)
    , getDefaultNSteps
    , runRootFinder
    , traceRoot
    , findRoot, findRootN
    , eps
    , realFloatDefaultNSteps
    ) where

import Control.Monad.Instances ()
import Data.Tagged

-- |General interface for numerical root finders.
class RootFinder r a b where
    -- |@initRootFinder f x0 x1@: Initialize a root finder for the given
    -- function with the initial bracketing interval (x0,x1).
    initRootFinder :: (a -> b) -> a -> a -> r a b
    
    -- |Step a root finder for the given function (which should generally 
    -- be the same one passed to @initRootFinder@), refining the finder's
    -- estimate of the location of a root.
    stepRootFinder :: (a -> b) -> r a b -> r a b
    
    -- |Extract the finder's current estimate of the position of a root.
    estimateRoot  :: r a b -> a
    
    -- |Extract the finder's current estimate of the upper bound of the 
    -- distance from @estimateRoot@ to an actual root in the function.
    -- 
    -- Generally, @estimateRoot r@ +- @estimateError r@ should bracket
    -- a root of the function.
    estimateError :: r a b -> a
    
    -- |Test whether a root finding algorithm has converged to a given 
    -- relative accuracy.
    converged :: (Num a, Ord a) => a -> r a b -> Bool
    converged xacc r = abs (estimateError r) <= abs xacc
    
    -- |Default number of steps after which root finding will be deemed 
    -- to have failed.  Purely a convenience used to control the behavior
    -- of built-in functions such as 'findRoot' and 'traceRoot'.  The
    -- default value is 250.
    defaultNSteps :: Tagged (r a b) Int
    defaultNSteps = Tagged 250

-- |Convenience function to access 'defaultNSteps' for a root finder, 
-- which requires a little bit of type-gymnastics.
-- 
-- This function does not evaluate its argument.
getDefaultNSteps :: RootFinder r a b => r a b -> Int
getDefaultNSteps rf = nSteps
    where
        Tagged nSteps = 
            (const :: Tagged a b -> a -> Tagged a b)
            defaultNSteps rf

-- |General-purpose driver for stepping a root finder.  Given a \"control\"
-- function, the function being searched, and an initial 'RootFinder' state,
-- @runRootFinder step f state@ repeatedly steps the root-finder and passes
-- each intermediate state, along with a count of steps taken, to @step@.
-- 
-- The @step@ funtion will be called with the following arguments:
--
-- [@ n :: 'Int' @] 
--  The number of steps taken thus far
-- 
-- [@ currentState :: r a b @]
--  The current state of the root finder
--
-- [@ continue :: c @]
--  The result of the \"rest\" of the iteration
--
-- For example, the following function simply iterates a root finder
-- and returns every intermediate state (similar to 'traceRoot'):
-- 
-- > iterateRoot :: RootFinder r a b => (a -> b) -> a -> a -> [r a b]
-- > iterateRoot f a b = runRootFinder (const (:)) f (initRootFinder f a b)
--
-- And the following function simply iterates the root finder to 
-- convergence or throws an error after a given number of steps:
--
-- > solve :: (RootFinder r a b, RealFloat a)
-- >       => Int -> (a -> b) -> a -> a -> r a b
-- > solve maxN f a b = runRootFinder step f (initRootFinder f a b)
-- >    where
-- >        step n x continue
-- >            | converged eps x   = x
-- >            | n > maxN          = error "solve: step limit exceeded"
-- >            | otherwise         = continue
-- 
runRootFinder :: (RootFinder r a b) =>
    (Int -> r a b -> c -> c) -> (a -> b) -> r a b -> c
runRootFinder cons f = go 0
    where
        go n x = n `seq` cons n x (go (n+1) (stepRootFinder f x))

-- |@traceRoot f x0 x1 mbEps@ initializes a root finder and repeatedly
-- steps it, returning each step of the process in a list.  No step limit
-- is imposed.
-- 
-- Termination criteria depends on @mbEps@; if it is of the form @Just eps@ 
-- then convergence to @eps@ is used (using the @converged@ method of the
-- root finder).  Otherwise, the trace is not terminated until subsequent
-- states are equal (according to '==').  This is a stricter condition than
-- convergence to 0; subsequent states may have converged to zero but as long
-- as any internal state changes the trace will continue.
traceRoot :: (Eq (r a b), RootFinder r a b, Num a, Ord a) =>
             (a -> b) -> a -> a -> Maybe a -> [r a b]
traceRoot f a b mbEps = runRootFinder cons f start
    where
        start = initRootFinder f a b
        
        cons _n x rest = x : if done x rest then [] else rest
        
        -- if tracing with no convergence test, apply a naive test
        -- to bail out if the root stops changing.  This is provided 
        -- because that's not always the same as convergence to 0,
        -- and the main purpose of this function is to watch what 
        -- actually happens inside the root finder.
        done = case mbEps of
            Nothing     -> \x (next:_)  -> x == next
            Just xacc   -> \x _rest     -> converged xacc x

-- |@findRoot f x0 x1 eps@ initializes a root finder and repeatedly
-- steps it.  When the algorithm converges to @eps@ or the 'defaultNSteps'
-- limit is exceeded, the current best guess is returned, with the @Right@ 
-- constructor indicating successful convergence or the @Left@ constructor 
-- indicating failure to converge.
findRoot :: (RootFinder r a b, Num a, Ord a) =>
            (a -> b) -> a -> a -> a -> Either (r a b) (r a b)
findRoot f a b xacc = result
    where
        result = findRootN nSteps f a b xacc
        nSteps = getDefaultNSteps (either id id result)

-- |Like 'findRoot' but with a specified limit on the number of steps (rather
-- than using 'defaultNSteps').
findRootN :: (RootFinder r a b, Num a, Ord a) =>
            Int -> (a -> b) -> a -> a -> a -> Either (r a b) (r a b)
findRootN nSteps f a b xacc = runRootFinder step f start
    where
        start = initRootFinder f a b
        
        step n x continue
            | converged xacc x  = Right x
            | n > nSteps        = Left  x
            | otherwise         = continue

-- |A useful constant: 'eps' is (for most 'RealFloat' types) the smallest
-- positive number such that @1 + eps /= 1@.
eps :: RealFloat a => a
eps = eps'
    where
        eps' = encodeFloat 1 (1 - floatDigits eps')

-- |For 'RealFloat' types, computes a suitable default step limit based
-- on the precision of the type and a margin of error.
realFloatDefaultNSteps :: RealFloat a => Float -> Tagged (r a b) Int
realFloatDefaultNSteps margin = nSteps
    where
        f :: (Int -> Tagged (r a b) Int) -> (a -> Int) -> a -> Tagged (r a b) Int
        f = (.)
        
        nSteps :: RealFloat a => Tagged (r a b) Int
        nSteps = f Tagged n 0
        
        n :: RealFloat a => a -> Int
        n x = round $ product
            [ margin
            , realToFrac (floatDigits x)
            , logBase 2 (realToFrac (floatRadix x))
            ]