{-# LANGUAGE UndecidableInstances #-}
module Feldspar.FixedPoint
    ( Fix(..), Fixable(..)
    , freezeFix, freezeFix', unfreezeFix, unfreezeFix'
    , (?!), fixFold
    )
where

import qualified Prelude
import Feldspar
import Feldspar.DSL.Network hiding (In,Out)
import Feldspar.Core.Representation
import Feldspar.Vector
import Data.Ratio

-- | Abstract real number type with exponent and mantissa
data Fix a =
    Fix
    { exponent  :: Data DefaultInt
    , mantissa  :: Data a
    }
    deriving (Prelude.Eq,Prelude.Show)

instance
    ( Bounded a
    , Numeric a
    , Bits a
    , Ord a
    , Range a ~ Size a
    , Prelude.Real a
    ) => Num (Fix a)
  where
    fromInteger n = Fix 0 (Prelude.fromInteger n)
    (+) = fixAddition
    (*) = fixMultiplication
    negate = fixNegate
    abs = fixAbsolute
    signum = fixSignum

instance
    ( Bounded a
    , Numeric a
    , Bits a
    , Ord a
    , Range a ~ Size a
    , Prelude.Real a
    , Integral a
    ) => Fractional (Fix a)
  where
    (/) = fixDiv'
    recip = fixRecip'
    fromRational = fixfromRational

fixAddition :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a) => Fix a -> Fix a -> Fix a
fixAddition f1@(Fix e1 m1) f2@(Fix e2 m2) = Fix e m
   where
     e    =  max e1 e2
     m    =  mantissa (fix e f1) + mantissa (fix e f2)

fixMultiplication :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a) => Fix a -> Fix a -> Fix a
fixMultiplication f1@(Fix e1 m1) f2@(Fix e2 m2) = Fix e m
   where
     e  =  e1 + e2
     m    =  m1 * m2
     
fixNegate :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a) => Fix a -> Fix a 
fixNegate f1@(Fix e1 m1)  = Fix e1 m
   where
     m = negate m1

fixAbsolute :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a) => Fix a -> Fix a 
fixAbsolute f1@(Fix e1 m1)  = Fix e1 m
   where
     m = abs m1

fixSignum :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a) => Fix a -> Fix a 
fixSignum f1@(Fix e1 m1)  = Fix 0 m
   where
     m = signum m1

fixFromInteger :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a) =>  Integer -> Fix a 
fixFromInteger i  = Fix 0 m
   where
     m = fromInteger i

fixDiv' :: (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a,Integral a) => Fix a -> Fix a -> Fix a
fixDiv' f1@(Fix e1 m1) f2@(Fix e2 m2) = Fix e m
   where
     e = e1-e2
     m  = div m1 m2

fixRecip' :: forall a . (Bounded a,Numeric a, Bits a, Ord a, Range a ~ Size a, Prelude.Real a,Integral a) => Fix a -> Fix a
fixRecip' f@(Fix e m) = Fix (e + (value $ wordLength (T :: T a) - 1)) (div sh m)
   where
     sh  :: Data a
     sh  = (1::Data a) << (value $ fromInteger $ toInteger $ wordLength (T :: T a) - 1)

fixfromRational :: forall a . (Range a ~ Size a, Numeric a, Integral a, Num a,Type a) =>
                   Prelude.Rational -> Fix a
fixfromRational inp = Fix exponent mantissa
   where
      inpAsFloat :: Float
      inpAsFloat =  fromRational inp
      intPart :: Float
      intPart =  fromRational $ toRational $ (Prelude.floor inpAsFloat)
      intPartWidth :: DefaultInt
      intPartWidth =  Prelude.ceiling  $ logBase 2 intPart
      fracPartWith :: DefaultInt
      fracPartWith =  (wordLength (T :: T a)) - intPartWidth - 2
      mantissa = value $ Prelude.floor $ inpAsFloat * 2.0 ** fromRational (toRational fracPartWith)
      exponent =  negate $ value fracPartWith

instance (Type a) => EdgeInfo (Fix a)
  where
    type Info (Fix a)   = EdgeSize () (DefaultInt, a)
    edgeInfo            = edgeInfo . toEdge

instance (Type a) => MultiEdge (Fix a) Feldspar EdgeSize
  where
    type Role     (Fix a)   = ()
    type Internal (Fix a)   = (DefaultInt, a)
    toEdge           = toEdge . freezeFix
    fromInEdge       = unfreezeFix . fromInEdge
    fromOutEdge info = unfreezeFix . fromOutEdge info

