-- Copyright (c) 2009-2010, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-- | Primitive and helper functions supported by Feldspar

module Feldspar.Core.Functions where



import qualified Prelude

import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Expr
import Feldspar.Prelude

import qualified Data.Bits as B

infix 4 ==
infix 4 /=
infix 4 <
infix 4 >
infix 4 <=
infix 4 >=
infix 1 ?



-- * Misc.

noSizeProp :: a -> ()
noSizeProp _ = ()

noSizeProp2 :: a -> b -> ()
noSizeProp2 _ _ = ()



(==) :: Storable a => Data a -> Data a -> Data Bool
a == b
  | a Prelude.== b = true
  | otherwise      = function2 "(==)" noSizeProp2 (Prelude.==) a b
  -- XXX Partial evaluation

(/=) :: Storable a => Data a -> Data a -> Data Bool
a /= b
  | a Prelude.== b = false
  | otherwise      = function2 "(/=)" noSizeProp2 (Prelude./=) a b
  -- XXX Partial evaluation

(<) :: Storable a => Data a -> Data a -> Data Bool
a < b
  | a Prelude.== b = false
  | otherwise      = function2 "(<)" noSizeProp2 (Prelude.<) a b

(>) :: Storable a => Data a -> Data a -> Data Bool
a > b
  | a Prelude.== b = false
  | otherwise      = function2 "(>)" noSizeProp2 (Prelude.>) a b

(<<<) :: Data Int -> Data Int -> Data Bool
a <<< b
  | a Prelude.== b      = false
  | sa `rangeLess`   sb = true
  | sb `rangeLessEq` sa = false
  | otherwise           = function2 "(<)" noSizeProp2 (Prelude.<) a b
  where
    sa = dataSize a
    sb = dataSize b
  -- XXX Enables more partial evaluation than (<). This function should be
  --     generalized and then replace (<).

(>>>) :: Data Int -> Data Int -> Data Bool
a >>> b
  | a Prelude.== b      = false
  | sb `rangeLess`   sa = true
  | sa `rangeLessEq` sb = false
  | otherwise           = function2 "(>)" noSizeProp2 (Prelude.>) a b
  where
    sa = dataSize a
    sb = dataSize b
  -- XXX Enables more partial evaluation than (>). This function should be
  --     generalized and then replace (>).

(<=) :: Storable a => Data a -> Data a -> Data Bool
a <= b
  | a Prelude.== b = true
  | otherwise      = function2 "(<=)" noSizeProp2 (Prelude.<=) a b
  -- XXX Partial evaluation

(>=) :: Storable a => Data a -> Data a -> Data Bool
a >= b
  | a Prelude.== b = true
  | otherwise      = function2 "(>=)" noSizeProp2 (Prelude.>=) a b
  -- XXX Partial evaluation

not :: Data Bool -> Data Bool
not = function "not" noSizeProp Prelude.not

-- | Selects the elements of the pair depending on the condition
(?) :: Computable a => Data Bool -> (a,a) -> a
cond ? (a,b) = ifThenElse cond (const a) (const b) unit

(&&) :: Data Bool -> Data Bool -> Data Bool
(&&) = function2 "(&&)" noSizeProp2 (Prelude.&&)

(||) :: Data Bool -> Data Bool -> Data Bool
(||) = function2 "(||)" noSizeProp2 (Prelude.||)

-- | Lazy conjunction, second argument only run if necessary
(&&*) :: Computable a =>
    (a -> Data Bool) -> (a -> Data Bool) -> (a -> Data Bool)
(f &&* g) a = ifThenElse (f a) g (const false) a

-- | Lazy disjunction, second argument only run if necessary
(||*) :: Computable a =>
    (a -> Data Bool) -> (a -> Data Bool) -> (a -> Data Bool)
(f ||* g) a = ifThenElse (f a) (const true) g a

min :: Storable a => Data a -> Data a -> Data a
min a b = a<b ? (a,b)

max :: Storable a => Data a -> Data a -> Data a
max a b = a>b ? (a,b)

minX :: Data Int -> Data Int -> Data Int
minX a b = case dataToExpr cond1 of
    Value _ _ -> cond1 ? (a,b)
    _         -> cond2 ? (b,a)
  where
    cond1 = a<<<b
    cond2 = b<<<a
  -- XXX Enables more partial evaluation than min. This function should be
  --     generalized and then replace min.

maxX :: Data Int -> Data Int -> Data Int
maxX a b = case dataToExpr cond1 of
    Value _ _ -> cond1 ? (a,b)
    _         -> cond2 ? (b,a)
  where
    cond1 = a>>>b
    cond2 = b>>>a
  -- XXX Enables more partial evaluation than max. This function should be
  --     generalized and then replace max.

