Skip to content

Inject

Layer injection for modifying model forward passes without changing model weights.

LayerInjector

LayerInjector

Inject custom processing between transformer layers.

Proven capabilities: - Identity injection (pass-through) - Scaling injection (modify magnitude) - Additive injection (add signals) - Extraction injection (capture without modifying)

Source code in src/model_garage/inject/layer.py
class LayerInjector:
    """
    Inject custom processing between transformer layers.

    Proven capabilities:
    - Identity injection (pass-through)
    - Scaling injection (modify magnitude)
    - Additive injection (add signals)
    - Extraction injection (capture without modifying)
    """

    def __init__(self, model: nn.Module):
        """
        Initialize injector for a model.

        Args:
            model: The model to inject into
        """
        self.model = model
        self.hook_manager = HookManager(model)
        self.active_injections: List[str] = []

    def inject(
        self,
        layer_name: str,
        injection_fn: Callable[[torch.Tensor], torch.Tensor],
        name: Optional[str] = None
    ) -> str:
        """
        Inject a function after a layer.

        Args:
            layer_name: Layer to inject after (e.g., "transformer.h.6")
            injection_fn: Function(hidden_states) -> modified_hidden_states
            name: Optional name for this injection

        Returns:
            Injection name for later removal
        """
        name = name or f"injection_{len(self.active_injections)}"

        hook_name = self.hook_manager.register_injection_hook(
            layer_name=layer_name,
            injection_fn=injection_fn,
            hook_name=name
        )

        self.active_injections.append(hook_name)
        return hook_name

    def inject_identity(self, layer_name: str) -> str:
        """
        Inject identity function (for testing).

        This should have NO effect on output.
        """
        return self.inject(
            layer_name=layer_name,
            injection_fn=lambda x: x,
            name=f"{layer_name}_identity"
        )

    def inject_scaling(self, layer_name: str, scale: float = 0.9) -> str:
        """
        Inject scaling function.

        Multiplies hidden states by scale factor.
        """
        return self.inject(
            layer_name=layer_name,
            injection_fn=lambda x: x * scale,
            name=f"{layer_name}_scale_{scale}"
        )

    def inject_additive(
        self,
        layer_name: str,
        bias: Union[torch.Tensor, float]
    ) -> str:
        """
        Inject additive bias.

        Adds a constant to hidden states.
        """
        return self.inject(
            layer_name=layer_name,
            injection_fn=lambda x: x + bias,
            name=f"{layer_name}_additive"
        )

    def inject_noise(
        self,
        layer_name: str,
        noise_scale: float = 0.01
    ) -> str:
        """
        Inject random noise (for exploration/creativity).
        """
        def add_noise(x):
            noise = torch.randn_like(x) * noise_scale
            return x + noise

        return self.inject(
            layer_name=layer_name,
            injection_fn=add_noise,
            name=f"{layer_name}_noise_{noise_scale}"
        )

    def inject_custom_layer(
        self,
        layer_name: str,
        custom_module: nn.Module
    ) -> str:
        """
        Inject a custom nn.Module.

        The module must accept and return tensors of same shape.
        """
        return self.inject(
            layer_name=layer_name,
            injection_fn=custom_module.forward,
            name=f"{layer_name}_custom"
        )

    def remove(self, name: str):
        """Remove a specific injection by name."""
        self.hook_manager.remove_hook(name)
        if name in self.active_injections:
            self.active_injections.remove(name)

    def remove_all(self):
        """Remove all active injections."""
        self.hook_manager.remove_all()
        self.active_injections.clear()

    def list_injections(self) -> List[str]:
        """List all active injection names."""
        return self.active_injections.copy()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.remove_all()

__init__

__init__(model)

Initialize injector for a model.

Parameters:

Name Type Description Default
model Module

The model to inject into

required
Source code in src/model_garage/inject/layer.py
def __init__(self, model: nn.Module):
    """
    Initialize injector for a model.

    Args:
        model: The model to inject into
    """
    self.model = model
    self.hook_manager = HookManager(model)
    self.active_injections: List[str] = []

inject

inject(layer_name, injection_fn, name=None)

Inject a function after a layer.

Parameters:

Name Type Description Default
layer_name str

Layer to inject after (e.g., "transformer.h.6")

required
injection_fn Callable[[Tensor], Tensor]

Function(hidden_states) -> modified_hidden_states

