Counterfactual Generators¶
Module contents¶
class DiCE (ml_model)
Bases:
CounterFactualExplainerDiverse Counterfactual Explanations (DiCE) generator.
Generates multiple diverse counterfactuals for each factual instance using the DiCE algorithm. DiCE creates counterfactual explanations by finding diverse alternative scenarios that would lead to a different prediction outcome.
- Parameters:
ml_model (Model) – The machine learning model wrapper. Should be created with
Model(model=..., backend="sklearn"). Can wrap:
Plain sklearn model (not recommended, preprocessing needed)
sklearn Pipeline with preprocessing (recommended)
- References:
Mothilal, R. K., Sharma, A., & Tan, C. (2020). Explaining machine learning classifiers through diverse counterfactual explanations. In Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency. https://doi.org/10.1145/3351095.3372850
Example:
from xai_cola.ce_generator import DiCE from xai_cola.ce_sparsifier.models import Model from xai_cola.ce_sparsifier.data import COLAData from sklearn.pipeline import Pipeline # Create pipeline with preprocessing pipe = Pipeline([("pre", column_transformer), ("clf", lgbm_clf)]) ml_model = Model(model=pipe, backend="sklearn") # Initialize DiCE (no data needed) dice = DiCE(ml_model=ml_model) # Generate counterfactuals (provide data here) data = COLAData(factual_data=df, label_column='Risk') factual_df, cf_df = dice.generate_counterfactuals( data=data, continuous_features=['age', 'income'], total_cfs=5 )generate_counterfactuals (data, factual_class=1, total_cfs=1, features_to_keep=None, continuous_features=None, permitted_range=None)
Generate counterfactuals for the given factual data.
- Parameters:
data (COLAData*, **required*) – Data wrapper containing the factual data (original raw data). Must include both features and target column.
factual_class (int*, **default=1*) – The class of the factual data. Normally, we set the factual_class as 1 (positive class) and we hope the counterfactual is 0 (negative class).
total_cfs (int*, **default=1*) – Total number of counterfactuals required for each query instance.
features_to_keep (list*, **optional*) – List of feature names to keep unchanged in the counterfactuals. Uses original feature names (before any preprocessing).
continuous_features (list*, **optional*) – List of continuous/numerical feature names for dice_ml.Data. Uses original feature names (before any preprocessing). If None, will use all features as continuous features. Categorical features are automatically inferred as all features minus continuous_features.
permitted_range (dict*, **optional*) – Explicit feature range constraints. Format:
{'feature_name': [min, max]}for numerical features, or{'feature_name': ['value1', 'value2']}for categorical features.- Returns:
factual_df (pd.DataFrame) – DataFrame with shape (n_samples, n_features + 1), includes target column. Contains the original factual data.
counterfactual_df (pd.DataFrame) – DataFrame with shape (n_samples, n_features + 1), includes target column. Target column values are set based on actual model predictions.
- Return type:
tuple
- Raises:
ValueError – If data is None
Example:
# Prepare data data = COLAData(factual_data=X_raw, label_column='target', numerical_features=['age', 'income']) # Generate counterfactuals factual_df, cf_df = dice.generate_counterfactuals( data=data, continuous_features=data.get_numerical_features(), features_to_keep=['gender'], # Keep gender unchanged total_cfs=5 )
class DisCount (ml_model)
Bases:
CounterFactualExplainerDistributional Counterfactual Explanations with Optimal Transport (DisCount) generator.
Generates distributional counterfactuals that maintain similar structure to the factual distribution while achieving desired prediction outcomes. Uses optimal transport theory and Wasserstein distances for robust distributional matching.
IMPORTANT: Only compatible with PyTorch models (backend=’pytorch’).
- Parameters:
ml_model (Model) – Pre-trained model wrapped in Model interface. Must be a PyTorch model (
backend='pytorch').- Raises:
ValueError – If ml_model.backend is not ‘pytorch’
- References:
You, L., Cao, L., Nilsson, M., Zhao, B., & Lei, L. (2025). DIStributional COUNTerfactual Explanation With Optimal Transport. In Proceedings of AISTATS 2025 (Oral). https://arxiv.org/pdf/2401.13112
Mathematical Foundation:
DisCount solves a constrained optimization problem:
\[\begin{split}\\min_x Q = (1-\\eta) Q_x(x, \\mu) + \\eta Q_y(x, \\nu)\end{split}\]subject to:
\[\begin{split}\\text{SW}_2(x, x') &\\leq U_1 \\\\ \\text{W}_2(b(x), y^*) &\\leq U_2\end{split}\]where:
x: Counterfactual distribution (optimized)
x’: Factual distribution (fixed)
y = b(x): Model predictions
y*: Target prediction distribution
SW₂: Sliced Wasserstein-2 distance
W₂: Wasserstein-2 distance
U₁, U₂: Upper bounds on distributional distances
Key Features:
Distribution-level counterfactual generation
Preserves distributional structure
Uses sliced Wasserstein distance for efficient computation
Robust to outliers with trimmed distances
Statistical guarantees via confidence upper bounds
Supports data transformation via COLAData.transform_method parameter
Automatically handles inverse transformation of counterfactuals
Example:
from xai_cola.ce_generator import DisCount from xai_cola.ce_sparsifier.models import Model from xai_cola.ce_sparsifier.data import COLAData # Setup with PyTorch model data = COLAData( factual_data=df, label_column='Risk', numerical_features=['Age', 'Credit amount', 'Duration'] ) ml_model = Model(model=pytorch_model, backend="pytorch") # Create DisCount explainer explainer = DisCount(ml_model=ml_model) # Generate distributional counterfactuals factual, counterfactual = explainer.generate_counterfactuals( data=data, factual_class=1, lr=5e-2, n_proj=10, delta=0.15, U_1=0.4, U_2=0.02, max_iter=100 )generate_counterfactuals (data=None, factual_class=1, lr=5e-2, n_proj=10, delta=0.15, U_1=0.4, U_2=0.02, l=0.15, r=1, max_iter=100, tau=1e1, silent=False, explain_columns=None)
Implement the DisCount algorithm to generate counterfactuals.
- Parameters:
data (COLAData*, **optional*) – Factual data wrapper
factual_class (int*, **default=1*) – The class of the factual data. Normally, we set the factual_class as 1 as the prediction of factual data is 1. And we hope the prediction of counterfactual data is 0.
lr (float*, **default=5e-2*) – Learning rate for optimization. Controls the step size in gradient descent updates.
n_proj (int*, **default=10*) – Number of random projections for computing sliced Wasserstein distance. More projections give better approximation but slower computation.
delta (float*, **default=0.15*) – Trimming constant ∈ (0, 0.5) for robust distance computation. Trims delta fraction from both tails of the distribution.
U_1 (float*, **default=0.4*) – Upper bound for input distributional distance (sliced Wasserstein distance between factual and counterfactual feature distributions).
U_2 (float*, **default=0.02*) – Upper bound for output prediction distance (Wasserstein distance between counterfactual predictions and target distribution).
l (float*, **default=0.15*) – Lower bound for interval narrowing in the balancing weight η search.
r (float*, **default=1*) – Upper bound for interval narrowing in the balancing weight η search.
max_iter (int*, **default=100*) – Maximum number of optimization iterations.
tau (float*, **default=1e1*) – Step size for manifold gradient descent. Cannot be too large or too small.
silent (bool*, **default=False*) – Whether to suppress optimization logs during training.
explain_columns (list*, **optional*) – List of column names that can be modified during optimization. If None, all transformed features can be modified.
- Returns:
factual_df (pd.DataFrame) – DataFrame with shape (n_samples, n_features + 1), includes target column
counterfactual_df (pd.DataFrame) – DataFrame with shape (n_samples, n_features + 1), includes target column. Target column values for counterfactual are set based on actual model predictions.
- Return type:
tuple
Note:
Only compatible with PyTorch models (backend=’pytorch’)
Supports data transformation via COLAData.transform_method parameter
Automatically handles inverse transformation of counterfactuals
Example:
# Generate with custom parameters factual, cf = explainer.generate_counterfactuals( data=data, factual_class=1, lr=1e-2, # Smaller learning rate n_proj=50, # More projections for accuracy delta=0.1, # Less trimming U_1=0.3, # Tighter distributional constraint U_2=0.01, # Tighter prediction constraint max_iter=200, # More iterations silent=False # Show progress )
Examples¶
Using DiCE¶
Basic Usage¶
from xai_cola.ce_generator import DiCE
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
# Create and train pipeline
pipe = Pipeline([
('scaler', StandardScaler()),
('clf', RandomForestClassifier())
])
pipe.fit(X_train, y_train)
# Setup
data = COLAData(
factual_data=df,
label_column='Risk',
numerical_features=['Age', 'Income']
)
ml_model = Model(model=pipe, backend="sklearn")
# Create DiCE explainer
explainer = DiCE(ml_model=ml_model)
# Generate counterfactuals
factual, counterfactual = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=2,
continuous_features=['Age', 'Income']
)
# Add to data for further processing
data.add_counterfactuals(counterfactual, with_target_column=True)
With Immutable Features¶
# Don't change Age or Gender
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=3,
features_to_keep=['Age', 'Gender'], # These features stay unchanged
continuous_features=['Income', 'Duration']
)
With Feature Ranges¶
# Set realistic bounds for features
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=2,
continuous_features=['Age', 'Income'],
permitted_range={
'Age': [20, 30], # Numerical range
'Education': ['Doctorate', 'Prof-school'] # Categorical values
}
)
Using DisCount¶
Basic Usage¶
from xai_cola.ce_generator import DisCount
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
import torch.nn as nn
# Define PyTorch model
class NeuralNet(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 2)
)
def forward(self, x):
return self.layers(x)
# Train model
pytorch_model = NeuralNet(input_dim=10)
# ... training code ...
# Setup with PyTorch model
data = COLAData(
factual_data=df,
label_column='Risk',
numerical_features=['Age', 'Credit amount', 'Duration']
)
ml_model = Model(model=pytorch_model, backend="pytorch")
# Create DisCount explainer
explainer = DisCount(ml_model=ml_model)
# Generate distributional counterfactuals
factual, counterfactual = explainer.generate_counterfactuals(
data=data,
factual_class=1,
lr=5e-2, # Learning rate
n_proj=10, # Number of projections
delta=0.15, # Trimming constant
U_1=0.4, # Distributional distance bound
U_2=0.02, # Prediction distance bound
max_iter=100, # Maximum iterations
silent=False # Print optimization logs
)
# Add to data
data.add_counterfactuals(counterfactual, with_target_column=True)
With Selective Feature Modification¶
# Only allow financial features to change
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
explain_columns=['Credit amount', 'Duration', 'Saving accounts'],
lr=5e-2,
U_1=0.3, # Tighter distributional constraint
U_2=0.01, # Tighter prediction constraint
max_iter=150
)
Fast Optimization¶
# Trade accuracy for speed
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
n_proj=5, # Fewer projections
max_iter=50, # Fewer iterations
silent=True, # No logs
lr=1e-1 # Larger learning rate
)
High-Quality Results¶
# Prioritize quality over speed
factual, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
n_proj=50, # More projections
max_iter=300, # More iterations
delta=0.1, # Less trimming
U_1=0.2, # Tighter distributional bound
U_2=0.01, # Tighter prediction bound
lr=1e-2 # Smaller, more stable learning rate
)
Integration with COLA¶
Complete Workflow with DiCE¶
from xai_cola import COLA
from xai_cola.ce_generator import DiCE
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
# 1. Setup
data = COLAData(
factual_data=df,
label_column='Risk',
numerical_features=['Age', 'Income']
)
ml_model = Model(model=trained_model, backend="sklearn")
# 2. Generate CFs with DiCE
explainer = DiCE(ml_model=ml_model)
_, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=2,
continuous_features=['Age', 'Income']
)
# 3. Add to data
data.add_counterfactuals(cf, with_target_column=True)
# 4. Refine with COLA
cola = COLA(data=data, ml_model=ml_model)
cola.set_policy(matcher='ot', attributor='pshap', random_state=42)
refined = cola.get_refined_counterfactual(limited_actions=5)
# 5. Visualize
cola.heatmap_direction(save_path='./results')
Complete Workflow with DisCount¶
from xai_cola import COLA
from xai_cola.ce_generator import DisCount
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
# 1. Setup with PyTorch model
data = COLAData(
factual_data=df,
label_column='Risk',
numerical_features=['Age', 'Credit amount', 'Duration']
)
ml_model = Model(model=pytorch_model, backend="pytorch")
# 2. Generate distributional CFs with DisCount
explainer = DisCount(ml_model=ml_model)
_, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
lr=5e-2,
U_1=0.4,
U_2=0.02,
max_iter=100
)
# 3. Add to data
data.add_counterfactuals(cf, with_target_column=True)
# 4. Refine with COLA
cola = COLA(data=data, ml_model=ml_model)
cola.set_policy(matcher='nn', attributor='pshap', random_state=42)
refined = cola.get_refined_counterfactual(limited_actions=5)
# 5. Visualize
cola.heatmap_binary(save_path='./results')
cola.stacked_bar_chart(save_path='./results')
Algorithm Comparison¶
Aspect |
DiCE |
DisCount |
|---|---|---|
Level |
Instance-wise |
Distribution-wise |
Backend |
Sklearn |
PyTorch only |
Best for |
Individual explanations |
Group explanations |
Diversity |
High (multiple CFs per instance) |
Moderate (maintains distribution) |
Computational Cost |
Fast |
Slower (optimization) |
Distributional Guarantees |
No |
Yes (preserves structure) |
Robustness |
Standard |
High (trimmed distances) |
See Also¶
COLA API - COLA refinement API
Data API - Data interface API
Models API - Model interface API