{-# LANGUAGE ForeignFunctionInterface #-}
-----------------------------------------------------------------------------
-- |
-- Module : HFANN
-- Copyright : (c) Olivier Boudry 2008
-- License : BSD-style (see the file LICENSE)
--
-- Maintainer : olivier.boudry@gmail.com
-- Stability : experimental
-- Portability : portable
--
-- The Fast Artificial Neural Network Library (FANN) is a free open source
-- neural network library written in C with support for both fully connected
-- and sparsely connected networks ().
--
-- HFANN is a Haskell interface to this library.
--
-----------------------------------------------------------------------------
module HFANN.Data (
-- * Data Types
FannType,
CFannType,
CFannTypePtr,
ActivationFunction,
TrainAlgorithm,
ErrorFunction,
StopFunction,
Fann (..),
FannPtr,
TrainData (..),
TrainDataPtr,
CallbackType,
CCallbackType,
-- * Activation Functions
activationLinear,
activationThreshold,
activationThresholdSymmetric,
activationSigmoid,
activationSigmoidStepwise,
activationSigmoidSymmetric,
activationSigmoidSymmetricStepwise,
activationGaussian,
activationGaussianSymmetric,
activationGaussianStepwise,
activationElliot,
activationElliotSymmetric,
activationLinearPiece,
activationLinearPieceSymmetric,
-- * Training Algorithms
trainIncremental,
trainBatch,
trainRPROP,
trainQuickProp,
-- * Error Functions
errorFunctionLinear,
errorFunctionTanH,
-- * Stop Functions
stopFunctionMSE,
stopFunctionBit,
-- * Callback
fannCallback,
) where
import Data.Word
import Foreign (Ptr, FunPtr)
import Foreign.C.Types
#include
-- | The Haskell input\/output type
--
-- This is the data type used in Haskell to represent the input\/output data.
--
type FannType = Double
-- |The C input\/output type
--
-- This is the data type used in the C library to represent the input\/output
-- data.
--
type CFannType = CDouble
-- | A pointer to the C input\/output type
--
type CFannTypePtr = Ptr CDouble
-- | The ANN structure
--
data Fann = Fann
-- | A pointer to an ANN structure
--
type FannPtr = Ptr Fann
-- | Data type of the training data structure
--
data TrainData = TrainData
-- | A pointer to the training data structure type
--
type TrainDataPtr = Ptr TrainData
-- | The Haskell Callback function type
--
type CallbackType = FannPtr -- The ANN
-> TrainDataPtr -- The training data
-> Int -- Max number of epochs
-> Int -- Number epochs between reports
-> Float -- Desired error
-> Int -- Current epoch
-> IO Bool -- Return True to stop training
-- | The C callback function type
--
type CCallbackType = FannPtr -- The ANN
-> TrainDataPtr -- The training data
-> CUInt -- Max number of epochs
-> CUInt -- Number epochs between reports
-> CFloat -- Desired error
-> CUInt -- Current epoch
-> IO Int -- Return -1 to stop training
-- | The type of the @Training Algorithm@ enumeration
--
type TrainAlgorithm = #{type enum fann_train_enum}
#{enum TrainAlgorithm,
,trainIncremental = FANN_TRAIN_INCREMENTAL
,trainBatch = FANN_TRAIN_BATCH
,trainRPROP = FANN_TRAIN_RPROP
,trainQuickProp = FANN_TRAIN_QUICKPROP
}
-- | Error function used during training.
--
-- errorFunctionLinear - Standard linear error function.
-- errorFunctionTanH - Tanh error function, usually better but can require a
-- lower learning rate. This error function agressively target outputs that
-- differ much from the desired, while not targetting output that only differ
-- a little that much.
--
-- The tanh function is not recommended for cascade training and incremental
-- training.
--
type ErrorFunction = #{type enum fann_errorfunc_enum}
#{enum ErrorFunction,
,errorFunctionLinear = FANN_ERRORFUNC_LINEAR
,errorFunctionTanH = FANN_ERRORFUNC_TANH
}
-- | Stop function used during training
--
-- stopFunctionMSE - stop criteria is Mean Square Error value.
-- stopFunctionBit - stop criteria is number of bits that fail
--
-- See 'getBitFailLimit', 'setBitFailLimit'.
--
-- The bits are counted in all of the training data, so this number can be
-- higher than the number of training data.
--
type StopFunction = #{type enum fann_stopfunc_enum}
#{enum StopFunction,
,stopFunctionMSE = FANN_STOPFUNC_MSE
,stopFunctionBit = FANN_STOPFUNC_BIT
}
-- | The type of the @Activation Function@ enumeration
--
type ActivationFunction = #{type enum fann_activationfunc_enum}
#{enum ActivationFunction,
, activationLinear = FANN_LINEAR
, activationThreshold = FANN_THRESHOLD
, activationThresholdSymmetric = FANN_THRESHOLD_SYMMETRIC
, activationSigmoid = FANN_SIGMOID
, activationSigmoidStepwise = FANN_SIGMOID_STEPWISE
, activationSigmoidSymmetric = FANN_SIGMOID_SYMMETRIC
, activationSigmoidSymmetricStepwise = FANN_SIGMOID_SYMMETRIC_STEPWISE
, activationGaussian = FANN_GAUSSIAN
, activationGaussianSymmetric = FANN_GAUSSIAN_SYMMETRIC
, activationGaussianStepwise = FANN_GAUSSIAN_STEPWISE
, activationElliot = FANN_ELLIOT
, activationElliotSymmetric = FANN_ELLIOT_SYMMETRIC
, activationLinearPiece = FANN_LINEAR_PIECE
, activationLinearPieceSymmetric = FANN_LINEAR_PIECE_SYMMETRIC
}
-- | Create a callback function to be used during training for reporting and
-- to stop the training
--
fannCallback :: (CallbackType)
-> IO (FunPtr CCallbackType)
fannCallback f = do
mkCallback $ \fann tdata me ebr de ep -> do
ret <- f fann tdata (fi me) (fi ebr) (fd de) (fi ep)
if ret then return (-1) else return 0
where
fi = fromIntegral
fd = realToFrac
foreign import ccall "wrapper"
mkCallback :: CCallbackType -> IO (FunPtr CCallbackType)