module Data.Cluster.ROC(
ROCConfig
, rocThreshold
, rocMaxClusters
, defaultROCConfig
, Prototype
, newPrototype
, prototypeValue
, prototypeWeight
, ClusterSpace(..)
, ROCContext
, emptyROCContext
, loadROCContext
, rocPrototypes
, clusterize
, clusterizeAddMerge
, clusterizeSingle
, clusterizeMerge
, clusterizeNewPrototype
, clusterizePostprocess
) where
import Data.Data
import Data.Monoid
import Data.Ord
import Data.Vector (Vector)
import GHC.Generics
import qualified Data.Foldable as F
import qualified Data.Vector as V
data ROCConfig = ROCConfig {
rocThreshold :: !Double
, rocMaxClusters :: !Int
} deriving (Generic, Data)
defaultROCConfig :: ROCConfig
defaultROCConfig = ROCConfig {
rocThreshold = 0
, rocMaxClusters = 10
}
class ClusterSpace a where
pointZero :: a
pointAdd :: a -> a -> a
pointScale :: Double -> a -> a
pointKernel :: a -> a -> Double
pointDistanceSquared :: a -> a -> Double
pointDistanceSquared x y = pointKernel x x 2 * pointKernel x y + pointKernel y y
data Prototype a = Prototype {
prototypeValue :: !a
, prototypeWeight :: !Double
} deriving (Eq, Show, Generic, Functor)
newPrototype :: a -> Prototype a
newPrototype a = Prototype a 0
instance ClusterSpace a => Monoid (Prototype a) where
mempty = Prototype pointZero 0
mappend p1 p2 = Prototype pos w
where
w = prototypeWeight p1 + prototypeWeight p2
pos = (1/w) `pointScale` ((prototypeWeight p1 `pointScale` prototypeValue p1) `pointAdd` (prototypeWeight p2 `pointScale` prototypeValue p2))
data ROCContext a = ROCContext {
cntxPrototypes :: !(Vector (Prototype a))
, cntxConfig :: !ROCConfig
} deriving (Generic, Functor)
emptyROCContext :: ROCConfig -> ROCContext a
emptyROCContext cfg = ROCContext {
cntxPrototypes = mempty
, cntxConfig = cfg
}
loadROCContext :: Foldable f => ROCConfig -> f (Prototype a) -> ROCContext a
loadROCContext cfg ps = (emptyROCContext cfg) { cntxPrototypes = V.fromList . F.toList $ ps }
rocPrototypes :: ROCContext a -> [Prototype a]
rocPrototypes = F.toList . cntxPrototypes
clusterize :: forall a f . (ClusterSpace a, Foldable f)
=> f a
-> ROCContext a
-> ROCContext a
clusterize xs cntx0 = clusterizePostprocess addAll
where
addAll = F.foldl' (flip clusterizeAddMerge) cntx0 xs
clusterizeAddMerge :: forall a . (ClusterSpace a)
=> a
-> ROCContext a
-> ROCContext a
clusterizeAddMerge x cntx = clusterizeNewPrototype x $ if n >= nmax then clusterizeMerge cntx' else cntx'
where
cntx' = clusterizeSingle x cntx
n = V.length . cntxPrototypes $ cntx'
nmax = rocMaxClusters . cntxConfig $ cntx'
clusterizeSingle :: forall a . (ClusterSpace a)
=> a
-> ROCContext a
-> ROCContext a
clusterizeSingle x ctx@ROCContext{..}
| V.null cntxPrototypes = ctx
| otherwise = ctx { cntxPrototypes = cntxPrototypes V.// [(winnerIndex, winner')] }
where
winnerIndex = V.minIndex . fmap (pointDistanceSquared x . prototypeValue) $ cntxPrototypes
winner = cntxPrototypes V.! winnerIndex
winner' = let
Prototype{..} = winner
сwinner = prototypeWeight + pointKernel x prototypeValue
ywinner = prototypeValue `pointAdd` ( (1 / сwinner) `pointScale` (x `pointAdd` pointScale (1) prototypeValue) )
in Prototype ywinner сwinner
clusterizeMerge :: forall a . (ClusterSpace a)
=> ROCContext a
-> ROCContext a
clusterizeMerge ctx@ROCContext{..}
| V.length cntxPrototypes <= 1 = ctx
| otherwise = ctx { cntxPrototypes = cntxPrototypes' }
where
(minxi, minyi, _) = V.minimumBy (comparing $ \(_, _, a) -> a) $ do
(xi, xv) <- V.indexed cntxPrototypes
(yi, yv) <- V.take xi $ V.indexed cntxPrototypes
pure (xi, yi, prototypeValue yv `pointDistanceSquared` prototypeValue xv)
x = cntxPrototypes V.! minxi
y = cntxPrototypes V.! minyi
x' = x <> y
removeAt i v = V.slice 0 i v <> V.slice (i+1) (V.length v i 1) v
cntxPrototypes' = removeAt minyi $ cntxPrototypes V.// [(minxi, x')]
clusterizeNewPrototype :: forall a . (ClusterSpace a)
=> a
-> ROCContext a
-> ROCContext a
clusterizeNewPrototype a ctx@ROCContext{..} = ctx { cntxPrototypes = cntxPrototypes `V.snoc` newProto }
where
newProto = Prototype a 0
clusterizePostprocess :: forall a . (ClusterSpace a)
=> ROCContext a
-> ROCContext a
clusterizePostprocess ctx@ROCContext{..} = ctx { cntxPrototypes = V.filter isValuable cntxPrototypes }
where
threshold = rocThreshold cntxConfig
isValuable p = prototypeWeight p >= threshold