----------------------------------------------------
-- |
-- Module     :  AI.Network
-- License    :  GPL
--
-- Maintainer :  Kiet Lam <ktklam9@gmail.com>
--
--
-- This module provides common activation functions
-- and their derivative
--
--
----------------------------------------------------


module AI.Calculation.Activation (
  Activation(..),
  getActivation,
  getDerivative
  ) where

import AI.Signatures


-- | Represents the activation of
-- each neuron in the neural network
data Activation = Sigmoid           -- ^ The sigmoid activation function
                | HyperbolicTangent -- ^ The hyperbolic tangent activation function


-- | Get the activation function associated with an activation
getActivation :: Activation -> ActivationFunction
getActivation Sigmoid = sigmoid
getActivation HyperbolicTangent = hTangent


-- | Get the derivative function associated with an activation
getDerivative :: Activation -> DerivativeFunction
getDerivative Sigmoid = sigmoidDeriv
getDerivative HyperbolicTangent = hTangentDeriv


-- The sigmoid function
sigmoid :: ActivationFunction
sigmoid x = (1 / (1 + exp(-x)))


-- The derivative of the sigmoid function
--
-- NOTE: The derivative is (sigmoid x) * (1 - sigmoid x)
-- NOT (x * (1 - x))
sigmoidDeriv :: DerivativeFunction
sigmoidDeriv x = (sigmoid x) * (1 - (sigmoid x))


-- The hyperbolic tangent function
hTangent :: ActivationFunction
hTangent x = tanh x


-- The derivative of the hyperbolic tangent
--
-- NOTE: The derivative is 1 - (tanh x)^2
-- NOT 1 - x^2
hTangentDeriv :: DerivativeFunction
hTangentDeriv x = 1 - ((tanh x) ** 2)