Skip to content

Commit

Permalink
nice prop setup for distr
Browse files Browse the repository at this point in the history
  • Loading branch information
dtumad committed Jan 22, 2024
1 parent 259c433 commit 3c9fc00
Show file tree
Hide file tree
Showing 11 changed files with 368 additions and 373 deletions.
3 changes: 1 addition & 2 deletions src/computational_monads/distribution_semantics/algebra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ begin
{ rw [prob_output_eq_zero hy2, mul_zero] } },
{ by_cases hyoa : y ∈ oa'.support,
{ refine congr_arg (λ z, z * ⁅= y | oa'⁆) _,
refine prob_event_eq_prob_output _ _ h (λ x' hx hx' hoa, hx _),
simp [set.mem_def] at hx',
refine prob_event_eq_prob_output x h (λ x' hx hx', _),
refine (mul_left_inj y).1 (hx'.trans h.symm) },
{ rw [prob_output_eq_zero hyoa, mul_zero, mul_zero] } }
end
Expand Down
4 changes: 2 additions & 2 deletions src/computational_monads/distribution_semantics/bool.lean
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ by simp only [prob_output_uniform_select_fintype_bind_apply_eq_sum, fintype.univ
fintype.card_bool, nat.cast_bit0, algebra_map.coe_one]

lemma prob_event_uniform_select_bool_bind
(oa : bool → oracle_comp unif_spec α) (e : α → Prop) :
e | $ᵗ bool >>= oa⁆ = (⁅e | oa tt⁆ + ⁅e | oa ff⁆) / 2 :=
(oa : bool → oracle_comp unif_spec α) (p : α → Prop) :
p | $ᵗ bool >>= oa⁆ = (⁅p | oa tt⁆ + ⁅p | oa ff⁆) / 2 :=
by simp only [prob_event_uniform_select_fintype_bind_eq_sum, fintype.univ_bool,
finset.sum_insert, finset.mem_singleton, not_false_iff, finset.sum_singleton, fintype.card_bool,
nat.cast_bit0, algebra_map.coe_one]
Expand Down
54 changes: 29 additions & 25 deletions src/computational_monads/distribution_semantics/ite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ by split_ifs; refl
@[simp] lemma prob_output_ite (x : α) : ⁅= x | if p then oa else oa'⁆ =
if p then ⁅= x | oa⁆ else ⁅= x | oa'⁆ := by split_ifs; refl

@[simp] lemma prob_event_ite (e : set α) : ⁅e | if p then oa else oa'⁆ =
if p thene | oa⁆ elsee | oa'⁆ := by split_ifs; refl
@[simp] lemma prob_event_ite (q : α → Prop) : ⁅q | if p then oa else oa'⁆ =
if p thenq | oa⁆ elseq | oa'⁆ := by split_ifs; refl

lemma dist_equiv_ite_iff (oa'' : oracle_comp spec' α) :
(oa'' ≃ₚ if p then oa else oa') ↔ (p → oa'' ≃ₚ oa) ∧ (¬ p → oa'' ≃ₚ oa') :=
Expand All @@ -60,7 +60,7 @@ variables (p : α → Prop) [decidable_pred p]

section bind_ite

variables (ob ob' : α → oracle_comp spec β) (y : β) (e : set β)
variables (ob ob' : α → oracle_comp spec β) (y : β)

/-- Running one of two computations based on an `ite` is like running them both from the start,
and then choosing which result to take based on a final `ite` statement. -/
Expand Down Expand Up @@ -88,9 +88,10 @@ end

/-- The probability of an event holding after a computation `oa` bound to an `ite`
can be written as a sum over outputs where the predicate is true and where it's false. -/
@[simp] lemma prob_event_bind_ite : ⁅e | do {x ← oa, if p x then ob x else ob' x}⁆ =
(∑' x, if p x then ⁅= x | oa⁆ * ⁅e | ob x⁆ else 0) +
(∑' x, if ¬ p x then ⁅= x | oa⁆ * ⁅e | ob' x⁆ else 0) :=
@[simp] lemma prob_event_bind_ite (q : β → Prop) :
⁅q | do {x ← oa, if p x then ob x else ob' x}⁆ =
(∑' x, if p x then ⁅= x | oa⁆ * ⁅q | ob x⁆ else 0) +
(∑' x, if ¬ p x then ⁅= x | oa⁆ * ⁅q | ob' x⁆ else 0) :=
begin
rw [← ennreal.tsum_add, prob_event_bind_eq_tsum],
exact tsum_congr (λ x, by split_ifs with hx; simp only [zero_add, add_zero]),
Expand All @@ -104,17 +105,17 @@ variables (ob ob' : α → oracle_comp spec β)

lemma bind_ite_dist_equiv_bind_left (h : ∀ x ∈ oa.support, p x) :
do {x ← oa, if p x then ob x else ob' x} ≃ₚ do {x ← oa, ob x} :=
bind_dist_equiv_bind_of_dist_equiv_right _ _ _ (λ x hx, dist_equiv_of_eq (if_pos (h x hx)))
bind_dist_equiv_bind_of_dist_equiv_right _ (λ x hx, dist_equiv_of_eq (if_pos (h x hx)))

lemma bind_ite_dist_equiv_bind_right (h : ∀ x ∈ oa.support, ¬ p x) :
do {x ← oa, if p x then ob x else ob' x} ≃ₚ do {x ← oa, ob' x} :=
bind_dist_equiv_bind_of_dist_equiv_right _ _ _ (λ x hx, dist_equiv_of_eq (if_neg (h x hx)))
bind_dist_equiv_bind_of_dist_equiv_right _ (λ x hx, dist_equiv_of_eq (if_neg (h x hx)))

end bind_ite_eq_bind

section bind_ite_const_left

variables (ob : oracle_comp spec β) (ob' : α → oracle_comp spec β) (y : β) (e : set β)
variables (ob : oracle_comp spec β) (ob' : α → oracle_comp spec β) (y : β)

lemma support_bind_ite_const_left (h : ∃ x ∈ oa.support, p x) :
(do {x ← oa, if p x then ob else ob' x}).support =
Expand All @@ -128,20 +129,21 @@ end
/-- Version of `prob_output_bind_ite_const` when only the left computation is constant -/
@[simp] lemma prob_output_bind_ite_const_left : ⁅= y | do {x ← oa, if p x then ob else ob' x}⁆ =
⁅p | oa⁆ * ⁅= y | ob⁆ + (∑' x, if ¬ p x then ⁅= x | oa⁆ * ⁅= y | ob' x⁆ else 0) :=
by simpa only [prob_output_bind_ite, prob_event_eq_tsum_ite,
by simp only [prob_output_bind_ite, prob_event_eq_tsum_ite,
← ennreal.tsum_mul_right, ite_mul, zero_mul]

/-- Version of `prob_event_bind_ite_const` when only the left computation is constant -/
@[simp] lemma prob_event_bind_ite_const_left : ⁅e | do {x ← oa, if p x then ob else ob' x}⁆ =
⁅p | oa⁆ * ⁅e | ob⁆ + (∑' x, if ¬ p x then ⁅= x | oa⁆ * ⁅e | ob' x⁆ else 0) :=
by simpa only [prob_event_bind_ite, prob_event_eq_tsum_ite,
@[simp] lemma prob_event_bind_ite_const_left (q : β → Prop) :
⁅q | do {x ← oa, if p x then ob else ob' x}⁆ =
⁅p | oa⁆ * ⁅q | ob⁆ + (∑' x, if ¬ p x then ⁅= x | oa⁆ * ⁅q | ob' x⁆ else 0) :=
by simp only [prob_event_bind_ite, prob_event_eq_tsum_ite,
← ennreal.tsum_mul_right, ite_mul, zero_mul]

end bind_ite_const_left

section bind_ite_const_right

variables (ob : α → oracle_comp spec β) (ob' : oracle_comp spec β) (y : β) (e : set β)
variables (ob : α → oracle_comp spec β) (ob' : oracle_comp spec β) (y : β)

/-- Version of `support_bind_ite_const` when only the right computation is constant -/
lemma support_bind_ite_const_right (h : ∃ x ∈ oa.support, ¬ p x) :
Expand All @@ -153,20 +155,21 @@ trans (support_bind_ite _ _ _ _) (congr_arg (λ x,
/-- Version of `prob_output_bind_ite_const` when only the right computation is constant -/
@[simp] lemma prob_output_bind_ite_const_right : ⁅= y | do {x ← oa, if p x then ob x else ob'}⁆ =
(∑' x, if p x then ⁅= x | oa⁆ * ⁅= y | ob x⁆ else 0) + (1 - ⁅p | oa⁆) * ⁅= y | ob'⁆ :=
by {rw ← prob_event_not, simpa only [prob_output_bind_ite, prob_event_eq_tsum_ite,
by {rw ← prob_event_not, simp only [prob_output_bind_ite, prob_event_eq_tsum_ite,
← ennreal.tsum_mul_right, ite_mul, zero_mul] }

/-- Version of `prob_event_bind_ite_const` when only the right computation is constant -/
@[simp] lemma prob_event_bind_ite_const_right : ⁅e | do {x ← oa, if p x then ob x else ob'}⁆ =
(∑' x, if p x then ⁅= x | oa⁆ * ⁅e | ob x⁆ else 0) + (1 - ⁅p | oa⁆) * ⁅e | ob'⁆ :=
by {rw ← prob_event_not, simpa only [prob_event_bind_ite, prob_event_eq_tsum_ite,
@[simp] lemma prob_event_bind_ite_const_right (q : β → Prop) :
⁅q | do {x ← oa, if p x then ob x else ob'}⁆ =
(∑' x, if p x then ⁅= x | oa⁆ * ⁅q | ob x⁆ else 0) + (1 - ⁅p | oa⁆) * ⁅q | ob'⁆ :=
by {rw ← prob_event_not, simp only [prob_event_bind_ite, prob_event_eq_tsum_ite,
← ennreal.tsum_mul_right, ite_mul, zero_mul] }

end bind_ite_const_right

section bind_ite_const

variables (ob ob' : oracle_comp spec β) (y : β) (e : set β)
variables (ob ob' : oracle_comp spec β) (y : β)

@[simp] lemma support_bind_ite_const (h : ∃ x ∈ oa.support, p x) (h' : ∃ x ∈ oa.support, ¬ p x) :
(do {x ← oa, if p x then ob else ob'}).support = ob.support ∪ ob'.support :=
Expand All @@ -189,12 +192,13 @@ end

/-- Simplified version of `prob_event_bind_ite` when only the predicate depends on the
result of the first computation, and the other two computations are constant. -/
@[simp] lemma prob_event_bind_ite_const : ⁅e | do {x ← oa, if p x then ob else ob'}⁆ =
⁅p | oa⁆ * ⁅e | ob⁆ + (1 - ⁅p | oa⁆) * ⁅e | ob'⁆ :=
@[simp] lemma prob_event_bind_ite_const (q : β → Prop) :
⁅q | do {x ← oa, if p x then ob else ob'}⁆ =
⁅p | oa⁆ * ⁅q | ob⁆ + (1 - ⁅p | oa⁆) * ⁅q | ob'⁆ :=
begin
rw [prob_event_bind_ite_const_left, ← prob_event_not],
refine congr_arg (λ x, _ + x) _,
simpa only [prob_event_eq_tsum_ite, ← ennreal.tsum_mul_right, ite_mul, zero_mul]
simp only [prob_event_eq_tsum_ite, ← ennreal.tsum_mul_right, ite_mul, zero_mul]
end

end bind_ite_const
Expand Down Expand Up @@ -248,9 +252,9 @@ variables (f : α → β)
⁅= y | f <$> if p then oa else oa'⁆ =
if p then ⁅= y | f <$> oa⁆ else ⁅= y | f <$> oa'⁆ := by split_ifs; refl

@[simp] lemma prob_event_map_ite (p : Prop) [decidable p] (e : set β) :
e | f <$> if p then oa else oa'⁆ =
if p thene | f <$> oa⁆ elsee | f <$> oa'⁆ := by split_ifs; refl
@[simp] lemma prob_event_map_ite (p : Prop) [decidable p] (q : β → Prop) :
q | f <$> if p then oa else oa'⁆ =
if p thenq | f <$> oa⁆ elseq | f <$> oa'⁆ := by split_ifs; refl

end map_ite

Expand Down
4 changes: 2 additions & 2 deletions src/computational_monads/distribution_semantics/map.lean
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ dist_equiv.ext (λ x, rfl)
-- ⁅= x | f <$> return' !spec! a⁆ = ⁅= x | return' !spec! (f a)⁆ :=
-- by pairwise_dist_equiv

-- lemma prob_event_map_return (e : set β) :
-- lemma prob_event_map_return (q : set β) :
-- ⁅e | f <$> (return' !spec! a)⁆ = ⁅e | return' !spec! (f a)⁆ :=
-- by rw [map_pure]

Expand All @@ -165,7 +165,7 @@ by simp only [map_map_eq_map_comp]
-- lemma prob_output_map_comp (x : γ) : ⁅= x | g <$> (f <$> oa)⁆ = ⁅= x | (g ∘ f) <$> oa⁆ :=
-- by pairwise_dist_equiv

-- lemma prob_event_map_comp (e : set γ) : ⁅e | g <$> (f <$> oa)⁆ = ⁅e | (g ∘ f) <$> oa⁆ :=
-- lemma prob_event_map_comp (p : set γ) : ⁅e | g <$> (f <$> oa)⁆ = ⁅e | (g ∘ f) <$> oa⁆ :=
-- by pairwise_dist_equiv

end map_comp
Expand Down
Loading

0 comments on commit 3c9fc00

Please sign in to comment.