required
name Optional[str]

Optional name for this injection

None

Returns:

Type Description
str

Injection name for later removal

Source code in src/model_garage/inject/layer.py
def inject(
    self,
    layer_name: str,
    injection_fn: Callable[[torch.Tensor], torch.Tensor],
    name: Optional[str] = None
) -> str:
    """
    Inject a function after a layer.

    Args:
        layer_name: Layer to inject after (e.g., "transformer.h.6")
        injection_fn: Function(hidden_states) -> modified_hidden_states
        name: Optional name for this injection

    Returns:
        Injection name for later removal
    """
    name = name or f"injection_{len(self.active_injections)}"

    hook_name = self.hook_manager.register_injection_hook(
        layer_name=layer_name,
        injection_fn=injection_fn,
        hook_name=name
    )

    self.active_injections.append(hook_name)
    return hook_name

inject_identity

inject_identity(layer_name)

Inject identity function (for testing).

This should have NO effect on output.

Source code in src/model_garage/inject/layer.py
def inject_identity(self, layer_name: str) -> str:
    """
    Inject identity function (for testing).

    This should have NO effect on output.
    """
    return self.inject(
        layer_name=layer_name,
        injection_fn=lambda x: x,
        name=f"{layer_name}_identity"
    )

inject_scaling

inject_scaling(layer_name, scale=0.9)

Inject scaling function.

Multiplies hidden states by scale factor.

Source code in src/model_garage/inject/layer.py
def inject_scaling(self, layer_name: str, scale: float = 0.9) -> str:
    """
    Inject scaling function.

    Multiplies hidden states by scale factor.
    """
    return self.inject(
        layer_name=layer_name,
        injection_fn=lambda x: x * scale,
        name=f"{layer_name}_scale_{scale}"
    )

inject_additive

inject_additive(layer_name, bias)

Inject additive bias.

Adds a constant to hidden states.

Source code in src/model_garage/inject/layer.py
def inject_additive(
    self,
    layer_name: str,
    bias: Union[torch.Tensor, float]
) -> str:
    """
    Inject additive bias.

    Adds a constant to hidden states.
    """
    return self.inject(
        layer_name=layer_name,
        injection_fn=lambda x: x + bias,
        name=f"{layer_name}_additive"
    )

inject_noise

inject_noise(layer_name, noise_scale=0.01)

Inject random noise (for exploration/creativity).

Source code in src/model_garage/inject/layer.py
def inject_noise(
    self,
    layer_name: str,
    noise_scale: float = 0.01
) -> str:
    """
    Inject random noise (for exploration/creativity).
    """
    def add_noise(x):
        noise = torch.randn_like(x) * noise_scale
        return x + noise

    return self.inject(
        layer_name=layer_name,
        injection_fn=add_noise,
        name=f"{layer_name}_noise_{noise_scale}"
    )

inject_custom_layer

inject_custom_layer(layer_name, custom_module)

Inject a custom nn.Module.

The module must accept and return tensors of same shape.

Source code in src/model_garage/inject/layer.py
def inject_custom_layer(
    self,
    layer_name: str,
    custom_module: nn.Module
) -> str:
    """
    Inject a custom nn.Module.

    The module must accept and return tensors of same shape.
    """
    return self.inject(
        layer_name=layer_name,
        injection_fn=custom_module.forward,
        name=f"{layer_name}_custom"
    )

remove

remove(name)

Remove a specific injection by name.

Source code in src/model_garage/inject/layer.py
def remove(self, name: str):
    """Remove a specific injection by name."""
    self.hook_manager.remove_hook(name)
    if name in self.active_injections:
        self.active_injections.remove(name)

remove_all

remove_all()

Remove all active injections.

Source code in src/model_garage/inject/layer.py
def remove_all(self):
    """Remove all active injections."""
    self.hook_manager.remove_all()
    self.active_injections.clear()

list_injections

list_injections()

List all active injection names.

Source code in src/model_garage/inject/layer.py
def list_injections(self) -> List[str]:
    """List all active injection names."""
    return self.active_injections.copy()

quick_inject

quick_inject

quick_inject(model, layer_idx, fn)

Quick helper to inject at a specific layer index.

Assumes GPT-2 style architecture (transformer.h.{idx}).

Usage

with quick_inject(model, 6, lambda x: x * 0.9) as injector: output = model(input_ids)

