Skip to content

Commit

Permalink
testing out these changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dtumad committed Jan 18, 2024
1 parent c721583 commit f1c563b
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ begin
support_bind, support_query, set.mem_univ] }
end

lemma mem_support_of_mem_support_eval_dist {oa : oracle_comp spec α} {x : α}
(hx : x ∈ ⁅oa⁆.support) : x ∈ oa.support := by rwa [← support_eval_dist]

lemma mem_support_eval_dist_of_mem_support {oa : oracle_comp spec α} {x : α}
(hx : x ∈ oa.support) : x ∈ ⁅oa⁆.support := by rwa [support_eval_dist]

/-- The support of the `pmf` associated to a computation is the coercion of its `fin_support`. -/
lemma support_eval_dist_eq_fin_support [decidable_eq α] : ⁅oa⁆.support = ↑oa.fin_support :=
(support_eval_dist oa).trans (coe_fin_support oa).symm
Expand Down
206 changes: 130 additions & 76 deletions src/computational_monads/distribution_semantics/defs/prob_event.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,113 +24,170 @@ namespace oracle_comp
open oracle_spec
open_locale big_operators ennreal

variables {α β γ : Type} {spec spec' : oracle_spec} (oa : oracle_comp spec α)
(e e' : set α) (p p' : α → Prop)
variables {α β γ : Type} {spec spec' : oracle_spec}
(oa : oracle_comp spec α)

/-- Probability of a predicate holding after running a particular experiment.
/-- Probability of a predicate `p` holding after running a computation `oa`.
Defined in terms of the outer measure associated to the corresponding `pmf` by `eval_dist`.
TODO: is this a better way to formulate most things? -/
noncomputable def prob_event' (oa : oracle_comp spec α) (p : α → Prop) : ℝ≥0∞ :=
⁅oa⁆.to_outer_measure p

/-- Probability of a predicate holding after running a particular experiment.
Defined in terms of the outer measure associated to the corresponding `pmf` by `eval_dist`. -/
noncomputable def prob_event (oa : oracle_comp spec α) (event : set α) : ℝ≥0∞ :=
⁅oa⁆.to_outer_measure event
See `prob_event_eq_tsum` for an expression in terms of sums. -/
noncomputable def prob_event (oa : oracle_comp spec α) (p : α → Prop) : ℝ≥0∞ :=
⁅oa⁆.to_outer_measure (p : set α)

notation `⁅` p `|` oa `⁆` := prob_event oa p

lemma prob_event.def (oa : oracle_comp spec α) (event : set α) :
⁅event | oa⁆ = ⁅oa⁆.to_outer_measure event := rfl
lemma prob_event.def' : prob_event oa = ⁅oa⁆.to_outer_measure := rfl

lemma prob_event.def (p : α → Prop) : ⁅p | oa⁆ = ⁅oa⁆.to_outer_measure p := rfl

lemma prob_event_eq_to_measure_apply (oa : oracle_comp spec α) (event : set α) :
⁅event | oa⁆ = (@pmf.to_measure α ⊤ ⁅oa⁆ event) :=
(@pmf.to_measure_apply_eq_to_outer_measure_apply α ⊤ ⁅oa⁆ event
lemma prob_event_eq_to_measure_apply (p : α → Prop) : ⁅p | oa⁆ = (@pmf.to_measure α ⊤ ⁅oa⁆ p) :=
(@pmf.to_measure_apply_eq_to_outer_measure_apply α ⊤ ⁅oa⁆ p
measurable_space.measurable_set_top).symm

lemma prob_event_le_one : ⁅e | oa⁆ ≤ 1 :=
(⁅oa⁆.to_outer_measure.mono (set.subset_univ e)).trans
@[simp] lemma prob_event_coe_sort (p : α → bool) : ⁅λ x, p x | oa⁆ = ⁅λ x, p x = tt | oa⁆ :=
by simp_rw [eq_self_iff_true]

@[simp] lemma prob_event_set (e : set α) : ⁅e | oa⁆ = ⁅(∈ e) | oa⁆ := rfl

@[simp] lemma prob_event_le_one (p : α → Prop) : ⁅p | oa⁆ ≤ 1 :=
(⁅oa⁆.to_outer_measure.mono (set.subset_univ p)).trans
(le_of_eq $ (⁅oa⁆.to_outer_measure_apply_eq_one_iff _).2 (set.subset_univ ⁅oa⁆.support))

lemma prob_event_eq_of_iff {p p' : α → Prop} (h : ∀ x, p x ↔ p' x) :
⁅p | oa⁆ = ⁅p' | oa⁆ := congr_arg (λ e, ⁅e | oa⁆) (set.ext h)
section basic

lemma prob_event_eq_of_mem_iff (h : ∀ x, x ∈ e ↔ x ∈ e') : ⁅e | oa⁆ = ⁅e' | oa⁆ :=
congr_arg (λ e, ⁅e | oa⁆) (set.ext h)
variables {p p' : α → Prop}

/-- If the a set `e'` contains all elements of `e` that have non-zero
probability under `⁅oa⁆` then the probability of `e'` is at least as big as that of `e`. -/
lemma prob_event_mono {e e'} (h : (e ∩ oa.support) ⊆ e') : ⁅e | oa⁆ ≤ ⁅e' | oa⁆ :=
pmf.to_outer_measure_mono ⁅oa⁆ (by simpa only [support_eval_dist])
lemma prob_event_mono' (h : ∀ x ∈ oa.support, p x → p' x) : ⁅p | oa⁆ ≤ ⁅p' | oa⁆ :=
pmf.to_outer_measure_mono ⁅oa⁆ (λ x hx, h x (mem_support_of_mem_support_eval_dist hx.2) hx.1)

/-- Weaker version of `prob_event_mono` when the subset holds without intersection `support`. -/
lemma prob_event_mono' {e e'} (h : e ⊆ e') : ⁅e | oa⁆ ≤ ⁅e' | oa⁆ :=
prob_event_mono oa (trans (set.inter_subset_left _ _) h)
lemma prob_event_mono (h : ∀ x, p x → p' x) : ⁅p | oa⁆ ≤ ⁅p' | oa⁆ :=
prob_event_mono' oa (λ x _, h x)

/-- The probability of a singleton set happening is just the `prob_output` of that element. -/
@[simp] lemma prob_event_singleton_eq_prob_output (x : α) : ⁅{x} | oa⁆ = ⁅= x | oa⁆ :=
by rw [prob_event.def, pmf.to_outer_measure_apply_singleton, prob_output]
/-- If two propositions are equivalent on the support of a computation,
then the both have the same probability of holding on the result of the computation. -/
lemma prob_event_ext' (h : ∀ x ∈ oa.support, p x ↔ p' x) : ⁅p | oa⁆ = ⁅p' | oa⁆ :=
le_antisymm (prob_event_mono' oa (λ x hx, (h x hx).1)) (prob_event_mono' oa (λ x hx, (h x hx).2))

/-- The probaility of the `(=) x` event is just the `prob_output` of that element. -/
@[simp] lemma prob_event_eq_eq_prob_output (x : α) : ⁅(=) x | oa⁆ = ⁅= x | oa⁆ :=
trans (congr_arg _ (set.ext $ λ y, eq_comm)) (prob_event_singleton_eq_prob_output oa x)
/-- Weaker version of `prob_event_ext'` when equivalence holds outside the `support`. -/
lemma prob_event_ext {p p' : αProp} (h : ∀ x, p x ↔ p' x) : ⁅p | oa⁆ = ⁅p' | oa⁆ :=
congr_arg (λ q, ⁅q | oa⁆) (funext (λ x, propext (h x)))

/-- The probability of the `(= x)` event is just the `prob_output` of that element. -/
@[simp] lemma prob_event_eq_eq_prob_output' (x : α) : ⁅(= x) | oa⁆ = ⁅= x | oa⁆ :=
trans (congr_arg _ (funext (λ x', by simp [@eq_comm α x' x]))) (prob_event_eq_eq_prob_output oa x)
pmf.to_outer_measure_apply_singleton ⁅oa⁆ x

/-- The probaility of the `(=) x` event is just the `prob_output` of that element. -/
@[simp] lemma prob_event_eq_eq_prob_output (x : α) : ⁅(=) x | oa⁆ = ⁅= x | oa⁆ :=
trans (prob_event_ext oa (λ _, eq_comm)) (prob_event_eq_eq_prob_output' oa x)

lemma prob_event_eq_of_eval_dist_eq {oa : oracle_comp spec α} {oa' : oracle_comp spec' α}
(h : ⁅oa⁆ = ⁅oa'⁆) (e : set α) : ⁅e | oa⁆ = ⁅e | oa'⁆ :=
(h : ⁅oa⁆ = ⁅oa'⁆) (p : α → Prop) : ⁅p | oa⁆ = ⁅p | oa'⁆ :=
by simp only [h, prob_event.def]

section sums

/-- Probability of an event in terms of non-decidable `set.indicator` sum -/
lemma prob_event_eq_tsum_indicator : ⁅e | oa⁆ = ∑' x, e.indicator ⁅oa⁆ x :=
pmf.to_outer_measure_apply ⁅oa⁆ e
end basic

lemma prob_event_eq_sum_indicator [fintype α] : ⁅e | oa⁆ = ∑ x, e.indicator ⁅oa⁆ x :=
(prob_event_eq_tsum_indicator oa e).trans (tsum_eq_sum (λ x hx, (hx $ finset.mem_univ x).elim))
section sums

lemma prob_event_eq_sum_fin_support_indicator [decidable_eq α] :
⁅e | oa⁆ = ∑ x in oa.fin_support, e.indicator ⁅oa⁆ x :=
(prob_event_eq_tsum_indicator oa e).trans (tsum_eq_sum $
λ a ha, set.indicator_apply_eq_zero.2 (λ _, prob_output_eq_zero' ha))
variable (p : α → Prop)

/-- Probability of an event in terms of a decidable `ite` sum-/
lemma prob_event_eq_tsum_ite [decidable_pred e] : ⁅e | oa⁆ = ∑' x, if x ∈ e then ⁅= x | oa⁆ else 0 :=
trans (prob_event_eq_tsum_indicator oa e) (tsum_congr $ λ _, by { rw set.indicator, congr} )
/-- The probability of an event `p` as a sum over the output type, using `set.indicator`
to filter out elements that don't satisfy `p x`. -/
lemma prob_event_eq_tsum_indicator :
⁅p | oa⁆ = ∑' x : α, {x ∈ oa.support | p x}.indicator ⁅oa⁆ x :=
begin
refine trans (⁅oa⁆.to_outer_measure_apply p) (tsum_congr (λ x, _)),
by_cases hx : x ∈ oa.support ∧ p x,
{ exact trans (set.indicator_of_mem hx.2 ⁅oa⁆) (symm $ set.indicator_of_mem hx ⁅oa⁆) },
{ refine or.rec_on (not_and_distrib.1 hx) (λ hx, _) (λ hx, _),
{ exact trans (set.indicator_apply_eq_zero.2 (λ _, prob_output_eq_zero hx))
(symm $ set.indicator_apply_eq_zero.2 (λ _, prob_output_eq_zero hx)) },
{ exact trans (set.indicator_of_not_mem hx ⁅oa⁆)
(symm $ set.indicator_of_not_mem (λ h, hx h.2) ⁅oa⁆) } }
end

lemma prob_event_eq_sum_ite [fintype α] [decidable_pred e] :
⁅e | oa⁆ = ∑ x, ite (x ∈ e) (⁅oa⁆ x) 0 :=
trans (prob_event_eq_sum_indicator oa e) (finset.sum_congr rfl $
λ _ _, by {rw set.indicator, congr})
/-- Weaker version of `prob_event_eq_tsum_indicator` that doesn't explicitly restrict the
set of elements to the support of the computation (implicitly the probabilities are still zero). -/
lemma prob_event_eq_tsum_indicator' : ⁅p | oa⁆ = ∑' x : α, {x | p x}.indicator ⁅oa⁆ x :=
⁅oa⁆.to_outer_measure_apply p

lemma prob_event_mem_set_eq_tsum_indicator (e : set α) :
⁅(∈ e) | oa⁆ = ∑' x, e.indicator ⁅oa⁆ x :=
by rw [prob_event_eq_tsum_indicator', set.set_of_mem_eq]

/-- If we have `decidable_eq` on the output type of a computation,
we can write the probability of an event as a finite sum over the `fin_support` of the computation,
using `set.indicator` to filter elements not in the event. -/
lemma prob_event_eq_sum_indicator [decidable_eq α] :
⁅p | oa⁆ = ∑ x in oa.fin_support, {x | p x}.indicator ⁅oa⁆ x :=
trans (prob_event_eq_tsum_indicator' oa p) (tsum_eq_sum (λ x hx,
set.indicator_apply_eq_zero.2 (λ h, prob_output_eq_zero' hx)))

/-- The probability of an event `p` as a sum over all outputs `x` of `oa` that satisfy `p`,
using a `subtype` in the domain to restrict the included probabilities. -/
lemma prob_event_eq_tsum_subtype : ⁅p | oa⁆ = ∑' x : {x ∈ oa.support| p x}, ⁅= x | oa⁆ :=
by rw [prob_event_eq_tsum_indicator, tsum_subtype, prob_output.def']

/-- Version of `prob_event_eq_tsum_subtype` that doesn't explicitly restrict the set of elements
to the support of the computation (implicitly the probabilities are still zero). -/
lemma prob_event_eq_tsum_subtype' : ⁅p | oa⁆ = ∑' x : {x | p x}, ⁅= x | oa⁆ :=
by rw [prob_event_eq_tsum_indicator', tsum_subtype, prob_output.def']

/-- If `p` is decidable, then we can write the probability of an event as a `tsum` over the
entire output type, using an `ite` statement to exclude outputs not satisfying `p`. -/
@[simp] lemma prob_event_eq_tsum_ite [decidable_pred p] :
⁅p | oa⁆ = ∑' x : α, if p x then ⁅= x | oa⁆ else 0 :=
trans (⁅oa⁆.to_outer_measure_apply p) (by simp_rw [set.indicator, set.mem_def, prob_output.def])

/-- If we have `decidable_eq` on the output type and `decidable_pred` of the event,
we can write the probability of an event as a finite sum over the `fin_support` of the computation,
using an if-then-else statement to filter elements not in the event. -/
lemma prob_event_eq_sum_ite [decidable_eq α] [decidable_pred p] :
⁅p | oa⁆ = ∑ x in oa.fin_support, if p x then ⁅= x | oa⁆ else 0 :=
trans (prob_event_eq_tsum_ite oa p) (tsum_eq_sum (λ x hx,
ite_eq_right_iff.2 (λ _, prob_output_eq_zero' hx)))

/-- If we have `decidable_eq` on the output type and `decidable_pred` of the event,
we can write the probability of an event as a finite sum over the `fin_support` of the computation,
by filtering the computation's `fin_support` by the predicate. -/
@[simp] lemma prob_event_eq_sum_filter [decidable_eq α] [decidable_pred p] :
⁅p | oa⁆ = ∑ x in oa.fin_support.filter p, ⁅= x | oa⁆ :=
trans (prob_event_eq_tsum_ite oa p) (trans (tsum_eq_sum (λ x hx, ite_eq_right_iff.2 (λ hpx,
prob_output_eq_zero' (λ h, hx (finset.mem_filter.2 ⟨h, hpx⟩)))))
(finset.sum_congr rfl (λ x hx, if_pos (finset.mem_filter.1 hx).2)))

/-- The probability of an output belonging to a `finset` can be written as the sum
of the probabilities of getting each element of the set from the computation. -/
@[simp] lemma prob_event_mem_finset_eq_sum (s : finset α) :
⁅(∈ s) | oa⁆ = ∑ x in s, ⁅= x | oa⁆ :=
trans (prob_event_eq_tsum_indicator' oa (∈ s)) (trans (tsum_eq_sum ((λ x hx,
set.indicator_of_not_mem hx _))) (finset.sum_congr rfl (λ x hx, set.indicator_of_mem hx ⁅oa⁆)))

lemma prob_event_eq_sum_fin_support_ite [decidable_eq α] [decidable_pred e] :
⁅e | oa⁆ = ∑ x in oa.fin_support, if x ∈ e then ⁅= x | oa⁆ else 0 :=
trans (prob_event_eq_sum_fin_support_indicator oa e) (finset.sum_congr rfl $
λ _ _, by {rw set.indicator, congr})
end sums

/-- If the event is a finite set, then the probability can be written as a sum over itself. -/
lemma prob_event_coe_finset_eq_sum (e : finset α) : ⁅↑e | oa⁆ = ∑ x in e, ⁅oa⁆ x :=
by rw [prob_event_eq_tsum_indicator, sum_eq_tsum_indicator]
section sets

end sums
/-- The probability of a singleton set is just the `prob_output` of that element. -/
lemma prob_event_mem_singleton_eq_prob_output (x : α) : ⁅(∈ ({x} : set α)) | oa⁆ = ⁅= x | oa⁆ :=
by simp only [set.mem_singleton_iff, prob_event_eq_eq_prob_output']

lemma prob_event_ext (h : ∀ x ∈ oa.support, x ∈ e ↔ x ∈ e') : ⁅e | oa⁆ = ⁅e' | oa⁆ :=
begin
rw [prob_event_eq_tsum_indicator, prob_event_eq_tsum_indicator],
refine tsum_congr (λ x, _),
by_cases hx : x ∈ oa.support,
{ by_cases hx' : x ∈ e,
{ simp only [hx', (h x hx).1 hx', set.indicator_of_mem, eval_dist_apply_eq_prob_output] },
{ simp only [hx', mt ((h x hx).2) hx', set.indicator_of_not_mem, not_false_iff] } },
{ rw [set.indicator_apply_eq_zero.2 (λ _, eval_dist_apply_eq_zero hx),
set.indicator_apply_eq_zero.2 (λ _, eval_dist_apply_eq_zero hx)] }
end
end sets

lemma prob_event_ext' (h : ∀ x ∈ oa.support, p x ↔ p' x) : ⁅p | oa⁆ = ⁅p' | oa⁆ :=
prob_event_ext oa p p' h
-- lemma prob_event_ext (h : ∀ x ∈ oa.support, x ∈ e ↔ x ∈ e') : ⁅e | oa⁆ = ⁅e' | oa⁆ :=
-- begin
-- rw [prob_event_eq_tsum_indicator, prob_event_eq_tsum_indicator],
-- refine tsum_congr (λ x, _),
-- by_cases hx : x ∈ oa.support,
-- { by_cases hx' : x ∈ e,
-- { simp only [hx', (h x hx).1 hx', set.indicator_of_mem, eval_dist_apply_eq_prob_output] },
-- { simp only [hx', mt ((h x hx).2) hx', set.indicator_of_not_mem, not_false_iff] } },
-- { rw [set.indicator_apply_eq_zero.2 (λ _, eval_dist_apply_eq_zero hx),
-- set.indicator_apply_eq_zero.2 (λ _, eval_dist_apply_eq_zero hx)] }
-- end

-- lemma prob_event_ext' (h : ∀ x ∈ oa.support, p x ↔ p' x) : ⁅p | oa⁆ = ⁅p' | oa⁆ :=
-- prob_event_ext oa p p' h

section support

Expand Down Expand Up @@ -242,7 +299,4 @@ begin
{ refl }
end

@[simp] lemma prob_event_coe_sort (p : α → bool) : ⁅λ x, p x | oa⁆ = ⁅λ x, p x = tt | oa⁆ :=
by simp_rw [eq_self_iff_true]

end oracle_comp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ noncomputable def prob_output (oa : oracle_comp spec α) (x : α) := ⁅oa⁆ x

notation `⁅=` x `|` oa `⁆` := prob_output oa x

lemma prob_output.def' (oa : oracle_comp spec α) : prob_output oa = ⁅oa⁆ := rfl

lemma prob_output.def (oa : oracle_comp spec α) (x : α) : ⁅= x | oa⁆ = ⁅oa⁆ x := rfl

lemma eval_dist.prob_output_ext_iff {oa : oracle_comp spec α} {oa' : oracle_comp spec' α} :
Expand Down

0 comments on commit f1c563b

Please sign in to comment.