-- Copyright (c) 2009, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-- | Defines types and classes for the data computed by "Feldspar" programs.

module Feldspar.Core.Types where



import Control.Applicative
import Control.Monad
import Data.Char
import Data.Foldable (Foldable)
import qualified Data.Foldable as Fold
import Data.Maybe
import Data.Traversable (Traversable, traverse)

import Types.Data.Num

import Feldspar.Utils



-- * Types as arguments

-- | Used to pass a type to a function without using 'undefined'.
data T a = T

numberT :: forall n . IntegerT n => T n -> Int
numberT _ = fromIntegerT (undefined :: n)



-- * Haskell source code

-- | Types that can represent Haskell types (as source code strings)
class HaskellType a
  where
    -- | Gives the Haskell type denoted by the argument.
    haskellType :: a -> String

instance HaskellType a => HaskellType (Tuple a)
  where
    haskellType = showTuple . fmap haskellType



-- | Types that can represent Haskell values (as source code strings)
class HaskellValue a
  where
    -- | Gives the Haskell code denoted by the argument.
    haskellValue :: a -> String

instance HaskellValue String
  where
    haskellValue = id

instance HaskellValue Int
  where
    haskellValue = show

instance HaskellValue a => HaskellValue (Tuple a)
  where
    haskellValue = showTuple . fmap haskellValue



-- * Tuples

-- | General tuple projection
class NaturalT n => GetTuple n a
  where
    type Part n a
    getTup :: T n -> a -> Part n a

instance GetTuple D0 (a,b)
  where
    type Part D0 (a,b) = a
    getTup _ (a,b) = a

instance GetTuple D1 (a,b)
  where
    type Part D1 (a,b) = b
    getTup _ (a,b) = b

instance GetTuple D0 (a,b,c)
  where
    type Part D0 (a,b,c) = a
    getTup _ (a,b,c) = a

instance GetTuple D1 (a,b,c)
  where
    type Part D1 (a,b,c) = b
    getTup _ (a,b,c) = b

instance GetTuple D2 (a,b,c)
  where
    type Part D2 (a,b,c) = c
    getTup _ (a,b,c) = c

instance GetTuple D0 (a,b,c,d)
  where
    type Part D0 (a,b,c,d) = a
    getTup _ (a,b,c,d) = a

instance GetTuple D1 (a,b,c,d)
  where
    type Part D1 (a,b,c,d) = b
    getTup _ (a,b,c,d) = b

instance GetTuple D2 (a,b,c,d)
  where
    type Part D2 (a,b,c,d) = c
    getTup _ (a,b,c,d) = c

instance GetTuple D3 (a,b,c,d)
  where
    type Part D3 (a,b,c,d) = d
    getTup _ (a,b,c,d) = d



-- | Untyped representation of nested tuples
data Tuple a
       = One a
       | Tup [Tuple a]
     deriving (Eq, Show)

instance Functor Tuple
  where
    fmap f (One a)  = One (f a)
    fmap f (Tup as) = Tup $ map (fmap f) as

instance Foldable Tuple
  where
    foldr f x (One a)  = f a x
    foldr f x (Tup as) = Fold.foldr (flip $ Fold.foldr f) x as

instance Traversable Tuple
  where
    traverse f (One a)  = pure One <*> f a
    traverse f (Tup as) = pure Tup <*> traverse (traverse f) as



-- | Shows a nested tuple in Haskell's tuple syntax (e.g @\"(a,(b,c))\"@)
showTuple :: Tuple String -> String
showTuple (One a)  = a
showTuple (Tup as) = showSeq "(" (map showTuple as) ")"

-- | Replaces each element by its path in the tuple tree. For example:
--
-- > tuplePath (Tup [One 'a',Tup [One 'b', One 'c']])
-- >   ==
-- > Tup [One [0],Tup [One [1,0],One [1,1]]]
tuplePath :: Tuple a -> Tuple [Int]
tuplePath tup = path [] tup
  where
    path pth (One _)  = One pth
    path pth (Tup as) = Tup [path (pth++[n]) a | (a,n) <- as `zip` [0..]]



