module HLearn.Models.Distributions.KernelDensityEstimator
    ( 
    
    KDEParams (..)
    , KDEBandwidth (..)
    
    
    , KDE (..)
    , KDE' (..)
    
    
    , genSamplePoints
    
    , module HLearn.Models.Distributions.KernelDensityEstimator.Kernels
    )
    where
          
import HLearn.Algebra
import HLearn.Models.Distributions.Common
import HLearn.Models.Distributions.Categorical
import HLearn.Models.Distributions.KernelDensityEstimator.Kernels
import qualified Control.ConstraintKinds as CK
import Control.DeepSeq
import qualified Data.Vector.Unboxed as VU
data KDEBandwidth prob = Constant prob | Variable (prob -> prob)
instance (Show prob) => Show (KDEBandwidth prob) where
    show (Constant x) = "Constant " ++ show x
instance (Eq prob) => Eq (KDEBandwidth prob) where
    Constant x1 == Constant x2 = x1==x2
    Variable f1 == Variable f2 = True
    _ == _ = False
instance (Ord prob) => Ord (KDEBandwidth prob) where
    Variable _ `compare` Variable _ = EQ
    Constant _ `compare` Variable _ = LT
    Variable _ `compare` Constant _ = GT
    Constant x `compare` Constant y = x `compare` y
instance (NFData prob) => NFData (KDEBandwidth prob) where
    rnf (Constant x) = rnf x
    rnf (Variable f) = seq f ()
    
calcBandwidth :: (KDEBandwidth prob) -> prob -> prob
calcBandwidth (Constant h) _ = h
calcBandwidth (Variable f) x = f x
data KDEParams prob = KDEParams
    { bandwidth :: KDEBandwidth prob
    , samplePoints :: VU.Vector prob 
    , kernel :: KernelBox prob
    }
    deriving (Show,Eq,Ord)
instance (NFData prob) => NFData (KDEParams prob) where
    rnf kdeparams = deepseq (bandwidth kdeparams)
                  $ deepseq (kernel kdeparams)
                  $ seq (samplePoints kdeparams)
                  $ ()
genSamplePoints :: 
    ( Fractional prob
    , VU.Unbox prob
    ) => Int 
      -> Int 
      -> Int 
      -> VU.Vector prob
genSamplePoints min max samples = VU.fromList $ map (\i -> (fromIntegral min) + (fromIntegral i)*step) [0..samples]
    where
        step = (fromIntegral $ maxmin)/(fromIntegral $ samples)
data KDE' prob = KDE'
    { params :: KDEParams prob
    , n :: prob
    , sampleVals :: VU.Vector prob
    }
    deriving (Show,Eq,Ord)
instance (NFData prob) => NFData (KDE' prob) where
    rnf kde = deepseq (params kde)
            $ deepseq (n kde)
            $ seq (sampleVals kde)
            $ ()
type KDE prob = RegSG2Group (KDE' prob)
instance (Eq prob, Num prob, VU.Unbox prob) => Semigroup (KDE' prob) where
    kde1 <> kde2 = if (params kde1) /= (params kde2)
        then error "KDE.(<>): different params"
        else kde1
            { n = (n kde1) + (n kde2)
            , sampleVals = VU.zipWith (+) (sampleVals kde1) (sampleVals kde2)
            }
instance (Eq prob, Num prob, VU.Unbox prob) => RegularSemigroup (KDE' prob) where
    inverse kde = kde
        { n = negate $ n kde
        , sampleVals = VU.map negate $ sampleVals kde
        }
instance (Num prob, VU.Unbox prob) => LeftOperator prob (KDE' prob) where
    p .* kde = kde
        { n = p * (n kde)
        , sampleVals = VU.map (*p) (sampleVals kde)
        }
instance (Num prob, VU.Unbox prob) => RightOperator prob (KDE' prob) where
    (*.) = flip (.*)
    
    
instance (Eq prob, Num prob, VU.Unbox prob) => Model (KDEParams prob) (KDE prob) where
    getparams (SGJust kde) = params kde
instance HomTrainer (KDEParams Double) Int (KDE Double) where
    train1dp' params dp = train1dp' params (fromIntegral dp :: Double)
instance (Eq prob, Fractional prob, VU.Unbox prob) => HomTrainer (KDEParams prob) prob (KDE prob) where
    train1dp' params dp = SGJust $ KDE'
        { params = params
        , n = 1
        , sampleVals = VU.map (\x -> k ((xdp)/(h x))) (samplePoints params)
        }
        where
            k u = (evalkernel (kernel params) u)/(h u)
            h   = calcBandwidth (bandwidth params)
    
instance (Ord prob, Fractional prob, VU.Unbox prob) => Distribution (KDE prob) prob prob where
    pdf (SGJust kde) dp 
        | dp <= (samplePoints $ params kde) VU.! 0 = 0 
        | dp >= (samplePoints $ params kde) VU.! l = 0 
        | otherwise = (y2y1)/(x2x1)*(dpx1)+y1
        where
            index = binsearch (samplePoints $ params kde) dp
            x1 = (samplePoints $ params kde) VU.! (index1)
            x2 = (samplePoints $ params kde) VU.! (index)
            y1 = ((sampleVals kde) VU.! (index1)) / (n kde)
            y2 = ((sampleVals kde) VU.! (index  )) / (n kde)
            l = (VU.length $ samplePoints $ params kde)1
binsearch :: (Ord a, VU.Unbox a) => VU.Vector a -> a -> Int
binsearch vec dp = go 0 (VU.length vec1)
    where 
        go low high
            | low==high = low
            | dp <= (vec VU.! mid) = go low mid
            | dp >  (vec VU.! mid) = go (mid+1) high
            where 
                mid = floor $ (fromIntegral $ low+high)/2
instance Morphism (Categorical Int Double) (KDEParams Double) (KDE Double) where
    cat $> kdeparams = train' kdeparams $ CK.fmap (fromIntegral :: Int -> Double) (cat $> FreeModParams)