{-# LANGUAGE ForeignFunctionInterface #-} ----------------------------------------------------------------------------- -- | -- Module : HFANN -- Copyright : (c) Olivier Boudry 2008 -- License : BSD-style (see the file LICENSE) -- -- Maintainer : olivier.boudry@gmail.com -- Stability : experimental -- Portability : portable -- -- The Fast Artificial Neural Network Library (FANN) is a free open source -- neural network library written in C with support for both fully connected -- and sparsely connected networks (). -- -- HFANN is a Haskell interface to this library. -- ----------------------------------------------------------------------------- module HFANN.Data ( -- * Data Types FannType, CFannType, CFannTypePtr, ActivationFunction, TrainAlgorithm, ErrorFunction, StopFunction, Fann (..), FannPtr, TrainData (..), TrainDataPtr, CallbackType, CCallbackType, -- * Activation Functions activationLinear, activationThreshold, activationThresholdSymmetric, activationSigmoid, activationSigmoidStepwise, activationSigmoidSymmetric, activationSigmoidSymmetricStepwise, activationGaussian, activationGaussianSymmetric, activationGaussianStepwise, activationElliot, activationElliotSymmetric, activationLinearPiece, activationLinearPieceSymmetric, -- * Training Algorithms trainIncremental, trainBatch, trainRPROP, trainQuickProp, -- * Error Functions errorFunctionLinear, errorFunctionTanH, -- * Stop Functions stopFunctionMSE, stopFunctionBit, -- * Callback fannCallback, ) where import Data.Word import Foreign (Ptr, FunPtr) import Foreign.C.Types #include -- | The Haskell input\/output type -- -- This is the data type used in Haskell to represent the input\/output data. -- type FannType = Double -- |The C input\/output type -- -- This is the data type used in the C library to represent the input\/output -- data. -- type CFannType = CDouble -- | A pointer to the C input\/output type -- type CFannTypePtr = Ptr CDouble -- | The ANN structure -- data Fann = Fann -- | A pointer to an ANN structure -- type FannPtr = Ptr Fann -- | Data type of the training data structure -- data TrainData = TrainData -- | A pointer to the training data structure type -- type TrainDataPtr = Ptr TrainData -- | The Haskell Callback function type -- type CallbackType = FannPtr -- The ANN -> TrainDataPtr -- The training data -> Int -- Max number of epochs -> Int -- Number epochs between reports -> Float -- Desired error -> Int -- Current epoch -> IO Bool -- Return True to stop training -- | The C callback function type -- type CCallbackType = FannPtr -- The ANN -> TrainDataPtr -- The training data -> CUInt -- Max number of epochs -> CUInt -- Number epochs between reports -> CFloat -- Desired error -> CUInt -- Current epoch -> IO Int -- Return -1 to stop training -- | The type of the @Training Algorithm@ enumeration -- type TrainAlgorithm = #{type enum fann_train_enum} #{enum TrainAlgorithm, ,trainIncremental = FANN_TRAIN_INCREMENTAL ,trainBatch = FANN_TRAIN_BATCH ,trainRPROP = FANN_TRAIN_RPROP ,trainQuickProp = FANN_TRAIN_QUICKPROP } -- | Error function used during training. -- -- errorFunctionLinear - Standard linear error function. -- errorFunctionTanH - Tanh error function, usually better but can require a -- lower learning rate. This error function agressively target outputs that -- differ much from the desired, while not targetting output that only differ -- a little that much. -- -- The tanh function is not recommended for cascade training and incremental -- training. -- type ErrorFunction = #{type enum fann_errorfunc_enum} #{enum ErrorFunction, ,errorFunctionLinear = FANN_ERRORFUNC_LINEAR ,errorFunctionTanH = FANN_ERRORFUNC_TANH } -- | Stop function used during training -- -- stopFunctionMSE - stop criteria is Mean Square Error value. -- stopFunctionBit - stop criteria is number of bits that fail -- -- See 'getBitFailLimit', 'setBitFailLimit'. -- -- The bits are counted in all of the training data, so this number can be -- higher than the number of training data. -- type StopFunction = #{type enum fann_stopfunc_enum} #{enum StopFunction, ,stopFunctionMSE = FANN_STOPFUNC_MSE ,stopFunctionBit = FANN_STOPFUNC_BIT } -- | The type of the @Activation Function@ enumeration -- type ActivationFunction = #{type enum fann_activationfunc_enum} #{enum ActivationFunction, , activationLinear = FANN_LINEAR , activationThreshold = FANN_THRESHOLD , activationThresholdSymmetric = FANN_THRESHOLD_SYMMETRIC , activationSigmoid = FANN_SIGMOID , activationSigmoidStepwise = FANN_SIGMOID_STEPWISE , activationSigmoidSymmetric = FANN_SIGMOID_SYMMETRIC , activationSigmoidSymmetricStepwise = FANN_SIGMOID_SYMMETRIC_STEPWISE , activationGaussian = FANN_GAUSSIAN , activationGaussianSymmetric = FANN_GAUSSIAN_SYMMETRIC , activationGaussianStepwise = FANN_GAUSSIAN_STEPWISE , activationElliot = FANN_ELLIOT , activationElliotSymmetric = FANN_ELLIOT_SYMMETRIC , activationLinearPiece = FANN_LINEAR_PIECE , activationLinearPieceSymmetric = FANN_LINEAR_PIECE_SYMMETRIC } -- | Create a callback function to be used during training for reporting and -- to stop the training -- fannCallback :: (CallbackType) -> IO (FunPtr CCallbackType) fannCallback f = do mkCallback $ \fann tdata me ebr de ep -> do ret <- f fann tdata (fi me) (fi ebr) (fd de) (fi ep) if ret then return (-1) else return 0 where fi = fromIntegral fd = realToFrac foreign import ccall "wrapper" mkCallback :: CCallbackType -> IO (FunPtr CCallbackType)