Quick Start¶
Get started with COLA in 5 minutes! This guide shows you how to refine counterfactual explanations with minimal code.
Installation¶
pip install xai-cola
Basic Workflow¶
COLA follows a simple 5-step workflow:
1. Load Data → 2. Train Model → 3. Generate CFs → 4. Refine with COLA → 5. Visualize
Complete Example¶
Here’s a complete working example using the built-in German Credit dataset:
# Step 1: Import libraries
from xai_cola import COLA
from xai_cola.datasets.german_credit import GermanCreditDataset
from xai_cola.ce_sparsifier.data import COLAData
from xai_cola.ce_sparsifier.models import Model
from xai_cola.ce_generator import DiCE
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
# Step 2: Load and prepare data
dataset = GermanCreditDataset()
X_train, y_train, X_test, y_test = dataset.get_original_train_test_split()
# Step 3: Train a model
pipe = Pipeline([
('scaler', StandardScaler()),
('clf', LogisticRegression(max_iter=1000))
])
pipe.fit(X_train, y_train)
print(f"Model accuracy: {pipe.score(X_test, y_test):.3f}")
# Step 4: Prepare COLA data interface
numerical_features = ['Age', 'Credit amount', 'Duration']
data = COLAData(
factual_data=X_test,
label_column='Risk',
numerical_features=numerical_features
)
# Step 5: Wrap model for COLA
ml_model = Model(model=pipe, backend="sklearn")
# Step 6: Generate counterfactuals with DiCE
explainer = DiCE(ml_model=ml_model)
factual, counterfactual = explainer.generate_counterfactuals(
data=data,
factual_class=1, # Generate CFs for high-risk instances
total_cfs=2, # 2 CFs per instance
continuous_features=numerical_features
)
# Step 7: Add counterfactuals to data
data.add_counterfactuals(counterfactual, with_target_column=True)
# Step 8: Initialize COLA and set policy
sparsifier = COLA(data=data, ml_model=ml_model)
sparsifier.set_policy(
matcher="ot", # Optimal transport matching
attributor="pshap", # PSHAP for feature attribution
random_state=42 # For reproducibility
)
# Step 9: Query minimum actions needed
min_actions = sparsifier.query_minimum_actions()
print(f"Minimum actions needed: {min_actions}")
# Step 10: Sparsify counterfactuals
sparsified_cf = sparsifier.sparsify_counterfactuals(limited_actions=min_actions)
print(f"✓ Sparsified {len(sparsified_cf)} counterfactuals!")
# Step 11: Compare results
factual_df, ce_df, ace_df = sparsifier.get_all_results(
limited_actions=min_actions
)
original_changes = (factual_df != ce_df).sum().sum()
sparsified_changes = (factual_df != ace_df).sum().sum()
print(f"Original CF: {original_changes} feature changes")
print(f"Sparsified ACE: {sparsified_changes} feature changes")
print(f"Reduction: {original_changes - sparsified_changes} fewer changes!")
# Step 12: Visualize results
sparsifier.heatmap_direction(save_path='./results')
sparsifier.stacked_bar_chart(save_path='./results')
print("✓ Visualizations saved to ./results/")
Expected Output¶
Model accuracy: 0.730
Minimum actions needed: 20
✓ Sparsified 10 counterfactuals!
Original CF: 30 feature changes
Sparsified ACE: 20 feature changes
Reduction: 10 fewer changes!
✓ Visualizations saved to ./results/
Breaking It Down¶
Step-by-Step Explanation¶
Steps 1-3: Standard ML Workflow
Train your model as usual. COLA works with any sklearn-compatible model.
Step 4-5: Data and Model Interface
# Wrap your data
data = COLAData(
factual_data=df,
label_column='target',
numerical_features=['Age', 'Income']
)
# Wrap your model
ml_model = Model(model=your_model, backend="sklearn")
Step 6-7: Generate Counterfactuals
# Use built-in CF explainer (DiCE or DisCount) / You can use your own explainer
explainer = DiCE(ml_model=ml_model)
_, cf = explainer.generate_counterfactuals(...)
# Add to data
data.add_counterfactuals(cf, with_target_column=True)
Step 8-10: Sparsify with COLA
# Initialize COLA
sparsifier = COLA(data=data, ml_model=ml_model)
# Set policy
sparsifier.set_policy(matcher="ot", attributor="pshap")
# Sparsify
refined = sparsifier.refine_counterfactuals(limited_actions=5)
Step 11-12: Analyze and Visualize
# Get all results
factual, ce, ace = sparsifier.get_all_results(limited_actions=5)
# Visualize
sparsifier.heatmap_direction(save_path='./results')
sparsifier.stacked_bar_chart(save_path='./results')
print("✓ Visualizations saved to ./results/")
Using Your Own Data¶
Replace the German Credit dataset with your own:
import pandas as pd
# Load your data
df = pd.read_csv('your_data.csv')
# Split into train/test
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
X_train = train_df.drop('target_column', axis=1)
y_train = train_df['target_column']
X_test = test_df.drop('target_column', axis=1)
y_test = test_df['target_column']
# Define your numerical features
numerical_features = ['feature1', 'feature2', 'feature3']
# Continue with COLA workflow...
data = COLAData(
factual_data=X_test,
label_column='target_column',
numerical_features=numerical_features
)
# Rest of the code remains the same
Using Different Models¶
Scikit-learn¶
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
ml_model = Model(model=clf, backend="sklearn")
PyTorch¶
import torch.nn as nn
class MyModel(nn.Module):
# ... your model definition ...
model = MyModel()
# ... train your model ...
ml_model = Model(model=model, backend="pytorch")
Common Variations¶
Matcher: Using ECT Matcher (Faster)¶
# For faster results, use ECT instead of OT
sparsifier.set_policy(matcher="ect", attributor="pshap")
COLA sparsify counterfactuals: With Feature Restrictions¶
# Only allow certain features to change
refined = sparsifier.refine_counterfactuals(
limited_actions=5,
features_to_vary=['Income', 'LoanAmount', 'Duration']
)
CE generator: Using DisCount Instead of DiCE¶
from xai_cola.ce_generator import DisCount
explainer = DisCount(ml_model=ml_model)
_, cf = explainer.generate_counterfactuals(
data=data,
factual_class=1,
cost_type='L1'
)
Jupyter Notebook Tips¶
For better visualization in Jupyter:
from IPython.display import display
# Display highlighted DataFrames
factual_style, ce_style, ace_style = sparsifier.highlight_changes_final()
display(ce_style)
display(ace_style)
# Display inline figures
%matplotlib inline
fig = sparsifier.heatmap_direction()
Troubleshooting¶
Error: “No counterfactuals found”
Solution: Relax constraints or increase total_cfs
# Increase CFs per instance
explainer.generate_counterfactuals(
data=data,
factual_class=1,
total_cfs=5 # More CFs = higher success rate
)
Error: “Must call set_policy before refining”
Solution: Always call set_policy() before refine_counterfactuals()
# Don't forget this!
sparsifier.set_policy(matcher="ot", attributor="pshap")
Error: “Counterfactual data not set”
Solution: Before creating COLA, you should add counterfactuals to COLAData, or you can input counterfactuals when initializing COLAData.
# Must do this before creating COLA
data.add_counterfactuals(cf, with_target_column=True)
sparsifier = COLA(data=data, ml_model=ml_model)
Next Steps¶
Now that you’ve completed the quick start:
Tutorial 1: Basic COLA Workflow - Detailed tutorial with explanations
Data Interface - Learn about data management
Counterfactual Explainers - Explore different CF generators
Matching Policies - Understand matching strategies
Visualization - Master visualization tools
Resources¶
Installation - Installation guide
Frequently Asked Questions - Frequently asked questions
COLA API - Complete API reference
GitHub Examples - More examples
Getting Help¶
Frequently Asked Questions - Check common questions
GitHub Issues - Report bugs
Contact: leiyo@dtu.dk, s232291@dtu.dk