{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}

module TBit.Numerical.Integration (integrate) where

import Numeric.Integration.TanhSinh

type Res   = (Double, Double)
type Bound = (Double, Double)

ttrap :: (Double -> Double) -> Double -> Double -> Res
ttrap f xmin xmax = (ans, err)
   where
      res = absolute 1e-3 $ parTrap f xmin xmax
      ans = result res
      err = errorEstimate res

class Integrable r where
  type Bounds r :: *
  integrate' :: r -> Bounds r -> Res

instance Integrable Double where
  type Bounds Double = ()
  integrate' x _ = (x, 0)

instance Integrable r => Integrable (Double -> r) where
  type Bounds (Double -> r) = (Bound, Bounds r)
  integrate' f ((xmin, xmax), args) = ttrap g xmin xmax
    where
      g x = fst $ integrate' (f x) args

class CurryBounds bs where
  type Curried bs a :: *
  curryBounds :: (bs -> a) -> Curried bs a

instance CurryBounds () where
  type Curried () a = a
  curryBounds f = f ()

instance CurryBounds bs => CurryBounds (b, bs) where
  type Curried (b, bs) a = b -> Curried bs a
  curryBounds f = \x -> curryBounds (\xs -> f (x, xs))

-- |Integrate a function of n variables by giving the corresponding n integration
--  domains, /i.e./
--
--  > let f x y z = x^2 + log (y-z)
--  > integrate f (0,1) (3,4) (1,2)
--
-- This code was borrowed from the excellent answer at http://stackoverflow.com/questions/23703360/using-numeric-integration-tanhsinh-for-n-dimensional-integration.
-- 
-- The integration uses the Tanh-Sinh quadrature method and relies on Kmett's integration libary.
integrate :: (Integrable r, CurryBounds (Bounds r)) => r -> Curried (Bounds r) Res
integrate = curryBounds . integrate'