Source code for src.model.language_model

"""Large language model wrapper for concept generation and reasoning.

This module provides functionality to use a language model to generate and reason
about concepts based on activation patterns and generation history.
"""

from collections.abc import Mapping

import torch
import transformers

from model import concept_history
from model import model
from prompts import prompt_utils


_ASSISTANT_TAG = "<|start_header_id|>assistant<|end_header_id|>"


[docs] class LanguageModel(model.Model): """Language model for generating and reasoning about neuron concepts. Uses a large language model to generate concept names and explanations based on activation history and concept scores. """ def __init__( self, model_id: str, device: str, model_swapping: bool, max_new_tokens: int, prompt_path: str, summary_prompt_path: str, ) -> None: """Initialize the language model. Args: model_id: Identifier of the language model (e.g., 'meta-llama/Llama-2-7b-chat'). device: Device to load model on ('cuda' or 'cpu'). model_swapping: Whether to enable model swapping functionality. max_new_tokens: Maximum number of tokens to generate per response. prompt_path: Path to the prompt template file for generation. summary_prompt_path: Path to the prompt template for summary generation. """ super().__init__(model_id, device, model_swapping) self._max_new_tokens = max_new_tokens self._prompt_path = prompt_path self._summary_prompt_path = summary_prompt_path self.generation_history = [] self.concept_history: Mapping[str, float] = {} def _load(self, **kwargs) -> torch.nn.Module: """Load the language model pipeline. Initializes a text-generation pipeline with the specified model and configures tokenizer settings for LLaMA models. Returns: The loaded language model. """ pipeline = transformers.pipeline( task="text-generation", model=self._model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device=self._device, ) if "llama" in self._model_id: pipeline.tokenizer.pad_token_id = ( pipeline.model.config.eos_token_id[0] ) self._pipeline = pipeline # TODO(piechotam) refactor this assignment return pipeline.model
[docs] def update_concept_history(self, new_concept: str, score: float) -> None: """Update the concept history with a new concept and score. Appends the new concept to the generation history and updates the concept history mapping with the new concept. Args: new_concept: The new concept name to add. score: The score associated with the new concept. """ self.generation_history.append(f"{new_concept},{score}") self.concept_history = concept_history.update_concept_history( self.concept_history, new_concept, score )
[docs] @model.gpu_inference_wrapper def generate_concept(self, top_k: int | None = None): """Generate a new concept name with reasoning using the language model. Uses the concept history and generation history to generate a new concept name. Optionally uses a summary prompt with top-k concepts for faster iteration. Parses the model response to extract both the answer (concept) and reasoning. Args: top_k: Optional number of top concepts to use in summary mode. If None, uses the full prompt template. Returns: Tuple of (concept_name, reasoning_text) from the model. """ self._pipeline.device = torch.device("cuda") if top_k is not None: generation_prompt = prompt_utils.generate_prompt( self.concept_history, self.generation_history, self._summary_prompt_path, top_k, ) else: generation_prompt = prompt_utils.generate_prompt( self.concept_history, self.generation_history, self._prompt_path ) # only llama for now message = self._pipeline.tokenizer.apply_chat_template( [{"role": "user", "content": generation_prompt}], tokenize=False, add_generation_prompt=True, ) terminators = [ self._pipeline.tokenizer.eos_token_id, self._pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>"), ] output = self._pipeline( message, max_new_tokens=self._max_new_tokens, eos_token_id=terminators, temperature=0.5, top_p=0.9, ) response = output[0]["generated_text"].split(_ASSISTANT_TAG)[1].strip() reasoning = ( response.split("<thinking>")[1].split("</thinking>")[0].strip() ) answer = response.split("<answer>")[1].split("</answer>")[0].strip() return answer, reasoning
[docs] def get_best_concept(self) -> str: """Get the best concept and its score from the history. Returns the concept with the highest score from the concept history. Returns: Tuple of (best_concept_name, score) for the highest-scoring concept. """ best_concept = max(self.concept_history, key=self.concept_history.get) score = self.concept_history[best_concept] return best_concept, score