Source code in src/model_garage/inject/layer.py
def quick_inject(model, layer_idx: int, fn: Callable) -> LayerInjector:
    """
    Quick helper to inject at a specific layer index.

    Assumes GPT-2 style architecture (transformer.h.{idx}).

    Usage:
        with quick_inject(model, 6, lambda x: x * 0.9) as injector:
            output = model(input_ids)
    """
    injector = LayerInjector(model)
    injector.inject(f"transformer.h.{layer_idx}", fn)
    return injector

SelfDebate

SelfDebate

High-level wrapper to add self-debate to any model.

Usage

debate = SelfDebate(model, layer_idx=6) with debate: output = model(input_ids) # Now uses debate at layer 6

Source code in src/model_garage/inject/debate.py
class SelfDebate:
    """
    High-level wrapper to add self-debate to any model.

    Usage:
        debate = SelfDebate(model, layer_idx=6)
        with debate:
            output = model(input_ids)  # Now uses debate at layer 6
    """

    def __init__(
        self,
        model: nn.Module,
        layer_idx: int = 6,
        divergence_method: str = "dropout",
        reconciliation_method: str = "average",
        divergence_strength: float = 0.1,
        layer_name_template: str = "transformer.h.{idx}"
    ):
        """
        Initialize self-debate wrapper.

        Args:
            model: The model to wrap
            layer_idx: Which layer to inject debate after
            divergence_method: How to create perspectives
            reconciliation_method: How to merge perspectives
            divergence_strength: How different perspectives should be
            layer_name_template: Template for layer names (GPT-2 style default)
        """
        self.model = model
        self.layer_idx = layer_idx
        self.layer_name = layer_name_template.format(idx=layer_idx)

        # Get hidden dim from model
        config = model.config
        hidden_dim = getattr(config, "hidden_size", getattr(config, "n_embd", 768))

        # Create chamber
        device = next(model.parameters()).device
        self.chamber = DebateChamber(
            hidden_dim=hidden_dim,
            divergence_method=divergence_method,
            reconciliation_method=reconciliation_method,
            divergence_strength=divergence_strength
        ).to(device)

        # Hook handle
        self._hook_handle = None
        self._debate_info = []

    def _hook_fn(self, module, input, output):
        """Hook function that applies debate."""
        if isinstance(output, tuple):
            hidden = output[0]
            debated, info = self.chamber.forward_with_info(hidden)
            self._debate_info.append(info)
            return (debated,) + output[1:]
        else:
            debated, info = self.chamber.forward_with_info(output)
            self._debate_info.append(info)
            return debated

    def enable(self):
        """Enable debate."""
        if self._hook_handle is None:
            layer = self._get_layer(self.layer_name)
            self._hook_handle = layer.register_forward_hook(self._hook_fn)
        self._debate_info = []

    def disable(self):
        """Disable debate."""
        if self._hook_handle is not None:
            self._hook_handle.remove()
            self._hook_handle = None

    def get_debate_info(self):
        """Get info from last forward pass."""
        return self._debate_info.copy()

    def _get_layer(self, name: str):
        """Get layer by dot-separated name."""
        parts = name.split(".")
        module = self.model
        for part in parts:
            if part.isdigit():
                module = module[int(part)]
            else:
                module = getattr(module, part)
        return module

    def __enter__(self):
        self.enable()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.disable()

__init__

__init__(model, layer_idx=6, divergence_method='dropout', reconciliation_method='average', divergence_strength=0.1, layer_name_template='transformer.h.{idx}')

Initialize self-debate wrapper.

Parameters:

Name Type Description Default
model Module

The model to wrap

required
layer_idx int

Which layer to inject debate after

6
divergence_method str

How to create perspectives

'dropout'
reconciliation_method str

How to merge perspectives

'average'
divergence_strength float

How different perspectives should be

0.1
layer_name_template str

Template for layer names (GPT-2 style default)

