{-# LINE 1 "src/HFANN/Data.hsc" #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LINE 2 "src/HFANN/Data.hsc" #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  HFANN
-- Copyright   :  (c) Olivier Boudry 2008
-- License     :  BSD-style (see the file LICENSE)
-- 
-- Maintainer  :  olivier.boudry@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- The Fast Artificial Neural Network Library (FANN) is a free open source
-- neural network library written in C with support for both fully connected
-- and sparsely connected networks (<http://leenissen.dk/fann/>).
--
-- HFANN is a Haskell interface to this library.
--
-----------------------------------------------------------------------------

module HFANN.Data (
  -- * Data Types
  FannType,
  CFannType,
  CFannTypePtr,
  ActivationFunction,
  TrainAlgorithm,
  ErrorFunction,
  StopFunction,
  Fann (..),
  FannPtr,
  TrainData (..),
  TrainDataPtr,
  CallbackType,
  CCallbackType,

  -- * Activation Functions
  activationLinear,
  activationThreshold,
  activationThresholdSymmetric,
  activationSigmoid,
  activationSigmoidStepwise,
  activationSigmoidSymmetric,
  activationSigmoidSymmetricStepwise,
  activationGaussian,
  activationGaussianSymmetric,
  activationGaussianStepwise,
  activationElliot,
  activationElliotSymmetric,
  activationLinearPiece,
  activationLinearPieceSymmetric,

  -- * Training Algorithms
  trainIncremental,
  trainBatch,
  trainRPROP,
  trainQuickProp,

  -- * Error Functions
  errorFunctionLinear,
  errorFunctionTanH,

  -- * Stop Functions
  stopFunctionMSE,
  stopFunctionBit,

  -- * Callback
  fannCallback,
  ) where

import Data.Word
import Foreign (Ptr, FunPtr)
import Foreign.C.Types


{-# LINE 75 "src/HFANN/Data.hsc" #-}

-- | The Haskell input\/output type
--
-- This is the data type used in Haskell to represent the input\/output data.
--
type FannType = Double

-- |The C input\/output type
--
-- This is the data type used in the C library to represent the input\/output
-- data.
--
type CFannType = CDouble

-- | A pointer to the C input\/output type
--
type CFannTypePtr = Ptr CDouble

-- | The ANN structure
--
data Fann = Fann

-- | A pointer to an ANN structure
--
type FannPtr = Ptr Fann

-- | Data type of the training data structure
--
data TrainData = TrainData

-- | A pointer to the training data structure type
--
type TrainDataPtr = Ptr TrainData

-- | The Haskell Callback function type
--
type CallbackType = FannPtr       -- The ANN
                 -> TrainDataPtr  -- The training data
                 -> Int           -- Max number of epochs
                 -> Int           -- Number epochs between reports
                 -> Float         -- Desired error
                 -> Int           -- Current epoch
                 -> IO Bool       -- Return True to stop training

-- | The C callback function type
--
type CCallbackType = FannPtr      -- The ANN
                  -> TrainDataPtr -- The training data
                  -> CUInt        -- Max number of epochs
                  -> CUInt        -- Number epochs between reports
                  -> CFloat       -- Desired error
                  -> CUInt        -- Current epoch
                  -> IO Int       -- Return -1 to stop training

-- | The type of the @Training Algorithm@ enumeration
--
type TrainAlgorithm = Word32
{-# LINE 132 "src/HFANN/Data.hsc" #-}

trainIncremental  :: TrainAlgorithm
trainIncremental  =  0
trainBatch  :: TrainAlgorithm
trainBatch  =  1
trainRPROP  :: TrainAlgorithm
trainRPROP  =  2
trainQuickProp  :: TrainAlgorithm
trainQuickProp  =  3

{-# LINE 139 "src/HFANN/Data.hsc" #-}

-- | Error function used during training.
-- 
-- errorFunctionLinear - Standard linear error function.
-- errorFunctionTanH - Tanh error function, usually better but can require a
-- lower learning rate. This error function agressively target outputs that
-- differ much from the desired, while not targetting output that only differ
-- a little that much.
--
-- The tanh function is not recommended for cascade training and incremental
-- training.
--
type ErrorFunction = Word32
{-# LINE 152 "src/HFANN/Data.hsc" #-}

errorFunctionLinear  :: ErrorFunction
errorFunctionLinear  =  0
errorFunctionTanH  :: ErrorFunction
errorFunctionTanH  =  1

{-# LINE 157 "src/HFANN/Data.hsc" #-}

-- | Stop function used during training
--
-- stopFunctionMSE - stop criteria is Mean Square Error value.
-- stopFunctionBit - stop criteria is number of bits that fail
-- 
-- See 'getBitFailLimit', 'setBitFailLimit'.
--
-- The bits are counted in all of the training data, so this number can be
-- higher than the number of training data.
--
type StopFunction = Word32
{-# LINE 169 "src/HFANN/Data.hsc" #-}

stopFunctionMSE  :: StopFunction
stopFunctionMSE  =  0
stopFunctionBit  :: StopFunction
stopFunctionBit  =  1

{-# LINE 174 "src/HFANN/Data.hsc" #-}

-- | The type of the @Activation Function@ enumeration
--
type ActivationFunction = Word32
{-# LINE 178 "src/HFANN/Data.hsc" #-}

activationLinear  :: ActivationFunction
activationLinear  =  0
activationThreshold  :: ActivationFunction
activationThreshold  =  1
activationThresholdSymmetric  :: ActivationFunction
activationThresholdSymmetric  =  2
activationSigmoid  :: ActivationFunction
activationSigmoid  =  3
activationSigmoidStepwise  :: ActivationFunction
activationSigmoidStepwise  =  4
activationSigmoidSymmetric  :: ActivationFunction
activationSigmoidSymmetric  =  5
activationSigmoidSymmetricStepwise  :: ActivationFunction
activationSigmoidSymmetricStepwise  =  6
activationGaussian  :: ActivationFunction
activationGaussian  =  7
activationGaussianSymmetric  :: ActivationFunction
activationGaussianSymmetric  =  8
activationGaussianStepwise  :: ActivationFunction
activationGaussianStepwise  =  9
activationElliot  :: ActivationFunction
activationElliot  =  10
activationElliotSymmetric  :: ActivationFunction
activationElliotSymmetric  =  11
activationLinearPiece  :: ActivationFunction
activationLinearPiece  =  12
activationLinearPieceSymmetric  :: ActivationFunction
activationLinearPieceSymmetric  =  13

{-# LINE 195 "src/HFANN/Data.hsc" #-}

-- | Create a callback function to be used during training for reporting and
-- to stop the training
--
fannCallback :: (CallbackType)
             -> IO (FunPtr CCallbackType)
fannCallback f = do
  mkCallback $ \fann tdata me ebr de ep -> do
    ret <- f fann tdata (fi me) (fi ebr) (fd de) (fi ep)
    if ret then return (-1) else return 0
  where
    fi = fromIntegral
    fd = realToFrac

foreign import ccall "wrapper"
  mkCallback :: CCallbackType -> IO (FunPtr CCallbackType)