{-# LANGUAGE ForeignFunctionInterface #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  HFANN
-- Copyright   :  (c) Olivier Boudry 2008
-- License     :  BSD-style (see the file LICENSE)
-- 
-- Maintainer  :  olivier.boudry@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- The Fast Artificial Neural Network Library (FANN) is a free open source
-- neural network library written in C with support for both fully connected
-- and sparsely connected networks (<http://leenissen.dk/fann/>).
--
-- HFANN is a Haskell interface to this library.
--
-----------------------------------------------------------------------------

module HFANN.Base (

  -- * ANN Creation
  withStandardFann,
  withSparseFann,
  withShortcutFann,

  -- * ANN Initialization
  randomizeWeights,
  initWeights,

  -- * ANN Use
  runFann,
  printConnections,
  printParameters,

  -- * ANN Information
  getOutputNodesCount,
  getInputNodesCount,
  getTotalNodesCount,
  getConnectionsCount,

  ) where

import HFANN.Data (FannType, CFannType, CFannTypePtr, FannPtr, TrainDataPtr)

import Control.Exception (bracket)

import Foreign (Ptr)
import Foreign.C.Types (CUInt, CFloat)
import Foreign.Marshal.Array (peekArray, withArray)

-- | Run the trained Neural Network on provided input
--
runFann :: FannPtr       -- ^ The ANN
        -> [FannType]    -- ^ A list of inputs
        -> IO [FannType] -- ^ A list of outputs
runFann fann input = do
  len <- getOutputNodesCount fann
  withArray (map realToFrac input) $ \arr -> do
    res <- f_fann_run fann arr
    a <- peekArray len res
    return $ map realToFrac a

-- | Create a new standard fully connected Neural Network and call the
--   given function with the ANN as argument.
--   When finished destroy the Neural Network.
--
--   The structure of the ANN is given by the first parameter. It's an
--   Int list giving the number of nodes per layer from input layer to
--   output layer.
--
--   Example: @[2,3,1]@ would describe an ANN with 2 nodes in the input layer,
--   one hidden layer of 3 nodes and 1 node in the output layer.
--
--   The function provided as second argument will be called with the created
--   ANN as parameter.
--
withStandardFann :: [Int]             -- ^ The ANN structure
                 -> (FannPtr -> IO a) -- ^ A function using the ANN
                 -> IO a              -- ^ The return value
withStandardFann nodes f = do
  bracket
    (createStandardFann nodes)
    destroyFann
    f

-- | Create a new sparse not fully connected Neural Network and call the
--   given function with the ANN as argument. When finished destroy the ANN.
--
withSparseFann :: Float             -- ^ The ratio of connections
               -> [Int]             -- ^ The ANN structure
               -> (FannPtr -> IO a) -- ^ A function using the ANN
               -> IO a              -- ^ The return value
withSparseFann ratio nodes f = do
  bracket
    (createSparseFann ratio nodes)
    destroyFann
    f

-- | Create a new sparse not fully connected Neural Network with shortcut
--   connections between layers and call the given function with the ANN 
--   as argument. When finished destroy the Neural Network
--
withShortcutFann :: [Int]             -- ^ The ANN structure
                 -> (FannPtr -> IO a) -- ^ A function using the ANN
                 -> IO a              -- ^ The return value
withShortcutFann nodes f = do
  bracket
    (createShortcutFann nodes)
    destroyFann
    f

-- | Randomize weights to values in the given range
--
-- Weights in a newly created ANN are already initialized to random values.
-- You can use this function if you want to customize the random weights
-- upper and lower bounds.
--
randomizeWeights :: FannPtr              -- ^ The ANN
                 -> (FannType, FannType) -- ^ min and max bounds for weights
                                         -- initialization
                 -> IO ()
randomizeWeights fann (low, high) = do
  f_fann_randomize_weights fann l h
  where
    l = realToFrac low
    h = realToFrac high

-- | Initialize the weights using Widrow + Nguyen’s algorithm.
-- 
-- This function behaves similarly to fann_randomize_weights.  It will use the
-- algorithm developed by Derrick Nguyen and Bernard Widrow to set the weights
-- in such a way as to speed up training.  This technique is not always
-- successful, and in some cases can be less efficient than a purely random
-- initialization.
--
-- The algorithm requires access to the range of the input data (ie, largest
-- and smallest input), and therefore accepts a second argument, data, which
-- is the training data that will be used to train the network.
--
foreign import ccall unsafe "doublefann.h fann_init_weights"
  initWeights :: FannPtr      -- ^ The ANN
              -> TrainDataPtr -- ^ The training data used to calibrate the
                              -- weights
              -> IO ()

-- | Return the number of input nodes of the Neural Network
--
getInputNodesCount :: FannPtr -- ^ The ANN
                    -> IO Int  -- ^ The number of input nodes
getInputNodesCount fann = do
  n <- f_fann_get_num_input fann
  return $ fromIntegral n

-- | Return the number of output nodes of the Neural Network
--
getOutputNodesCount :: FannPtr -- ^ The ANN
                    -> IO Int  -- ^ The number of output nodes
getOutputNodesCount fann = do
  n <- f_fann_get_num_output fann
  return $ fromIntegral n

-- | Return the total number of nodes of the Neural Network
--
getTotalNodesCount :: FannPtr -- ^ The ANN
                    -> IO Int  -- ^ The number of nodes
getTotalNodesCount fann = do
  n <- f_fann_get_total_neurons fann
  return $ fromIntegral n

-- | Return the total number of connections of the Neural Network
--
getConnectionsCount :: FannPtr -- ^ The ANN
                    -> IO Int  -- ^ The number of connections
getConnectionsCount fann = do
  n <- f_fann_get_total_connections fann
  return $ fromIntegral n

-- | Print the ANN connections
--
foreign import ccall safe "doublefann.h fann_print_connections"
  printConnections :: FannPtr -- The ANN
                   -> IO ()

-- | Print the ANN parameters
--
foreign import ccall safe "doublefann.h fann_print_parameters"
  printParameters :: FannPtr -- The ANN
                   -> IO ()

--------------------------------------------------------------------------------
-- Non exported functions
--------------------------------------------------------------------------------

-- | Create a new standard fully connected Neural Network

createStandardFann :: [Int] -> IO FannPtr
createStandardFann nodes = do
  let len = fromIntegral $ length nodes
  withArray (map fromIntegral nodes) $ f_fann_create_standard_array len

-- | Create a sparse not fully connected Neural Network

createSparseFann :: Float -> [Int] -> IO FannPtr
createSparseFann ratio nodes = do
  let len = fromIntegral $ length nodes
  withArray (map fromIntegral nodes)
              $ f_fann_create_sparse_array (realToFrac ratio) len

-- | Create a sparse not fully connected Neural Network with shortcut
--   connections between layers

createShortcutFann :: [Int] -> IO FannPtr
createShortcutFann nodes = do
  let len = fromIntegral $ length nodes
  withArray (map fromIntegral nodes) $ f_fann_create_shortcut_array len

-- | Destroy the Neural Network, releasing memory.
--
foreign import ccall unsafe "doublefann.h fann_destroy"
  destroyFann :: FannPtr -- ^ The ANN to destroy
              -> IO ()

-- | Create a standard fully connected Neural Network

foreign import ccall unsafe "doublefann.h fann_create_standard_array"
  f_fann_create_standard_array :: CUInt -> Ptr CUInt -> IO FannPtr

-- | Create a sparse not fully connected Neural Network

foreign import ccall unsafe "doublefann.h fann_create_sparse_array"
  f_fann_create_sparse_array :: CFloat -> CUInt -> Ptr CUInt -> IO FannPtr

-- | Create a sparse not fully connected Neural Network with shortcuts between
--   layers

foreign import ccall unsafe "doublefann.h fann_create_shortcut_array"
  f_fann_create_shortcut_array :: CUInt -> Ptr CUInt -> IO FannPtr

-- | Run the trained Neural Network with a specific input

foreign import ccall safe "doublefann.h fann_run"
  f_fann_run :: FannPtr -> CFannTypePtr -> IO (CFannTypePtr)

-- | Randomize the weights to values in the given range

foreign import ccall unsafe "doublefann.h fann_randomize_weights"
  f_fann_randomize_weights :: FannPtr -> CFannType -> CFannType -> IO ()

-- | Return the number of input nodes
--
foreign import ccall safe "doublefann.h fann_get_num_input"
  f_fann_get_num_input :: FannPtr -> IO (CUInt)

-- | Return the total number of nodes
--
foreign import ccall safe "doublefann.h fann_get_total_neurons"
  f_fann_get_total_neurons :: FannPtr -> IO (CUInt)

-- | Return the total number of connections
--
foreign import ccall safe "doublefann.h fann_get_total_connections"
  f_fann_get_total_connections :: FannPtr -> IO (CUInt)

-- | Return the number of output nodes
--
foreign import ccall safe "doublefann.h fann_get_num_output"
  f_fann_get_num_output :: FannPtr -> IO (CUInt)