--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

module Feldspar.FixedPoint where

import qualified Prelude
import Feldspar.Prelude
import Feldspar.Core.Types
import Feldspar.Core.Expr
import Feldspar.Core
import Data.Ratio

import System.IO.Unsafe
import Feldspar.Core.Functions

type Fix32  = (Int, Data Signed32)
type UFix32 = (Int, Data Unsigned32)
type Fix16  = (Int, Data Signed16)
type UFix16 = (Int, Data Unsigned16)
type Fix8  = (Int, Data Signed8)
type UFix8 = (Int, Data Unsigned8)
type Fix  = (Int,Data Int)

intToFix :: Int -> Data Int -> Fix
intToFix exp val = (exp, val)

intToFix32 :: Int -> Data Signed32 -> Fix32
intToFix32 exp val = (exp, val)

intToUFix32 :: Int -> Data Unsigned32 -> UFix32
intToUFix32 exp val = (exp, val)

intToFix16 :: Int -> Data Signed16 -> Fix16
intToFix16 exp val = (exp, val)

intToUFix16 :: Int -> Data Unsigned16 -> UFix16
intToUFix16 exp val = (exp, val)

intToFix8 :: Int -> Data Signed8 -> Fix8
intToFix8 exp val = (exp, val)

intToUFix8 :: Int -> Data Unsigned8 -> UFix8
intToUFix8 exp val = (exp, val)

fixToInt :: Int -> Fix -> Data Int
fixToInt exp' (exp,val) = val `leftShift` (exp-exp')

fix32ToInt :: Int -> Fix32 -> Data Signed32
fix32ToInt exp' (exp,val) = val `leftShift` (exp-exp')

uFix32ToInt :: Int -> UFix32 -> Data Unsigned32
uFix32ToInt exp' (exp,val) = val `leftShift` (exp-exp')

fix16ToInt :: Int -> Fix16 -> Data Signed16
fix16ToInt exp' (exp,val) = val `leftShift` (exp-exp')

uFix16ToInt :: Int -> UFix16 -> Data Unsigned16
uFix16ToInt exp' (exp,val) = val `leftShift` (exp-exp')

fix8ToInt :: Int -> Fix8 -> Data Signed8
fix8ToInt exp' (exp,val) = val `leftShift` (exp-exp')

uFix8ToInt :: Int -> UFix8 -> Data Unsigned8
uFix8ToInt exp' (exp,val) = val `leftShift` (exp-exp')

floatToFix :: Float -> Fix
floatToFix f = (0, value $ Prelude.round f)

floatToFix32 :: Float -> Fix32
floatToFix32 f = (0, value $ Prelude.round f)

floatToUFix32 :: Float -> UFix32
floatToUFix32 f = (0, value $ Prelude.round f)

floatToFix16 :: Float -> Fix16
floatToFix16 f = (0, value $ Prelude.round f)

floatToUFix16 :: Float -> UFix16
floatToUFix16 f = (0, value $ Prelude.round f)

floatToFix8 :: Float -> Fix8
floatToFix8 f = (0, value $ Prelude.round f)

floatToUFix8 :: Float -> UFix8
floatToUFix8 f = (0, value $ Prelude.round f)


floatToFix32' :: Int -> Float -> Fix32
floatToFix32' exp fl = (exp, value $ Prelude.round $
          (fl Prelude./ (2.0 Prelude.** (fromInteger(toInteger exp)))::Float))

floatToUFix32' :: Int -> Float -> UFix32
floatToUFix32' exp fl = (exp, value $ Prelude.round $
          (fl Prelude./ (2.0 Prelude.** (fromInteger(toInteger exp)))::Float))

floatToFix16' :: Int -> Float -> Fix16
floatToFix16' exp fl = (exp, value $ Prelude.round $
          (fl Prelude./ (2.0 Prelude.** (fromInteger(toInteger exp)))::Float))

floatToUFix16' :: Int -> Float -> UFix16
floatToUFix16' exp fl = (exp, value $ Prelude.round $
          (fl Prelude./ (2.0 Prelude.** (fromInteger(toInteger exp)))::Float))

