{-# LANGUAGE NoImplicitPrelude, Rank2Types, NoMonomorphismRestriction #-}
module Knots.Util where

import Knots.Prelude

import Control.DeepSeq
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.Vector as V

-- | @choose n k@ computes all cardinality-@k@ subsets of { 0, 1, ..., n-1 }.
choose :: Int -> Int -> Vector (Set Int)
choose n k = V.fromList $ choose' (Set.fromList [0 .. n-1]) k

-- | @choose' s k@ computes all cardinality-@k@ subsets of the set @s@.
choose' :: Ord a => Set a -> Int -> [ Set a ]
choose' s k = assert (n >= k) continue where
  n = Set.size s
  continue
    | k == 0    = [ Set.empty ]
    | n == k    = [ s ]
    | otherwise = map (Set.insert x) (choose' s' (k-1)) ++ (choose' s' k)
  (x,s') = Set.deleteFindMin s

choose'' :: Int -> Int -> Set (Set Int)
choose'' n k = Set.fromList $ choose' (Set.fromList [0 .. n-1]) k

map2 :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b)
map2 = fmap . fmap

map3 :: (Functor f, Functor g, Functor h) => (a -> b) -> f (g (h a)) -> f (g (h b))
map3 = fmap . map2

map'i :: (Functor (f i), Functor (g i)) => (forall x k. f k x -> g k x) -> f i (f j a) -> g i (g j a)
map'i f = fmap f . f

map' :: (Functor f, Functor g) => (forall x. f x -> g x) -> f (f a) -> g (g a)
map' f = fmap f . f

groupMap :: (Ord k, Ord l) => (k -> a -> l) -> Map k a -> Map l (Map k a)
groupMap f = Map.foldlWithKey' (\m k a -> Map.alter (insert k a) (f k a) m) Map.empty
    where insert k a Nothing  = Just (Map.singleton k a)
          insert k a (Just m) = Just (Map.insert k a m)

convertMap1 :: (Ord k1, Ord k2) => Map k1 (Map k2 a) -> Map (k1,k2) a
convertMap1 m = Map.fromList [ ((i,j),x) | (i,y) <- Map.toList m, (j,x) <- Map.toList y ]

convertMap2 :: (Ord k1, Ord k2) => Map (k1,k2) a -> Map k1 (Map k2 a)
convertMap2 m = Map.fromListWith Map.union
                [ (i, Map.singleton j x) | ((i,j),x) <- Map.toList m ]

convertMap3 :: (Ord k1, Ord k2) => Map k1 (Map k2 a) -> Map k2 (Map k1 a)
convertMap3 = convertMap2 . Map.mapKeys swap . convertMap1

convertMap4 :: (Ord k) => [Map k a] -> Map k [Maybe a]
convertMap4 maps =
    let the_keys = Set.unions (map Map.keysSet maps)
    in Map.fromList [ ( k, [ Map.lookup k mp | mp <- maps ] ) | k <- Set.toList the_keys ]

prod :: (Monoid a, Ord a) => Set a -> Set a -> Set a
prod s t = Set.fromList [ x <> y | x <- toList s, y <- toList t ]

prod' :: (Ord a, Ord b) => Set a -> Set b -> Set (a,b)
prod' s t = Set.fromList [ (x,y) | x <- toList s, y <- toList t ]

-- | List of subsets of a given set.
power :: Ord a => Set a -> [ Set a ]
power set = case Set.minView set of
    Nothing         -> [ set ]
    Just (x,set')   -> let p = power set'
                       in p ++ map (Set.insert x) p

-- | Scalar multiplication
(.*) :: (Ring r, Functor f) => r -> f r -> f r
(.*) = fmap . (*)

-- | Strict, unpacked pair of two @Int@ values.
data IntPair = IntPair {-# UNPACK #-} !Int {-# UNPACK #-} !Int
    deriving (Eq,Ord,Read,Show)

instance NFData IntPair where
   rnf (IntPair x y) = x `seq` y `seq` ()

replace x y z | x == z      = y
              | otherwise   = z