Source code for src.scoring

"""Scoring metrics for concept quality evaluation.

Provides various metrics for quantifying how well a concept explains neuron
behavior by comparing activations on synthetic concept images versus
control images.
"""

import enum

import sklearn
import torch


def _calculate_auc(
    neuron_control_activations: torch.Tensor,
    neuron_synthetic_activations: torch.Tensor,
) -> float:
    """Calculate AUC metric for concept scoring.

    Computes the area under the ROC curve treating control activations as
    negative class (0) and synthetic activations as positive class (1).

    Args:
        neuron_control_activations: Activations from control concept images.
        neuron_synthetic_activations: Activations from synthetic concept images.

    Returns:
        AUC score between 0 and 1, higher is better.
    """
    concept_labels = torch.cat(
        (
            torch.zeros([neuron_control_activations.shape[0]]),
            torch.ones([neuron_synthetic_activations.shape[0]]),
        ),
        0,
    )

    activations = torch.cat(
        (neuron_control_activations, neuron_synthetic_activations), 0
    ).to("cpu")

    return sklearn.metrics.roc_auc_score(concept_labels, activations)


def _calculate_average_activation(
    neuron_activations: torch.Tensor,
) -> float:
    """Calculate average neuron activation.

    Args:
        neuron_activations: Activation values for a neuron.

    Returns:
        Mean activation value.
    """
    return neuron_activations.mean().item()


def _calculate_mad(
    neuron_control_activations: torch.Tensor,
    neuron_synthetic_activations: torch.Tensor,
) -> float:
    """Calculate Mean Absolute Difference (MAD) metric.

    Computes the standardized difference between synthetic and control
    activation means, normalized by the control activation standard deviation.

    Args:
        neuron_control_activations: Activations from control concept images.
        neuron_synthetic_activations: Activations from synthetic concept images.

    Returns:
        Standardized activation difference score.
    """
    average_control_activation = _calculate_average_activation(
        neuron_control_activations
    )
    average_synthetic_activation = _calculate_average_activation(
        neuron_synthetic_activations
    )
    control_std = neuron_control_activations.std().item()

    activations_difference = (
        average_synthetic_activation - average_control_activation
    )
    denominator = control_std if control_std != 0 else 1.0
    return activations_difference / denominator


[docs] class Metric(enum.Enum): """Enumeration of concept scoring metrics. Attributes: AUC: Area under ROC curve between control and synthetic activations. MAD: Mean absolute difference normalized by control activation standard deviation. AVG_ACTIVATION: Average activation magnitude on synthetic images. """ AUC = enum.member(_calculate_auc) MAD = enum.member(_calculate_mad) AVG_ACTIVATION = enum.member(_calculate_average_activation)
[docs] def calculate_metric( neuron_control_activations: torch.Tensor, neuron_synthetic_activations: torch.Tensor, metric: Metric, ) -> float: """Calculate the specified metric for a concept. Computes the selected scoring metric by comparing control and synthetic neuron activations. Different metrics measure different aspects of how well a concept explains the neuron. Args: neuron_control_activations: Neuron activations from control images. neuron_synthetic_activations: Neuron activations from synthetic concept images. metric: The metric to compute (AUC, MAD, or AVG_ACTIVATION). Returns: Score value for the selected metric. Raises: ValueError: If activations are not one-dimensional. """ if ( neuron_control_activations.ndim != 1 or neuron_control_activations.ndim != 1 ): raise ValueError("Neuron activations must be one dimensional.") if metric == Metric.AVG_ACTIVATION: return metric.value(neuron_synthetic_activations) else: return metric.value( neuron_control_activations, neuron_synthetic_activations )