module Data.Function.Memoize (
  
  Memoizable(..),
  
  
  memoize2, memoize3, memoize4, memoize5, memoize6, memoize7,
  
  memoFix, memoFix2, memoFix3, memoFix4, memoFix5, memoFix6, memoFix7,
  
  traceMemoize,
  
  memoizeFinite,
  
  deriveMemoizable, deriveMemoizableParams, deriveMemoize,
) where
import Control.Applicative
import Control.Monad
import Debug.Trace
import Data.Function.Memoize.Class
import Data.Function.Memoize.TH
memoize2 ∷ (Memoizable a, Memoizable b) ⇒
           (a → b → v) → a → b → v
memoize2 v = memoize (memoize . v)
memoize3 ∷ (Memoizable a, Memoizable b, Memoizable c) ⇒
           (a → b → c → v) → a → b → c → v
memoize3 v = memoize (memoize2 . v)
memoize4 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d) ⇒
           (a → b → c → d → v) →
           a → b → c → d → v
memoize4 v = memoize (memoize3 . v)
memoize5 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d,
            Memoizable e) ⇒
           (a → b → c → d → e → v) →
           a → b → c → d → e → v
memoize5 v = memoize (memoize4 . v)
memoize6 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d,
            Memoizable e, Memoizable f) ⇒
           (a → b → c → d → e → f → v) →
           a → b → c → d → e → f → v
memoize6 v = memoize (memoize5 . v)
memoize7 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d,
            Memoizable e, Memoizable f, Memoizable g) ⇒
           (a → b → c → d → e → f → g → v) →
           a → b → c → d → e → f → g → v
memoize7 v = memoize (memoize6 . v)
memoFix ∷ Memoizable a ⇒ ((a → v) → a → v) → a → v
memoFix ff = f where f = memoize (ff f)
memoFix2 ∷ (Memoizable a, Memoizable b) ⇒
           ((a → b → v) → a → b → v) → a → b → v
memoFix2 ff = f where f = memoize2 (ff f)
memoFix3 ∷ (Memoizable a, Memoizable b, Memoizable c) ⇒
           ((a → b → c → v) → a → b → c → v) → a → b → c → v
memoFix3 ff = f where f = memoize3 (ff f)
memoFix4 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d) ⇒
           ((a → b → c → d → v) → (a → b → c → d → v)) →
           a → b → c → d → v
memoFix4 ff = f where f = memoize4 (ff f)
memoFix5 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d,
            Memoizable e) ⇒
           ((a → b → c → d → e → v) → (a → b → c → d → e → v)) →
           a → b → c → d → e → v
memoFix5 ff = f where f = memoize5 (ff f)
memoFix6 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d,
            Memoizable e, Memoizable f) ⇒
           ((a → b → c → d → e → f → v) → (a → b → c → d → e → f → v)) →
           a → b → c → d → e → f → v
memoFix6 ff = f where f = memoize6 (ff f)
memoFix7 ∷ (Memoizable a, Memoizable b, Memoizable c, Memoizable d,
            Memoizable e, Memoizable f, Memoizable g) ⇒
           ((a → b → c → d → e → f → g → v) → (a → b → c → d → e → f → g → v)) →
           a → b → c → d → e → f → g → v
memoFix7 ff = f where f = memoize7 (ff f)
traceMemoize ∷ (Memoizable a, Show a) ⇒ (a → b) → a → b
traceMemoize f = memoize (\a → traceShow a (f a))
deriveMemoizable ''()
deriveMemoizable ''Bool
deriveMemoizable ''Ordering
deriveMemoizable ''Maybe
deriveMemoizable ''Either
deriveMemoizable ''[]
deriveMemoizable ''(,)
deriveMemoizable ''(,,)
deriveMemoizable ''(,,,)
deriveMemoizable ''(,,,,)
deriveMemoizable ''(,,,,,)
deriveMemoizable ''(,,,,,,)
deriveMemoizable ''(,,,,,,,)
deriveMemoizable ''(,,,,,,,,)
deriveMemoizable ''(,,,,,,,,,)
deriveMemoizable ''(,,,,,,,,,,)
deriveMemoizable ''(,,,,,,,,,,,)
data BinaryTreeCache v
 = BinaryTreeCache {
    btValue         ∷ v,
    btLeft, btRight ∷ BinaryTreeCache v
   }
   deriving Functor
