{-|
This module provides a safe bindings to libsvm functions and structures with implicit memory handling.
-}
module Data.SVM
  ( Vector
  , Problem
  , KernelType (..)
  , Algorithm (..)
  , ExtraParam (..)
  , Model
  , train
  , train'
  , crossValidate
  , crossValidate'
  , loadModel
  , saveModel
  , predict
  , withPrintFn
  , CSvmPrintFn
  ) where

import           Control.Exception
import           Control.Monad         (liftM, when)
import           Data.IntMap           (IntMap, toList)
import qualified Data.IntMap           as M
import           Data.SVM.Raw          (CSvmModel, CSvmNode (..), CSvmParameter,
                                        CSvmPrintFn, CSvmProblem (..),
                                        c_clone_model_support_vectors,
                                        c_svm_check_parameter,
                                        c_svm_cross_validation,
                                        c_svm_destroy_model, c_svm_load_model,
                                        c_svm_predict, c_svm_save_model,
                                        c_svm_set_print_string_function,
                                        c_svm_train, createSvmPrintFnPtr,
                                        defaultCParam)
import qualified Data.SVM.Raw          as R
import           Foreign.C.String
import           Foreign.ForeignPtr
import           Foreign.Marshal.Alloc (alloca, free, malloc)
import           Foreign.Marshal.Array
import           Foreign.Ptr           (Ptr, freeHaskellFunPtr, nullPtr)
import           Foreign.Storable      (peek, poke)

-- |Vector type provides a sparse implementation of vector. It uses IntMap as underlying implementation.
type Vector = IntMap Double

-- |SVM problem is a list of maps from training vectors to 1.0 or -1.0
type Problem = [(Double, Vector)]

-- |'Model' is a wrapper over foreign pointer to 'CSvmModel'
newtype Model = Model (ForeignPtr CSvmModel)

-- |Kernel function for SVM algorithm.
data KernelType = Linear -- ^Linear kernel function, i.e. dot product
                | RBF     { KernelType -> Double
gamma :: Double } -- ^Gaussian radial basis function with parameter 'gamma'
                | Sigmoid { gamma :: Double, KernelType -> Double
coef0 :: Double } -- ^Sigmoid kernel function
                | Poly    { gamma :: Double, coef0 :: Double, KernelType -> Int
degree :: Int} -- ^Inhomogeneous polynomial function

-- |SVM Algorithm with parameters
data Algorithm = CSvc  { Algorithm -> Double
c :: Double } -- ^c-SVC algorithm
               | NuSvc { Algorithm -> Double
nu :: Double } -- ^nu-SVC algorithm
               | NuSvr { nu :: Double, c :: Double } -- ^nu-SVR algorithm
               | EpsilonSvr { Algorithm -> Double
epsilon :: Double, c :: Double } -- ^eps-SVR algorithm
               | OneClassSvm { nu :: Double } -- ^One class SVM

-- |Extra parameters of SVM implementation
data ExtraParam = ExtraParam {ExtraParam -> Double
cacheSize   :: Double,
                              ExtraParam -> Int
shrinking   :: Int,
                              ExtraParam -> Int
probability :: Int}

-- |Default extra parameters of SVM implamentation
defaultExtra :: ExtraParam
defaultExtra :: ExtraParam
defaultExtra = ExtraParam :: Double -> Int -> Int -> ExtraParam
ExtraParam {cacheSize :: Double
cacheSize = Double
1000, shrinking :: Int
shrinking = Int
1, probability :: Int
probability = Int
0}

mergeKernel :: KernelType -> CSvmParameter -> CSvmParameter
mergeKernel :: KernelType -> CSvmParameter -> CSvmParameter
mergeKernel KernelType
Linear CSvmParameter
p        = CSvmParameter
p { kernel_type :: CKernelType
R.kernel_type = CKernelType
R.linear }
mergeKernel (RBF Double
g) CSvmParameter
p       = CSvmParameter
p { kernel_type :: CKernelType
R.kernel_type = CKernelType
R.rbf,
                                  gamma :: CDouble
R.gamma = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
g }
mergeKernel (Sigmoid Double
g Double
cf) CSvmParameter
p = CSvmParameter
p { kernel_type :: CKernelType
R.kernel_type = CKernelType
R.sigmoid,
                                  gamma :: CDouble
R.gamma = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
g,
                                  coef0 :: CDouble
R.coef0 = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
cf }
mergeKernel (Poly Double
g Double
cf Int
d) CSvmParameter
p  = CSvmParameter
p { kernel_type :: CKernelType
R.kernel_type = CKernelType
R.poly,
                                  gamma :: CDouble
R.gamma = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
g,
                                  coef0 :: CDouble
R.coef0 = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
cf,
                                  degree :: CInt
R.degree = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
d}

