In the second last chapter For a Few Monads More of the very nice tutorial "Learn You a Haskell for a Great Good" the author defines the following monad:

import Data.Ratio  
newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving Show
flatten :: Prob (Prob a) -> Prob a  
flatten (Prob xs) = Prob $ concat $ map multAll xs  
  where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Monad Prob where  
  return x = Prob [(x,1%1)]  
  m >>= f = flatten (fmap f m)  
  fail _ = Prob []

I wondered if it is possible in Haskell to specialize the bind operator ">>=" in case the value in the monad belongs to a special typeclass like Eq, as I'd like to add up all probabilities belonging to the same value.


This is called a "restricted monad" and you define it like this:

{-# LANGUAGE ConstraintKinds, TypeFamilies, KindSignatures, FlexibleContexts, UndecidableInstances #-}
module Control.Restricted (RFunctor(..),
                           RMonadPlus(..),) where
import Prelude hiding (Functor(..), Monad(..))
import Data.Foldable (Foldable(foldMap))
import GHC.Exts (Constraint)

class RFunctor f where
    type Restriction f a :: Constraint
    fmap :: (Restriction f a, Restriction f b) => (a -> b) -> f a -> f b

class (RFunctor f) => RApplicative f where
    pure :: (Restriction f a) => a -> f a
    (<*>) :: (Restriction f a, Restriction f b) => f (a -> b) -> f a -> f b

class (RApplicative m) => RMonad m where
    (>>=) :: (Restriction m a, Restriction m b) => m a -> (a -> m b) -> m b
    (>>) :: (Restriction m a, Restriction m b)  => m a -> m b ->  m b
    a >> b = a >>= \_ -> b
    join :: (Restriction m a, Restriction m (m a)) => m (m a) -> m a
    join a = a >>= id
    fail :: (Restriction m a) => String -> m a
    fail = error

return :: (RMonad m, Restriction m a) => a -> m a
return = pure

class (RMonad m) => RMonadPlus m where
    mplus :: (Restriction m a) => m a -> m a -> m a
    mzero :: (Restriction m a) => m a
    msum :: (Restriction m a, Foldable t) => t (m a) -> m a
    msum t = getRMonadPlusMonoid $ foldMap RMonadPlusMonoid t

data RMonadPlusMonoid m a = RMonadPlusMonoid { getRMonadPlusMonoid :: m a }

instance (RMonadPlus m, Restriction m a) => Monoid (RMonadPlusMonoid m a) where
    mappend (RMonadPlusMonoid x) (RMonadPlusMonoid y) = RMonadPlusMonoid $ mplus x y
    mempty = RMonadPlusMonoid mzero
    mconcat t = RMonadPlusMonoid . msum $ map getRMonadPlusMonoid t

guard :: (RMonadPlus m, Restriction m a) => Bool -> m ()
guard p = if p then return () else mzero

To use a restricted monad, you need to begin your file like this:

{-# LANGUAGE ConstraintKinds, TypeFamilies, RebindableSyntax #-}
module {- module line -} where
import Prelude hiding (Functor(..), Monad(..))
import Control.Restricted


