module Test.Agata.Common where


import Test.QuickCheck

import Control.Monad (liftM)
import Control.Monad.State.Lazy

import Data.Tagged


type Dimension a = Tagged a Int

instance Num b => Num (Tagged a b) where
  (+) = liftM2 (+)
  (*) = liftM2 (*)
  (-) = liftM2 (-)
  negate = liftM negate
  abs = liftM abs
  signum = liftM signum
  fromInteger = return . fromInteger

instance Real b => Real (Tagged a b) where
  toRational = toRational . unTagged

instance Integral b => Integral (Tagged a b) where
  quot = liftM2 quot
  rem = liftM2 rem
  div = liftM2 div
  mod = liftM2 mod
  quotRem a b = unTagged $ liftM2 quotRem a b >>= \(x,y) -> return (return x,return y)
  divMod a b = unTagged $ liftM2 divMod a b >>= \(x,y) -> return (return x,return y)
  toInteger = toInteger . unTagged

instance Enum b => Enum (Tagged a b) where
  succ = liftM succ
  pred = liftM pred
  toEnum = return . toEnum
  fromEnum = fromEnum . unTagged
  enumFrom = map return . unTagged . liftM enumFrom
  enumFromThen a = map return . unTagged . liftM2 enumFromThen a
  enumFromTo a = map return . unTagged . liftM2 enumFromTo a
  enumFromThenTo a b = map return . unTagged . liftM3 enumFromThenTo a b

taggedWith :: Tagged b a -> b -> Tagged b a
taggedWith = const

type Improving a = StateT (Int, Int, [Int]) Gen a
currentDimension :: Improving (Dimension a)
currentDimension = return `fmap` getLevel where
  getLevel :: Improving Int
  getLevel = gets $ \(l,r,ss) -> l
request :: Improving ()
request = modify $ \(l,r,ss) -> (l,r+1,ss)
acquire :: Improving Int
acquire = do
  get >>= check
  (l,r,s:ss) <- get
  put (l,r,ss)
  return s
  where
    check s = case s of
      (l,r,s:ss) -> return ()
      _ -> error $ "acquire: " ++ show s


piles 0 _      = return []
piles a b 
  | a <= 0     = error "piling 0 or fever piles"
  | otherwise  = genSorted a b b >>= permute where
  genSorted 1 n _ = return [n]
  genSorted p n m = do 
    r <- choose (ceiling $ fromIntegral  n / fromIntegral p,min m n)
    liftM (r:) $ genSorted (p-1) (n-r) (min m r)

permute :: [a] -> Gen [a]
permute = fromList
  where
  fromList []  = return []
  fromList [x] = return [x]
  fromList xs  = fromList l `merge` fromList r
      where (l,r) = splitAt (length xs `div` 2) xs
  merge :: Gen [a] -> Gen [a] -> Gen [a]
  merge rxs rys = do
    xs <- rxs; ys <- rys
    merge' (length xs, xs) (length ys, ys)
   where
    merge' (0 , [])   (_ , ys)   = return ys
    merge' (_ , xs)   (0 , [])   = return xs
    merge' (nx, x:xs) (ny, y:ys) = do
      k <- choose (1,nx+ny)
      if k <= nx
        then (x:) `liftM` ((nx-1, xs) `merge'` (ny, y:ys))
        else (y:) `liftM` ((nx, x:xs) `merge'` (ny-1, ys))