{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ExtendedDefaultRules #-}
{-# OPTIONS_GHC -Wall #-}

-- | A 'Range' a is a tuple representing an interval of a number space.  A Range can be thought of as consisting of a low and high value, though low<high isn't strictly enforced, allowing a negative space so to speak.
-- The library uses the 'NumHask' classes and thus most of the usual arithmetic operators can be used.

module NumHask.Range
  ( Range(..)
  , (...)
  , low
  , high
  , mid
  , width
  , element
  , singleton
  , singular
  , intersection
  , contains
  , range
  , project
  , LinearPos(..)
  , linearSpace
  , linearSpaceSensible
  , fromLinearSpace
 ) where

import NumHask.Prelude
import Control.Category (id)
import Control.Lens hiding (Magma, singular, element, contains, (...))
import qualified Control.Foldl as L
import Test.QuickCheck

-- | a newtype wrapped (a, a) tuple
newtype Range a = Range { range_ :: (a, a) }
  deriving (Eq, Ord, Show, Functor)

-- | alternative constructor
(...) :: Ord a => a -> a -> Range a
a ... b
  | a <= b = Range (a, b)
  | otherwise = Range (b, a)

-- | lens for the fst of the tuple
low :: Lens' (Range a) a
low = lens (\(Range (l,_)) -> l) (\(Range (_,u)) l -> Range (l,u))

-- | lens for the snd of the tuple
high :: Lens' (Range a) a
high = lens (\(Range (_,u)) -> u) (\(Range (l,_)) u -> Range (l,u))

-- | mid-value lens
mid ::
    (BoundedField a) =>
    Lens' (Range a) a
mid =
    lens
    plushom
    (\r m -> Range (m - plushom r, m + plushom r))

-- | range width lens
width ::
    (BoundedField a) =>
    Lens' (Range a) a
width =
    lens
    (\(Range (l,u)) -> (u-l))
    (\r w -> Range (plushom r - w/two, plushom r + w/two))

instance (Arbitrary a) => Arbitrary (Range a) where
    arbitrary = do
        a <- arbitrary
        b <- arbitrary
        pure (Range (a,b))

-- | choosing the convex hull as plus seems like a natural choice, given the cute zero definition.
instance (Ord a) => AdditiveMagma (Range a) where
    plus (Range (l0,u0)) (Range (l1,u1)) = Range (min l0 l1, max u0 u1)

instance (Ord a, BoundedField a) => AdditiveUnital (Range a) where
    zero = Range (infinity,neginfinity)

instance (Ord a) => AdditiveAssociative (Range a)
instance (Ord a) => AdditiveCommutative (Range a)
instance (Ord a, BoundedField a) => Additive (Range a)

instance (Ord a) => Semigroup (Range a) where
    (<>) = plus

instance (AdditiveUnital (Range a), Semigroup (Range a)) => Monoid (Range a) where
    mempty = zero
    mappend = (<>)

instance (Ord a) => AdditiveInvertible (Range a)
    where
        negate (Range (l,u)) = Range (u,l)

instance (BoundedField a, Ord a) => AdditiveGroup (Range a)

-- | natural interpretation of a `Range a` as an `a` is the mid-point
instance (BoundedField a) =>
    AdditiveHomomorphic (Range a) a where
    plushom (Range (l,u)) = (l+u)/two

-- | natural interpretation of an `a` as a `Range a` is a singular Range
instance (Ord a) =>
    AdditiveHomomorphic a (Range a) where
    plushom a = singleton a

-- | times may well be some sort of affine projection lurking under the hood
instance (BoundedField a) => MultiplicativeMagma (Range a) where
    times a b = Range (m - r/two, m + r/two)
        where
          m = view mid b + (view mid a * view width b)
          r = view width a * view width b

-- | The unital object derives from:
--
-- view range one = one
-- view mid zero = zero
-- ie (-0.5,0.5)
instance (BoundedField a) => MultiplicativeUnital (Range a) where
    one = Range (negate half, half)

instance (BoundedField a) => MultiplicativeAssociative (Range a)

instance (Ord a, BoundedField a) => MultiplicativeInvertible (Range a) where
    recip a = case view width a == zero of
      True  -> theta
      False -> Range (m - r/two, m + r/two)
        where
          m = negate (view mid a) * recip (view width a)
          r = recip (view width a)

instance (Ord a, BoundedField a) => MultiplicativeRightCancellative (Range a)
instance (Ord a, BoundedField a) => MultiplicativeLeftCancellative (Range a)

instance (BoundedField a, Ord a) => Signed (Range a) where
    sign (Range (l,u)) = if u >= l then one else negate one
    abs (Range (l,u)) = if u >= l then Range (l,u) else Range (u,l)

instance (AdditiveGroup a) => Normed (Range a) a where
    size (Range (l, u)) = u-l

instance (Ord a, AdditiveGroup a) => Metric (Range a) a where
    distance (Range (l,u)) (Range (l',u'))
        | u < l' = l' - u
        | u' < l = l - u'
        | otherwise = zero

-- | theta is a bit like 1/infinity
theta :: (AdditiveUnital a) => Range a
theta = Range (zero, zero)

two :: (MultiplicativeUnital a, Additive a) => a
two = one + one

half :: (BoundedField a) => a
half = one / (one + one)

singleton :: a -> Range a
singleton a = Range (a,a)

-- | determine whether a point is within the range
element :: (Ord a) => a -> Range a -> Bool
element a (Range (l,u)) = a >= l && a <= u

-- | is the range a singleton point
singular :: (Eq a) => Range a -> Bool
singular (Range (l,u)) = l==u

intersection :: (Ord a) => Range a -> Range a -> Range a
intersection a b =
    Range (max (view low a) (view low b), min (view high a) (view high b))

contains :: (Ord a) => Range a -> Range a -> Bool
contains (Range (l,u)) (Range (l',u')) = l <= l' && u >= u'

-- | range of a foldable
range :: (Foldable f, Ord a, BoundedField a) => f a -> Range a
range = L.fold (L.Fold (\x a -> x + singleton a) zero id)

-- | project a data point from an old range to a new range
-- project o n (view low o) == view low n
-- project o n (view high o) == view high n
-- project a a == id
project :: (Field b) => Range b -> Range b -> b -> b
project (Range (l0,u0)) (Range (l1,u1)) p =
    ((p-l0)/(u0-l0)) * (u1-l1) + l1

-- * linear
-- | overns where data points go on the range
data LinearPos = OuterPos | InnerPos | LowerPos | UpperPos | MidPos deriving (Eq)

-- | turn a range into a list of n equally-spaced `a`s
linearSpace :: (Field a, FromInteger a) => LinearPos -> Range a -> Int -> [a]
linearSpace o (Range (l, u)) n = (+ if o==MidPos then step/two else zero) <$> posns
  where
    posns = (l +) . (step *) . fromIntegral <$> [i0..i1]
    step = (u - l)/fromIntegral n
    (i0,i1) = case o of
                OuterPos -> (0,n)
                InnerPos -> (1,n - 1)
                LowerPos -> (0,n - 1)
                UpperPos -> (1,n)
                MidPos -> (0,n - 1)

-- | turn a range into n `a`s pleasing to human sense and sensibility
-- the `a`s may well lie outside the original range as a result
linearSpaceSensible :: (Fractional a, Ord a, FromInteger a, QuotientField a, ExpField a) =>
    LinearPos -> Range a -> Int -> [a]
linearSpaceSensible tp (Range (l, u)) n =
    (+ if tp==MidPos then step/two else zero) <$> posns
  where
    posns = (first' +) . (step *) . fromIntegral <$> [i0..i1]
    span = u - l
    step' = 10 ^^ floor (logBase 10 (span/fromIntegral n))
    err = fromIntegral n / span * step'
    step
      | err <= 0.15 = 10 * step'
      | err <= 0.35 = 5 * step'
      | err <= 0.75 = 2 * step'
      | otherwise = step'
    first' = step * fromIntegral (ceiling (l/step))
    last' = step * fromIntegral (floor (u/step))
    n' = round ((last' - first')/step)
    (i0,i1) = case tp of
                OuterPos -> (0,n')
                InnerPos -> (1,n' - 1)
                LowerPos -> (0,n' - 1)
                UpperPos -> (1,n')
                MidPos -> (0,n' - 1)

-- | take a list of (ascending) `a`s and make some (ascending) ranges
-- based on OuterPos
-- fromLinearSpace . linearSpace OuterPos == id
-- linearSpace OuterPos . fromLinearSpace == id
fromLinearSpace :: [a] -> [Range a]
fromLinearSpace as = zipWith (curry Range) as (drop 1 as)