{-# LANGUAGE ForeignFunctionInterface #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  HFANN
-- Copyright   :  (c) Olivier Boudry 2008
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  olivier.boudry@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- The Fast Artificial Neural Network Library (FANN) is a free open source
-- neural network library written in C with support for both fully connected
-- and sparsely connected networks (<http://leenissen.dk/fann/>).
--
-- HFANN is a Haskell interface to this library.
--
-----------------------------------------------------------------------------

module HFANN.Train (
  -- * Training
  train,
  trainEpoch,
  trainOnFile,
  trainOnData,
  testData,
  test,
  getMSE,
  resetMSE,
  getBitFail,

  -- * Training data manipulation
  withTrainData,
  loadTrainData,
  destroyTrainData,
  shuffleTrainData,
  scaleInputTrainData,
  scaleOutputTrainData,
  scaleTrainData,
  mergeTrainData,
  duplicateTrainData,
  subsetTrainData,
  trainDataLength,
  getTrainDataInputNodesCount,
  getTrainDataOutputNodesCount,
  saveTrainData,

  -- * Parameters
  getTrainingAlgorithm,
  setTrainingAlgorithm,
  getLearningRate,
  setLearningRate,
  getLearningMomentum,
  setLearningMomentum,
  setActivationFunction,
  setActivationFunctionLayer,
  setActivationFunctionHidden,
  setActivationFunctionOutput,
  setActivationSteepness,
  setActivationSteepnessLayer,
  setActivationSteepnessHidden,
  setActivationSteepnessOutput,
  getTrainErrorFunction,
  setTrainErrorFunction,
  getTrainStopFunction,
  setTrainStopFunction,
  getBitFailLimit,
  setBitFailLimit,
  setCallback,
  getQuickPropDecay,
  setQuickPropDecay,
  getQuickPropMu,
  setQuickPropMu,
  getRPROPIncreaseFactor,
  setRPROPIncreaseFactor,
  getRPROPDecreaseFactor,
  setRPROPDecreaseFactor,
  getRPROPDeltaMin,
  setRPROPDeltaMin,
  getRPROPDeltaMax,
  setRPROPDeltaMax,
  ) where

import           HFANN.Base            (getOutputNodesCount)
import           HFANN.Data            (ActivationFunction, CCallbackType,
                                        CFannType, CFannTypePtr, CallbackType,
                                        ErrorFunction, FannPtr, FannType,
                                        StopFunction, TrainAlgorithm,
                                        TrainDataPtr, fannCallback)

import           Control.Exception     (bracket)

import           Foreign               (FunPtr)
import           Foreign.C.String      (CString)
import           Foreign.C.String      (withCString)
import           Foreign.C.Types
import           Foreign.Marshal.Array (peekArray, withArray)

-- | Train the Neural Network on the given input and output values
--
train :: FannPtr    -- ^ The ANN to be trained
      -> [FannType] -- ^ The input
      -> [FannType] -- ^ The expected output
      -> IO ()
train fann input output = do
  withArray hi $ \ci -> do
    withArray ho $ \co -> do
      f_fann_train fann ci co
  where
    hi = map realToFrac input
    ho = map realToFrac output

-- | Test the Neural Network on the given input and output values
--
test :: FannPtr -- ^ The ANN to be tested
      -> [FannType]  -- ^ The input
      -> [FannType]  -- ^ The expected output
      -> IO [FannType]
test fann input output = do
  len <- getOutputNodesCount fann
  withArray hi $ \ci -> do
    withArray ho $ \co -> do
      arrPtr <- f_fann_test fann ci co
      map realToFrac `fmap` peekArray len arrPtr
  where
    hi = map realToFrac input
    ho = map realToFrac output

-- | Train the Neural Network on the given data file
--
trainOnFile :: FannPtr -- ^ The ANN to be trained
            -> String  -- ^ The path to the training data file
            -> Int     -- ^ The max number of epochs to train
            -> Int     -- ^ The number of epochs between reports
            -> Double  -- ^ The desired error
            -> IO ()
trainOnFile fann fileName maxEpochs epochsBetweenReports desiredError = do
  withCString fileName $ \fn ->
    f_fann_train_on_file fann fn (fi maxEpochs) (fi epochsBetweenReports) (fd desiredError)
  where
    fi = fromIntegral
    fd = realToFrac

-- | Train the Neural Network on a training dataset.
--
-- Instead of printing out reports every \"epochs between reports\", a callback
-- function can be called (see 'setCallback')
--
-- A value of zero in the "epochs between reports" means no reports should be
-- printed.
--
trainOnData :: FannPtr      -- ^ The ANN to be trained
            -> TrainDataPtr -- ^ The training data
            -> Int          -- ^ The max number of epochs to train
            -> Int          -- ^ The number of epochs between reports
            -> Double       -- ^ The desired error
            -> IO ()
trainOnData fann tdata maxEpochs epochsBetweenReports desiredError = do
  f_fann_train_on_data fann tdata (fi maxEpochs) (fi epochsBetweenReports) (fd desiredError)
  where
    fi = fromIntegral
    fd = realToFrac

-- | Train one epoch with a set of training data
--
-- Train one epoch with the given training data. One epoch is where all the
-- training data is considered exactly once.
--
-- The function returns the MSE error as it is calculated either before or
-- during the actual training. This is not the actual MSE after the training
-- epoch but since calculating this will require to go through the entire
-- training set once more it is more adequate to use this value during
-- training.
--
-- The training algorithm used by this function is chosen by the
-- 'setTrainingAlgorithm' function.
--
-- See also:
--   'trainOnData', 'testData'
--
trainEpoch :: FannPtr -> TrainDataPtr -> IO Float
trainEpoch fann tdata = do
  realToFrac `fmap` f_fann_train_epoch fann tdata


--------------------------------------------------------------------------------
-- Training data manipulation
--------------------------------------------------------------------------------

-- | Read training data from file and run the given function on that data.
--
withTrainData :: String                 -- ^ The path to the training data file
              -> (TrainDataPtr -> IO a) -- ^ A function using the training data
              -> IO a                   -- ^ The return value
withTrainData file f = do
  bracket
    (loadTrainData file)
    destroyTrainData
    f

-- | Reads training data from a file.
--
-- The file must be formatted like:
--
-- > num_records num_input num_output
-- > inputdata separated by space
-- > outputdata separated by space
-- >
-- > ...
-- > ...
-- >
-- > inputdata separated by space
-- > outputdata separated by space
--
-- See also:
--   'trainOnData', 'destroyTrain', 'saveTrain'
--
loadTrainData :: String          -- ^ Path to the data file
                -> IO TrainDataPtr -- ^ The loaded training data
loadTrainData file = do
  withCString file $ \fname ->
    f_fann_read_train_from_file fname

-- | Destroy training data
--
-- Destroy training data and properly deallocates the memory.
--
-- Be sure to use this function after finished using the training data unless
-- the training data is part of a 'withTrainData' call.
--
foreign import ccall safe "doublefann.h fann_destroy_train"
  destroyTrainData :: TrainDataPtr -- ^ The data to destroy
                   -> IO ()

-- | Shuffles training data, randomizing the order.
--
-- This is recomended for incremental training, while it has no influence
-- during batch training
--
foreign import ccall safe "doublefann.h fann_shuffle_train_data"
  shuffleTrainData :: TrainDataPtr -- ^ The data to randomly reorder
                   -> IO ()

-- | Scales the inputs in the training data to the specified range.
--
-- See also:
--   'scaleOutputData', 'scaleTrainData'
--
scaleInputTrainData :: TrainDataPtr -- ^ The data to be scaled
                    -> FannType     -- ^ The minimum bound
                    -> FannType     -- ^ The maximum bound
                    -> IO ()
scaleInputTrainData tdata mini maxi = do
  f_fann_scale_input_train_data tdata cmini cmaxi
  where
    cmini = realToFrac mini
    cmaxi = realToFrac maxi

-- | Scales the output in the training data to the specified range.
--
-- See also:
--   'scaleInputData', 'scaleTrainData'
--
scaleOutputTrainData :: TrainDataPtr -- ^ The data to be scaled
                     -> FannType     -- ^ The minimum bound
                     -> FannType     -- ^ The maximum bound
                     -> IO ()
scaleOutputTrainData tdata mini maxi = do
  f_fann_scale_output_train_data tdata cmini cmaxi
  where
    cmini = realToFrac mini
    cmaxi = realToFrac maxi

-- | Scales the inputs and outputs in the training data to the specified range.
--
-- See also:
--   'scaleOutputData', 'scaleInputData'
--
scaleTrainData :: TrainDataPtr -- ^ The data to be scaled
               -> FannType     -- ^ The minimum bound
               -> FannType     -- ^ The maximum bound
               -> IO ()
scaleTrainData tdata mini maxi = do
  f_fann_scale_train_data tdata cmini cmaxi
  where
    cmini = realToFrac mini
    cmaxi = realToFrac maxi

-- | Merges two training data sets into a new one.
--
foreign import ccall safe "doublefann.h fann_merge_train_data"
  mergeTrainData :: TrainDataPtr    -- ^ training data set 1
                 -> TrainDataPtr    -- ^ training data set 2
                 -> IO TrainDataPtr -- ^ a copy of the merged data sets 1 and 2

-- | Returns an exact copy of a training data set.
--
foreign import ccall safe "doublefann.h fann_duplicate_train_data"
  duplicateTrainData :: TrainDataPtr    -- ^ The training data
                     -> IO TrainDataPtr -- ^ A new copy

-- | Returns a copy of a subset of the training data, starting at the given
-- offset and taking the given count of elements.
--
-- > len <- trainDataLength tdata
-- > newtdata <- subsetTrainData tdata 0 len
--
-- Will do the same as 'duplicateTrainData'
--
-- See also:
--   'trainDataLength'
--
subsetTrainData :: TrainDataPtr -> Int -> Int -> IO TrainDataPtr
subsetTrainData tdata offset len = do
  f_fann_subset_train_data tdata o l
  where
    o = fromIntegral offset
    l = fromIntegral len

-- | Returns the number of training patterns in the training data.
--
trainDataLength :: TrainDataPtr -> IO Int
trainDataLength tdata = do
  fromIntegral `fmap` f_fann_length_train_data tdata

-- | Returns the number of input nodes in the training data
--
getTrainDataInputNodesCount :: TrainDataPtr -> IO Int
getTrainDataInputNodesCount tdata = do
  fromIntegral `fmap` f_fann_num_input_train_data tdata

-- | Returns the number of output nodes in the training data
--
getTrainDataOutputNodesCount :: TrainDataPtr -> IO Int
getTrainDataOutputNodesCount tdata = do
  fromIntegral `fmap` f_fann_num_output_train_data tdata

-- | Save the training structure to a file with the format as specified in
-- 'loadTrainData'
--
-- See also
--   'loadTrainData'
--
saveTrainData :: TrainDataPtr -> String -> IO ()
saveTrainData tdata file = do
  withCString file $ \cstr -> do
    success <- f_fann_save_train tdata cstr
    if success == (-1) then
      error $ "Could not save train data to file " ++ file
      else return ()


--------------------------------------------------------------------------------
-- Parameters manipulation
--------------------------------------------------------------------------------

-- | Return the training algorithm. This training algorithm is used by
-- 'trainOnData' and associated functions.
--
-- Note that this algorithm is also used during 'cascadeTrainOnData' although
-- only fannTrainRPROP and fannTrainQuickProp is allowed during cascade
-- training.
--
-- See also:
--  'setTrainingAlgorithm', 'TrainAlgorithm'
--
foreign import ccall safe "doublefann.h fann_get_training_algorithm"
  getTrainingAlgorithm :: FannPtr           -- ^ The ANN
                       -> IO TrainAlgorithm -- ^ The training algorithm

-- | Set the training algorithm.
--
-- See also:
--   'getTrainingAlgorithm', 'TrainingAlgorithm'
--
foreign import ccall safe "doublefann.h fann_set_training_algorithm"
  setTrainingAlgorithm :: FannPtr        -- ^ The ANN
                                -> TrainAlgorithm -- ^ The training algorithm
                                -> IO ()

-- | Return the learning rate.
--
-- The learning rate is used to determine how aggressive the training should be
-- for some of the training algorithms ('fannTrainIncremental',
-- 'fannTrainBatch', 'fannTrainQuickProp').
--
-- Note that it is not used in 'fannTrainRPROP'.
--
-- The default learning rate is 0.7.
--
-- See also:
--  'setLearningRate', 'setTrainingAlgorithm'
--
getLearningRate :: FannPtr -> IO Float
getLearningRate fann = do
  realToFrac `fmap` f_fann_get_learning_rate fann

-- | Set the learning rate.
--
-- See getLearningRate for more information about the learning rate.
--
-- See also:
--   'getLearingRate'
--
setLearningRate :: FannPtr -> Float -> IO ()
setLearningRate fann rate = do
  f_fann_set_learning_rate fann crate
  where
    crate = realToFrac rate

-- | Return the learning momentum.
--
-- The learning momentum can be used to speed up the 'fannTrainIncremental'
-- training algorithm.
--
-- A too high momentum will however not benefit training. Setting momentum to
-- 0 will be the same as not using the momentum parameter. The recommended
-- value for this parameter is between 0.0 and 1.0.
--
-- The default momentum is 0.
--
-- See also:
--   'setLearningMomentum', 'setTrainingAlgorithm'
--
getLearningMomentum :: FannPtr -> IO Float
getLearningMomentum fann = do
  realToFrac `fmap` f_fann_get_learning_momentum fann

-- | Set the learning momentum.
--
-- More info available in 'getLearningMomentum'.
--
setLearningMomentum :: FannPtr -> Float -> IO ()
setLearningMomentum fann momentum = do
  f_fann_set_learning_momentum fann cmomentum
  where
    cmomentum = realToFrac momentum

-- | Set the activation function for the neuron specified in layer specified,
-- counting the input layer as layer 0.
--
-- It is not possible to set activation functions for the neurons in the input
-- layer.
--
-- When choosing an activation function it is important to note that the
-- activation function have different range. In 'fannSigmoid' is in the 0 .. 1
-- range while fannSigmoidSymmetric is in the -1 .. 1 range and fannLinear is
-- unbound.
--
-- The default activation function is fannSigmoidStepwise.
--
-- See also:
--   'setActivationFunctionLayer', 'setActivationFunctionHidden',
--   'setActivationFunctionOutput', 'setActivationSteepness'
--
setActivationFunction :: FannPtr            -- ^ The ANN
                      -> ActivationFunction -- ^ The activation function
                      -> Int                -- ^ The layer
                      -> Int                -- ^ The neuron
                      -> IO ()
setActivationFunction fann func layer neuron = do
  f_fann_set_activation_function fann func clayer cneuron
  where
    clayer = fromIntegral layer
    cneuron = fromIntegral neuron

-- | Set the activation function for all neurons of a given layer, counting
-- the input layer as layer 0.
--
-- It is not possible to set an activation function for the neurons in the
-- input layer.
--
-- See also:
--   'setActivationFunction', 'setActivationFunctionHidden',
--   'setActivationFunctionOutput', 'setActivationSteepnessLayer'
--
setActivationFunctionLayer :: FannPtr            -- ^ The ANN
                           -> ActivationFunction -- ^ The activation function
                           -> Int                -- ^ The layer
                           -> IO ()
setActivationFunctionLayer fann func layer = do
  f_fann_set_activation_function_layer fann func clayer
  where
    clayer = fromIntegral layer

-- | Set the activation function for all the hidden layers.
--
-- See also:
--   'setActivationFunction', 'setActivationFunctionLayer',
--   'setActivationFunctionOutput'
--
setActivationFunctionHidden :: FannPtr            -- ^ The ANN
                            -> ActivationFunction -- ^ The Activation Function
                            -> IO ()
setActivationFunctionHidden fann func = do
  f_fann_set_activation_function_hidden fann func

-- | Set the activation function for the output layer.
--
-- See also:
--   'setActivationFunction', 'setActivationFunctionLayer',
--   'setActivationFunctionHidden'
--
setActivationFunctionOutput :: FannPtr            -- ^ The ANN
                            -> ActivationFunction -- ^ The Activation Function
                            -> IO ()
setActivationFunctionOutput fann func = do
  f_fann_set_activation_function_output fann func

-- | Set the activation steepness of the specified neuron in the specified
-- layer, counting the input layer as 0.
--
-- It is not possible to set activation steepness for the neurons in the input
-- layer.
--
-- The steepness of an activation function says something about how fast the
-- activation function goes from the minimum to the maximum. A high value
-- for the activation function will also give a more agressive training.
--
-- When training networks where the output values should be at the extremes
-- (usually 0 and 1, depending on the activation function), a steep activation
-- can be used (e.g. 1.0).
--
-- The default activation steepness is 0.5
--
-- See also:
--   'setActivationSteepnessLayer', 'setActivationSteepnessHidden',
--   'setActivationSteepnessOutput', 'setActivationFunction'
--
setActivationSteepness :: FannPtr  -- ^ The ANN
                       -> FannType -- ^ The steepness
                       -> Int      -- ^ The layer
                       -> Int      -- ^ The neuron
                       -> IO ()
setActivationSteepness fann steep layer neuron = do
  f_fann_set_activation_steepness fann csteep clayer cneuron
  where
    csteep = realToFrac steep
    clayer = fromIntegral layer
    cneuron = fromIntegral neuron

-- | Set the activation steepness for all of the neurons in the given layer,
-- counting the input layer as layer 0.
--
-- It is not possible to set the activation steepness for the neurons in the
-- input layer.
--
-- See also:
--   'setActivationSteepness', 'setActivationSteepnessHidden',
--   'setActivationSteepnessOutput', 'setActivationFunction'.
--
setActivationSteepnessLayer :: FannPtr  -- ^ The ANN
                            -> FannType -- ^ The steepness
                            -> Int      -- ^ The layer
                            -> IO ()
setActivationSteepnessLayer fann steep layer = do
  f_fann_set_activation_steepness_layer fann csteep clayer
  where
    csteep = realToFrac steep
    clayer = fromIntegral layer

-- | Set the activation steepness of all the nodes in all hidden layers.
--
-- See also:
--   'setActivationSteepness', 'setActivationSteepnessLayer',
--   'setActivationSteepnessOutput', 'setActivationFunction'
--
setActivationSteepnessHidden :: FannPtr  -- ^ The ANN
                             -> FannType -- ^ The steepness
                             -> IO ()
setActivationSteepnessHidden fann steep = do
  f_fann_set_activation_steepness_hidden fann csteep
  where
    csteep = realToFrac steep

-- | Set the activation steepness of all the nodes in all output layer.
--
-- See also:
--   'setActivationSteepness', 'setActivationSteepnessLayer',
--   'setActivationSteepnessHidden', 'setActivationFunction'
--
setActivationSteepnessOutput :: FannPtr  -- ^ The ANN
                             -> FannType -- ^ The steepness
                             -> IO ()
setActivationSteepnessOutput fann steep = do
  f_fann_set_activation_steepness_output fann csteep
  where
    csteep = realToFrac steep

-- | Return the error function used during training.
--
-- The error function is described in 'ErrorFunction'
--
-- The default error function is 'errorFunctionTanH'
--
-- See also:
--   'setTrainErrorFunction'
--
foreign import ccall safe "doublefann.h fann_get_train_error_function"
  getTrainErrorFunction :: FannPtr          -- ^ The ANN
                        -> IO ErrorFunction -- ^ The error function

-- | Set the error function used during training.
--
-- The error function is described in 'ErrorFunction'
--
-- See also:
--   'getTrainErrorFunction'
--
foreign import ccall safe "doublefann.h fann_set_train_error_function"
  setTrainErrorFunction :: FannPtr       -- ^ The ANN
                        -> ErrorFunction -- ^ The error function
                        -> IO ()

-- | Returns the stop function used during training.
--
-- The stop function is described in 'StopFunction'
--
-- The default stop function is 'stopFunctionMSE'
--
-- See also:
--   'setTrainStopFunction', 'setBitFailLimit'
--
foreign import ccall safe "doublefann.h fann_get_train_stop_function"
  getTrainStopFunction :: FannPtr -> IO StopFunction

-- | Set the stop function used during training.
--
-- The stop function is described in 'StopFunction'
--
-- The default stop function is 'stopFunctionMSE'
--
-- See also:
--   'getTrainStopFunction', 'getBitFailLimit'
--
foreign import ccall safe "doublefann.h fann_set_train_stop_function"
  setTrainStopFunction :: FannPtr -> StopFunction -> IO ()

-- | Returns the bit fail limit used during training.
--
-- The bit fail limit is used during training where the 'StopFunction' is set
-- 'stopFunctionBit'.
--
-- The limit is the maximum accepted difference between the desired output
-- and the actual output during training. Each output that diverges more than
-- this is counted as an error bit.
--
-- This difference is divided by two when dealing with symmetric activation
-- functions, so that symmetric and not symmetric activation functions can use
-- the same limit.
--
-- The default bit fail limit is 0.35.
--
-- See also:
--   'setBitFailLimit'
--
getBitFailLimit :: FannPtr -> IO FannType
getBitFailLimit fann = do
  realToFrac `fmap` f_fann_get_bit_fail_limit fann

-- | Set the bit fail limit used during training.
--
-- See also:
--   'getBitFailLimit'
--
setBitFailLimit :: FannPtr -> FannType -> IO ()
setBitFailLimit fann limit = do
  f_fann_set_bit_fail_limit fann climit
  where
    climit = realToFrac limit

-- | Returns the quickprop decay
--
-- The decay is a small negative valued number which is the factor that the
-- weights should become smaller in each iteration during quickprop training.
--
-- This is used to make sure that the weights do not become too high during
-- training.
--
-- The default decay is -0.0001
--
-- See also:
--   'setQuickPropDecay'
--
getQuickPropDecay :: FannPtr -> IO Float
getQuickPropDecay fann = do
  realToFrac `fmap` f_fann_get_quickprop_decay fann

-- | Sets the quickprop decay factor
--
-- See also:
--   'getQuickPropDecay'
--
setQuickPropDecay :: FannPtr -> Float -> IO ()
setQuickPropDecay fann decay = do
  f_fann_set_quickprop_decay fann cdecay
  where
    cdecay = realToFrac decay

-- | Returns the quickprop mu factor
--
-- The mu factor is used to increase and decrease the step-size during quickprop
-- training.
-- The mu factor should always be above 1, since it would otherwise decrease the
-- step-size when it was supposed to increase it.
--
-- The default mu factor is 1.75
--
-- See also:
--   'setQuickPropMu'
--
getQuickPropMu :: FannPtr -> IO Float
getQuickPropMu fann = do
  realToFrac `fmap` f_fann_get_quickprop_mu fann

-- | Sets the quickprop mu factor
--
-- See also:
--   'getQuickPropMu'
--
setQuickPropMu :: FannPtr -> Float -> IO ()
setQuickPropMu fann mu = do
  f_fann_set_quickprop_mu fann cmu
  where
    cmu = realToFrac mu

-- | Returns the RPROP increase factor
--
-- The RPROP increase factor is a value larger than 1, which is used to
-- increase the step-size during RPROP training.
--
-- The default increase factor is 1.2
--
-- See also:
--   'setRPROPIncreaseFactor'
--
getRPROPIncreaseFactor :: FannPtr -> IO Float
getRPROPIncreaseFactor fann = do
  realToFrac `fmap` f_fann_get_rprop_increase_factor fann

-- | Sets the RPROP increase factor
--
-- See also:
--   'getRPROPIncreaseFactor'
--
setRPROPIncreaseFactor :: FannPtr -> Float -> IO ()
setRPROPIncreaseFactor fann incfac = do
  f_fann_set_rprop_increase_factor fann cincfac
  where
    cincfac = realToFrac incfac

-- | Returns the RPROP decrease factor
--
-- The RPROP decrease factor is a value larger than 1, which is used to
-- decrease the step-size during RPROP training.
--
-- The default decrease factor is 0.5
--
-- See also:
--   'setRPROPDecreaseFactor'
--
getRPROPDecreaseFactor :: FannPtr -> IO Float
getRPROPDecreaseFactor fann = do
  realToFrac `fmap` f_fann_get_rprop_decrease_factor fann

-- | Sets the RPROP decrease factor
--
-- See also:
--   'getRPROPDecreaseFactor'
--
setRPROPDecreaseFactor :: FannPtr -> Float -> IO ()
setRPROPDecreaseFactor fann incfac = do
  f_fann_set_rprop_decrease_factor fann cincfac
  where
    cincfac = realToFrac incfac

-- | Returns the RPROP delta min factor
--
-- The delta min factor is a small positive number determining how small the
-- minimum step-size may be.
--
-- The default value delta min is 0.0
--
-- See also:
--   'setRPROPDeltaMin'
--
getRPROPDeltaMin :: FannPtr -> IO Float
getRPROPDeltaMin fann = do
  realToFrac `fmap` f_fann_get_rprop_delta_min fann

-- | Sets the RPROP delta min
--
-- See also:
--   'getRPROPDeltaMin'
--
setRPROPDeltaMin :: FannPtr -> Float -> IO ()
setRPROPDeltaMin fann delmin = do
  f_fann_set_rprop_delta_min fann cdelmin
  where
    cdelmin = realToFrac delmin

-- | Returns the RPROP delta max factor
--
-- The delta max factor is a positive number determining how large the
-- maximum step-size may be.
--
-- The default value delta max is 50.0
--
-- See also:
--   'setRPROPDeltaMax'
--
getRPROPDeltaMax :: FannPtr -> IO Float
getRPROPDeltaMax fann = do
  realToFrac `fmap` f_fann_get_rprop_delta_max fann

-- | Sets the RPROP delta max
--
-- See also:
--   'getRPROPDeltaMax'
--
setRPROPDeltaMax :: FannPtr -> Float -> IO ()
setRPROPDeltaMax fann delmax = do
  f_fann_set_rprop_delta_max fann cdelmax
  where
    cdelmax = realToFrac delmax

-- | Test ANN on training data
--
-- This function will run the ANN on the training data and return the error
-- value. It can be used to validate the check the quality of the ANN on some
-- test data.
--
foreign import ccall safe "doublefann.h fann_test_data"
  testData :: FannPtr      -- ^ The ANN to be used
           -> TrainDataPtr -- ^ The training data
           -> IO CFloat    -- ^ The error value

-- | Set the callback function to be used for reporting and to stop training
--
-- The callback function will be called based on the \"Epoch between reports\"
-- defined frequency.
--
-- The type of the callback function is:
--
-- > callback :: FannPtr      -- The ANN being trained
-- >          -> TrainDataPtr -- The training data in use
-- >          -> Int          -- Max number of epochs
-- >          -> Int          -- Number of epochs between reports
-- >          -> Float        -- Desired error
-- >          -> Int          -- Current epoch
-- >          -> Bool         -- True to terminate training, False to continue
--
setCallback :: FannPtr -> CallbackType -> IO ()
setCallback fann func = do
  cfunc <- fannCallback func
  f_fann_set_callback fann cfunc

-- | Get the mean square error from the ANN
--
-- This value is calculated during training or testing, and can therefore
-- sometimes be a bit off if the weights have been changed since the last
-- calculation of the value.
--
getMSE :: FannPtr  -- ^ The ANN
       -> IO Float -- ^ The mean square error
getMSE fann = do
  realToFrac `fmap` f_fann_get_MSE fann

-- | Get the number of fail bits
--
-- The number of fail bits means the number of output neurons which differ more
-- than the bit fail limit (see 'getBitFailLimit', 'setBitFailLimit').
--
-- This value is reset by 'resetMSE' and updated by the same functions which
-- also updates the MSE value 'testData', 'trainEpoch'.
--
getBitFail :: FannPtr -- ^ The ANN
           -> IO Int  -- ^ The number of fail bits
getBitFail fann = do
  fromIntegral `fmap` f_fann_get_bit_fail fann

-- | Reset the mean square error from the network.
--
-- This function also resets the number of bits that fail.
--
foreign import ccall safe "doublefann.h fann_reset_MSE"
  resetMSE :: FannPtr -> IO ()


--------------------------------------------------------------------------------
-- Non exported functions
--------------------------------------------------------------------------------

-- | Train one iteration
--
foreign import ccall safe "doublefann.h fann_train"
  f_fann_train :: FannPtr -> CFannTypePtr -> CFannTypePtr -> IO ()

-- | Test one iteration
--
foreign import ccall safe "doublefann.h fann_test"
  f_fann_test :: FannPtr -> CFannTypePtr -> CFannTypePtr -> IO CFannTypePtr

-- | Train one epoch
--
foreign import ccall safe "doublefann.h fann_train_epoch"
  f_fann_train_epoch :: FannPtr  -- The ANN
             -> TrainDataPtr     -- The training data
             -> IO CFloat        -- The MSE before this training

-- | Train the Neural Network on the given data file

foreign import ccall safe "doublefann.h fann_train_on_file"
  f_fann_train_on_file :: FannPtr -> CString -> CUInt -> CUInt -> CFloat -> IO ()

-- | Scales the inputs in the training data to the specified range.
--
foreign import ccall safe "doublefann.h fann_scale_input_train_data"
  f_fann_scale_input_train_data :: TrainDataPtr -> CFannType -> CFannType -> IO ()

-- | Scales the outputs in the training data to the specified range.
--
foreign import ccall safe "doublefann.h fann_scale_output_train_data"
  f_fann_scale_output_train_data :: TrainDataPtr -> CFannType -> CFannType -> IO ()

-- | Scales the inputs and outputs in the training data to the specified range.
--
foreign import ccall safe "doublefann.h fann_scale_train_data"
  f_fann_scale_train_data :: TrainDataPtr -> CFannType -> CFannType -> IO ()

-- | Returns a copy of a subset of the training data
--
foreign import ccall safe "doublefann.h fann_subset_train_data"
  f_fann_subset_train_data :: TrainDataPtr -> CUInt -> CUInt -> IO TrainDataPtr

-- | Returns the number of training patterns in the training data.
--
foreign import ccall safe "doublefann.h fann_length_train_data"
  f_fann_length_train_data :: TrainDataPtr -> IO CUInt

-- | Returns the number of input nodes in the training data
--
foreign import ccall safe "doublefann.h fann_num_input_train_data"
  f_fann_num_input_train_data :: TrainDataPtr -> IO CUInt

-- | Returns the number of output nodes in the training data
--
foreign import ccall safe "doublefann.h fann_num_output_train_data"
  f_fann_num_output_train_data :: TrainDataPtr -> IO CInt

-- | Save the training data to a file
--
foreign import ccall safe "doublefann.h fann_save_train"
  f_fann_save_train :: TrainDataPtr -> CString -> IO CUInt

-- | Return the learning rate.
--
foreign import ccall safe "doublefann.h fann_get_learning_rate"
  f_fann_get_learning_rate :: FannPtr -> IO CFloat

-- | Set the learning rate.
--
foreign import ccall safe "doublefann.h fann_set_learning_rate"
  f_fann_set_learning_rate :: FannPtr -> CFloat -> IO ()

-- | Return the learning momentum
--
foreign import ccall safe "doublefann.h fann_get_learning_momentum"
  f_fann_get_learning_momentum :: FannPtr -> IO CFloat

-- | Set the learning momentum
--
foreign import ccall safe "doublefann.h fann_set_learning_momentum"
  f_fann_set_learning_momentum :: FannPtr -> CFloat -> IO ()

-- | Set the activation function for one neuron
--
foreign import ccall safe "doublefann.h fann_set_activation_function"
  f_fann_set_activation_function :: FannPtr -> ActivationFunction -> CInt -> CInt -> IO ()

-- | Set the activation function for all neurons in a layer
--
foreign import ccall safe "doublefann.h fann_set_activation_function_layer"
  f_fann_set_activation_function_layer :: FannPtr -> ActivationFunction -> CInt -> IO ()

-- | Set the hidden nodes group activation function
--
foreign import ccall safe "doublefann.h fann_set_activation_function_hidden"
  f_fann_set_activation_function_hidden :: FannPtr -> ActivationFunction -> IO ()

-- | Set the output nodes group activation function
--
foreign import ccall safe "doublefann.h fann_set_activation_function_output"
  f_fann_set_activation_function_output :: FannPtr -> ActivationFunction -> IO ()

-- | Set the activation steepness of a given neuron in a given layer
--
foreign import ccall safe "doublefann.h fann_set_activation_steepness"
  f_fann_set_activation_steepness :: FannPtr -> CFannType -> CInt -> CInt -> IO ()

-- | Set the activation steepness for all the nodes in a given layer
--
foreign import ccall safe "doublefann.h fann_set_activation_steepness_layer"
  f_fann_set_activation_steepness_layer :: FannPtr -> CFannType -> CInt -> IO ()

-- | Set the activation steepness for all the nodes in the hidden layers
--
foreign import ccall safe "doublefann.h fann_set_activation_steepness_hidden"
  f_fann_set_activation_steepness_hidden :: FannPtr -> CFannType -> IO ()

-- | Set the activation steepness for all the nodes in the output layers
--
foreign import ccall safe "doublefann.h fann_set_activation_steepness_output"
  f_fann_set_activation_steepness_output :: FannPtr -> CFannType -> IO ()

-- | Get the bit fail limit
--
foreign import ccall safe "doublefann.h fann_get_bit_fail_limit"
  f_fann_get_bit_fail_limit :: FannPtr -> IO CFannType

-- | Set the bit fail limit
--
foreign import ccall safe "doublefann.h fann_set_bit_fail_limit"
  f_fann_set_bit_fail_limit :: FannPtr -> CFannType -> IO ()

-- | Get the quickprop decay
--
foreign import ccall safe "doublefann.h fann_get_quickprop_decay"
  f_fann_get_quickprop_decay :: FannPtr -> IO CFloat

-- | Set the quickprop decay
--
foreign import ccall safe "doublefann.h fann_set_quickprop_decay"
  f_fann_set_quickprop_decay :: FannPtr -> CFloat -> IO ()

-- | Get the quickprop mu factor
--
foreign import ccall safe "doublefann.h fann_get_quickprop_mu"
  f_fann_get_quickprop_mu :: FannPtr -> IO CFloat

-- | Set the quickprop mu factor
--
foreign import ccall safe "doublefann.h fann_set_quickprop_mu"
  f_fann_set_quickprop_mu :: FannPtr -> CFloat -> IO ()

-- | Set the RPROP increase factor
--
foreign import ccall safe "doublefann.h fann_set_rprop_increase_factor"
  f_fann_set_rprop_increase_factor :: FannPtr -> CFloat -> IO ()

-- | Get the RPROP increase factor
--
foreign import ccall safe "doublefann.h fann_get_rprop_increase_factor"
  f_fann_get_rprop_increase_factor :: FannPtr -> IO CFloat

-- | Set the RPROP decrease factor
--
foreign import ccall safe "doublefann.h fann_set_rprop_decrease_factor"
  f_fann_set_rprop_decrease_factor :: FannPtr -> CFloat -> IO ()

-- | Get the RPROP decrease factor
--
foreign import ccall safe "doublefann.h fann_get_rprop_decrease_factor"
  f_fann_get_rprop_decrease_factor :: FannPtr -> IO CFloat

-- | Set the RPROP delta min factor
--
foreign import ccall safe "doublefann.h fann_set_rprop_delta_min"
  f_fann_set_rprop_delta_min :: FannPtr -> CFloat -> IO ()

-- | Get the RPROP delta min factor
--
foreign import ccall safe "doublefann.h fann_get_rprop_delta_min"
  f_fann_get_rprop_delta_min :: FannPtr -> IO CFloat

-- | Set the RPROP delta max factor
--
foreign import ccall safe "doublefann.h fann_set_rprop_delta_max"
  f_fann_set_rprop_delta_max :: FannPtr -> CFloat -> IO ()

-- | Get the RPROP delta max factor
--
foreign import ccall safe "doublefann.h fann_get_rprop_delta_max"
  f_fann_get_rprop_delta_max :: FannPtr -> IO CFloat

-- | Load training data from file
--
foreign import ccall safe "doublefann.h fann_read_train_from_file"
  f_fann_read_train_from_file :: CString -> IO (TrainDataPtr)

-- | Train the ANN on loaded training data
--
foreign import ccall safe "doublefann.h fann_train_on_data"
  f_fann_train_on_data :: FannPtr -> TrainDataPtr -> CUInt -> CUInt -> CFloat -> IO ()

-- | Set the callback function used in training
--
foreign import ccall safe "doublefann.h fann_set_callback"
  f_fann_set_callback :: FannPtr -> FunPtr CCallbackType -> IO ()

-- | Read the mean square error from the ANN
--
foreign import ccall safe "doublefann.h fann_get_MSE"
  f_fann_get_MSE :: FannPtr -> IO CFloat

-- | Get the number of fail bits
--
foreign import ccall safe "doublefann.h fann_get_bit_fail"
  f_fann_get_bit_fail :: FannPtr -> IO CUInt