instance (Type a) => Syntactic (Fix a)

-- | Convers an abstract real number to a pair of exponent and mantissa
freezeFix :: (Type a) => Fix a -> Data (DefaultInt,a)
freezeFix (Fix e m) = pair e m

-- | Convers an abstract real number to fixed point integer with given exponent
freezeFix' :: (Bits a) => DefaultInt -> Fix a -> Data a
freezeFix' e f = mantissa $ fix (value e) f

-- | Converts a pair of exponent and mantissa to an abstract real number
unfreezeFix :: (Type a) => Data (DefaultInt,a) -> Fix a
unfreezeFix p = Fix (getFst p) (getSnd p)

-- | Converts a fixed point integer with given exponent to an abstract real number
unfreezeFix' :: DefaultInt -> Data a -> Fix a
unfreezeFix' e m = Fix (value e) m

significantBits :: forall a . (Type a, Size a ~ Range a, Num a, Ord a, Prelude.Real a) => Data a -> DefaultInt
significantBits x = DefaultInt $ fromInteger $ toInteger $ (Prelude.floor mf)+1
  where
    r :: Range a
    r = dataSize x
    m :: a
    m = Prelude.max (Prelude.abs $ lowerBound r) (Prelude.abs $ upperBound r)
    mf :: Float
    mf = logBase 2 $ fromRational $ toRational m

setSignificantBits :: forall a . (Type a, Size a ~ Range a, Num a, Ord a, Prelude.Real a) => a -> Data a -> Data a
setSignificantBits sb x = resizeData r x
   where 
     r :: Range a
     r =  Range 0 sb

wordLength :: forall a . (Prelude.Bounded a,Type a,Size a ~ Range a,Num a,Ord a,Prelude.Real a) => T a -> DefaultInt
wordLength x = (Prelude.ceiling $ logBase 2 $ fromRational $ toRational (maxBound :: a)) + 1

wordLength' :: forall a . (Prelude.Bounded a,Prelude.Real a) => a -> DefaultInt
wordLength' x = swl
   where
    b   :: a
    wl  :: DefaultInt
    swl :: DefaultInt
    b   = maxBound::a
    wl  = Prelude.ceiling $ logBase 2 $ fromRational $ toRational b
    swl = wl + 1 

-- | Operations to get and set exponent
class (Splittable t) => Fixable t where
    fix :: Data DefaultInt -> t -> t
    getExp :: t -> Data DefaultInt

instance (Bits a) => Fixable (Fix a) where
    fix e' (Fix e m) = Fix e' $ e' > e ? (m >> i2n (e' - e), m << i2n (e - e'))
    getExp = Feldspar.FixedPoint.exponent

instance Fixable (Data Float) where
    fix = const id
    getExp = const $ fromInteger $ toInteger $ Feldspar.exponent (0.0 :: Float)

data T a = T

-- | Operations to split data into dynamic and static parts
class (Syntactic (Dynamic t)) => Splittable t where
    type Static t
    type Dynamic t
    store       :: t -> (Static t, Dynamic t)
    retrieve    :: (Static t, Dynamic t) -> t
    patch       :: Static t -> t -> t
    common      :: t -> t -> Static t

instance (Type a) => Splittable (Data a) where
    type Static (Data a) = ()
    type Dynamic (Data a) = Data a
    store x = ((),x)
    retrieve = snd
    patch = const id
    common _ _ = ()

instance (Type a, Bits a) => Splittable (Fix a) where
    type Static (Fix a) = Data DefaultInt
    type Dynamic (Fix a) = Data a
    store f = (Feldspar.FixedPoint.exponent f, mantissa f)
    retrieve = uncurry Fix
    patch = fix
    common f g = max (Feldspar.FixedPoint.exponent f) (Feldspar.FixedPoint.exponent g)

-- | A version of vector fold for fixed point algorithms
fixFold :: forall a b . (Splittable a) => (a -> b -> a) -> a -> Vector b -> a
fixFold fun ini vec = retrieve (static, fold fun' ini' vec)
  where
    static = fst $ store ini
    ini' = snd $ store ini
    fun' st el = snd $ store $ patch static $ retrieve (static,st) `fun` el

-- | A version of branching for fixed point algorithms
infix 1 ?!
(?!) :: forall a . (Syntactic a, Splittable a) => Data Bool -> (a,a) -> a
cond ?! (x,y) = retrieve (comm, cond ? (x',y'))
  where
    comm = common x y
    x' = snd $ store $ patch comm x
    y' = snd $ store $ patch comm y