mergeAlgo :: Algorithm -> CSvmParameter -> CSvmParameter
mergeAlgo :: Algorithm -> CSvmParameter -> CSvmParameter
mergeAlgo (CSvc Double
cf) CSvmParameter
p         = CSvmParameter
p { svm_type :: CSvmType
R.svm_type = CSvmType
R.cSvc,
                                   c :: CDouble
R.c = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
cf }
mergeAlgo (NuSvc Double
n) CSvmParameter
p       = CSvmParameter
p { svm_type :: CSvmType
R.svm_type = CSvmType
R.nuSvc,
                                   nu :: CDouble
R.nu = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
n }
mergeAlgo (NuSvr Double
n Double
cf) CSvmParameter
p     = CSvmParameter
p { svm_type :: CSvmType
R.svm_type = CSvmType
R.nuSvr,
                                   nu :: CDouble
R.nu = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
n,
                                   c :: CDouble
R.c = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
cf }
mergeAlgo (EpsilonSvr Double
e Double
cf) CSvmParameter
p = CSvmParameter
p { svm_type :: CSvmType
R.svm_type = CSvmType
R.epsilonSvr,
                                   eps :: CDouble
R.eps = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
e,
                                   c :: CDouble
R.c = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
cf }
mergeAlgo (OneClassSvm Double
n) CSvmParameter
p = CSvmParameter
p { svm_type :: CSvmType
R.svm_type = CSvmType
R.oneClass,
                                   nu :: CDouble
R.nu = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
n }

mergeExtra :: ExtraParam -> CSvmParameter -> CSvmParameter
mergeExtra :: ExtraParam -> CSvmParameter -> CSvmParameter
mergeExtra (ExtraParam Double
cf Int
s Int
pr) CSvmParameter
p = CSvmParameter
p { cache_size :: CDouble
R.cache_size = Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
cf,
                                       shrinking :: CInt
R.shrinking = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s,
                                       probability :: CInt
R.probability = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
pr }

-------------------------------------------------------------------------------

