Counterfactual Explainers¶
Overview¶
COLA provides built-in counterfactual explanation generators and supports integration with external explainers. These explainers generate the initial counterfactuals that COLA then refines to minimize the number of required actions.
Built-in Explainers¶
COLA includes two main explainers:
DiCE - Instance-wise counterfactual generation (Diverse Counterfactual Explanations) - FaccT 2020 paper
DisCount - Distributional counterfactual generation (Distributional Counterfactual Explanations With Optimal Transport) - AISTATS 2025 oral
DiCE Explainer¶
DiCE generates multiple diverse counterfactuals for individual instances. For more details, please refer to https://interpret.ml/DiCE/.
When to use DiCE:
You want diverse counterfactual options for each instance
You need instance-level explanations
You want to control which features can be changed
You care about proximity and sparsity
Basic Usage of our Built-in DiCE.¶
from xai_cola.ce_generator import DiCE
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
# 1. Prepare data
data = COLAData(
factual_data=df,
label_column='Risk',
numerical_features=['Age', 'Income']
)
# 2. Wrap model
ml_model = Model(model=trained_model, backend="sklearn")
# 3. Create explainer
explainer = DiCE(ml_model=ml_model)
# 4. Generate counterfactuals
factual, counterfactual = explainer.generate_counterfactuals(
data=data,
factual_class=1, # Generate CFs for instances with class 1
total_cfs=2, # Generate 2 CFs per instance
continuous_features=['Age', 'Income']
)
# 5. Add to data
data.add_counterfactuals(counterfactual, with_target_column=True)
Parameters Explained¶
- factual_class (int)
The class label of factual instances. The explainer will generate counterfactuals for the opposite class.
# Factual instances have class 1, generate CFs for class 0 factual_class=1 # Factual instances have class 0, generate CFs for class 1 factual_class=0
- total_cfs (int)
Number of counterfactuals per instance.
total_cfs=1 # One CF per instance (faster) total_cfs=5 # Five CFs per instance (more diverse)
- continuous_features (list)
Features that can take continuous values.
continuous_features=['Age', 'Income', 'Duration']
- features_to_keep (list)
Features that should NOT be changed (immutable features).
# Don't change Age or Gender features_to_keep=['Age', 'Gender']
- features_to_vary (list)
Only these features can be changed.
# Only allow changing Income and Duration features_to_vary=['Income', 'Duration']
Note
Use either
features_to_keepORfeatures_to_vary, not both.- permitted_range (dict)
Allowed ranges for features. For numerical features, specify [min, max]. For categorical features, specify list of allowed values.
permitted_range={ 'age': [20, 30], # Age must be between 20-30 'education': ['Doctorate', 'Prof-school'] # Only these education levels allowed }
Advanced DiCE Usage¶
Example 1: With Immutable Features¶
# Scenario: Age and Gender cannot be changed
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=3,
features_to_keep=['Age', 'Gender'],
continuous_features=['Income', 'Duration']
)
Example 2: With Feature Ranges¶
# Scenario: Realistic constraints on features
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=2,
continuous_features=['Age', 'Income', 'LoanAmount'],
permitted_range={
'Age': [18, 70],
'Income': [10000, 200000],
'LoanAmount': [1000, 50000]
}
)
Example 3: Selective Feature Changes¶
# Scenario: Only allow financial features to change
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=2,
features_to_vary=['Income', 'LoanAmount', 'Duration'],
continuous_features=['Income', 'LoanAmount', 'Duration']
)
DisCount Explainer¶
DisCount generates distributional counterfactuals - it finds a counterfactual distribution that maintains similar structure to the factual distribution while achieving the desired prediction outcome.
When to use DisCount:
You have groups of instances to explain
You care about distributional properties
You want cost-efficient group-level changes
You need to maintain data distribution shape
Requirements:
Only compatible with PyTorch models (
backend='pytorch')Model must be wrapped with
Model(model=..., backend='pytorch')
Mathematical Background¶
DisCount solves a constrained optimization problem to find a counterfactual distribution x that is close to the factual distribution x’ while ensuring the model predictions y = b(x) are close to a target distribution y*.
Optimization Objective:
The algorithm minimizes a weighted objective function:
where:
Q_x(x, μ): Sliced Wasserstein distance between counterfactual and factual feature distributions
Q_y(x, ν): Wasserstein distance between counterfactual predictions and target distribution
η ∈ [l, r]: Balancing weight that trades off between input similarity and output accuracy
Constraints:
The optimization is subject to distributional constraints:
Key Components:
Sliced Wasserstein Distance (SW₂): Computed by projecting distributions onto random 1D directions Θ = {θₖ}ₖ₌₁ᴺ and averaging the 1D Wasserstein distances
Trimmed Distance: Uses trimming constant δ to remove outliers from both tails of the distribution for robust computation
Unified Confidence Upper Bound (UCL): Statistical upper bounds on W₂ and SW₂ with Bonferroni correction for the n_proj projections
Interval Narrowing: Adaptive algorithm to find optimal balancing weight η based on constraint slack: a = U₁ - SW₂, b = U₂ - W₂
Basic Usage¶
from xai_cola.ce_generator import DisCount
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
# 1. Prepare data
data = COLAData(
factual_data=df,
label_column='Risk',
numerical_features=['Age', 'Credit amount', 'Duration']
)
# 2. Wrap PyTorch model
ml_model = Model(model=pytorch_model, backend="pytorch")
# 3. Create explainer
explainer = DisCount(ml_model=ml_model)
# 4. Generate distributional counterfactuals
factual, counterfactual = explainer.generate_counterfactuals(
data=data,
factual_class=1, # Factual instances class
lr=5e-2, # Learning rate
n_proj=10, # Number of projections
delta=0.15, # Trimming constant
U_1=0.4, # Distributional distance bound
U_2=0.02, # Prediction distance bound
max_iter=100, # Maximum iterations
silent=False # Print optimization logs
)
# 5. Add to data
data.add_counterfactuals(counterfactual, with_target_column=True)
Parameters Explained¶
- factual_class (int, default=1)
The class label of factual instances. The explainer will generate counterfactuals for the opposite class.
- lr (float, default=5e-2)
Learning rate for optimization. Controls the step size in gradient descent updates.
lr=5e-2 # Default: moderate learning rate lr=1e-1 # Faster convergence, may be unstable lr=1e-3 # Slower but more stable
- n_proj (int, default=10)
Number of random projections for computing sliced Wasserstein distance. More projections give better approximation but slower computation.
n_proj=10 # Default: good balance n_proj=50 # More accurate, slower n_proj=5 # Faster, less accurate
- delta (float, default=0.15)
Trimming constant ∈ (0, 0.5) for robust distance computation. Trims delta fraction from both tails of the distribution when computing Wasserstein distance.
delta=0.15 # Default: moderate robustness delta=0.05 # Less trimming, more sensitive to outliers delta=0.3 # More trimming, more robust to outliers
- U_1 (float, default=0.4)
Upper bound for input distributional distance (sliced Wasserstein distance between factual and counterfactual feature distributions). Controls how much the counterfactual distribution can differ from the factual distribution.
U_1=0.4 # Default: moderate similarity U_1=0.2 # Stricter - counterfactuals closer to factuals U_1=0.6 # Looser - allow more distribution shift
- U_2 (float, default=0.02)
Upper bound for output prediction distance (Wasserstein distance between counterfactual predictions and target distribution). Controls how close the counterfactual predictions should be to the desired outcome.
U_2=0.02 # Default: tight prediction control U_2=0.01 # Stricter - predictions closer to target U_2=0.05 # Looser - allow more prediction variance
- l (float, default=0.15)
Lower bound for interval narrowing in the balancing weight η search. Used in the optimization algorithm to balance input and output constraints.
l=0.15 # Default lower bound # Typically kept in [0, 0.5] range
- r (float, default=1)
Upper bound for interval narrowing in the balancing weight η search.
r=1 # Default upper bound # Typically kept at 1.0
- max_iter (int, default=100)
Maximum number of optimization iterations. More iterations allow better convergence but take longer.
max_iter=100 # Default max_iter=200 # More iterations for complex problems max_iter=50 # Faster but may not converge
- tau (float, default=1e1)
Step size for manifold gradient descent. Controls the scale of gradient updates in the optimization.
tau=1e1 # Default tau=1e2 # Larger steps, faster but may be unstable tau=1e0 # Smaller steps, more stable
- silent (bool, default=False)
Whether to suppress optimization logs during training.
silent=False # Print logs (useful for debugging) silent=True # Suppress logs (cleaner output)
- explain_columns (list, optional)
List of column names that can be modified during optimization. If None, all transformed features can be modified.
explain_columns=['Credit amount', 'Duration'] # Only modify these explain_columns=None # Allow all features to change
Advanced DisCount Usage¶
Example 1: With Selective Feature Modification¶
# Only allow financial features to change
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
explain_columns=['Credit amount', 'Duration', 'Saving accounts'],
lr=5e-2,
U_1=0.3, # Tighter distributional constraint
U_2=0.01, # Tighter prediction constraint
max_iter=150
)
Example 2: Faster Optimization¶
# Trade accuracy for speed
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
n_proj=5, # Fewer projections
max_iter=50, # Fewer iterations
silent=True, # No logs
lr=1e-1 # Larger learning rate
)
Example 3: High-Quality Results¶
# Prioritize quality over speed
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
n_proj=50, # More projections
max_iter=300, # More iterations
delta=0.1, # Less trimming
U_1=0.2, # Tighter distributional bound
U_2=0.01, # Tighter prediction bound
lr=1e-2 # Smaller, more stable learning rate
)
Algorithm Details¶
Parameter Correspondence with Theory:
The implementation maps to the theoretical algorithm as follows:
Core Optimization Variables:
x: Counterfactual feature distribution (optimized)
x’: Factual feature distribution (fixed, from
data)y = b(x): Model predictions on counterfactuals
y*: Target prediction distribution (default: all instances predicted as
1 - factual_class)b(·): Prediction model (from
ml_model)
Distance Metrics:
where:
N =
n_proj: Number of random projection directionsθₖ: Random unit vectors sampled from unit sphere
W₂: 1D Wasserstein-2 distance with δ-trimming
Trimmed Wasserstein Distance:
where δ = delta removes outliers from distribution tails.
Balancing Weight Update:
The interval narrowing algorithm updates η at each iteration based on constraint slack:
a = U₁ - SW₂(x, x’): Slack in input constraint
b = U₂ - W₂(b(x), y*): Slack in output constraint
The balancing weight η is computed as:
If a < 0 and b ≥ 0: η = l (prioritize input constraint)
If a ≥ 0 and b < 0: η = r (prioritize output constraint)
If a < 0 and b < 0: η = l + b/(a+b) × (r-l) (both constraints violated)
If a ≥ 0 and b ≥ 0: η = l + a/(a+b) × (r-l) (both constraints satisfied)
where [l, r] is the search interval (l, r parameters).
Gradient Update:
where:
τ =
tau: Step size for manifold gradient descentThe actual update uses SGD/Adam with learning rate =
lr
Unified Confidence Upper Bound (UCL):
The algorithm uses statistical upper bounds to ensure constraints hold with high probability:
where UCL uses δ-trimming and Bonferroni correction (α/N for N projections).
Hyperparameter Tuning Recommendations:
n_proj ∝ feature dimension: For d-dimensional features, use n_proj ≈ 10-50
delta ∈ [0.05, 0.3]: Smaller for clean data, larger for noisy data
U_1/U_2 from domain knowledge: Analyze historical distribution shifts or business tolerance
lr ≈ tau: In practice, set both to similar magnitudes (e.g., 5e-2 and 1e1)
max_iter: 100-300 for convergence; use early stopping based on constraint satisfaction
External Explainers¶
You can use COLA with any counterfactual explainer that produces DataFrames or arrays counterfactuals.
# Your custom explainer
def my_explainer(X, model):
# ... your counterfactual generation logic ...
return counterfactuals_df
# Generate CFs
cf_df = my_explainer(X_test, model)
# Use with COLA
data = COLAData(factual_data=X_test, label_column='y', numerical_features=['age', 'income'])
data.add_counterfactuals(cf_df, with_target_column=True)
# Refine with COLA
from xai_cola import COLA
sparsifier = COLA(data=data, ml_model=ml_model)
refined = sparsifier.refine_counterfactuals(limited_actions=5)
Common Issues¶
Issue 1: No Counterfactuals Found¶
Error:
ValueError: No valid counterfactuals found
Possible causes:
Too many immutable features - relax
features_to_keepToo strict ranges - widen
permitted_rangeModel is too confident - increase
total_cfsor adjust proximity weight
Solutions:
# ❌ Too restrictive
explainer.generate_counterfactuals(
data=data,
features_to_keep=['Age', 'Gender', 'Income', 'Job'], # Too many!
permitted_range={'Duration': [1, 2]} # Too narrow!
)
# ✅ More flexible
explainer.generate_counterfactuals(
data=data,
features_to_keep=['Age', 'Gender'], # Only truly immutable
permitted_range={'Duration': [1, 60]} # Reasonable range
)
Issue 2: Shape Mismatch¶
Error:
ValueError: Factual and counterfactual have different number of columns
Cause: Counterfactual DataFrame doesn’t match factual structure.
Solution: Ensure column consistency:
# Check columns
print("Factual columns:", data.factual_df.columns.tolist())
print("CF columns:", cf_df.columns.tolist())
# Make sure they match (order doesn't matter, but names must)
assert set(data.factual_df.columns) == set(cf_df.columns)
Best Practices¶
✅ DO:
Start with fewer CFs for faster iteration
# Start with 1 CF per instance total_cfs=1 # Increase later if needed total_cfs=5
Always specify continuous_features
continuous_features=['Age', 'Income', 'Duration']
Use realistic feature constraints
features_to_keep=['Age', 'Gender'] # Truly immutable permitted_range={'Income': [0, 500000]} # Realistic bounds
Verify counterfactuals before refinement
# Check CF predictions cf_preds = ml_model.predict(cf_df.drop('Risk', axis=1)) print("CF predictions:", cf_preds) print("Desired class:", desired_class)
❌ DON’T:
Don’t use too many immutable features
Don’t forget to add counterfactuals to COLAData
# ❌ Forgot this step factual, cf = explainer.generate_counterfactuals(...) sparsifier = COLA(data=data, ml_model=ml_model) # Error! # ✅ Remember to add CFs data.add_counterfactuals(cf, with_target_column=True) sparsifier = COLA(data=data, ml_model=ml_model) # Works!
Don’t mix continuous and categorical in continuous_features
API Reference¶
For complete parameter details, see:
DiCEDisCount
Next Steps¶
Learn about Matching Policies - Configuring COLA refinement
Explore Visualization - Visualizing results
See Data Interface - Managing data