module SVM (DataSet (..), SVMSolution (..), KernelFunction (..), SVM (..), LSSVM (..), KernelMatrix (..),
reciprocalKernelFunction, radialKernelFunction, linearKernelFunction, splineKernelFunction,
polyKernelFunction, mlpKernelFunction) where
import Data.Array.Unboxed
import Data.List (foldl')
data DataSet = DataSet {points::(Array Int [Double]), values::(UArray Int Double)}
data SVMSolution = SVMSolution {alpha::(UArray Int Double), sv::(Array Int [Double]), bias::Double}
newtype KernelMatrix = KernelMatrix (UArray Int Double)
newtype KernelFunction = KernelFunction ([Double] -> [Double] -> [Double] -> Double)
reciprocalKernelFunction :: [Double] -> [Double] -> [Double] -> Double
reciprocalKernelFunction (a:as) (x:xs) (y:ys) = (reciprocalKernelFunction as xs ys) / (x + y + 2*a)
reciprocalKernelFunction _ _ _ = 1
radialKernelFunction :: [Double] -> [Double] -> [Double] -> Double
radialKernelFunction (a:as) x y = exp $ (cpshelp 0 x y) / a
where cpshelp !accum (x:xs) (y:ys) = cpshelp (accum + (xy)**2) xs ys
cpshelp !accum _ _ = negate accum
linearKernelFunction :: [Double] -> [Double] -> [Double] -> Double
linearKernelFunction (a:as) (x:xs) (y:ys) = x * y + linearKernelFunction as xs ys
linearKernelFunction _ _ _ = 0
splineKernelFunction :: [Double] -> [Double] -> [Double] -> Double
splineKernelFunction a x y | dp <= 1.0 = (2/3) dp^2 + (0.5*dp^3)
| dp <= 2.0 = (1/6) * (2dp)^3
| otherwise = 0.0
where dp = linearKernelFunction a x y
polyKernelFunction :: [Double] -> [Double] -> [Double] -> Double
polyKernelFunction (a0:a1:as) x y = (a0 + linearKernelFunction as x y)**a1
mlpKernelFunction :: [Double] -> [Double] -> [Double] -> Double
mlpKernelFunction (a0:a1:as) x y = tanh (a0 * linearKernelFunction as x y a1)
class SVM a where
createKernelMatrix :: a -> (Array Int [Double]) -> KernelMatrix
dcost :: a -> Double
evalKernel :: a -> [Double] -> [Double] -> Double
simulate :: a -> SVMSolution -> (Array Int [Double]) -> [Double]
solve :: a -> DataSet -> Double -> Int -> SVMSolution
createKernelMatrix a x = KernelMatrix matrix
where matrix = listArray (1, dim) [eval i j | j <- indices x, i <- range(1,j)]
dim = ((n+1) * n) `quot` 2
eval i j | (i /= j) = evalKernel a (x!i) (x!j)
| otherwise = evalKernel a (x!i) (x!j) + dcost a
n = snd $ bounds x
simulate a (SVMSolution alpha sv b) points = [(eval p) + b | p <- elems points]
where eval x = mDot alpha $ listArray (bounds sv) [evalKernel a x v | v <- elems sv]
solve svm (DataSet points values) epsilon maxIter = SVMSolution alpha points b
where b = (mSum v) / (mSum nu)
alpha = mZipWith (\x y -> x b*y) v nu
nu = cga startx ones ones kernel epsilon maxIter
v = cga startx values values kernel epsilon maxIter
ones = listArray (1, n) $ replicate n 1
startx = listArray (1, n) $ replicate n 0
n = snd $ bounds values
kernel = createKernelMatrix svm points
data LSSVM = LSSVM {kf::KernelFunction, cost::Double, params::[Double]}
instance SVM LSSVM where
dcost = (0.5 /) . cost
evalKernel (LSSVM (KernelFunction kf) _ params) = kf params
type CGArray = UArray Int Double
cga :: CGArray -> CGArray -> CGArray -> KernelMatrix -> Double -> Int -> CGArray
cga x p r k epsilon max_iter = cgahelp x p r norm max_iter False
where norm = mDot r r
cgahelp x _ _ _ _ True = x
cgahelp x p r delta iter _ = cgahelp next_x next_p next_r next_delta (iter1) stop
where stop = (next_delta < epsilon * norm) || (iter == 0)
next_x = mAdd x $ scalarmult alpha p
next_p = mAdd next_r $ scalarmult (next_delta/delta) p
next_r = mAdd r $ scalarmult (negate alpha) vector
vector = matmult k p
next_delta = mDot next_r next_r
alpha = delta / (mDot p vector)
matmult :: KernelMatrix -> (UArray Int Double) -> (UArray Int Double)
matmult (KernelMatrix k) v = listArray (1, d) $ helper 1 1
where d = snd $ bounds v
helper i pos | (i < d) = cpsdot 0 1 pos : helper (i+1) (pos+i)
| otherwise = [cpsdot 0 1 pos]
where cpsdot acc j n | (j < i) = cpsdot (acc + k!n * v!j) (j+1) (n+1)
| (j < d) = cpsdot (acc + k!n * v!j) (j+1) (n+j)
| otherwise = acc + k!n * v!j
scalarmult :: Double -> (UArray Int Double) -> (UArray Int Double)
scalarmult = amap . (*)
mZipWith :: (Double -> Double -> Double) -> (UArray Int Double) -> (UArray Int Double) -> (UArray Int Double)
mZipWith f v1 v2 = array (bounds v1) [(i, f (v1!i) (v2!i)) | i <- indices v1]
mSum :: (UArray Int Double) -> Double
mSum = foldl' (+) 0 . elems
mDot :: (UArray Int Double) -> (UArray Int Double) -> Double
mDot = (mSum .) . mZipWith (*)
mAdd :: (UArray Int Double) -> (UArray Int Double) -> (UArray Int Double)
mAdd = mZipWith (+)