-- * Data

-- | Representation of primitive types
data PrimitiveType
  = UnitType
  | BoolType
  | IntType
  | FloatType
    deriving (Eq, Show)

-- | Untyped representation of primitive data
data PrimitiveData
  = UnitData
  | BoolData  Bool
  | IntData   Int
  | FloatData Float
    deriving (Eq, Show)

-- | Representation of storable types (arrays of primitive data). Array
-- dimensions are given as a list of integers, starting with outermost array
-- level. Primitive types are treated as zero-dimensional arrays.
data StorableType = StorableType [Int] PrimitiveType
    deriving (Eq, Show)

-- | Untyped representation of storable data. Arrays have a length argument that
-- gives the number of elements on the outermost array level. If the data list
-- is shorter than this length, the missing elements are taken to have
-- undefined value. If the data list is longer, the excessive elements are just
-- ignored.
data StorableData
  = PrimitiveData PrimitiveData
  | StorableData Int [StorableData]
    deriving (Eq, Show)

instance HaskellType PrimitiveType
  where
    haskellType UnitType  = "()"
    haskellType BoolType  = "Bool"
    haskellType IntType   = "Int"
    haskellType FloatType = "Float"

instance HaskellValue PrimitiveData
  where
    haskellValue UnitData      = "()"
    haskellValue (BoolData  a) = map toLower (show a)
    haskellValue (IntData   a) = show a
    haskellValue (FloatData a) = show a

instance HaskellType StorableType
  where
    haskellType (StorableType dim t) = arrType ++ dimComment
      where
        l       = length dim
        arrType = replicate l '[' ++ haskellType t ++ replicate l ']'
        dimComment
          | [] <- dim = ""
          | otherwise = showSeq "{-" (map haskellValue dim) "-}"

instance HaskellValue StorableData
  where
    haskellValue (PrimitiveData a)   = haskellValue a
    haskellValue (StorableData _ as) = showSeq "[" (map haskellValue as) "]"



-- | Primitive types
class Storable a => Primitive a

instance Primitive ()
instance Primitive Bool
instance Primitive Int
instance Primitive Float



-- | Array represented as (nested) list. If @a@ is a storable type and @n@ is a
-- type-level natural number, @n :> a@ represents an array of @n@ elements of
-- type @a@. For example, @D3:>D10:>Int@ is a 3 by 10 array of integers. Arrays
-- constructed using 'fromList' are guaranteed not to contain too many elements
-- in any dimension. If there are too few elements in any dimension, the missing
-- ones are taken to have undefined value.
data n :> a = (NaturalT n, Storable a) => ArrayList [a]

infixr 5 :>

instance (NaturalT n, Storable a, Eq a) => Eq (n :> a)
  where
    ArrayList a == ArrayList b = a == b

instance (NaturalT n, Storable a, Show (ListBased a)) => Show (n :> a)
  where
    show = show . toList

instance (NaturalT n, Storable a, Ord a) => Ord (n :> a)
  where
    ArrayList a `compare` ArrayList b = a `compare` b



mapArray ::
    (NaturalT n, Storable a, Storable b) => (a -> b) -> (n :> a) -> (n :> b)

mapArray f (ArrayList as) = ArrayList $ map f as
  -- Couldn't use Functor because of the extra class constraints.