convertToNodeArray :: Vector -> [CSvmNode]
convertToNodeArray :: Vector -> [CSvmNode]
convertToNodeArray = ((Int, Double) -> CSvmNode) -> [(Int, Double)] -> [CSvmNode]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Double) -> CSvmNode
forall a a. (Integral a, Real a) => (a, a) -> CSvmNode
convertNode ([(Int, Double)] -> [CSvmNode])
-> (Vector -> [(Int, Double)]) -> Vector -> [CSvmNode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector -> [(Int, Double)]
forall a. IntMap a -> [(Int, a)]
toList (Vector -> [(Int, Double)])
-> (Vector -> Vector) -> Vector -> [(Int, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Bool) -> Vector -> Vector
forall a. (a -> Bool) -> IntMap a -> IntMap a
M.filter (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0)
  where
    convertNode :: (a, a) -> CSvmNode
convertNode (a
key, a
val) = CInt -> CDouble -> CSvmNode
CSvmNode (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
key) (a -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
val)

endMarker :: CSvmNode
endMarker :: CSvmNode
endMarker = CInt -> CDouble -> CSvmNode
CSvmNode (-CInt
1) CDouble
0.0

newCSvmNodeArray :: Vector -> IO (Ptr CSvmNode)
newCSvmNodeArray :: Vector -> IO (Ptr CSvmNode)
newCSvmNodeArray Vector
v = CSvmNode -> [CSvmNode] -> IO (Ptr CSvmNode)
forall a. Storable a => a -> [a] -> IO (Ptr a)
newArray0 CSvmNode
endMarker (Vector -> [CSvmNode]
convertToNodeArray Vector
v)

withCSvmNodeArray :: Vector -> (Ptr CSvmNode -> IO a) -> IO a
withCSvmNodeArray :: Vector -> (Ptr CSvmNode -> IO a) -> IO a
withCSvmNodeArray Vector
v = CSvmNode -> [CSvmNode] -> (Ptr CSvmNode -> IO a) -> IO a
forall a b. Storable a => a -> [a] -> (Ptr a -> IO b) -> IO b
withArray0 CSvmNode
endMarker (Vector -> [CSvmNode]
convertToNodeArray Vector
v)

newCSvmProblem :: Problem -> IO (Ptr CSvmProblem)
newCSvmProblem :: Problem -> IO (Ptr CSvmProblem)
newCSvmProblem Problem
lvs = do [Ptr CSvmNode]
nodePtrList <- ((Double, Vector) -> IO (Ptr CSvmNode))
-> Problem -> IO [Ptr CSvmNode]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Vector -> IO (Ptr CSvmNode)
newCSvmNodeArray (Vector -> IO (Ptr CSvmNode))
-> ((Double, Vector) -> Vector)
-> (Double, Vector)
-> IO (Ptr CSvmNode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double, Vector) -> Vector
forall a b. (a, b) -> b
snd) Problem
lvs
                        Ptr (Ptr CSvmNode)
nodePtrPtr  <- [Ptr CSvmNode] -> IO (Ptr (Ptr CSvmNode))
forall a. Storable a => [a] -> IO (Ptr a)
newArray [Ptr CSvmNode]
nodePtrList
                        Ptr CDouble
labelPtr <- [CDouble] -> IO (Ptr CDouble)
forall a. Storable a => [a] -> IO (Ptr a)
newArray ([CDouble] -> IO (Ptr CDouble))
-> ([Double] -> [CDouble]) -> [Double] -> IO (Ptr CDouble)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> CDouble) -> [Double] -> [CDouble]
forall a b. (a -> b) -> [a] -> [b]
map Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac ([Double] -> IO (Ptr CDouble)) -> [Double] -> IO (Ptr CDouble)
forall a b. (a -> b) -> a -> b
$ ((Double, Vector) -> Double) -> Problem -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Vector) -> Double
forall a b. (a, b) -> a
fst Problem
lvs
                        let z :: CInt
z = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> (Problem -> Int) -> Problem -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Problem -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Problem -> CInt) -> Problem -> CInt
forall a b. (a -> b) -> a -> b
$ Problem
lvs
                        Ptr CSvmProblem
ptr <- IO (Ptr CSvmProblem)
forall a. Storable a => IO (Ptr a)
malloc
                        Ptr CSvmProblem -> CSvmProblem -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CSvmProblem
ptr (CSvmProblem -> IO ()) -> CSvmProblem -> IO ()
forall a b. (a -> b) -> a -> b
$ CInt -> Ptr CDouble -> Ptr (Ptr CSvmNode) -> CSvmProblem
CSvmProblem CInt
z Ptr CDouble
labelPtr Ptr (Ptr CSvmNode)
nodePtrPtr
                        Ptr CSvmProblem -> IO (Ptr CSvmProblem)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr CSvmProblem
ptr

freeCSVmProblem :: Ptr CSvmProblem -> IO ()
freeCSVmProblem :: Ptr CSvmProblem -> IO ()
freeCSVmProblem Ptr CSvmProblem
ptr = do CSvmProblem
prob <- Ptr CSvmProblem -> IO CSvmProblem
forall a. Storable a => Ptr a -> IO a
peek Ptr CSvmProblem
ptr
                         Ptr CDouble -> IO ()
forall a. Ptr a -> IO ()
free (Ptr CDouble -> IO ()) -> Ptr CDouble -> IO ()
forall a b. (a -> b) -> a -> b
$ CSvmProblem -> Ptr CDouble
y CSvmProblem
prob
                         [Ptr CSvmNode]
