Tutorial 1: Basic COLA Workflow

Learning Objectives

By the end of this tutorial, you will:

  • Understand the complete COLA workflow

  • Generate and refine counterfactual explanations

  • Visualize the results

  • Interpret the output

Prerequisites

  • COLA installed (pip install xai-cola)

  • Basic Python knowledge

  • Understanding of machine learning concepts

The Problem

Imagine you’re building a loan approval system. A customer was denied a loan, and you need to explain what they should change to get approved. Traditional counterfactual explainers might suggest changing 2-6 different things. COLA helps you identify the 1-2 most important changes.

Step 1: Prepare Your Data

import pandas as pd
from xai_cola.data import COLAData
from datasets.german_credit import GermanCreditDataset

# Load the German Credit dataset
dataset = GermanCreditDataset()
df = dataset.get_dataframe()

# Select instances that were denied (Risk = 1)
# In this dataset, Risk=1 means bad credit
denied_customers = df[df['Risk'] == 1].sample(5, random_state=42)

# Create the data interface
data = COLAData(
    factual_data=denied_customers,
    label_column='Risk',
    transform='ohe-zscore'  # One-hot encode categoricals and z-score normalize
)

print(f"Selected {len(denied_customers)} customers to explain")
print(denied_customers.columns)

What’s happening:

  • We load the German Credit dataset (built-in with COLA)

  • Select 5 customers who were denied loans

  • Wrap the data in COLAData for automatic preprocessing

Step 2: Load Your ML Model

import joblib
from xai_cola.models import Model

# Load pre-trained model (you can use your own model here)
classifier = joblib.load('lgbm_GremanCredit.pkl')

# Wrap it in COLA's model interface
ml_model = Model(model=classifier, backend="sklearn")

# Verify the model works
predictions = ml_model.predict(data.get_transformed_data())
print(f"Model predictions: {predictions}")

What’s happening:

  • Load a pre-trained LightGBM classifier

  • Wrap it in COLA’s Model interface so COLA can interact with it

  • The model interface handles backend-specific details

Step 3: Generate Initial Counterfactuals

from xai_cola.ce_generator import DiCE

# Initialize the DiCE explainer
explainer = DiCE(ml_model=ml_model)

# Generate counterfactuals
# We want to flip predictions from 1 (bad) to 0 (good)
factual, counterfactual = explainer.generate_counterfactuals(
    data=data,
    factual_class=1,           # Current class: denied
    total_cfs=1,               # Generate 1 counterfactual per factual
    features_to_keep=['Age', 'Sex']  # Don't change these features
)

print(f"Generated {len(counterfactual)} counterfactuals")
print(f"Factual shape: {factual.shape}")
print(f"Counterfactual shape: {counterfactual.shape}")

What’s happening:

  • DiCE generates counterfactual explanations

  • Each denied customer gets a “what-if” scenario that would lead to approval

  • We keep Age and Sex fixed (immutable features)

Output example:

Generated 5 counterfactuals
Factual shape: (5, 20)
Counterfactual shape: (5, 20)

Step 4: Initialize COLA and Set Policy

from xai_cola import COLA

# Add counterfactuals to the data
data.add_counterfactuals(counterfactual, with_target_column=True)

# Initialize COLA refiner
refiner = COLA(
    data=data,
    ml_model=ml_model
)

# Configure the refinement policy
refiner.set_policy(
    matcher="ect",           # Exact matching (best for DiCE)
    attributor="pshap",      # Use PSHAP for feature importance
    Avalues_method="max"     # How to aggregate importance scores
)

print("COLA initialized and ready to refine")

What’s happening:

  • We tell COLA about the counterfactuals

  • Set up the refinement policy:

    • matcher="ect": Match each factual to its counterfactual

    • attributor="pshap": Use Shapley values with joint probability

    • Avalues_method="max": Take maximum importance when aggregating

Step 5: Refine Counterfactuals

# Refine to use at most 3 feature changes
factual_refined, ce_refined, ace_refined = refiner.get_all_results(
    limited_actions=3
)

print(f"Original counterfactuals: {ce_refined.shape}")
print(f"Action-limited counterfactuals: {ace_refined.shape}")

# Check how many features actually changed
original_changes = (factual_refined != ce_refined).sum(axis=1)
refined_changes = (factual_refined != ace_refined).sum(axis=1)

print(f"\nOriginal CE - features changed: {original_changes.tolist()}")
print(f"Refined ACE - features changed: {refined_changes.tolist()}")

What’s happening:

  • COLA refines the counterfactuals to change at most 3 features

  • We compare the original vs refined versions

  • Refined counterfactuals require fewer changes

Example output:

Original CE - features changed: [8, 10, 7, 9, 11]
Refined ACE - features changed: [3, 3, 3, 3, 3]

Step 6: Visualize Results

Method 1: Highlighted DataFrames (Best for Small Datasets)

# Get highlighted versions showing what changed
refine_factual, refine_ce, refine_ace = refiner.highlight_changes_final()

print("\n=== FACTUAL (Original Customer) ===")
display(refine_factual)

print("\n=== COUNTERFACTUAL (DiCE Suggestion) ===")
display(refine_ce)

print("\n=== ACTION-LIMITED COUNTERFACTUAL (COLA Refinement) ===")
display(refine_ace)