floatToFix8' :: Int -> Float -> Fix8
floatToFix8' exp fl = (exp, value $ Prelude.round $
          (fl Prelude./ (2.0 Prelude.** (fromInteger(toInteger exp)))::Float))

floatToUFix8' :: Int -> Float -> UFix8
floatToUFix8' exp fl = (exp, value $ Prelude.round $
          (fl Prelude./ (2.0 Prelude.** (fromInteger(toInteger exp)))::Float))


toExp32 :: Int -> Fix32 -> Fix32
toExp32 exp (e,i) = (exp, i `leftShift` (e-exp))

toExpU32 :: Int -> UFix32 -> UFix32
toExpU32 exp (e,i) = (exp, i `leftShift` (e-exp))

toExp16 :: Int -> Fix16 -> Fix16
toExp16 exp (e,i) = (exp, i `leftShift` (e-exp))

toExpU16 :: Int -> UFix16 -> UFix16
toExpU16 exp (e,i) = (exp, i `leftShift` (e-exp))

toExp8 :: Int -> Fix8 -> Fix8
toExp8 exp (e,i) = (exp, i `leftShift` (e-exp))

toExpU8 :: Int -> UFix8 -> UFix8
toExpU8 exp (e,i) = (exp, i `leftShift` (e-exp))

fixToFloat :: (Integral a,Integral b) => ( a , Data b ) -> Float
fixToFloat fix =( 2.0 Prelude.** (fromInteger (toInteger(fst fix)))) Prelude.*
                 ( (fromInteger ( toInteger ( evalD (snd fix) )) )::Float )

fix32ToFloat :: Fix32-> Float
fix32ToFloat fix = fixToFloat fix

uFix32ToFloat :: UFix32-> Float
uFix32ToFloat fix = fixToFloat fix

fix16ToFloat :: Fix16-> Float
fix16ToFloat fix = fixToFloat fix

uFix16ToFloat :: UFix16-> Float
uFix16ToFloat fix = fixToFloat fix

fix8ToFloat :: Fix8-> Float
fix8ToFloat fix = fixToFloat fix

uFix8ToFloat :: UFix8-> Float
uFix8ToFloat fix = fixToFloat fix

inBounds :: Bool -> Int -> Int -> Bool
inBounds s wbits i | s Prelude.&& (i Prelude.> sintmax)       = False
           | s Prelude.&& (i Prelude.< sintmin)               = False
           | (Prelude.not s) Prelude.&& (i Prelude.> uintmax) = False
           | (Prelude.not s) Prelude.&& (i Prelude.< uintmin) = False
           | otherwise  = True
   where
      (sintmax :: Int) = 2 Prelude.^ (wbits Prelude.- 1) - 1
      (sintmin :: Int) = -sintmax
      (uintmax :: Int) = 2 Prelude.^ wbits Prelude.- 1
      (uintmin :: Int) = 0

fl01toFix :: (Integral a,Integral b) => Bool ->Int-> Float
                                     -> (a,Data b) -> Bool -> (a,Data b)