vecList <- Int -> Ptr (Ptr CSvmNode) -> IO [Ptr CSvmNode]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> CInt -> Int
forall a b. (a -> b) -> a -> b
$ CSvmProblem -> CInt
l CSvmProblem
prob) (CSvmProblem -> Ptr (Ptr CSvmNode)
x CSvmProblem
prob)
                         (Ptr CSvmNode -> IO ()) -> [Ptr CSvmNode] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Ptr CSvmNode -> IO ()
forall a. Ptr a -> IO ()
free [Ptr CSvmNode]
vecList
                         Ptr (Ptr CSvmNode) -> IO ()
forall a. Ptr a -> IO ()
free (Ptr (Ptr CSvmNode) -> IO ()) -> Ptr (Ptr CSvmNode) -> IO ()
forall a b. (a -> b) -> a -> b
$ CSvmProblem -> Ptr (Ptr CSvmNode)
x CSvmProblem
prob
                         Ptr CSvmProblem -> IO ()
forall a. Ptr a -> IO ()
free Ptr CSvmProblem
ptr

withProblem :: Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem :: Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem Problem
prob = IO (Ptr CSvmProblem)
-> (Ptr CSvmProblem -> IO ()) -> (Ptr CSvmProblem -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Problem -> IO (Ptr CSvmProblem)
newCSvmProblem Problem
prob) Ptr CSvmProblem -> IO ()
freeCSVmProblem

---

withParam :: ExtraParam
             -> Algorithm
             -> KernelType
             -> (Ptr CSvmParameter -> IO a)
             -> IO a
withParam :: ExtraParam
-> Algorithm -> KernelType -> (Ptr CSvmParameter -> IO a) -> IO a
withParam ExtraParam
extra Algorithm
algo KernelType
kern Ptr CSvmParameter -> IO a
f =
    let merge :: CSvmParameter -> CSvmParameter
merge = Algorithm -> CSvmParameter -> CSvmParameter
mergeAlgo Algorithm
algo (CSvmParameter -> CSvmParameter)
-> (CSvmParameter -> CSvmParameter)
-> CSvmParameter
-> CSvmParameter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelType -> CSvmParameter -> CSvmParameter
mergeKernel KernelType
kern (CSvmParameter -> CSvmParameter)
-> (CSvmParameter -> CSvmParameter)
-> CSvmParameter
-> CSvmParameter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtraParam -> CSvmParameter -> CSvmParameter
mergeExtra ExtraParam
extra
        param :: CSvmParameter
param = CSvmParameter -> CSvmParameter
merge CSvmParameter
defaultCParam
    in (Ptr CSvmParameter -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSvmParameter -> IO a) -> IO a)
-> (Ptr CSvmParameter -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmParameter
paramPtr -> Ptr CSvmParameter -> CSvmParameter -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CSvmParameter
paramPtr CSvmParameter
param IO () -> IO a -> IO a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr CSvmParameter -> IO a
f Ptr CSvmParameter
paramPtr

checkParam :: Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam :: Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr = do
    let errStr :: CString
errStr = Ptr CSvmProblem -> Ptr CSvmParameter -> CString
c_svm_check_parameter Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CString
errStr CString -> CString -> Bool
forall a. Eq a => a -> a -> Bool
/= CString
forall a. Ptr a
nullPtr) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ CString -> IO String
peekCString CString
errStr IO String -> (String -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> (String -> String) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"svm: "String -> String -> String
forall a. [a] -> [a] -> [a]
++)

--

-- |Like 'train' but with extra parameters
train' :: ExtraParam -> Algorithm -> KernelType -> Problem -> IO Model
train' :: ExtraParam -> Algorithm -> KernelType -> Problem -> IO Model
train' ExtraParam
extra Algorithm
algo KernelType
kern Problem
prob =
    Problem -> (Ptr CSvmProblem -> IO Model) -> IO Model
forall a. Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem Problem
prob ((Ptr CSvmProblem -> IO Model) -> IO Model)
-> (Ptr CSvmProblem -> IO Model) -> IO Model
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmProblem
probPtr ->
    ExtraParam
