Skip to content

Commit

Permalink
fixing issues with weight computation in erupt
Browse files Browse the repository at this point in the history
Signed-off-by: 1andrin <[email protected]>
  • Loading branch information
1andrin committed Nov 15, 2024
1 parent d578328 commit 1541247
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 22 deletions.
67 changes: 45 additions & 22 deletions causaltune/erupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,49 +73,72 @@ def score(
return (w * outcome).mean()

def weights(
self, df: pd.DataFrame, policy: Union[Callable, np.ndarray, pd.Series]
self,
df: pd.DataFrame,
policy: Union[Callable, np.ndarray, pd.Series]
) -> pd.Series:
W = df[self.treatment_name].astype(int)
assert all(
[x >= 0 for x in W.unique()]
), "Treatment values must be non-negative integers"
assert all([x >= 0 for x in W.unique()]), "Treatment values must be non-negative integers"

# Handle policy input
if callable(policy):
policy = policy(df).astype(int)
if isinstance(policy, pd.Series):
policy = policy.values
policy = np.array(policy)

d = pd.Series(index=df.index, data=policy)
assert all(
[x >= 0 for x in d.unique()]
), "Policy values must be non-negative integers"
assert all([x >= 0 for x in d.unique()]), "Policy values must be non-negative integers"

# Get propensity scores with better handling of edge cases
if isinstance(self.propensity_model, DummyPropensity):
p = self.propensity_model.predict_proba()
else:
p = self.propensity_model.predict_proba(df[self.X_names])
# normalize to hopefully avoid NaNs
p = np.maximum(p, 1e-4)

try:
p = self.propensity_model.predict_proba(df[self.X_names])
except Exception:
# Fallback to safe defaults if prediction fails
p = np.full((len(df), 2), 0.5)

# Clip propensity scores to avoid division by zero or extreme weights
min_clip = max(1e-6, self.clip) # Ensure minimum clip is not too small
p = np.clip(p, min_clip, 1 - min_clip)

# Initialize weights
weight = np.zeros(len(df))

for i in W.unique():
weight[W == i] = 1 / p[:, i][W == i]


try:
# Calculate weights with safer operations
for i in W.unique():
mask = (W == i)
p_i = p[:, i][mask]
# Add small constant to denominator to prevent division by zero
weight[mask] = 1 / (p_i + 1e-10)
except Exception:
# If something goes wrong, return safe weights
weight = np.ones(len(df))

# Zero out weights where policy disagrees with actual treatment
weight[d != W] = 0.0

# Handle extreme weights
if self.remove_tiny:
weight[weight > 1 / self.clip] = 0.0
else:
weight[weight > 1 / self.clip] = 1 / self.clip

# and just for paranoia's sake let's normalize, though it shouldn't
# matter for big samples
weight *= len(df) / sum(weight)

assert not np.isnan(weight.sum()), "NaNs in ERUPT weights"

# Normalize weights
sum_weight = weight.sum()
if sum_weight > 0:
weight *= len(df) / sum_weight
else:
# If all weights are zero, use uniform weights
weight = np.ones(len(df)) / len(df)

# Final check for NaNs
if np.any(np.isnan(weight)):
# Replace any remaining NaNs with uniform weights
weight = np.ones(len(df)) / len(df)

return pd.Series(index=df.index, data=weight)

def probabilistic_erupt_score(
Expand Down
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.

0 comments on commit 1541247

Please sign in to comment.