What you’ll see:

  • Color-coded DataFrames where:

    • 🟢 Green: No change

    • 🟡 Yellow: Feature changed

  • Easy to see which features were modified

Method 2: Heatmaps (Best for Any Dataset)

# Binary heatmap: Shows which features changed (0 = no change, 1 = changed)
refiner.heatmap_binary(save_path='./results', save_mode='combined')

# Directional heatmap: Shows if features increased (+1), decreased (-1), or stayed same (0)
refiner.heatmap_direction(save_path='./results', save_mode='combined')

What you’ll see:

  • Visual comparison of CE vs ACE

  • Clear patterns of which features change most often

Method 3: Stacked Bar Chart

# Compare efficiency: What % of features changed?
refiner.stacked_bar_chart(save_path='./results')

What you’ll see:

  • Bar chart showing percentage of features modified

  • Clear comparison: ACE requires fewer changes than CE

Method 4: Diversity Analysis

# Find minimal feature combinations
factual_df, diversity_styles = refiner.diversity()

print("\n=== DIVERSITY ANALYSIS ===")
for i, style in enumerate(diversity_styles):
    print(f"\nCustomer {i+1} - Minimal changes needed:")
    display(style)

What you’ll see:

  • For each customer, shows the minimal set of features that need to change

  • Multiple valid combinations highlighted

Step 7: Interpret the Results

# Let's look at a specific customer
customer_idx = 0

# Original customer data
print(f"Customer {customer_idx} was denied because:")
print(denied_customers.iloc[customer_idx])

# What DiCE suggests (many changes)
print(f"\nDiCE suggests changing {original_changes[customer_idx]} features")

# What COLA suggests (fewer changes)
print(f"COLA suggests changing only {refined_changes[customer_idx]} features")

# Which features to change
changed_features = factual_refined.columns[
    (factual_refined.iloc[customer_idx] != ace_refined.iloc[customer_idx])
]
print(f"\nCOLA recommends changing: {changed_features.tolist()}")

Complete Example

Here’s the complete code in one place:

import pandas as pd
import joblib
from xai_cola import COLA
from xai_cola.data import COLAData
from xai_cola.models import Model
from xai_cola.ce_generator import DiCE
from datasets.german_credit import GermanCreditDataset

# 1. Load data
dataset = GermanCreditDataset()
df = dataset.get_dataframe()
denied_customers = df[df['Risk'] == 1].sample(5, random_state=42)

# 2. Create data interface
data = COLAData(
    factual_data=denied_customers,
    label_column='Risk',
    transform='ohe-zscore'
)

# 3. Load model
classifier = joblib.load('lgbm_GremanCredit.pkl')
ml_model = Model(model=classifier, backend="sklearn")

# 4. Generate counterfactuals
explainer = DiCE(ml_model=ml_model)
factual, counterfactual = explainer.generate_counterfactuals(
    data=data,
    factual_class=1,
    total_cfs=1,
    features_to_keep=['Age', 'Sex']
)

# 5. Refine with COLA
data.add_counterfactuals(counterfactual, with_target_column=True)
refiner = COLA(data=data, ml_model=ml_model)
refiner.set_policy(matcher="ect", attributor="pshap", Avalues_method="max")

factual_refined, ce_refined, ace_refined = refiner.get_all_results(limited_actions=3)

# 6. Visualize
refine_factual, refine_ce, refine_ace = refiner.highlight_changes_final()
display(refine_ace)

refiner.heatmap_binary(save_path='./results', save_mode='combined')
refiner.stacked_bar_chart(save_path='./results')

Exercises

Exercise 1: Vary the Number of Actions

Try different values for limited_actions (1, 3, 5, 10). How does this affect the results?

Exercise 2: Try Different Matchers

Change matcher="ect" to "ot", "nn", or "cem". Which works best for this dataset?

Exercise 3: Feature Selection

Use features_to_vary to restrict which features can be changed:

factual_refined, ce_refined, ace_refined = refiner.get_all_results(
    limited_actions=3,
    features_to_vary=['Credit amount', 'Duration', 'Age']
)

Solutions

Click to see solutions

Exercise 1 Solution

for actions in [1, 3, 5, 10]:
    f, c, a = refiner.get_all_results(limited_actions=actions)
    changes = (f != a).sum(axis=1)
    print(f"Limited to {actions} actions: {changes.tolist()}")

Exercise 2 Solution

for matcher in ["ect", "ot", "nn", "cem"]:
    refiner.set_policy(matcher=matcher, attributor="pshap", Avalues_method="max")
    f, c, a = refiner.get_all_results(limited_actions=3)
    changes = (f != a).sum(axis=1)
    print(f"Matcher {matcher}: {changes.tolist()}")

Exercise 3 Solution

f, c, a = refiner.get_all_results(
    limited_actions=3,
    features_to_vary=['Credit amount', 'Duration', 'Age']
)
# Check which features actually changed
for col in f.columns:
    if (f[col] != a[col]).any():
        print(f"Changed: {col}")

Next Steps

Summary

In this tutorial, you learned:

  • ✅ The complete COLA workflow

  • ✅ How to generate and refine counterfactuals

  • ✅ Multiple visualization techniques

  • ✅ How to interpret COLA’s output

COLA reduces the number of feature changes needed while maintaining the same outcome, making counterfactual explanations more actionable and practical.