{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}

module TheModule where
import Prelude hiding (head,map,foldl,foldr,zipWith)
import qualified Prelude as P
import Foreign




-- the generic vector type : a fixed-length list

data a :. b = !a :. !b
  deriving (Eq,Ord,Read,Show)

infixr :.

type Vec2  a = a :. a :. ()
type Vec3  a = a :. (Vec2 a)
type Vec4  a = a :. (Vec3 a)
type Mat44 a = Vec4 (Vec4 a)





dot u v = fold (+) (zipWith (*) u v)
{-# INLINE dot #-}

multmv m v = map (dot v) m
{-# INLINE multmv #-}



-- basic vector functions : map,fold,zipWith

class Map a b u v | u -> a, v -> b, b u -> v, a v -> u where
  map :: (a -> b) -> u -> v

instance Map a b (a :. ()) (b :. ()) where
  map f (x :. ()) = (f x) :. ()
  {-# INLINE map #-}

instance Map a b (a':.u) (b':.v) => Map a b (a:.a':.u) (b:.b':.v) where
  map f (x:.v) = (f x):.(map f v)
  {-# INLINE map #-}


class ZipWith a b c u v w | u->a, v->b, w->c, u v c -> w where
  zipWith :: (a -> b -> c) -> u -> v -> w

instance ZipWith a b c (a:.()) (b:.()) (c:.()) where
  zipWith f (x:._) (y:._) = f x y :.()
  {-# INLINE zipWith #-}

instance ZipWith a b c (a:.()) (b:.b:.bs) (c:.()) where
  zipWith f (x:._) (y:._) = f x y :.()
  {-# INLINE zipWith #-}

instance ZipWith a b c (a:.a:.as) (b:.()) (c:.()) where
  zipWith f (x:._) (y:._) = f x y :.()
  {-# INLINE zipWith #-}

instance
  ZipWith a b c (a':.u) (b':.v) (c':.w)
  => ZipWith a b c (a:.a':.u) (b:.b':.v) (c:.c':.w)
    where
      zipWith f (x:.u) (y:.v) = f x y :. zipWith f u v
      {-# INLINE zipWith #-}


class Fold a v | v -> a where
  fold  :: (a -> a -> a) -> v -> a
  foldl :: (b -> a -> b) -> b -> v -> b
  foldr :: (a -> b -> b) -> b -> v -> b

instance Fold a (a:.()) where
  fold  f   (a:._) = a
  foldl f z (a:._) = f z a
  foldr f z (a:._) = f a z
  {-# INLINE fold #-}
  {-# INLINE foldl #-}
  {-# INLINE foldr #-}

instance Fold a (a':.u) => Fold a (a:.a':.u) where
  fold  f   (a:.v) = f a (fold f v)
  foldl f z (a:.v) = f (foldl f z v) a
  foldr f z (a:.v) = f a (foldr f z v)
  {-# INLINE fold #-}
  {-# INLINE foldl #-}
  {-# INLINE foldr #-}


-- storable instances for vectors

instance Storable a => Storable (a:.()) where
  sizeOf _ = sizeOf (undefined::a)
  alignment _ = alignment (undefined::a)
  peek !p = peek (castPtr p) >>= \a -> return (a:.())
  peekByteOff !p !o = peek (p`plusPtr`o)
  peekElemOff !p !i = peek (p`plusPtr`(i*sizeOf(undefined::a)))
  poke !p (a:._) = poke (castPtr p) a
  pokeByteOff !p !o !x = poke (p`plusPtr`o) x
  pokeElemOff !p !i !x = poke (p`plusPtr`(i*sizeOf(undefined::a))) x
  {-# INLINE sizeOf #-}
  {-# INLINE alignment #-}
  {-# INLINE peek #-}
  {-# INLINE peekByteOff #-}
  {-# INLINE peekElemOff #-}
  {-# INLINE poke #-}
  {-# INLINE pokeByteOff #-}
  {-# INLINE pokeElemOff #-}

instance (Storable a, Storable (a:.v)) 
  => Storable (a:.a:.v) 
  where
  sizeOf _ = sizeOf (undefined::a) + sizeOf (undefined::(a:.v))
  alignment _ = alignment (undefined::a)
  peek !p = 
    peek (castPtr p) >>= \a -> 
    peek (castPtr (p`plusPtr`sizeOf(undefined::a))) >>= \v -> 
    return (a:.v)
  peekByteOff !p !o = peek (p`plusPtr`o)
  peekElemOff !p !i = peek (p`plusPtr`(i*sizeOf(undefined::(a:.a:.v))))
  poke !p (a:.v) = 
    poke (castPtr p) a >> 
    poke (castPtr (p`plusPtr`sizeOf(undefined::a))) v
  pokeByteOff !p !o !x = poke (p`plusPtr`o) x
  pokeElemOff !p !i !x = poke (p`plusPtr`(i*sizeOf(undefined::(a:.a:.v)))) x
  {-# INLINE sizeOf #-}
  {-# INLINE alignment #-}
  {-# INLINE peek #-}
  {-# INLINE peekByteOff #-}
  {-# INLINE peekElemOff #-}
  {-# INLINE poke #-}
  {-# INLINE pokeByteOff #-}
  {-# INLINE pokeElemOff #-}