'transformer.h.{idx}'
Source code in src/model_garage/inject/debate.py
def __init__(
    self,
    model: nn.Module,
    layer_idx: int = 6,
    divergence_method: str = "dropout",
    reconciliation_method: str = "average",
    divergence_strength: float = 0.1,
    layer_name_template: str = "transformer.h.{idx}"
):
    """
    Initialize self-debate wrapper.

    Args:
        model: The model to wrap
        layer_idx: Which layer to inject debate after
        divergence_method: How to create perspectives
        reconciliation_method: How to merge perspectives
        divergence_strength: How different perspectives should be
        layer_name_template: Template for layer names (GPT-2 style default)
    """
    self.model = model
    self.layer_idx = layer_idx
    self.layer_name = layer_name_template.format(idx=layer_idx)

    # Get hidden dim from model
    config = model.config
    hidden_dim = getattr(config, "hidden_size", getattr(config, "n_embd", 768))

    # Create chamber
    device = next(model.parameters()).device
    self.chamber = DebateChamber(
        hidden_dim=hidden_dim,
        divergence_method=divergence_method,
        reconciliation_method=reconciliation_method,
        divergence_strength=divergence_strength
    ).to(device)

    # Hook handle
    self._hook_handle = None
    self._debate_info = []

enable

enable()

Enable debate.

Source code in src/model_garage/inject/debate.py
def enable(self):
    """Enable debate."""
    if self._hook_handle is None:
        layer = self._get_layer(self.layer_name)
        self._hook_handle = layer.register_forward_hook(self._hook_fn)
    self._debate_info = []

disable

disable()

Disable debate.

Source code in src/model_garage/inject/debate.py
def disable(self):
    """Disable debate."""
    if self._hook_handle is not None:
        self._hook_handle.remove()
        self._hook_handle = None

get_debate_info

get_debate_info()

Get info from last forward pass.

Source code in src/model_garage/inject/debate.py
def get_debate_info(self):
    """Get info from last forward pass."""
    return self._debate_info.copy()

DebateChamber

DebateChamber

Bases: Module

A debate chamber that creates divergent perspectives and reconciles them.

Can be injected between any two layers using LayerInjector.