-> Algorithm
-> KernelType
-> (Ptr CSvmParameter -> IO Model)
-> IO Model
forall a.
ExtraParam
-> Algorithm -> KernelType -> (Ptr CSvmParameter -> IO a) -> IO a
withParam ExtraParam
extra Algorithm
algo KernelType
kern ((Ptr CSvmParameter -> IO Model) -> IO Model)
-> (Ptr CSvmParameter -> IO Model) -> IO Model
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmParameter
paramPtr -> do
        Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
        Ptr CSvmModel
modelPtr <- Ptr CSvmProblem -> Ptr CSvmParameter -> IO (Ptr CSvmModel)
c_svm_train Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
        CInt
_ <- Ptr CSvmModel -> IO CInt
c_clone_model_support_vectors Ptr CSvmModel
modelPtr
        ForeignPtr CSvmModel
modelForeignPtr <- FinalizerPtr CSvmModel
-> Ptr CSvmModel -> IO (ForeignPtr CSvmModel)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr CSvmModel
c_svm_destroy_model Ptr CSvmModel
modelPtr
        Model -> IO Model
forall (m :: * -> *) a. Monad m => a -> m a
return (Model -> IO Model) -> Model -> IO Model
forall a b. (a -> b) -> a -> b
$ ForeignPtr CSvmModel -> Model
Model ForeignPtr CSvmModel
modelForeignPtr


-- | The 'train' function allows training a 'Model' starting from a 'Problem'
-- by specifying an 'Algorithm' and a 'KernelType'
train :: Algorithm -> KernelType -> Problem -> IO Model
train :: Algorithm -> KernelType -> Problem -> IO Model
train = ExtraParam -> Algorithm -> KernelType -> Problem -> IO Model
train' ExtraParam
defaultExtra

-- |Like 'crossvalidate' but with extra parameters
crossValidate' :: ExtraParam
                  -> Algorithm
                  -> KernelType
                  -> Problem
                  -> Int
                  -> IO [Double]
crossValidate' :: ExtraParam
-> Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate' ExtraParam
extra Algorithm
algo KernelType
kern Problem
prob Int
nFold =
    Problem -> (Ptr CSvmProblem -> IO [Double]) -> IO [Double]
forall a. Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem Problem
prob ((Ptr CSvmProblem -> IO [Double]) -> IO [Double])
-> (Ptr CSvmProblem -> IO [Double]) -> IO [Double]
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmProblem
probPtr ->
    ExtraParam
-> Algorithm
-> KernelType
-> (Ptr CSvmParameter -> IO [Double])
-> IO [Double]
forall a.
ExtraParam
-> Algorithm -> KernelType -> (Ptr CSvmParameter -> IO a) -> IO a
withParam ExtraParam
extra Algorithm
algo KernelType
kern ((Ptr CSvmParameter -> IO [Double]) -> IO [Double])
-> (Ptr CSvmParameter -> IO [Double]) -> IO [Double]
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmParameter
paramPtr -> do
        Int
probLen <- (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> (CSvmProblem -> CInt) -> CSvmProblem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CSvmProblem -> CInt
R.l) (CSvmProblem -> Int) -> IO CSvmProblem -> IO Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` Ptr CSvmProblem -> IO CSvmProblem
forall a. Storable a => Ptr a -> IO a
peek Ptr CSvmProblem
probPtr
        Int -> (Ptr CDouble -> IO [Double]) -> IO [Double]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
probLen ((Ptr CDouble -> IO [Double]) -> IO [Double])
-> (Ptr CDouble -> IO [Double]) -> IO [Double]
forall a b. (a -> b) -> a -> b
$ \Ptr CDouble
targetPtr -> do -- (length prob is inefficient)
            Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
            let c_nFold :: CInt
c_nFold = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nFold
            Ptr CSvmProblem
-> Ptr CSvmParameter -> CInt -> Ptr CDouble -> IO ()
c_svm_cross_validation Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr CInt
c_nFold Ptr CDouble
targetPtr
            (CDouble -> Double) -> [CDouble] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map CDouble -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac ([CDouble] -> [Double]) -> IO [CDouble] -> IO [Double]
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` Int -> Ptr CDouble -> IO [CDouble]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
probLen Ptr CDouble
targetPtr

-- |Stratified cross validation
crossValidate :: Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate :: Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate = ExtraParam
-> Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate' ExtraParam
defaultExtra

