src.model.explained_model module

Explained model for capturing intermediate layer activations.

This module provides functionality to load pretrained models and capture activations from specific layers for neuron interpretability analysis.

class src.model.explained_model.ExplainedModel(model_id: str, layer: str, device: str, model_swapping: bool)[source]

Bases: Model

Model wrapper that captures and returns intermediate layer activations.

Extends the base Model class by registering forward hooks on a specified layer to capture neuron activations for interpretation and analysis.

get_activations(input_batch: torch.Tensor) torch.Tensor[source]

Forward pass through model and return layer activations.

Passes input through the model and captures activations from the registered layer. Handles different activation shapes for CNN and Vision Transformer architectures, performing appropriate pooling.

Parameters:

input_batch – Input tensor of shape (N, C, H, W).

Returns:

Neuron activations of shape (N, num_neurons).

Raises:

ValueError – If input_batch is not 4-dimensional or if activations have unexpected dimensionality.