Visualization API¶
Overview¶
COLA provides a suite of pure functions for visualizing counterfactual refinement results. These functions are stateless and can be used independently, but it is recommended to call them through COLA class methods, as the COLA class automatically manages the required data.
Recommended Usage:
Call through COLA class methods (see COLA API documentation):
cola.heatmap_binary()- Binary heatmapcola.heatmap_direction()- Direction heatmapcola.stacked_bar_chart()- Stacked bar chartcola.highlight_changes_comparison()- Comparison highlight tablecola.highlight_changes_final()- Final value highlight tablecola.diversity()- Diversity analysis
Direct Usage:
Advanced users can import functions directly from this module for greater flexibility. This documentation describes the signatures and parameters of these low-level functions.
Available Visualization Types:
Highlighted DataFrames - Interactive tables with color-coded changes
Direction Heatmaps - Show which features increased/decreased
Binary Heatmaps - Show which features changed (yes/no)
Stacked Bar Charts - Compare feature change counts before/after refinement
Diversity Analysis - Explore alternative minimal feature combinations
Module contents¶
Highlighted DataFrames¶
function highlight_changes_comparison (factual_df, counterfactual_df, refined_counterfactual_df, label_column)
Pure function to highlight changes with comparison format (old -> new).
This function displays changes in the format “factual_value -> counterfactual_value” to show both the original and modified values side by side.
- Parameters:
factual_df (pd.DataFrame) – Original factual data
counterfactual_df (pd.DataFrame) – Full counterfactual data (corresponding counterfactual)
refined_counterfactual_df (pd.DataFrame) – Action-limited counterfactual data
label_column (str) – Name of the target/label column
- Returns:
factual_df (pd.DataFrame) – Original factual data (copy)
ce_style (pandas.io.formats.style.Styler) – Styled DataFrame showing factual → full counterfactual
ace_style (pandas.io.formats.style.Styler) – Styled DataFrame showing factual → action-limited counterfactual
- Return type:
tuple
Color Coding:
Yellow background: Feature value changed (except label column)
Light gray background + black border: Label column changed
No highlight: Value unchanged
Note:
Typically called through
cola.highlight_changes_comparison()method, which automatically provides the required data.Example:
from xai_cola.ce_sparsifier.visualization import highlight_changes_comparison factual_copy, ce_style, ace_style = highlight_changes_comparison( factual_df=factual, counterfactual_df=counterfactual, refined_counterfactual_df=refined_cf, label_column='Risk' ) # Display in Jupyter display(ce_style) display(ace_style)
function highlight_changes_final (factual_df, counterfactual_df, refined_counterfactual_df, label_column)
Pure function to highlight changes showing only the final values.
This function displays only the final counterfactual values without showing the “factual -> counterfactual” format, making it cleaner for presentation.
- Parameters:
factual_df (pd.DataFrame) – Original factual data
counterfactual_df (pd.DataFrame) – Full counterfactual data (corresponding counterfactual)
refined_counterfactual_df (pd.DataFrame) – Action-limited counterfactual data
label_column (str) – Name of the target/label column
- Returns:
factual_df (pd.DataFrame) – Original factual data (copy)
ce_style (pandas.io.formats.style.Styler) – Styled DataFrame showing only full counterfactual values
ace_style (pandas.io.formats.style.Styler) – Styled DataFrame showing only action-limited counterfactual values
- Return type:
tuple
Color Coding:
Same as
highlight_changes_comparison():
Yellow background: Feature value changed
Light gray background + black border: Label column changed
No highlight: Value unchanged
Note:
Typically called through
cola.highlight_changes_final()method.
function highlight_differences (data, df_a, df_b, target_name)
Low-level pure function to highlight differences between two DataFrames.
Creates a style DataFrame with background colors and borders to highlight differences between the two input DataFrames. Used internally by other highlight functions.
- Parameters:
data (pd.DataFrame) – The DataFrame to style (should match df_a and df_b structure)
df_a (pd.DataFrame) – First DataFrame for comparison (typically factual)
df_b (pd.DataFrame) – Second DataFrame for comparison (typically counterfactual)
target_name (str) – Name of the target/label column
- Returns:
Style DataFrame with CSS styling strings for highlighting differences
- Return type:
pd.DataFrame
Binary Heatmap¶
function generate_binary_heatmap (factual_df, counterfactual_df, refined_counterfactual_df, label_column, save_path=None, save_mode=’combined’, show_axis_labels=True)
Pure function to generate binary change heatmap visualizations.
This function generates heatmaps that show binary changes: whether a value changed (red) or remained unchanged (lightgrey).
- Parameters:
factual_df (pd.DataFrame) – Original factual data
counterfactual_df (pd.DataFrame) – Full counterfactual data (corresponding counterfactual)
refined_counterfactual_df (pd.DataFrame) – Action-limited counterfactual data
label_column (str) – Name of the target/label column
save_path (str*, **optional*) – Path to save the heatmap images. If None, plots are automatically displayed in Jupyter (default). If provided, saves to the specified path and closes the plots. Can be a directory path or file path.
save_mode (str*, **default=’combined’*) – How to save the heatmaps when save_path is provided: - “combined”: Save both heatmaps in a single combined image (top and bottom) - “separate”: Save two separate image files (heatmap_ce.png and heatmap_ace.png) - Ignored if save_path is None
show_axis_labels (bool*, **default=True*) – Whether to show x and y axis labels (column names and row indices). If True, displays column names and row indices. If False, hides them.
- Returns:
(plot1, plot2) - Heatmap plots (matplotlib Figure objects)
- Return type:
tuple
Color Coding:
Light grey: Value unchanged
Red: Feature value changed
Dark blue (#000080): Label column changed
Note:
Typically called through
cola.heatmap_binary()method.
function heatmap_binary_changes (factual, counterfactual, target_name, background_color=’lightgrey’, changed_feature_color=’red’, show_axis_labels=False)
Pure function to generate binary change heatmap (shows if value changed or not).
This function generates a heatmap that shows binary changes: whether a value changed (red) or remained unchanged (lightgrey). For the target column, changed values are shown in dark blue.
- Parameters:
factual (pd.DataFrame*, **optional*) – Factual DataFrame
counterfactual (pd.DataFrame*, **optional*) – Counterfactual DataFrame
target_name (str*, **optional*) – Name of the target/label column
background_color (str*, **optional*) – Background color for unchanged cells. Defaults to ‘lightgrey’.
changed_feature_color (str*, **optional*) – Color for changed features. Defaults to ‘red’.
show_axis_labels (bool*, **optional*) – Whether to show x and y axis labels. Defaults to False.
- Returns:
The heatmap figure for changes from factual to counterfactual
- Return type:
matplotlib.figure.Figure
Direction Heatmap¶
function generate_direction_heatmap (factual_df, counterfactual_df, refined_counterfactual_df, label_column, numerical_features=None, save_path=None, save_mode=’combined’, show_axis_labels=True)
Pure function to generate directional change heatmap visualizations with distinction between numerical and categorical features.
This function generates heatmaps that show the direction of changes:
Numerical features (increased): light blue (#56B4E9)
Numerical features (decreased): light cyan (#7FCDCD)
Categorical features (changed): peru (#CD853F)
Value unchanged: lightgrey
Target column: changed values shown in black
- Parameters:
factual_df (pd.DataFrame) – Original factual data
counterfactual_df (pd.DataFrame) – Full counterfactual data (corresponding counterfactual)
refined_counterfactual_df (pd.DataFrame) – Action-limited counterfactual data
label_column (str) – Name of the target/label column
numerical_features (list*, **optional*) – List of numerical feature names. If None, all features (except label) are treated as numerical.
save_path (str*, **optional*) – Path to save the heatmap images. If None, plots are automatically displayed in Jupyter (default). If provided, saves to the specified path and closes the plots. Can be a directory path or file path.
save_mode (str*, **default=’combined’*) – How to save the heatmaps when save_path is provided: - “combined”: Save both heatmaps in a single combined image (top and bottom) - “separate”: Save two separate image files - Ignored if save_path is None
show_axis_labels (bool*, **default=True*) – Whether to show x and y axis labels (column names and row indices). If True, displays column names and row indices. If False, hides them.
- Returns:
(plot1, plot2) - Heatmap plots (matplotlib Figure objects)
- Return type:
tuple
Note:
Typically called through
cola.heatmap_direction()method.
function heatmap_direction_changes (factual, counterfactual, target_name, numerical_features=None, unchanged_color=’lightgrey’, increased_color=’#56B4E9’, decreased_color=’#7FCDCD’, categorical_changed_color=’#CD853F’, show_axis_labels=False)
Pure function to generate directional change heatmap with distinction between numerical and categorical features.
This function generates a heatmap that shows the direction of changes:
Numerical features (increased): light blue (#56B4E9)
Numerical features (decreased): light cyan (#7FCDCD)
Categorical features (changed): peru (#CD853F)
Value unchanged: lightgrey
Target column: changed values shown in black (unchanged remain lightgrey)
- Parameters:
factual (pd.DataFrame*, **optional*) – Factual DataFrame
counterfactual (pd.DataFrame*, **optional*) – Counterfactual DataFrame
target_name (str*, **optional*) – Name of the target/label column
numerical_features (list*, **optional*) – List of numerical feature names. If None, all features are treated as numerical.
unchanged_color (str*, **optional*) – Background color for unchanged cells. Defaults to ‘lightgrey’.
increased_color (str*, **optional*) – Color for increased values (numerical only). Defaults to ‘#56B4E9’ (light blue).
decreased_color (str*, **optional*) – Color for decreased values (numerical only). Defaults to ‘#7FCDCD’ (light cyan).
categorical_changed_color (str*, **optional*) – Color for categorical feature changes. Defaults to ‘#CD853F’ (peru).
show_axis_labels (bool*, **optional*) – Whether to show x and y axis labels. Defaults to False.
- Returns:
The heatmap figure showing direction of changes from factual to counterfactual
- Return type:
matplotlib.figure.Figure
Stacked Bar Chart¶
function generate_stacked_bar_chart (factual_df, counterfactual_df, refined_counterfactual_df, label_column, save_path=None, refined_color=’#D9F2D0’, counterfactual_color=’#FBE3D6’, instance_labels=None)
Pure function to generate and optionally save a stacked percentage bar chart.
This is a wrapper around create_stacked_bar_chart() for consistency with other visualization functions.
Creates a percentage-based stacked bar chart where each bar shows:
Green segment (#D9F2D0): percentage of positions modified by refined_counterfactual
Orange segment (#FBE3D6): percentage of additional positions modified only by counterfactual
Total bar length: 100% (representing all counterfactual modifications)
Labels show both percentage and actual count (e.g., “60.0% (3)”)
- Parameters:
factual_df (pd.DataFrame) – Original factual data
counterfactual_df (pd.DataFrame) – Full counterfactual data (corresponding counterfactual)
refined_counterfactual_df (pd.DataFrame) – Action-limited counterfactual data
label_column (str) – Name of the target/label column to exclude from comparison
save_path (str*, **optional*) – Path to save the chart image. If None, chart is not saved.
refined_color (str*, **default=’#D9F2D0’*) – Color for refined counterfactual modified positions (light green)
counterfactual_color (str*, **default=’#FBE3D6’*) – Color for counterfactual modified positions (light pink/orange)
instance_labels (list*, **optional*) – Custom labels for instances
- Returns:
The stacked bar chart figure
- Return type:
matplotlib.figure.Figure
Note:
Typically called through
cola.stacked_bar_chart()method.
function create_stacked_bar_chart (factual_df, counterfactual_df, refined_counterfactual_df, label_column, save_path=None, refined_color=’#D9F2D0’, counterfactual_color=’#FBE3D6’, instance_labels=None)
Pure function to create a horizontal stacked percentage bar chart comparing modification positions.
This function creates a percentage-based stacked bar chart where each bar represents an instance (100% total), showing the proportion of modified positions in refined counterfactual vs. original counterfactual relative to factual data.
Each bar shows:
Green segment (#D9F2D0): percentage of positions modified by refined_counterfactual
Orange segment (#FBE3D6): percentage of additional positions modified only by counterfactual
Total bar length: 100% (representing all counterfactual modifications)
Labels on bars show both percentage and actual count (e.g., “60.0% (3)”)
- Parameters:
Same as
generate_stacked_bar_chart()- Returns:
The stacked bar chart figure
- Return type:
matplotlib.figure.Figure
Diversity Analysis¶
function generate_diversity_for_all_instances (factual_df, refined_counterfactual_df, ml_model, label_column, cola_data=None)
Generate diversity analysis for all instances.
This is the main pure function that processes all instances. For each instance, it finds minimal feature combinations that can flip the prediction from factual’s target value to refined counterfactual’s target value (e.g., from 1 to 0).
- Parameters:
factual_df (pd.DataFrame) – Factual data (with label column)
refined_counterfactual_df (pd.DataFrame) – Refined counterfactual data (with label column)
ml_model (Model) – ML model for prediction
label_column (str) – Name of target column
cola_data (COLAData*, **optional*) – COLAData object for data transformation (needed when using transform_method)
- Returns:
factual_df (pd.DataFrame) – Original factual data (copy)
diversity_styles (List[Styler]) – List of styled DataFrames (one per instance), each showing all minimal feature combinations for that instance
- Return type:
Tuple[pd.DataFrame, List[Styler]]
Note:
Typically called through
cola.diversity()method.
function find_minimal_feature_combinations (factual_row, refined_counterfactual_row, ml_model, label_column, cola_data=None)
Find all minimal feature combinations that can flip the target from factual to refined counterfactual.
This function finds minimal sets of features that, when changed from factual to refined counterfactual values, will cause the model prediction to change from factual’s target value to refined counterfactual’s target value (e.g., from 1 to 0).
- Parameters:
factual_row (pd.Series) – Single factual instance (with label column)
refined_counterfactual_row (pd.Series) – Single refined counterfactual instance (with label column)
ml_model (Model) – ML model for prediction
label_column (str) – Name of target column
cola_data (COLAData*, **optional*) – COLAData object for data transformation (needed when using transform_method)
- Returns:
List of minimal feature combinations (each is a set of feature names)
- Return type:
List[Set[str]]
Algorithm:
Uses exhaustive enumeration starting from 1 feature up to all changed features. Ensures true minimality: if changing feature A alone works, combinations like {A, B} are excluded.
function generate_diversity_dataframe (factual_row, refined_counterfactual_row, minimal_combinations, label_column)
Generate diversity DataFrame with all minimal combinations.
The first row shows the refined counterfactual (all changes from refined CF), followed by rows showing each minimal combination. If a minimal combination modifies exactly the same features as the refined counterfactual, it is excluded since it provides no additional diversity.
- Parameters:
factual_row (pd.Series) – Single factual instance
refined_counterfactual_row (pd.Series) – Single refined counterfactual instance
minimal_combinations (List[Set[str]]) – List of minimal feature combinations
label_column (str) – Name of target column
- Returns:
DataFrame with first row as refined counterfactual, followed by one row per minimal combination (excluding duplicates)
- Return type:
pd.DataFrame
function highlight_diversity_changes (factual_row, diversity_df, label_column)
Highlight changes in diversity DataFrame.
The first row shows the refined counterfactual with all changes highlighted, with yellow background to distinguish it from minimal combinations. Following rows show minimal combinations with only their specific changes highlighted.
- Parameters:
factual_row (pd.Series) – Single factual instance
diversity_df (pd.DataFrame) – DataFrame with diversity combinations (first row is complete counterfactual)
label_column (str) – Name of target column
- Returns:
Styled DataFrame with highlighted changes
- Return type:
Styler
Color Coding:
First row (complete counterfactual): Yellow background with black border
Other rows (minimal combinations): #FFFFCC background with black border
Label column changes: Light gray background with black border
Unchanged values: No highlight
Examples¶
Note: The following examples show direct function calls. In practice, it is recommended to call through COLA class methods (see COLA API documentation).
Highlighted DataFrames¶
Comparison Format (old -> new)¶
from xai_cola.ce_sparsifier.visualization import highlight_changes_comparison
factual_copy, ce_style, ace_style = highlight_changes_comparison(
factual_df=factual,
counterfactual_df=counterfactual,
refined_counterfactual_df=refined_cf,
label_column='Risk'
)
# Display in Jupyter
display(ce_style)
display(ace_style)
# Save to HTML
ce_style.to_html('original_cf.html')
ace_style.to_html('refined_ace.html')
Final Values Only¶
from xai_cola.ce_sparsifier.visualization import highlight_changes_final
factual_copy, ce_style, ace_style = highlight_changes_final(
factual_df=factual,
counterfactual_df=counterfactual,
refined_counterfactual_df=refined_cf,
label_column='Risk'
)
display(ace_style)
Binary Heatmap¶
from xai_cola.ce_sparsifier.visualization import generate_binary_heatmap
plot1, plot2 = generate_binary_heatmap(
factual_df=factual,
counterfactual_df=counterfactual,
refined_counterfactual_df=refined_cf,
label_column='Risk',
save_path='./results',
save_mode='combined',
show_axis_labels=True
)
Direction Heatmap¶
from xai_cola.ce_sparsifier.visualization import generate_direction_heatmap
plot1, plot2 = generate_direction_heatmap(
factual_df=factual,
counterfactual_df=counterfactual,
refined_counterfactual_df=refined_cf,
label_column='Risk',
numerical_features=['Age', 'Income', 'Credit amount'],
save_path='./results',
save_mode='combined',
show_axis_labels=True
)
Stacked Bar Chart¶
from xai_cola.ce_sparsifier.visualization import generate_stacked_bar_chart
fig = generate_stacked_bar_chart(
factual_df=factual,
counterfactual_df=counterfactual,
refined_counterfactual_df=refined_cf,
label_column='Risk',
save_path='./results',
instance_labels=['Sample 1', 'Sample 2', 'Sample 3']
)
Diversity Analysis¶
from xai_cola.ce_sparsifier.visualization import generate_diversity_for_all_instances
factual_copy, diversity_styles = generate_diversity_for_all_instances(
factual_df=factual,
refined_counterfactual_df=refined_cf,
ml_model=ml_model,
label_column='Risk',
cola_data=data # Optional, needed if using transform_method
)
# Display results for each instance
for i, style in enumerate(diversity_styles):
print(f"Instance {i+1} diversity:")
display(style)
See Also¶
COLA API - COLA main class (recommended way to call visualizations)
Data API - Data interface
Models API - Model interface