instance Memoizable Integer where
  memoize f = integerLookup (f <$> theIntegers)
data IntegerCache v
  = IntegerCache {
      icZero                 ∷ v,
      icNegative, icPositive ∷ PosIntCache v
    }
  deriving Functor
type PosIntCache v = BinaryTreeCache v
theIntegers ∷ IntegerCache Integer
theIntegers
  = IntegerCache {
      icZero     = 0,
      icNegative = negate <$> thePosInts,
      icPositive = thePosInts
    }
thePosInts ∷ PosIntCache Integer
thePosInts =
  BinaryTreeCache {
   btValue = 1,
   btLeft  = fmap (* 2) thePosInts,
   btRight = fmap (succ . (* 2)) thePosInts
 }
integerLookup ∷ IntegerCache v → Integer → v
integerLookup cache n =
  case n `compare` 0 of
    EQ → icZero cache
    GT → posIntLookup (icPositive cache) n
    LT → posIntLookup (icNegative cache) (negate n)
posIntLookup ∷ PosIntCache v → Integer → v
posIntLookup cache 1 = btValue cache
posIntLookup cache n
  | even n    = posIntLookup (btLeft cache) (n `div` 2)
  | otherwise = posIntLookup (btRight cache) (n `div` 2)
newtype Finite a = ToFinite { fromFinite ∷ a }
  deriving (Eq, Bounded, Enum)
instance (Bounded a, Enum a) ⇒ Memoizable (Finite a) where
  memoize f = finiteLookup (f <$> theFinites)
theFinites ∷ (Bounded a, Enum a) ⇒ BinaryTreeCache a
theFinites = loop minBound maxBound where
  loop start stop =
    BinaryTreeCache {
      btValue = mean,
      btLeft  = loop start (pred mean),
      btRight = loop (succ mean) stop
    }
    where mean = meanFinite start stop
finiteLookup ∷ (Bounded a, Enum a) ⇒ BinaryTreeCache v → a → v
finiteLookup cache0 a0 =
  loop start0 stop0 cache0 where
    start0 = fromEnum (minBound `asTypeOf` a0)
    stop0  = fromEnum (maxBound `asTypeOf` a0)
    a      = fromEnum a0
    loop start stop cache =
      let mean = meanFinite start stop in
        case a `compare` mean of
          EQ → btValue cache
          LT → loop start (pred mean) (btLeft cache)
          GT → loop (succ mean) stop (btRight cache)
meanFinite     ∷ (Bounded a, Enum a) ⇒ a → a → a
meanFinite a b = toEnum (ia + (ib  ia) `div` 2)
  where
    ia = fromEnum a
    ib = fromEnum b
memoizeFinite   ∷ (Enum a, Bounded a) ⇒ (a → v) → a → v
memoizeFinite f = memoize (f . fromFinite) . ToFinite
instance Memoizable Int where memoize = memoizeFinite
instance Memoizable Char where memoize = memoizeFinite
instance (Eq a, Bounded a, Enum a, Memoizable b) ⇒ Memoizable (a → b) where
  memoize = functionLookup . theFunctions
functionLookup ∷ (Eq a, Bounded a, Enum a, Memoizable b) ⇒
                 FunctionCache b v → (a → b) → v
functionLookup cache f =
  fcNil (foldl fcCons cache (f <$> [minBound .. maxBound]))
theFunctions ∷ (Eq a, Bounded a, Enum a, Memoizable b) ⇒
               ((a → b) → v) → FunctionCache b v
theFunctions f =
  FunctionCache {
    fcNil  = f undefined,
    fcCons = memoize (\b → theFunctions (f . extend b))
  }
    where
      extend b g a
        | a == minBound = b
        | otherwise     = g (pred a)
data FunctionCache b v
  = FunctionCache {
      fcNil  ∷ v,
      fcCons ∷ b → FunctionCache b v
    }
_fib ∷ Integer → Integer
_fib = memoFix $ \fib n → case n of
  0 → 1
  1 → 1
  _ → fib (n  1) + fib (n  2)
_isNot       ∷ (Bool → Bool) → Bool
_isNot       = memoize $ \f →
  trace "_isNot" $
    f True == False && f False == True
_countTrue ∷ (Bool → Bool → Bool) → Integer
_countTrue = memoize $ \f →
  trace "_countTrue" $
    toInteger (length (f <$> [False,True] <*> [False,True] >>= guard))