Source code in src/model_garage/inject/debate.py
class DebateChamber(nn.Module):
    """
    A debate chamber that creates divergent perspectives and reconciles them.

    Can be injected between any two layers using LayerInjector.
    """

    def __init__(
        self,
        hidden_dim: int,
        divergence_method: str = "dropout",
        reconciliation_method: str = "average",
        divergence_strength: float = 0.1
    ):
        """
        Initialize debate chamber.

        Args:
            hidden_dim: Dimension of hidden states
            divergence_method: How to create different perspectives
                - "dropout": Different dropout masks
                - "perturbation": Add different noise
                - "projection": Different learned projections
            reconciliation_method: How to merge perspectives
                - "average": Simple mean
                - "confidence": Weight by magnitude
                - "gated": Learned gating
            divergence_strength: How different the perspectives should be
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.divergence_method = divergence_method
        self.reconciliation_method = reconciliation_method
        self.divergence_strength = divergence_strength

        # Setup divergence
        if divergence_method == "projection":
            self.proj_a = nn.Linear(hidden_dim, hidden_dim)
            self.proj_b = nn.Linear(hidden_dim, hidden_dim)
            # Initialize A as identity, B as perturbed
            nn.init.eye_(self.proj_a.weight)
            nn.init.zeros_(self.proj_a.bias)
            nn.init.eye_(self.proj_b.weight)
            with torch.no_grad():
                self.proj_b.weight.add_(torch.randn_like(self.proj_b.weight) * divergence_strength)
            nn.init.zeros_(self.proj_b.bias)

        # Setup reconciliation
        if reconciliation_method == "gated":
            self.gate = nn.Linear(hidden_dim * 2, hidden_dim)

    def create_perspectives(
        self,
        hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Create two different perspectives of the hidden states.
        """
        if self.divergence_method == "dropout":
            # Different dropout masks
            mask_a = torch.bernoulli(
                torch.ones_like(hidden) * (1 - self.divergence_strength)
            ) / (1 - self.divergence_strength)
            mask_b = torch.bernoulli(
                torch.ones_like(hidden) * (1 - self.divergence_strength)
            ) / (1 - self.divergence_strength)
            view_a = hidden * mask_a
            view_b = hidden * mask_b

        elif self.divergence_method == "perturbation":
            # Add different noise
            noise_a = torch.randn_like(hidden) * self.divergence_strength
            noise_b = torch.randn_like(hidden) * self.divergence_strength
            view_a = hidden + noise_a
            view_b = hidden + noise_b

        elif self.divergence_method == "projection":
            # Different learned projections
            view_a = self.proj_a(hidden)
            view_b = self.proj_b(hidden)

        else:
            raise ValueError(f"Unknown divergence method: {self.divergence_method}")

        return view_a, view_b

    def reconcile(
        self,
        view_a: torch.Tensor,
        view_b: torch.Tensor
    ) -> torch.Tensor:
        """
        Reconcile two perspectives into one.
        """
        if self.reconciliation_method == "average":
            return (view_a + view_b) / 2

        elif self.reconciliation_method == "confidence":
            # Weight by magnitude (confidence)
            conf_a = view_a.abs().mean(dim=-1, keepdim=True)
            conf_b = view_b.abs().mean(dim=-1, keepdim=True)
            total = conf_a + conf_b + 1e-8
            return (conf_a / total) * view_a + (conf_b / total) * view_b

        elif self.reconciliation_method == "gated":
            # Learned gating
            combined = torch.cat([view_a, view_b], dim=-1)
            gate = torch.sigmoid(self.gate(combined))
            return gate * view_a + (1 - gate) * view_b

        else:
            raise ValueError(f"Unknown reconciliation method: {self.reconciliation_method}")

    def forward(self, hidden: torch.Tensor) -> torch.Tensor:
        """
        Apply debate: create perspectives, reconcile, return result.
        """
        view_a, view_b = self.create_perspectives(hidden)
        return self.reconcile(view_a, view_b)

    def forward_with_info(self, hidden: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """
        Apply debate and return additional info.
        """
        view_a, view_b = self.create_perspectives(hidden)
        reconciled = self.reconcile(view_a, view_b)

        # Compute metrics
        with torch.no_grad():
            cosine_sim = F.cosine_similarity(
                view_a.flatten(), view_b.flatten(), dim=0
            ).item()
            l2_diff = (view_a - view_b).norm().item()

        info = {
            "cosine_similarity": cosine_sim,
            "l2_difference": l2_diff,
            "divergence_method": self.divergence_method,
            "reconciliation_method": self.reconciliation_method,
        }

        return reconciled, info

__init__

__init__(hidden_dim, divergence_method='dropout', reconciliation_method='average', divergence_strength=0.1)

Initialize debate chamber.

Parameters:

Name Type Description Default
hidden_dim int

Dimension of hidden states

required
divergence_method str

How to create different perspectives - "dropout": Different dropout masks - "perturbation": Add different noise - "projection": Different learned projections

'dropout'
reconciliation_method str

How to merge perspectives - "average": Simple mean - "confidence": Weight by magnitude - "gated": Learned gating

'average'
divergence_strength float

How different the perspectives should be

0.1
Source code in src/model_garage/inject/debate.py
def __init__(
    self,
    hidden_dim: int,
    divergence_method: str = "dropout",
    reconciliation_method: str = "average",
    divergence_strength: float = 0.1
):
    """
    Initialize debate chamber.

    Args:
        hidden_dim: Dimension of hidden states
        divergence_method: How to create different perspectives
            - "dropout": Different dropout masks
            - "perturbation": Add different noise
            - "projection": Different learned projections
        reconciliation_method: How to merge perspectives
            - "average": Simple mean
            - "confidence": Weight by magnitude
            - "gated": Learned gating
        divergence_strength: How different the perspectives should be
    """
    super().__init__()
    self.hidden_dim = hidden_dim
    self.divergence_method = divergence_method
    self.reconciliation_method = reconciliation_method
    self.divergence_strength = divergence_strength

    # Setup divergence
    if divergence_method == "projection":
        self.proj_a = nn.Linear(hidden_dim, hidden_dim)
        self.proj_b = nn.Linear(hidden_dim, hidden_dim)
        # Initialize A as identity, B as perturbed
        nn.init.eye_(self.proj_a.weight)
        nn.init.zeros_(self.proj_a.bias)
        nn.init.eye_(self.proj_b.weight)
        with torch.no_grad():
            self.proj_b.weight.add_(torch.randn_like(self.proj_b.weight) * divergence_strength)
        nn.init.zeros_(self.proj_b.bias)

    # Setup reconciliation
    if reconciliation_method == "gated":
        self.gate = nn.Linear(hidden_dim * 2, hidden_dim)

create_perspectives

create_perspectives(hidden)

Create two different perspectives of the hidden states.

Source code in src/model_garage/inject/debate.py
def create_perspectives(
    self,
    hidden: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create two different perspectives of the hidden states.
    """
    if self.divergence_method == "dropout":
        # Different dropout masks
        mask_a = torch.bernoulli(
            torch.ones_like(hidden) * (1 - self.divergence_strength)
        ) / (1 - self.divergence_strength)
        mask_b = torch.bernoulli(
            torch.ones_like(hidden) * (1 - self.divergence_strength)
        ) / (1 - self.divergence_strength)
        view_a = hidden * mask_a
        view_b = hidden * mask_b

    elif self.divergence_method == "perturbation":
        # Add different noise
        noise_a = torch.randn_like(hidden) * self.divergence_strength
        noise_b = torch.randn_like(hidden) * self.divergence_strength
        view_a = hidden + noise_a
        view_b = hidden + noise_b

    elif self.divergence_method == "projection":
        # Different learned projections
        view_a = self.proj_a(hidden)
        view_b = self.proj_b(hidden)

    else:
        raise ValueError(f"Unknown divergence method: {self.divergence_method}")

    return view_a, view_b

reconcile

reconcile(view_a, view_b)

Reconcile two perspectives into one.

Source code in src/model_garage/inject/debate.py
def reconcile(
    self,
    view_a: torch.Tensor,
    view_b: torch.Tensor
) -> torch.Tensor:
    """
    Reconcile two perspectives into one.
    """
    if self.reconciliation_method == "average":
        return (view_a + view_b) / 2

    elif self.reconciliation_method == "confidence":
        # Weight by magnitude (confidence)
        conf_a = view_a.abs().mean(dim=-1, keepdim=True)
        conf_b = view_b.abs().mean(dim=-1, keepdim=True)
        total = conf_a + conf_b + 1e-8
        return (conf_a / total) * view_a + (conf_b / total) * view_b

    elif self.reconciliation_method == "gated":
        # Learned gating
        combined = torch.cat([view_a, view_b], dim=-1)
        gate = torch.sigmoid(self.gate(combined))
        return gate * view_a + (1 - gate) * view_b

    else:
        raise ValueError(f"Unknown reconciliation method: {self.reconciliation_method}")

forward

forward(hidden)

Apply debate: create perspectives, reconcile, return result.

Source code in src/model_garage/inject/debate.py
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
    """
    Apply debate: create perspectives, reconcile, return result.
    """
    view_a, view_b = self.create_perspectives(hidden)
    return self.reconcile(view_a, view_b)

forward_with_info

forward_with_info(hidden)

Apply debate and return additional info.

Source code in src/model_garage/inject/debate.py
def forward_with_info(self, hidden: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
    """
    Apply debate and return additional info.
    """
    view_a, view_b = self.create_perspectives(hidden)
    reconciled = self.reconcile(view_a, view_b)

    # Compute metrics
    with torch.no_grad():
        cosine_sim = F.cosine_similarity(
            view_a.flatten(), view_b.flatten(), dim=0
        ).item()
        l2_diff = (view_a - view_b).norm().item()

    info = {
        "cosine_similarity": cosine_sim,
        "l2_difference": l2_diff,
        "divergence_method": self.divergence_method,
        "reconciliation_method": self.reconciliation_method,
    }

    return reconciled, info

Temperature Debates

TemperatureDebate

Base class for temperature-based debate strategies.

Creates two perspectives: - Conservative: low temperature, focused distribution - Exploratory: high temperature, diverse distribution

Then reconciles them for creative yet coherent output.

Source code in src/model_garage/inject/temperature.py
class TemperatureDebate:
    """
    Base class for temperature-based debate strategies.

    Creates two perspectives:
    - Conservative: low temperature, focused distribution
    - Exploratory: high temperature, diverse distribution

    Then reconciles them for creative yet coherent output.
    """

    def __init__(
        self,
        conservative_temp: float = 0.5,
        exploratory_temp: float = 1.2,
    ):
        self.conservative_temp = conservative_temp
        self.exploratory_temp = exploratory_temp

    def get_perspectives(
        self,
        logits: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Create conservative and exploratory probability distributions.

        Args:
            logits: Raw logits [batch, vocab_size]

        Returns:
            (probs_conservative, probs_exploratory)
        """
        probs_cons = F.softmax(logits / self.conservative_temp, dim=-1)
        probs_wild = F.softmax(logits / self.exploratory_temp, dim=-1)
        return probs_cons, probs_wild

    def debate(self, logits: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """
        Apply debate and return blended probabilities.

        Override in subclasses for specific strategies.
        """
        raise NotImplementedError("Subclasses must implement debate()")

get_perspectives

get_perspectives(logits)

Create conservative and exploratory probability distributions.

Parameters:

Name Type Description Default
logits Tensor

Raw logits [batch, vocab_size]

required

Returns:

Type Description
Tuple[Tensor, Tensor]

(probs_conservative, probs_exploratory)

Source code in src/model_garage/inject/temperature.py
def get_perspectives(
    self,
    logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create conservative and exploratory probability distributions.

    Args:
        logits: Raw logits [batch, vocab_size]

    Returns:
        (probs_conservative, probs_exploratory)
    """
    probs_cons = F.softmax(logits / self.conservative_temp, dim=-1)
    probs_wild = F.softmax(logits / self.exploratory_temp, dim=-1)
    return probs_cons, probs_wild

debate

debate(logits)

Apply debate and return blended probabilities.

Override in subclasses for specific strategies.

Source code in src/model_garage/inject/temperature.py
def debate(self, logits: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
    """
    Apply debate and return blended probabilities.

    Override in subclasses for specific strategies.
    """
    raise NotImplementedError("Subclasses must implement debate()")

AdaptiveDebate

Bases: TemperatureDebate

Adapt wild probability based on context entropy.

High entropy context -> use more conservative (already uncertain) Low entropy context -> can afford to be wilder

Source code in src/model_garage/inject/temperature.py
class AdaptiveDebate(TemperatureDebate):
    """
    Adapt wild probability based on context entropy.

    High entropy context -> use more conservative (already uncertain)
    Low entropy context -> can afford to be wilder
    """

    def __init__(
        self,
        conservative_temp: float = 0.5,
        exploratory_temp: float = 1.2,
        base_wild_prob: float = 0.2,
        entropy_threshold: float = 2.0
    ):
        super().__init__(conservative_temp, exploratory_temp)
        self.base_wild_prob = base_wild_prob
        self.entropy_threshold = entropy_threshold

    def debate(self, logits: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """
        Adapt strategy based on current entropy.
        """
        probs_cons, probs_wild = self.get_perspectives(logits)

        # Compute entropy of conservative distribution
        entropy = -(probs_cons * torch.log(probs_cons + 1e-10)).sum(dim=-1).mean().item()

        # High entropy = be more conservative
        # Low entropy = can be wilder
        if entropy > self.entropy_threshold:
            wild_prob = self.base_wild_prob * 0.5  # Reduce wildness
        else:
            wild_prob = self.base_wild_prob * 1.5  # Increase wildness

        wild_prob = min(wild_prob, 0.5)  # Cap at 50%

        # Apply random switch with adapted probability
        if random.random() < wild_prob:
            chosen = probs_wild
            choice = "exploratory"
        else:
            chosen = probs_cons
            choice = "conservative"

        info = {
            "strategy": "adaptive",
            "entropy": entropy,
            "adapted_wild_prob": wild_prob,
            "choice": choice,
        }

        return chosen, info

debate

debate(logits)

Adapt strategy based on current entropy.

Source code in src/model_garage/inject/temperature.py
def debate(self, logits: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
    """
    Adapt strategy based on current entropy.
    """
    probs_cons, probs_wild = self.get_perspectives(logits)

    # Compute entropy of conservative distribution
    entropy = -(probs_cons * torch.log(probs_cons + 1e-10)).sum(dim=-1).mean().item()

    # High entropy = be more conservative
    # Low entropy = can be wilder
    if entropy > self.entropy_threshold:
        wild_prob = self.base_wild_prob * 0.5  # Reduce wildness
    else:
        wild_prob = self.base_wild_prob * 1.5  # Increase wildness

    wild_prob = min(wild_prob, 0.5)  # Cap at 50%

    # Apply random switch with adapted probability
    if random.random() < wild_prob:
        chosen = probs_wild
        choice = "exploratory"
    else:
        chosen = probs_cons
        choice = "conservative"

    info = {
        "strategy": "adaptive",
        "entropy": entropy,
        "adapted_wild_prob": wild_prob,
        "choice": choice,
    }

    return chosen, info