-- | Storable types (zero- or higher-level arrays of primitive data). Should be
-- the same set of types as 'Storable', but this class has no 'Typeable'
-- context, so it doesn't cause a cycle.
--
-- Example:
--
-- > *Feldspar.Core.Types> toList (replicateArray 3 :: D4 :> D2 :> Int)
-- > [[3,3],[3,3],[3,3],[3,3]]
class Typeable a => Storable a
  where
    -- | List-based representation of a storable type
    type ListBased a :: *
    -- | The innermost element of a storable type
    type Element a :: *

    -- | Constructs an array filled with the given element. For primitive types,
    -- this is just the identity function.
    replicateArray :: Element a -> a

    -- | Converts a storable type to a (zero- or higher-level) nested list.
    toList :: a -> ListBased a

    -- | Constructs a storable type from a (zero- or higher-level) nested list.
    -- The resulting value is guaranteed not to have too many elements in any
    -- dimension. Excessive elements are simply cut away.
    fromList :: ListBased a -> a

    -- | Converts a storable value to its untyped representation.
    toData :: a -> StorableData

instance Storable ()
  where
    type ListBased () = ()
    type Element   () = ()

    replicateArray = id
    toList         = id
    fromList       = id

    toData a = PrimitiveData $ case a of
      () -> UnitData

instance Storable Bool
  where
    type ListBased Bool = Bool
    type Element   Bool = Bool

    replicateArray = id
    toList         = id
    fromList       = id
    toData         = PrimitiveData . BoolData

instance Storable Int
  where
    type ListBased Int = Int
    type Element   Int = Int

    replicateArray = id
    toList         = id
    fromList       = id
    toData         = PrimitiveData . IntData

instance Storable Float
  where
    type ListBased Float = Float
    type Element   Float = Float

    replicateArray = id
    toList         = id
    fromList       = id
    toData         = PrimitiveData . FloatData

instance (NaturalT n, Storable a) => Storable (n :> a)
  where
    type ListBased (n :> a) = [ListBased a]
    type Element   (n :> a) = Element a

    replicateArray = ArrayList . replicate n . replicateArray
      where
        n = fromIntegerT (undefined :: n)

    toList (ArrayList as) = map toList as

    fromList as = ArrayList $ take n $ map fromList as
      where
        n = fromIntegerT (undefined :: n)

    toData (ArrayList a) = StorableData n $ map toData a
      where
        n = fromIntegerT (undefined :: n)



isRectangular :: Storable a => a -> Bool
isRectangular = isJust . checkRect . toData
  where
    checkRect (PrimitiveData _)   = return []
    checkRect (StorableData _ []) = return []
    checkRect (StorableData _ as) = do
        dims <- mapM checkRect as
        guard $ allEqual dims
        return (length as : head dims)



-- | All supported types of data (nested tuples of storable data)
class (Eq a, Ord a) => Typeable a
  where
    -- | Gives the representation of the indexing type.
    typeOf :: T a -> Tuple StorableType

instance Typeable ()
  where
    typeOf = const $ One $ StorableType [] UnitType

instance Typeable Bool
  where
    typeOf = const $ One $ StorableType [] BoolType

instance Typeable Int
  where
    typeOf = const $ One $ StorableType [] IntType

instance Typeable Float
  where
    typeOf = const $ One $ StorableType [] FloatType

instance (NaturalT n, Storable a) => Typeable (n :> a)
  where
    typeOf = const $ One $ StorableType (n:dim) t
     where
       n = fromIntegerT (undefined :: n)
       One (StorableType dim t) = typeOf (T::T a)

instance (Typeable a, Typeable b) => Typeable (a,b)
  where
    typeOf = const $ Tup [typeOf (T::T a), typeOf (T::T b)]

instance (Typeable a, Typeable b, Typeable c) => Typeable (a,b,c)
  where
    typeOf = const $ Tup [typeOf (T::T a), typeOf (T::T b), typeOf (T::T c)]

instance (Typeable a, Typeable b, Typeable c, Typeable d) => Typeable (a,b,c,d)
  where
    typeOf = const $ Tup
      [ typeOf (T::T a)
      , typeOf (T::T b)
      , typeOf (T::T c)
      , typeOf (T::T d)
      ]



-- | Checks if the given type is primitive.
isPrimitive :: Typeable a => T a -> Bool
isPrimitive a = case typeOf a of
    One (StorableType [] _) -> True
    _                       -> False