{-# LANGUAGE ForeignFunctionInterface #-} ----------------------------------------------------------------------------- -- | -- Module : HFANN -- Copyright : (c) Olivier Boudry 2008 -- License : BSD-style (see the file LICENSE) -- -- Maintainer : olivier.boudry@gmail.com -- Stability : experimental -- Portability : portable -- -- The Fast Artificial Neural Network Library (FANN) is a free open source -- neural network library written in C with support for both fully connected -- and sparsely connected networks (). -- -- HFANN is a Haskell interface to this library. -- ----------------------------------------------------------------------------- module HFANN.Train ( setActivationFunctionHidden, setActivationFunctionOutput, setTrainingAlgorithm, trainOnFile, withTrainData, testData, ) where import HFANN.Data (ActivationFunction, FannPtr, TrainAlgorithm, TrainDataPtr) import Control.Exception (bracket) import Foreign.C.String (CString) import Foreign.C.Types (CUInt, CFloat) import Foreign.C.String (withCString) -- | Train the Neural Network on the given data file -- trainOnFile :: FannPtr -- ^ The ANN to be trained -> String -- ^ The path to the training data file -> Int -- ^ The max number of epochs to train -> Int -- ^ The number of epochs between reports -> Double -- ^ The desired error -> IO () trainOnFile fann fileName maxEpochs epochsBetweenReports desiredError = do withCString fileName $ \fn -> f_fann_train_on_file fann fn (fi maxEpochs) (fi epochsBetweenReports) (fd desiredError) where fi = fromIntegral fd = realToFrac -- | Set the hidden nodes group activation function -- setActivationFunctionHidden :: FannPtr -- ^ The ANN -> ActivationFunction -- ^ The Activation Function -> IO () setActivationFunctionHidden fann func = do f_fann_set_activation_function_hidden fann func -- | Set the output nodes group activation function setActivationFunctionOutput :: FannPtr -- ^ The ANN -> ActivationFunction -- ^ The Activation Function -> IO () setActivationFunctionOutput fann func = do f_fann_set_activation_function_output fann func -- | Set the training algorithm -- setTrainingAlgorithm :: FannPtr -- ^ The ANN -> TrainAlgorithm -- ^ The Training Algorithm -> IO () setTrainingAlgorithm fann alg = do f_fann_set_training_algorithm fann alg -- | Read training data from file and run the given function on that data. -- withTrainData :: String -- ^ The path to the training data file -> (TrainDataPtr -> IO a) -- ^ A function using the training data -> IO a -- ^ The return value withTrainData file f = do bracket (createTrainData file) destroyTrainData f -- | Create training data from file name -- createTrainData :: String -> IO TrainDataPtr createTrainData file = do withCString file $ \fname -> f_fann_read_train_from_file fname -- | Train the Neural Network on the given data file foreign import ccall unsafe "fann.h fann_train_on_file" f_fann_train_on_file :: FannPtr -> CString -> CUInt -> CUInt -> CFloat -> IO () -- | Set the hidden nodes group activation function foreign import ccall unsafe "fann.h fann_set_activation_function_hidden" f_fann_set_activation_function_hidden :: FannPtr -> ActivationFunction -> IO () -- | Set the output nodes group activation function foreign import ccall unsafe "fann.h fann_set_activation_function_output" f_fann_set_activation_function_output :: FannPtr -> ActivationFunction -> IO () -- | Set the training algorithm foreign import ccall unsafe "fann.h fann_set_training_algorithm" f_fann_set_training_algorithm :: FannPtr -> TrainAlgorithm -> IO () -- | Load training data from file -- foreign import ccall unsafe "fann.h fann_read_train_from_file" f_fann_read_train_from_file :: CString -> IO (TrainDataPtr) -- | Destroy training data -- foreign import ccall unsafe "fann.h fann_destroy_train" destroyTrainData :: TrainDataPtr -> IO () -- | Test ANN on training data -- -- This function will run the ANN on the training data and return the error -- value. It can be used to validate the check the quality of the ANN on some -- test data. -- foreign import ccall unsafe "fann.h fann_test_data" testData :: FannPtr -- ^ The ANN to be used -> TrainDataPtr -- ^ The training data -> IO CFloat -- ^ The error value