div :: Data Int -> Data Int -> Data Int
div = function2 "div" (\_ _ -> fullRange) Prelude.div  -- XXX Improve size propagation

mod :: Data Int -> Data Int -> Data Int
mod = function2 "mod" (\_ _ -> fullRange) Prelude.mod  -- XXX Improve size propagation

(^) :: Data Int -> Data Int -> Data Int
(^) = function2 "(^)" (\_ _ -> fullRange) (Prelude.^)  -- XXX Improve size propagation



-- * Loops

-- | For-loop
--
-- @`for` start end init body@:
--
--   * @start@\/@end@ are the start\/end indexes.
--
--   * @init@ is the starting state.
--
--   * @body@ computes the next state given the current loop index (ranging over
--     @[start .. end]@) and the current state.
for :: Computable a => Data Int -> Data Int -> a -> (Data Int -> a -> a) -> a
for start end init body = snd $ whileSized sz cont body' (start,init)
  where
    szi = rangeByRange (dataSize start) (dataSize end)
    sz  = (szi,universal)

    cont  (i,s) = i <= end
    body' (i,s) = (i+1, body i s)



-- | A sequential \"unfolding\" of an vector
--
-- @`unfoldCore` l init step@:
--
--   * @l@ is the length of the resulting vector.
--
--   * @init@ is the initial state.
--
--   * @step@ is a function computing a new element and the next state from the
--     current index and current state. The index is the position of the new
--     element in the output vector.
unfoldCore
    :: (Computable state, Storable a)
    => Data Length
    -> state
    -> (Data Int -> state -> (Data a, state))
    -> (Data [a], state)

unfoldCore l init step = for 0 (l-1) (outp,init) $ \i (o,state) ->
    let (a,state') = step i state
     in (setIx o i a, state')
  where
    outp = array (mapMonotonic fromIntegral (dataSize l) :> universal) []


-- * Bit manipulation

infixl 5 <<,>>
infixl 4 

-- | The following class provides functions for bit level manipulation
class (B.Bits a, Storable a) => Bits a
  where
  -- Logical operations
  (.&.)         :: Data a -> Data a -> Data a
  (.&.)         =  function2 "(.&.)" (\_ _ -> universal) (B..&.)
  (.|.)         :: Data a -> Data a -> Data a
  (.|.)         =  function2 "(.|.)" (\_ _ -> universal) (B..|.)
  xor           :: Data a -> Data a -> Data a
  xor           =  function2 "xor" (\_ _ -> universal) B.xor
  ()           :: Data a -> Data a -> Data a
  ()           = xor
  complement    :: Data a -> Data a
  complement    =  function "complement" (const universal) B.complement

  -- Operations on individual bits
  bit           :: Data Int -> Data a
  bit           =  function "bit" (const universal) B.bit
  setBit        :: Data a -> Data Int -> Data a
  setBit        =  function2 "setBit" (\_ _ -> universal) B.setBit
  clearBit      :: Data a -> Data Int -> Data a
  clearBit      =  function2 "clearBit" (\_ _ -> universal) B.clearBit
  complementBit :: Data a -> Data Int -> Data a
  complementBit =  function2 "complementBit" (\_ _ -> universal) B.complementBit
  testBit       :: Data a -> Data Int -> Data Bool
  testBit       =  function2 "testBit" noSizeProp2 B.testBit

  -- Moving bits around
  shiftL        :: Data a -> Data Int -> Data a
  shiftL        =  function2 "shiftL" (\_ _ -> universal) B.shiftL
  (<<)          :: Data a -> Data Int -> Data a
  (<<)          =  shiftL
  shiftR        :: Data a -> Data Int -> Data a
  shiftR        =  function2 "shiftR" (\_ _ -> universal) B.shiftR
  (>>)          :: Data a -> Data Int -> Data a
  (>>)          =  shiftR
  rotateL       :: Data a -> Data Int -> Data a
  rotateL       =  function2 "rotateL" (\_ _ -> universal) B.rotateL
  rotateR       :: Data a -> Data Int -> Data a
  rotateR       =  function2 "rotateR" (\_ _ -> universal) B.rotateR

  -- Queries about the type
  bitSize       :: Data a -> Data Int
  bitSize       =  function "bitSize" (const naturalRange) B.bitSize
  isSigned      :: Data a -> Data Bool
  isSigned      =  function "isSigned" noSizeProp B.isSigned

instance Bits Int