fl01toFix s bts fl fix gt
  | (Prelude.not gt) Prelude.&& ( fl1 Prelude.> fl   ) =
        fl01toFix s bts fl ((fst fix) Prelude.- 1, snd fix  ) Prelude.False
  | (Prelude.not gt) Prelude.&& ( fl1 Prelude.< fl   ) =
        fl01toFix s bts fl ((fst fix) Prelude.- 1, snd fix  ) Prelude.True
  | (Prelude.not gt) Prelude.&& ( fl1 Prelude.== fl   ) =
        ((fst fix) Prelude.- 1, snd fix  )
  | gt Prelude.&& ( (inBounds s bts val') Prelude.&& ( fl2 Prelude.> fl )  ) =
      fl01toFix s bts fl ((fst fix) Prelude.- 1, 2 * (snd fix) ) Prelude.True
  | gt Prelude.&& ( (inBounds s bts val') Prelude.&& ( fl2 Prelude.< fl )  ) =
      fl01toFix s bts fl ((fst fix) Prelude.- 1,2 * ( snd fix) + 1) Prelude.True
  | gt Prelude.&& ( (inBounds s bts val') Prelude.&& ( fl2 Prelude.== fl )  ) =
      fl01toFix s bts fl ((fst fix) Prelude.- 1, 2 * (snd fix) +1 ) Prelude.True
  | otherwise = fix
    where
      fl2 = (2.0 Prelude.* (fromInteger val) Prelude.+ 1.0 ) Prelude.*
            (2.0 Prelude.** ( (fromInteger exp) Prelude.- 1.0 ))
      fl1 =( fromInteger val ) Prelude.*
           (2.0 Prelude.** ( (fromInteger exp) Prelude.- 1.0 ))
      val'= 2 Prelude.* (fromInteger val) Prelude.+ 1
      val = toInteger $ evalD $ snd fix
      exp = toInteger $ fst fix

fl01toFix' :: Float -> Fix -> Bool -> Fix
fl01toFix' = fl01toFix True 31

fl01toUFix32 :: Float -> UFix32 -> Bool -> UFix32
fl01toUFix32 = fl01toFix False 32

fl01toFix32  :: Float -> Fix32 -> Bool -> Fix32
fl01toFix32 = fl01toFix True 31

fl01toUFix16 :: Float -> UFix16 -> Bool -> UFix16
fl01toUFix16 = fl01toFix False 16

fl01toFix16  :: Float -> Fix16 -> Bool -> Fix16
fl01toFix16 = fl01toFix True 15

fl01toUFix8 :: Float -> UFix8 -> Bool -> UFix8
fl01toUFix8 = fl01toFix False 8

fl01toFix8  :: Float -> Fix8 -> Bool -> Fix8
fl01toFix8 = fl01toFix True 7

zeroOneToFix :: Float -> Fix
zeroOneToFix fl = fl01toFix' fl (1,1) Prelude.False

zeroOneToFix32 :: Float -> Fix32
zeroOneToFix32 fl = fl01toFix32 fl (1,1) Prelude.False

zeroOneToUFix32 :: Float -> UFix32
zeroOneToUFix32 fl = fl01toUFix32 fl (1,1) Prelude.False

zeroOneToFix16 :: Float -> Fix16
zeroOneToFix16 fl = fl01toFix16 fl (1,1) Prelude.False

zeroOneToUFix16 :: Float -> UFix16
zeroOneToUFix16 fl = fl01toUFix16 fl (1,1) Prelude.False

zeroOneToFix8 :: Float -> Fix8
zeroOneToFix8 fl = fl01toFix8 fl (1,1) Prelude.False

zeroOneToUFix8 :: Float -> UFix8
zeroOneToUFix8 fl = fl01toUFix8 fl (1,1) Prelude.False


addFix ::(Integral b,Bits b) =>
              Int -> (Int,Data b) -> (Int,Data b) -> (Int,Data b)
addFix e (e1,i1) (e2,i2) =
      (e, i1 `leftShift` (e1 Prelude.- e) + i2 `leftShift` (e2 Prelude.- e))

addFix'' :: Int -> Fix -> Fix -> Fix
addFix'' = addFix

addFix32 :: Int -> Fix32 -> Fix32 -> Fix32
addFix32 = addFix

addUFix32 :: Int -> UFix32 -> UFix32 -> UFix32
addUFix32 = addFix

addFix16 :: Int -> Fix16 -> Fix16 -> Fix16
addFix16 = addFix

addUFix16 :: Int -> UFix16 -> UFix16 -> UFix16
addUFix16 = addFix

addFix8 :: Int -> Fix8 -> Fix8 -> Fix8
addFix8 = addFix

addUFix8 :: Int -> UFix8 -> UFix8 -> UFix8
addUFix8 = addFix

recipFix :: (Integral b,Bits b) =>
                Int -> (Int,Data b) -> (Int,Data b)
recipFix exp (e,i) = (e2,i2)
   where
      e2 = exp
      i2 = div sh i
      sh = 1 `rightShift` (exp Prelude.+ e)

recipFix' :: Int -> Fix -> Fix
recipFix' =  recipFix

recipFix32 :: Int -> Fix32 -> Fix32
recipFix32 =  recipFix

recipUFix32 :: Int -> UFix32 -> UFix32
recipUFix32 =  recipFix

recipFix16 :: Int -> Fix16 -> Fix16
recipFix16 =  recipFix

recipUFix16 :: Int -> UFix16 -> UFix16
recipUFix16 =  recipFix

recipFix8 :: Int -> Fix8 -> Fix8
recipFix8 =  recipFix

recipUFix8 :: Int -> UFix8 -> UFix8
recipUFix8 =  recipFix

divFix :: (Integral b,Bits b) =>
               Int -> (Int,Data b) -> (Int,Data b)
           -> (Int,Data b)
divFix exp (e1,i1) (e2,i2) = (e,i)
   where
      e = exp
      i = div sh i2
      val = e1 Prelude.- e2 Prelude.- exp
      sh = i1 `leftShift` val

divFix' :: Int -> Fix -> Fix -> Fix
divFix' = divFix

divFix32 :: Int -> Fix32 -> Fix32 -> Fix32
divFix32 = divFix

divUFix32 :: Int -> UFix32 -> UFix32 -> UFix32
divUFix32 = divFix

divFix16 :: Int -> Fix16 -> Fix16 -> Fix16
divFix16 = divFix

divUFix16 :: Int -> UFix16 -> UFix16 -> UFix16
divUFix16 = divFix

divFix8 :: Int -> Fix8 -> Fix8 -> Fix8
divFix8 = divFix

divUFix8 :: Int -> UFix8 -> UFix8 -> UFix8
divUFix8 = divFix

addFix' ::(Integral b,Bits b) =>
              (Int,Data b) -> (Int,Data b) -> (Int,Data b)
addFix' (e1,i1) (e2,i2) =
      (  m, ( i1 `leftShift` (e1 Prelude.- m)) +
             ( i2 `leftShift` ( e2 Prelude.- m ) ) )
   where
      m = Prelude.max e1 e2

mulFix' ::(Integral b,Bits b) =>
              (Int,Data b) -> (Int,Data b) -> (Int,Data b)
mulFix' (e1,i1) (e2,i2)=(added ,(i1*i2 ) )
      where
        added = e1 Prelude.+ e2

negate' ::(Integral b,Bits b) =>
              (Int,Data b) -> (Int,Data b)
negate' (e,i) = (e, negate i )

abs' ::(Integral b,Bits b) =>
              (Int,Data b) -> (Int,Data b)
abs' (e,i) = (e,abs(i))

signum' ::(Integral b,Bits b) =>
              (Int,Data b) -> (Int,Data b)
signum' (e,i) = ( 0 , signum i )

fromInteger' ::(Integral b,Bits b) =>
                Integer -> (Int,Data b)
fromInteger' i = ( 0 , fromInteger i )

instance Num Fix where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'

instance Num Fix32 where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'

instance Num UFix32 where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'

instance Num Fix16 where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'

instance Num UFix16 where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'


instance Num Fix8 where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'

instance Num UFix8 where
    x + y = addFix' x y
    x * y=mulFix' x y
    negate = negate'
    abs = abs'
    signum = signum'
    fromInteger = fromInteger'

recip' ::(Integral b,Bits b) =>
              Int -> (Int,Data b) -> (Int,Data b)
recip' bts (e,i) = ( e2, i2 )
      where
        k   = bts - 2
        e2  = Prelude.negate $ e Prelude.+ k
        sh  = 1 `leftShift` k
        i2  = div sh i

fromRational' ::(Integral b,Bits b,Num (Int,Data b)) =>
                Bool -> Int->(Float->(Int,Data b))->(Integer->(Int,Data b))
                    -> Rational -> (Int,Data b)
fromRational' s bts zotf fi  rat = addFix e integ frac
      where
       e      = (fst frac) Prelude.+ toShift'
       toShift' | s = Prelude.min toShift
                    ((bts Prelude.- 1) Prelude.- bitsInteg)
                | (Prelude.not s) =
                    Prelude.min toShift (bts Prelude.- bitsInteg)
       toShift | s = Prelude.max 0
                   (bitsFrac Prelude.- (bts Prelude.- 1) Prelude.+ bitsInteg)
               | (Prelude.not s) =
                   Prelude.max 0 (bitsFrac Prelude.- bts Prelude.+ bitsInteg)
       bitsFrac  = Prelude.floor $
                     Prelude.logBase 2.0 (fromInteger (toInteger vfrac))
       bitsInteg = Prelude.floor $
                     Prelude.logBase 2.0 (fromInteger (toInteger vinteg))
       vinteg = evalD $ snd integ
       vfrac  = evalD $ snd frac
       frac   = zotf fl01
       integ  = (fi
                  ( Prelude.quot (numerator rat) (denominator rat)  ))
       fl01   = fl - ((Prelude.fromInteger (Prelude.floor fl))::Float)
       fl     = (Prelude.fromRational rat)::Float

instance Fractional Fix where
   recip = recip' 32
   fromRational = fromRational' True  32 zeroOneToFix fromInteger

instance Fractional Fix32 where
   recip = recip' 32
   fromRational = fromRational' True 32 zeroOneToFix32 fromInteger

instance Fractional UFix32 where
   recip = recip' 31
   fromRational = fromRational' False 31 zeroOneToUFix32 fromInteger

instance Fractional Fix16 where
   recip = recip' 16
   fromRational = fromRational' True 16 zeroOneToFix16 fromInteger

instance Fractional UFix16 where
   recip = recip' 15
   fromRational = fromRational' False 15 zeroOneToUFix16 fromInteger

instance Fractional Fix8 where
   recip = recip' 8
   fromRational = fromRational' True 8 zeroOneToFix8 fromInteger

instance Fractional UFix8 where
   recip = recip' 7
   fromRational = fromRational' False 7 zeroOneToUFix8 fromInteger

class FixFloatLike a  where
   addFF   :: Int -> a -> a -> a
   recipFF :: Int -> a -> a
   divFF   :: Int -> a -> a -> a

instance FixFloatLike (Data Float) where
   addFF _ x y = x + y
   recipFF _ x = 1/x
   divFF _ x y = x/y


instance FixFloatLike Fix where
   addFF   = addFix''
   recipFF = recipFix'
   divFF   = divFix'

instance FixFloatLike Fix32 where
   addFF   = addFix32
   recipFF = recipFix32
   divFF   = divFix32

instance FixFloatLike UFix32 where
   addFF   = addUFix32
   recipFF = recipUFix32
   divFF   = divUFix32

instance FixFloatLike Fix16 where
   addFF   = addFix16
   recipFF = recipFix16
   divFF   = divFix16

instance FixFloatLike UFix16 where
   addFF   = addUFix16
   recipFF = recipUFix16
   divFF   = divUFix16

instance FixFloatLike Fix8 where
   addFF   = addFix8
   recipFF = recipFix8
   divFF   = divFix8

instance FixFloatLike UFix8 where
   addFF   = addUFix8
   recipFF = recipUFix8
   divFF   = divUFix8


class FromFloat t where
   float :: Float -> t

instance FromFloat (Data Float) where
   float = value

instance FromFloat Fix where
   float = floatToFix

instance FromFloat Fix32 where
   float = floatToFix32

instance FromFloat UFix32 where
   float = floatToUFix32

instance FromFloat Fix16 where
   float = floatToFix16

instance FromFloat UFix16 where
   float = floatToUFix16

instance FromFloat Fix8 where
   float = floatToFix8

instance FromFloat UFix8 where
   float = floatToUFix8

-- Helper functions to generate shift with non-negative parameter
leftShift :: Bits a => Data a -> Int -> Data a
leftShift a b
    | b Prelude.>= 0    = a << value b
    | otherwise         = a >> value (Prelude.negate b)

rightShift :: Bits a => Data a -> Int -> Data a
rightShift a b
    | b Prelude.>= 0    = a >> value b
    | otherwise         = a << value (Prelude.negate b)