module Math.Meanshift
(
meanShift,meanShiftWindow
,fixedPoint, fixedPointE
,Window,Support
,kde
) where
import qualified Data.Vector.Unboxed as V
import Data.List hiding (sum)
import qualified Data.List as L
import Prelude hiding (sum)
norm² :: Vector -> Double
norm² = V.sum . V.map (**2)
normalKernel,normalKernel' :: Double -> Double
normalKernel x = exp(0.5 * x)
normalKernel' x = 2 * exp(0.5 * x)
kde :: Double -> [Vector] -> (Vector -> Double)
kde h vs x = (1 / (n*((2*π)**(d/2))*(h**d)))
* (L.sum $ map (\xi -> normalKernel (norm² ((x ^- xi) ./ h))) vs)
where
n = fi . length $ vs
d = fi . V.length . head $ vs
meanShift :: Double -> [Vector] -> (Vector -> Vector)
meanShift h vs x = sumW d vs dists (1/V.sum dists)
where
d = V.length (head vs)
dists = V.fromList $ map (\xi -> normalKernel' $ distPerH x xi) vs
distPerH :: Vector -> Vector -> Double
distPerH !a !b = V.sum (V.zipWith (\u v -> ((uv) / h)^(2::Int)) a b)
type Window = Support -> [Vector]
type Support = (Vector,Double)
meanShiftWindow :: Int -> Window -> Double -> (Vector -> Vector)
meanShiftWindow d window h x
= sumW d w dists (1/V.sum dists)
where
dists = V.fromList $ map (\xi -> normalKernel' $ distPerH x xi) w
w = window (x,h*2)
distPerH :: Vector -> Vector -> Double
distPerH !a !b = V.sum (V.zipWith (\u v -> ((uv) / h)^(2::Int)) a b)
fixedPoint :: Eq a => (a -> a) -> a -> [a]
fixedPoint f x = x:let x' = f x in if x'/=x then fixedPoint f x' else [x']
fixedPointE :: Double -> (Vector -> Vector) -> Vector -> [Vector]
fixedPointE e f x = x:let x' = f x
in if V.sum (V.map abs $ x' ^- x) > e then fixedPointE e f x' else [x']
type Vector = V.Vector Double
v :: [Double] -> Vector
v = V.fromList
(^+),(^-),(^/) :: Vector -> Vector -> Vector
(^-) = V.zipWith ()
(^+) = V.zipWith (+)
(^/) = V.zipWith (/)
a .+ b = V.map (+b) a
(.+),(./),(.*) :: Vector -> Double -> Vector
a ./ b = V.map (/b) a
a .* b = V.map (*b) a
infixl 7 ^/
infixl 6 ^+ , ^-
box :: x -> [x]
box x = [x]
box2 :: x -> x -> [x]
box2 x y = [x,y]
sv :: Double -> Vector
sv = V.singleton
fs :: Vector -> Double
fs = V.head
sumD :: Int -> [Vector] -> Vector
sumD d xs = V.generate d (\i -> L.sum (map (`V.unsafeIndex` i) xs) )
sumW :: Int -> [Vector] -> Vector -> Double -> Vector
sumW d es ws n = V.generate d (\i -> go i 0 es 0)
where
go i j (x:xs) acc = go i (j+1) xs $ acc + n*(x V.! i)*(ws V.! j)
go _ _ [] acc = acc
π :: Double
π = pi
fi :: Int -> Double
fi = fromIntegral