{-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE TypeFamilies #-} module Test.Synthesizer.LLVM.Utility where import qualified Synthesizer.LLVM.Parameterized.SignalPacked as SigPS import qualified Synthesizer.LLVM.Parameterized.Signal as SigP import qualified Synthesizer.LLVM.Parameter as Param import qualified Synthesizer.LLVM.Frame.SerialVector as Serial import qualified Synthesizer.State.Signal as SigS import Control.Monad (liftM, liftM2, ) import Control.Applicative ((<$>), ) import qualified Data.StorableVector.Lazy as SVL import qualified Data.StorableVector as SV import Data.StorableVector.Lazy (ChunkSize, ) import Foreign.Storable (Storable, ) import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Class as Class import qualified LLVM.Core as LLVM import qualified Type.Data.Num.Decimal as TypeNum import System.Random (Random, randomRs, StdGen, mkStdGen, ) import qualified Test.QuickCheck as QC import qualified Algebra.RealRing as RealRing import qualified Algebra.Absolute as Absolute import NumericPrelude.Numeric import NumericPrelude.Base genRandomVectorParam :: QC.Gen (Int, StdGen) genRandomVectorParam = liftM2 (,) (QC.choose (1,100)) (mkStdGen <$> QC.arbitrary) randomStorableVector :: (Storable a, Random a) => (a, a) -> (Int, StdGen) -> SV.Vector a randomStorableVector range (len, seed) = fst $ SV.packN len $ randomRs range seed randomStorableVectorLoop :: (Storable a, Random a) => (a, a) -> (Int, StdGen) -> SVL.Vector a randomStorableVectorLoop range param = SVL.cycle $ SVL.fromChunks [randomStorableVector range param] randomSignal :: (Class.MakeValueTuple a, Class.ValueTuple a ~ tuple, Memory.C tuple, Storable a, Random a) => (a, a) -> Param.T p (Int, StdGen) -> SigP.T p (Class.ValueTuple a) randomSignal range p = SigP.fromStorableVectorLazy (randomStorableVectorLoop range <$> p) render :: (Storable a, Class.MakeValueTuple a, Class.ValueTuple a ~ al, Memory.C al) => (SVL.Vector a -> sig) -> SigP.T p (Class.ValueTuple a) -> IO (ChunkSize -> p -> sig) render limit sig = fmap (\func chunkSize -> limit . func chunkSize) $ SigP.runChunky sig data CheckSimilarityState a = CheckSimilarityState a (SVL.Vector a) (SigS.T a) instance (Storable a, Ord a, Absolute.C a) => QC.Testable (CheckSimilarityState a) where property (CheckSimilarityState tol xs ys) = QC.property $ SigS.foldR (&&) True $ -- dangerous, since shortened signals would be tolerated SigS.zipWith (\x y -> abs(x-y) < tol) (SigS.fromStorableSignal xs) ys {-# INLINE checkSimilarityState #-} checkSimilarityState :: (RealRing.C a, Storable a, Class.MakeValueTuple a, Class.ValueTuple a ~ av, Memory.C av) => a -> (SVL.Vector a -> SVL.Vector a) -> SigP.T p av -> (p -> SigS.T a) -> IO (ChunkSize -> p -> CheckSimilarityState a) checkSimilarityState tol limit gen0 sig1 = liftM (\sig0 chunkSize p -> CheckSimilarityState tol (sig0 chunkSize p) (sig1 p)) (render limit gen0) data CheckSimilarity a = CheckSimilarity a (SVL.Vector a) (SVL.Vector a) instance (Storable a, Ord a, Absolute.C a) => QC.Testable (CheckSimilarity a) where property (CheckSimilarity tol xs ys) = QC.property $ SigS.foldR (&&) True $ -- dangerous, since shortened signals would be tolerated SigS.zipWith (\x y -> abs(x-y) < tol) (SigS.fromStorableSignal xs) (SigS.fromStorableSignal ys) {-# INLINE checkSimilarity #-} checkSimilarity :: (RealRing.C b, Storable b, Storable a, Class.MakeValueTuple a, Class.ValueTuple a ~ av, Memory.C av) => b -> (SVL.Vector a -> SVL.Vector b) -> SigP.T p av -> SigP.T p av -> IO (ChunkSize -> p -> CheckSimilarity b) checkSimilarity tol limit gen0 gen1 = liftM2 (\sig0 sig1 chunkSize p -> CheckSimilarity tol (sig0 chunkSize p) (sig1 chunkSize p)) (render limit gen0) (render limit gen1) checkSimilarityPacked :: Float -> (SVL.Vector Float -> SVL.Vector Float) -> SigP.T p (LLVM.Value Float) -> SigP.T p (Serial.Value TypeNum.D4 Float) -> IO (ChunkSize -> p -> CheckSimilarity Float) checkSimilarityPacked tol limit scalar vector = checkSimilarity tol limit scalar (SigPS.unpack vector) {- | Instead of testing on equality immediately we use this interim data type. This allows us to inspect the signals that are compared. -} data CheckEquality a = CheckEquality (SVL.Vector a) (SVL.Vector a) instance (Storable a, Eq a) => QC.Testable (CheckEquality a) where property (CheckEquality x y) = QC.property (x==y) checkEquality :: (Eq a, Storable a, Class.MakeValueTuple a, Class.ValueTuple a ~ av, Memory.C av) => (SVL.Vector a -> SVL.Vector a) -> SigP.T p av -> SigP.T p av -> IO (ChunkSize -> p -> CheckEquality a) checkEquality limit gen0 gen1 = liftM2 (\sig0 sig1 chunkSize p -> CheckEquality (sig0 chunkSize p) (sig1 chunkSize p)) (render limit gen0) (render limit gen1)