{-
This should be in the standard library.
-}
module StorableInstance where

import Foreign.Storable (Storable (..), )
import Foreign.Ptr (castPtr, )
import qualified Number.Complex as Complex
import qualified Number.Ratio   as Ratio
import qualified Algebra.PrincipalIdealDomain as PID


roundUp :: Int -> Int -> Int
roundUp m x = x + mod (-x) m

-- is handling of alignment correct?
instance (Storable a, Storable b) => Storable (a,b) where
   sizeOf ~(a,b) =
      roundUp (alignment b) (sizeOf a) + sizeOf b
   alignment ~(a,b) = gcd (alignment a) (alignment b)
{- doesn't work - no monomorphism
   peek ptr =
      do a <- peekByteOff ptr 0
         let bu = undefined
         b <- peekByteOff ptr (roundUp (alignment bu) (sizeOf a))
         return (a, asTypeOf b bu)
-}
   peek ptr =
      do a <- peekByteOff ptr 0
         let peekSecond :: Storable b => b -> IO b
             peekSecond bu =
                peekByteOff ptr (roundUp (alignment bu) (sizeOf a))
         b <- peekSecond undefined
         return (a, b)
   poke ptr (a,b) =
      pokeByteOff ptr 0 a >>
      pokeByteOff ptr (roundUp (alignment b) (sizeOf a)) b


instance (Storable a, Storable b, Storable c) => Storable (a,b,c) where
   sizeOf    = sizeOf    . tripleToPair
   alignment = alignment . tripleToPair
   peek ptr = fmap (\ ~(~(a,b),c) -> (a,b,c)) (peek (castPtr ptr))
   poke ptr = poke (castPtr ptr) . tripleToPair

tripleToPair :: (a,b,c) -> ((a,b),c)
tripleToPair ~(a,b,c) = ((a,b),c)

instance (Storable a) => Storable (Complex.T a) where
   sizeOf    = sizeOf    . complexToPair
   alignment = alignment . complexToPair
   peek ptr = fmap (uncurry (Complex.+:)) (peek (castPtr ptr))
   poke ptr = poke (castPtr ptr) . complexToPair

complexToPair :: Complex.T a -> (a,a)
complexToPair a = (Complex.real a, Complex.imag a)

instance (Storable a, PID.C a) => Storable (Ratio.T a) where
   sizeOf    = sizeOf    . ratioToPair
   alignment = alignment . ratioToPair
   peek ptr = fmap (uncurry (Ratio.%)) (peek (castPtr ptr))
   poke ptr = poke (castPtr ptr) . ratioToPair

ratioToPair :: Ratio.T a -> (a,a)
ratioToPair x = (Ratio.numerator x, Ratio.denominator x)


{-
{- Why is this allowed? -}
test :: Char
test = const 'a' undefined

{- Why is type defaulting applied here? The type of 'c' should be fixed. -}
test1 :: (Integral a, RealField.C a) => a
test1 =
   let c = undefined
   in  asTypeOf (round c) c
-}