{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-top-binds #-}
module DataTypeTest
    ( specsWith
    , dataTypeMigrate
    , roundTime
    , roundUTCTime
    ) where

import Control.Applicative (liftA2)
import qualified Data.ByteString as BS
import Data.Fixed (Pico)
import Data.Foldable (for_)
import Data.IntMap (IntMap)
import qualified Data.Text as T
import Data.Time (Day, UTCTime (..), TimeOfDay, timeToTimeOfDay, timeOfDayToTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import Test.QuickCheck.Arbitrary (Arbitrary, arbitrary)
import Test.QuickCheck.Gen (Gen(..))
import Test.QuickCheck.Instances ()
import Test.QuickCheck.Random (newQCGen)

import Database.Persist.TH
import Init

type Tuple a b = (a, b)

-- Test lower case names
share [mkPersist persistSettings, mkMigrate "dataTypeMigrate"] [persistLowerCase|
DataTypeTable no-json
    text Text
    textMaxLen Text maxlen=100
    bytes ByteString
    bytesTextTuple (Tuple ByteString Text)
    bytesMaxLen ByteString maxlen=100
    int Int
    intList [Int]
    intMap (IntMap Int)
    double Double
    bool Bool
    day Day
    utc UTCTime
|]

cleanDB'
    ::
    ( MonadIO m, PersistStoreWrite (BaseBackend backend), PersistQuery backend) => ReaderT backend m ()
cleanDB' = deleteWhere ([] :: [Filter (DataTypeTableGeneric backend)])

roundFn :: RealFrac a => a -> Integer
roundFn = round

roundTime :: TimeOfDay -> TimeOfDay
roundTime t = timeToTimeOfDay $ fromIntegral $ roundFn $ timeOfDayToTime t

roundUTCTime :: UTCTime -> UTCTime
roundUTCTime t =
    posixSecondsToUTCTime $ fromIntegral $ roundFn $ utcTimeToPOSIXSeconds t

randomValues :: Arbitrary a => Int -> IO [a]
randomValues i = do
  gs <- replicateM i newQCGen
  return $ zipWith (unGen arbitrary) gs [0..]

instance Arbitrary DataTypeTable where
  arbitrary = DataTypeTable
     <$> arbText                -- text
     <*> (T.take 100 <$> arbText)          -- textManLen
     <*> arbitrary              -- bytes
     <*> liftA2 (,) arbitrary arbText      -- bytesTextTuple
     <*> (BS.take 100 <$> arbitrary)       -- bytesMaxLen
     <*> arbitrary              -- int
     <*> arbitrary              -- intList
     <*> arbitrary              -- intMap
     <*> arbitrary              -- double
     <*> arbitrary              -- bool
     <*> arbitrary              -- day
     <*> (truncateUTCTime   =<< arbitrary) -- utc

specsWith
    :: forall db backend m entity.
    ( db ~ ReaderT backend m
    , PersistStoreRead backend
    , PersistEntity entity
    , PersistEntityBackend entity ~ BaseBackend backend
    , Arbitrary entity
    , PersistStoreWrite backend
    , PersistStoreWrite (BaseBackend backend)
    , PersistQueryWrite (BaseBackend backend)
    , PersistQueryWrite backend
    , MonadFail m
    , MonadIO m
    )
    => (db () -> IO ())
    -- ^ DB Runner
    -> Maybe (db [Text])
    -- ^ Optional migrations to run
    -> [TestFn entity]
    -- ^ List of entity fields to test
    -> [(String, entity -> Pico)]
    -- ^ List of pico fields to test
    -> (entity -> Double)
    -> Spec
specsWith runDb mmigration checks apprxChecks doubleFn = describe "data type specs" $
    it "handles all types" $ asIO $ runDb $ do

        _ <- sequence_ mmigration
        -- Ensure reading the data from the database works...
        _ <- sequence_ mmigration
        cleanDB'
        rvals <- liftIO $ randomValues 1000
        for_ rvals $ \x -> do
            key <- insert x
            Just y <- get key
            liftIO $ do
                let check :: (Eq a, Show a) => String -> (entity -> a) -> IO ()
                    check s f = (s, f x) @=? (s, f y)
                -- Check floating-point near equality
                let check' :: (Fractional p, Show p, Real p) => String -> (entity -> p) -> IO ()
                    check' s f
                        | abs (f x - f y) < 0.000001 = return ()
                        | otherwise = (s, f x) @=? (s, f y)
                -- Check individual fields for better error messages
                for_ checks $ \(TestFn msg f) -> check msg f
                for_ apprxChecks $ \(msg, f) -> check' msg f

                -- Do a special check for Double since it may
                -- lose precision when serialized.
                when (getDoubleDiff (doubleFn x) (doubleFn y) > 1e-14) $
                    check "double" doubleFn
    where
      normDouble :: Double -> Double
      normDouble x | abs x > 1 = x / 10 ^ (truncate (logBase 10 (abs x)) :: Integer)
                   | otherwise = x
      getDoubleDiff x y = abs (normDouble x - normDouble y)