```{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
-- |
-- Module    : Numeric.Tools.Equation
-- Copyright : (c) 2011 Aleksey Khudyakov
--
-- Maintainer  : Aleksey Khudyakov <alexey.skladnoy@gmail.com>
-- Stability   : experimental
-- Portability : portable
--
-- Numerical solution of ordinary equations.
module Numeric.Tools.Equation (
-- * Data type
Root(..)
, fromRoot
-- * Equations solversv
, solveBisection
, solveRidders
, solveNewton
-- * References
-- \$references
) where

import Numeric.ApproxEq

import Control.Applicative
import Data.Typeable       (Typeable)

-- | The result of searching for a root of a mathematical function.
data Root a = NotBracketed
-- ^ The function does not have opposite signs when
-- evaluated at the lower and upper bounds of the search.
| SearchFailed
-- ^ The search failed to converge to within the given
-- error tolerance after the given number of iterations.
| Root a
-- ^ A root was successfully found.

instance Functor Root where
fmap _ NotBracketed = NotBracketed
fmap _ SearchFailed = SearchFailed
fmap f (Root a)     = Root (f a)

NotBracketed >>= _ = NotBracketed
SearchFailed >>= _ = SearchFailed
Root a       >>= m = m a
return = Root

mzero = SearchFailed
r@(Root _) `mplus` _ = r
_          `mplus` p = p

instance Applicative Root where
pure  = Root
(<*>) = ap

instance Alternative Root where
empty = SearchFailed
r@(Root _) <|> _ = r
_          <|> p = p

-- | Returns either the result of a search for a root, or the default
-- value if the search failed.
fromRoot :: a                   -- ^ Default value.
-> Root a              -- ^ Result of search for a root.
-> a
fromRoot _ (Root a) = a
fromRoot a _        = a

-- | Use bisection method to compute root of function.
--
-- The function must have opposite signs when evaluated at the lower
-- and upper bounds of the search (i.e. the root must be bracketed).
solveBisection :: Double             -- ^ Required absolute precision
-> (Double,Double)    -- ^ Range
-> (Double -> Double) -- ^ Equation
-> Root Double
solveBisection eps (lo,hi) f
| flo * fhi > 0 = NotBracketed
| flo == 0      = Root lo
| fhi == 0      = Root hi
| flo < 0       = worker 0 lo hi
| otherwise     = worker 0 hi lo
where
flo = f lo
fhi = f hi
-- Worker function. Preconditions:
--   f a < 0
--   f b > 0
worker i a b
| within 1 a b       = Root a
| abs (b - a) <= eps = Root c
| fc == 0            = Root c
| i >= (100 :: Int)  = SearchFailed
| fc < 0             = worker (i+1) c b
| otherwise          = worker (i+1) a c
where
c  = 0.5 * (a + b)
fc = f c

-- | Use the method of Ridders to compute a root of a function.
--
-- The function must have opposite signs when evaluated at the lower
-- and upper bounds of the search (i.e. the root must be bracketed).
solveRidders :: Double               -- ^ Absolute error tolerance.
-> (Double,Double)      -- ^ Lower and upper bounds for the search.
-> (Double -> Double)   -- ^ Function to find the roots of.
-> Root Double
solveRidders eps (lo,hi) f
| flo == 0    = Root lo
| fhi == 0    = Root hi
| flo*fhi > 0 = NotBracketed -- root is not bracketed
| otherwise   = go lo flo hi fhi 0
where
go !a !fa !b !fb !i
-- Root is bracketed within 1 ulp. No improvement could be made
| within 1 a b       = Root a
-- Root is found. Check that f(m) == 0 is nessesary to ensure
-- that root is never passed to 'go'
| fm == 0            = Root m
| fn == 0            = Root n
| d < eps            = Root n
-- Too many iterations performed. Fail
| i >= (100 :: Int)  = SearchFailed
-- Ridder's approximation coincide with one of old
-- bounds. Revert to bisection
| n == a || n == b   = case () of
_| fm*fa < 0 -> go a fa m fm (i+1)
| otherwise -> go m fm b fb (i+1)
-- Proceed as usual
| fn*fm < 0          = go n fn m fm (i+1)
| fn*fa < 0          = go a fa n fn (i+1)
| otherwise          = go n fn b fb (i+1)
where
d    = abs (b - a)
dm   = (b - a) * 0.5
!m   = a + dm
!fm  = f m
!dn  = signum (fb - fa) * dm * fm / sqrt(fm*fm - fa*fb)
!n   = m - signum dn * min (abs dn) (abs dm - 0.5 * eps)
!fn  = f n
!flo = f lo
!fhi = f hi

-- | Solve equation using Newton-Raphson method. Root must be
--   bracketed. If Newton's step jumps outside of bracket or do not
--   converge sufficiently fast function reverts to bisection.
solveNewton :: Double             -- ^ Absolute error tolerance
-> (Double,Double)    -- ^ Lower and upper bounds for root
-> (Double -> Double) -- ^ Function
-> (Double -> Double) -- ^ Function's derivative
-> Root Double
solveNewton eps (lo,hi) f f'
| flo == 0      = Root lo
| fhi == 0      = Root hi
| flo * fhi > 0 = NotBracketed
| otherwise     = let d   = 0.5*(hi-lo)
mid = 0.5*(lo+hi)
fun = worker 0 d mid (f mid) (f' mid)
in if flo < 0 then fun lo hi
else fun hi lo
where
flo = f lo
fhi = f hi
-- Worker function
worker i dxOld x fx fx' a b
-- Convergence achieved
| fx == 0           = Root x  -- x is a root
| within 1 a b      = Root x  -- Bracket is too tight. No improvements could be made
| abs (a - b) < eps = Root x  --
| abs dx      < eps = Root x' -- Precision achieved
| within 0 x x'     = Root x' -- Newton step doesn't improve solution
-- Too many iterations
| i > (100::Int)    = SearchFailed
-- Newton step jumped out of range or step decreases too slow.
-- Revert to bisection.
-- NOTE this guard is selected if dx evaluated to NaN or ±∞.
| not ((a - x')*(b - x') < 0) || abs (dxOld / dx) < 2 =
let dx' = 0.5 * (b - a)
mid = 0.5 * (a + b)
in if fx < 0 then step dx' mid
else step dx' mid
-- Normal newton step
| otherwise =
if fx < 0 then step dx x'
else step dx x'
where
dx = fx / fx'
x' = x - dx
-- Perform one step
step dy y
| fy < 0    = worker (i+1) dy  y fy fy'  y b
| otherwise = worker (i+1) dy  y fy fy'  a y
where fy  = f  y
fy' = f' y
{-# INLINABLE solveNewton #-}

-- \$references
--
-- * Ridders, C.F.J. (1979) A new algorithm for computing a single
--   root of a real continuous function.
--   /IEEE Transactions on Circuits and Systems/ 26:979&#8211;980.

```