module HLearn.Models.Distributions.KernelDensityEstimator
(
KDEParams (..)
, KDEBandwidth (..)
, KDE (..)
, KDE' (..)
, genSamplePoints
, module HLearn.Models.Distributions.KernelDensityEstimator.Kernels
)
where
import HLearn.Algebra
import HLearn.Models.Distributions.Common
import HLearn.Models.Distributions.Categorical
import HLearn.Models.Distributions.KernelDensityEstimator.Kernels
import qualified Control.ConstraintKinds as CK
import Control.DeepSeq
import qualified Data.Vector.Unboxed as VU
data KDEBandwidth prob = Constant prob | Variable (prob -> prob)
instance (Show prob) => Show (KDEBandwidth prob) where
show (Constant x) = "Constant " ++ show x
instance (Eq prob) => Eq (KDEBandwidth prob) where
Constant x1 == Constant x2 = x1==x2
Variable f1 == Variable f2 = True
_ == _ = False
instance (Ord prob) => Ord (KDEBandwidth prob) where
Variable _ `compare` Variable _ = EQ
Constant _ `compare` Variable _ = LT
Variable _ `compare` Constant _ = GT
Constant x `compare` Constant y = x `compare` y
instance (NFData prob) => NFData (KDEBandwidth prob) where
rnf (Constant x) = rnf x
rnf (Variable f) = seq f ()
calcBandwidth :: (KDEBandwidth prob) -> prob -> prob
calcBandwidth (Constant h) _ = h
calcBandwidth (Variable f) x = f x
data KDEParams prob = KDEParams
{ bandwidth :: KDEBandwidth prob
, samplePoints :: VU.Vector prob
, kernel :: KernelBox prob
}
deriving (Show,Eq,Ord)
instance (NFData prob) => NFData (KDEParams prob) where
rnf kdeparams = deepseq (bandwidth kdeparams)
$ deepseq (kernel kdeparams)
$ seq (samplePoints kdeparams)
$ ()
genSamplePoints ::
( Fractional prob
, VU.Unbox prob
) => Int
-> Int
-> Int
-> VU.Vector prob
genSamplePoints min max samples = VU.fromList $ map (\i -> (fromIntegral min) + (fromIntegral i)*step) [0..samples]
where
step = (fromIntegral $ maxmin)/(fromIntegral $ samples)
data KDE' prob = KDE'
{ params :: KDEParams prob
, n :: prob
, sampleVals :: VU.Vector prob
}
deriving (Show,Eq,Ord)
instance (NFData prob) => NFData (KDE' prob) where
rnf kde = deepseq (params kde)
$ deepseq (n kde)
$ seq (sampleVals kde)
$ ()
type KDE prob = RegSG2Group (KDE' prob)
instance (Eq prob, Num prob, VU.Unbox prob) => Semigroup (KDE' prob) where
kde1 <> kde2 = if (params kde1) /= (params kde2)
then error "KDE.(<>): different params"
else kde1
{ n = (n kde1) + (n kde2)
, sampleVals = VU.zipWith (+) (sampleVals kde1) (sampleVals kde2)
}
instance (Eq prob, Num prob, VU.Unbox prob) => RegularSemigroup (KDE' prob) where
inverse kde = kde
{ n = negate $ n kde
, sampleVals = VU.map negate $ sampleVals kde
}
instance (Num prob, VU.Unbox prob) => LeftOperator prob (KDE' prob) where
p .* kde = kde
{ n = p * (n kde)
, sampleVals = VU.map (*p) (sampleVals kde)
}
instance (Num prob, VU.Unbox prob) => RightOperator prob (KDE' prob) where
(*.) = flip (.*)
instance (Eq prob, Num prob, VU.Unbox prob) => Model (KDEParams prob) (KDE prob) where
getparams (SGJust kde) = params kde
instance HomTrainer (KDEParams Double) Int (KDE Double) where
train1dp' params dp = train1dp' params (fromIntegral dp :: Double)
instance (Eq prob, Fractional prob, VU.Unbox prob) => HomTrainer (KDEParams prob) prob (KDE prob) where
train1dp' params dp = SGJust $ KDE'
{ params = params
, n = 1
, sampleVals = VU.map (\x -> k ((xdp)/(h x))) (samplePoints params)
}
where
k u = (evalkernel (kernel params) u)/(h u)
h = calcBandwidth (bandwidth params)
instance (Ord prob, Fractional prob, VU.Unbox prob) => Distribution (KDE prob) prob prob where
pdf (SGJust kde) dp
| dp <= (samplePoints $ params kde) VU.! 0 = 0
| dp >= (samplePoints $ params kde) VU.! l = 0
| otherwise = (y2y1)/(x2x1)*(dpx1)+y1
where
index = binsearch (samplePoints $ params kde) dp
x1 = (samplePoints $ params kde) VU.! (index1)
x2 = (samplePoints $ params kde) VU.! (index)
y1 = ((sampleVals kde) VU.! (index1)) / (n kde)
y2 = ((sampleVals kde) VU.! (index )) / (n kde)
l = (VU.length $ samplePoints $ params kde)1
binsearch :: (Ord a, VU.Unbox a) => VU.Vector a -> a -> Int
binsearch vec dp = go 0 (VU.length vec1)
where
go low high
| low==high = low
| dp <= (vec VU.! mid) = go low mid
| dp > (vec VU.! mid) = go (mid+1) high
where
mid = floor $ (fromIntegral $ low+high)/2
instance Morphism (Categorical Int Double) (KDEParams Double) (KDE Double) where
cat $> kdeparams = train' kdeparams $ CK.fmap (fromIntegral :: Int -> Double) (cat $> FreeModParams)