-----------------------------------------------------------------------

-- |Save model to the file
saveModel :: Model -> FilePath -> IO ()
saveModel :: Model -> String -> IO ()
saveModel (Model ForeignPtr CSvmModel
modelForeignPtr) String
path =
    ForeignPtr CSvmModel -> (Ptr CSvmModel -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CSvmModel
modelForeignPtr ((Ptr CSvmModel -> IO ()) -> IO ())
-> (Ptr CSvmModel -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmModel
modelPtr -> do
        CString
pathString <- String -> IO CString
newCString String
path
        CInt
ret <- CString -> Ptr CSvmModel -> IO CInt
c_svm_save_model CString
pathString Ptr CSvmModel
modelPtr
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
ret CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"svm: error saving the model:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
ret

-- |Load model from the file
loadModel :: FilePath -> IO Model
loadModel :: String -> IO Model
loadModel String
path = do
    Ptr CSvmModel
modelPtr <- CString -> IO (Ptr CSvmModel)
c_svm_load_model (CString -> IO (Ptr CSvmModel)) -> IO CString -> IO (Ptr CSvmModel)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> IO CString
newCString String
path
    ForeignPtr CSvmModel -> Model
Model (ForeignPtr CSvmModel -> Model)
-> IO (ForeignPtr CSvmModel) -> IO Model
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` FinalizerPtr CSvmModel
-> Ptr CSvmModel -> IO (ForeignPtr CSvmModel)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr CSvmModel
c_svm_destroy_model Ptr CSvmModel
modelPtr

-- |Predict a value for 'Vector' by using 'Model'
predict :: Model -> Vector -> IO Double
predict :: Model -> Vector -> IO Double
predict (Model ForeignPtr CSvmModel
modelForeignPtr) Vector
vector = IO Double
action
    where action :: IO Double
          action :: IO Double
action = ForeignPtr CSvmModel -> (Ptr CSvmModel -> IO Double) -> IO Double
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CSvmModel
modelForeignPtr ((Ptr CSvmModel -> IO Double) -> IO Double)
-> (Ptr CSvmModel -> IO Double) -> IO Double
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmModel
modelPtr ->
                   Vector -> (Ptr CSvmNode -> IO Double) -> IO Double
forall a. Vector -> (Ptr CSvmNode -> IO a) -> IO a
withCSvmNodeArray Vector
vector ((Ptr CSvmNode -> IO Double) -> IO Double)
-> (Ptr CSvmNode -> IO Double) -> IO Double
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmNode
vectorPtr ->
                        CDouble -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (CDouble -> Double) -> IO CDouble -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CSvmModel -> Ptr CSvmNode -> IO CDouble
c_svm_predict Ptr CSvmModel
modelPtr Ptr CSvmNode
vectorPtr

-- |Wrapper to change the libsvm output reporting function.
--
-- libsvm by default writes some statistics to stdout. If you don't
-- want any output from libsvm, you can do e.g.:
--
-- >>> withPrintFn (\_ -> return ()) $ train (NuSvc 0.25) (RBF 1) feats
withPrintFn :: CSvmPrintFn -> IO a -> IO a
withPrintFn :: CSvmPrintFn -> IO a -> IO a
withPrintFn CSvmPrintFn
printfn IO a
body = IO (FunPtr CSvmPrintFn)
-> (FunPtr CSvmPrintFn -> IO ())
-> (FunPtr CSvmPrintFn -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
  (do
    FunPtr CSvmPrintFn
c_printfn <- CSvmPrintFn -> IO (FunPtr CSvmPrintFn)
createSvmPrintFnPtr CSvmPrintFn
printfn
    FunPtr CSvmPrintFn -> IO ()
c_svm_set_print_string_function FunPtr CSvmPrintFn
c_printfn
    FunPtr CSvmPrintFn -> IO (FunPtr CSvmPrintFn)
forall (m :: * -> *) a. Monad m => a -> m a
return FunPtr CSvmPrintFn
c_printfn
  )
  FunPtr CSvmPrintFn -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
  (IO a -> FunPtr CSvmPrintFn -> IO a
forall a b. a -> b -> a
const IO a
body)