Skip to content

Core

The core module provides model loading, hook management, tensor utilities, and device management that power the rest of Model Garage.

ModelLoader

ModelLoader

Standardized model loading for Model Garage.

Handles: - HuggingFace models - Device placement - Memory optimization - Model info extraction

Source code in src/model_garage/core/loader.py
class ModelLoader:
    """
    Standardized model loading for Model Garage.

    Handles:
    - HuggingFace models
    - Device placement
    - Memory optimization
    - Model info extraction
    """

    SUPPORTED_ARCHITECTURES = {
        "gpt2": "GPT2LMHeadModel",
        "llama": "LlamaForCausalLM",
        "gemma": "GemmaForCausalLM",
        "phi": "PhiForCausalLM",
        "mistral": "MistralForCausalLM",
    }

    def __init__(self, device: Optional[str] = None):
        """
        Initialize loader.

        Args:
            device: Target device ("cuda", "cpu", "auto"). Default: auto-detect.
        """
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device

    def load(
        self,
        model_id: str,
        load_tokenizer: bool = True,
        dtype: Optional[torch.dtype] = None,
        **kwargs
    ) -> Tuple[Any, Optional[Any], Dict[str, Any]]:
        """
        Load a model and optionally its tokenizer.

        Args:
            model_id: HuggingFace model ID or local path
            load_tokenizer: Whether to load tokenizer
            dtype: Optional dtype override (e.g., torch.float16)
            **kwargs: Additional args passed to from_pretrained

        Returns:
            (model, tokenizer, model_info)
        """
        from transformers import AutoModelForCausalLM, AutoTokenizer

        # Determine loading strategy based on device
        load_kwargs = {**kwargs}

        if self.device == "cuda":
            # Try to use GPU efficiently
            if dtype is None:
                dtype = torch.float16  # Default to fp16 on GPU

            load_kwargs["torch_dtype"] = dtype

            # Use device_map if model is large
            try:
                load_kwargs["device_map"] = "auto"
            except:
                pass  # Fall back to manual placement

        # Load model
        try:
            model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
        except Exception as e:
            # Fallback: load without device_map
            load_kwargs.pop("device_map", None)
            model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
            model = model.to(self.device)

        model.eval()

        # Load tokenizer
        tokenizer = None
        if load_tokenizer:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

        # Extract model info
        model_info = self._extract_info(model, model_id)

        return model, tokenizer, model_info

    def _extract_info(self, model: Any, model_id: str) -> Dict[str, Any]:
        """Extract useful info about the model."""
        config = model.config

        info = {
            "model_id": model_id,
            "architecture": config.architectures[0] if hasattr(config, "architectures") and config.architectures else "unknown",
            "hidden_size": getattr(config, "hidden_size", getattr(config, "n_embd", None)),
            "num_layers": getattr(config, "num_hidden_layers", getattr(config, "n_layer", None)),
            "num_heads": getattr(config, "num_attention_heads", getattr(config, "n_head", None)),
            "vocab_size": config.vocab_size,
            "max_position": getattr(config, "max_position_embeddings", getattr(config, "n_positions", None)),
            "device": str(next(model.parameters()).device),
            "dtype": str(next(model.parameters()).dtype),
            "total_params": sum(p.numel() for p in model.parameters()),
            "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad),
        }

        return info

    def get_layer_names(self, model: Any) -> Dict[str, str]:
        """
        Get standard layer names for a model.

        Returns dict mapping generic names to model-specific paths.
        """
        config = model.config
        arch = config.architectures[0] if hasattr(config, "architectures") and config.architectures else ""

        if "GPT2" in arch:
            n_layers = config.n_layer
            return {
                "embedding": "transformer.wte",
                "position_embedding": "transformer.wpe",
                "layers": [f"transformer.h.{i}" for i in range(n_layers)],
                "final_norm": "transformer.ln_f",
                "output_head": "lm_head",
            }
        elif "Llama" in arch or "Gemma" in arch or "Mistral" in arch:
            n_layers = config.num_hidden_layers
            return {
                "embedding": "model.embed_tokens",
                "layers": [f"model.layers.{i}" for i in range(n_layers)],
                "final_norm": "model.norm",
                "output_head": "lm_head",
            }
        elif "Phi" in arch:
            n_layers = config.num_hidden_layers
            return {
                "embedding": "model.embed_tokens",
                "layers": [f"model.layers.{i}" for i in range(n_layers)],
                "final_norm": "model.final_layernorm",
                "output_head": "lm_head",
            }
        else:
            # Generic fallback
            return {
                "note": f"Unknown architecture: {arch}. Inspect model manually.",
            }

__init__

__init__(device=None)

Initialize loader.

Parameters:

Name Type Description Default
device Optional[str]

Target device ("cuda", "cpu", "auto"). Default: auto-detect.

None
Source code in src/model_garage/core/loader.py
def __init__(self, device: Optional[str] = None):
    """
    Initialize loader.

    Args:
        device: Target device ("cuda", "cpu", "auto"). Default: auto-detect.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    self.device = device

load

load(model_id, load_tokenizer=True, dtype=None, **kwargs)

Load a model and optionally its tokenizer.

Parameters:

Name Type Description Default
model_id str

HuggingFace model ID or local path

required
load_tokenizer bool

Whether to load tokenizer

True
dtype Optional[dtype]

Optional dtype override (e.g., torch.float16)

None
**kwargs

Additional args passed to from_pretrained

{}

Returns:

Type Description
Tuple[Any, Optional[Any], Dict[str, Any]]

(model, tokenizer, model_info)

Source code in src/model_garage/core/loader.py
def load(
    self,
    model_id: str,
    load_tokenizer: bool = True,
    dtype: Optional[torch.dtype] = None,
    **kwargs
) -> Tuple[Any, Optional[Any], Dict[str, Any]]:
    """
    Load a model and optionally its tokenizer.

    Args:
        model_id: HuggingFace model ID or local path
        load_tokenizer: Whether to load tokenizer
        dtype: Optional dtype override (e.g., torch.float16)
        **kwargs: Additional args passed to from_pretrained

    Returns:
        (model, tokenizer, model_info)
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer

    # Determine loading strategy based on device
    load_kwargs = {**kwargs}

    if self.device == "cuda":
        # Try to use GPU efficiently
        if dtype is None:
            dtype = torch.float16  # Default to fp16 on GPU

        load_kwargs["torch_dtype"] = dtype

        # Use device_map if model is large
        try:
            load_kwargs["device_map"] = "auto"
        except:
            pass  # Fall back to manual placement

    # Load model
    try:
        model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
    except Exception as e:
        # Fallback: load without device_map
        load_kwargs.pop("device_map", None)
        model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
        model = model.to(self.device)

    model.eval()

    # Load tokenizer
    tokenizer = None
    if load_tokenizer:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    # Extract model info
    model_info = self._extract_info(model, model_id)

    return model, tokenizer, model_info

get_layer_names

get_layer_names(model)

Get standard layer names for a model.

Returns dict mapping generic names to model-specific paths.

Source code in src/model_garage/core/loader.py
def get_layer_names(self, model: Any) -> Dict[str, str]:
    """
    Get standard layer names for a model.

    Returns dict mapping generic names to model-specific paths.
    """
    config = model.config
    arch = config.architectures[0] if hasattr(config, "architectures") and config.architectures else ""

    if "GPT2" in arch:
        n_layers = config.n_layer
        return {
            "embedding": "transformer.wte",
            "position_embedding": "transformer.wpe",
            "layers": [f"transformer.h.{i}" for i in range(n_layers)],
            "final_norm": "transformer.ln_f",
            "output_head": "lm_head",
        }
    elif "Llama" in arch or "Gemma" in arch or "Mistral" in arch:
        n_layers = config.num_hidden_layers
        return {
            "embedding": "model.embed_tokens",
            "layers": [f"model.layers.{i}" for i in range(n_layers)],
            "final_norm": "model.norm",
            "output_head": "lm_head",
        }
    elif "Phi" in arch:
        n_layers = config.num_hidden_layers
        return {
            "embedding": "model.embed_tokens",
            "layers": [f"model.layers.{i}" for i in range(n_layers)],
            "final_norm": "model.final_layernorm",
            "output_head": "lm_head",
        }
    else:
        # Generic fallback
        return {
            "note": f"Unknown architecture: {arch}. Inspect model manually.",
        }

quick_load

quick_load

quick_load(model_id, device=None)

Quick helper to load a model.

Usage

model, tokenizer, info = quick_load("gpt2")

Source code in src/model_garage/core/loader.py
def quick_load(model_id: str, device: Optional[str] = None):
    """
    Quick helper to load a model.

    Usage:
        model, tokenizer, info = quick_load("gpt2")
    """
    loader = ModelLoader(device)
    return loader.load(model_id)

HookManager

HookManager

Centralized hook management for model manipulation.

Features: - Named hooks for easy tracking - Automatic cleanup - Hook chaining - Debug logging

Source code in src/model_garage/core/hooks.py
class HookManager:
    """
    Centralized hook management for model manipulation.

    Features:
    - Named hooks for easy tracking
    - Automatic cleanup
    - Hook chaining
    - Debug logging
    """

    def __init__(self, model: nn.Module, debug: bool = False):
        self.model = model
        self.hooks: Dict[str, HookHandle] = {}
        self.debug = debug
        self._captured_data: Dict[str, Any] = {}

    def register_forward_hook(
        self,
        layer_name: str,
        hook_fn: Callable,
        hook_name: Optional[str] = None
    ) -> str:
        """
        Register a forward hook on a named layer.

        Args:
            layer_name: Name of layer (e.g., "transformer.h.6")
            hook_fn: Function(module, input, output) -> modified_output or None
            hook_name: Optional name for this hook

        Returns:
            Hook name for later reference
        """
        layer = self._get_layer(layer_name)
        hook_name = hook_name or f"{layer_name}_forward_{len(self.hooks)}"

        if self.debug:
            original_fn = hook_fn
            def hook_fn(module, input, output):
                print(f"[Hook] {hook_name} triggered on {layer_name}")
                return original_fn(module, input, output)

        handle = layer.register_forward_hook(hook_fn)

        self.hooks[hook_name] = HookHandle(
            name=hook_name,
            layer_name=layer_name,
            hook_type="forward",
            handle=handle
        )

        return hook_name

    def register_capture_hook(
        self,
        layer_name: str,
        hook_name: Optional[str] = None,
        capture_input: bool = False,
        capture_output: bool = True
    ) -> str:
        """
        Register a hook that captures activations without modifying them.

        Captured data accessible via get_captured(hook_name).
        """
        hook_name = hook_name or f"{layer_name}_capture_{len(self.hooks)}"

        def capture_fn(module, input, output):
            data = {}
            if capture_input:
                data["input"] = input[0].detach().clone() if isinstance(input, tuple) else input.detach().clone()
            if capture_output:
                data["output"] = output[0].detach().clone() if isinstance(output, tuple) else output.detach().clone()
            self._captured_data[hook_name] = data
            return None  # Don't modify

        return self.register_forward_hook(layer_name, capture_fn, hook_name)

    def register_injection_hook(
        self,
        layer_name: str,
        injection_fn: Callable[[torch.Tensor], torch.Tensor],
        hook_name: Optional[str] = None
    ) -> str:
        """
        Register a hook that modifies layer output.

        Args:
            layer_name: Name of layer to inject after
            injection_fn: Function(hidden_states) -> modified_hidden_states
            hook_name: Optional name
        """
        hook_name = hook_name or f"{layer_name}_inject_{len(self.hooks)}"

        def inject_fn(module, input, output):
            if isinstance(output, tuple):
                hidden = output[0]
                modified = injection_fn(hidden)
                return (modified,) + output[1:]
            else:
                return injection_fn(output)

        return self.register_forward_hook(layer_name, inject_fn, hook_name)

    def get_captured(self, hook_name: str) -> Optional[Dict[str, torch.Tensor]]:
        """Get data captured by a capture hook."""
        return self._captured_data.get(hook_name)

    def clear_captured(self):
        """Clear all captured data."""
        self._captured_data.clear()

    def remove_hook(self, hook_name: str):
        """Remove a specific hook by name."""
        if hook_name in self.hooks:
            self.hooks[hook_name].remove()
            del self.hooks[hook_name]

    def remove_all(self):
        """Remove all registered hooks."""
        for hook in self.hooks.values():
            hook.remove()
        self.hooks.clear()
        self._captured_data.clear()

    def list_hooks(self) -> List[str]:
        """List all registered hook names."""
        return list(self.hooks.keys())

    def _get_layer(self, layer_name: str) -> nn.Module:
        """Get a layer by dot-separated name."""
        parts = layer_name.split(".")
        module = self.model
        for part in parts:
            if part.isdigit():
                module = module[int(part)]
            else:
                module = getattr(module, part)
        return module

    def __enter__(self):
        """Context manager support."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Cleanup hooks on exit."""
        self.remove_all()

register_forward_hook

register_forward_hook(layer_name, hook_fn, hook_name=None)

Register a forward hook on a named layer.

Parameters:

Name Type Description Default
layer_name str

Name of layer (e.g., "transformer.h.6")

required
hook_fn Callable

Function(module, input, output) -> modified_output or None

required
hook_name Optional[str]

Optional name for this hook

None

Returns:

Type Description
str

Hook name for later reference

Source code in src/model_garage/core/hooks.py
def register_forward_hook(
    self,
    layer_name: str,
    hook_fn: Callable,
    hook_name: Optional[str] = None
) -> str:
    """
    Register a forward hook on a named layer.

    Args:
        layer_name: Name of layer (e.g., "transformer.h.6")
        hook_fn: Function(module, input, output) -> modified_output or None
        hook_name: Optional name for this hook

    Returns:
        Hook name for later reference
    """
    layer = self._get_layer(layer_name)
    hook_name = hook_name or f"{layer_name}_forward_{len(self.hooks)}"

    if self.debug:
        original_fn = hook_fn
        def hook_fn(module, input, output):
            print(f"[Hook] {hook_name} triggered on {layer_name}")
            return original_fn(module, input, output)

    handle = layer.register_forward_hook(hook_fn)

    self.hooks[hook_name] = HookHandle(
        name=hook_name,
        layer_name=layer_name,
        hook_type="forward",
        handle=handle
    )

    return hook_name

register_capture_hook

register_capture_hook(layer_name, hook_name=None, capture_input=False, capture_output=True)

Register a hook that captures activations without modifying them.

Captured data accessible via get_captured(hook_name).

Source code in src/model_garage/core/hooks.py
def register_capture_hook(
    self,
    layer_name: str,
    hook_name: Optional[str] = None,
    capture_input: bool = False,
    capture_output: bool = True
) -> str:
    """
    Register a hook that captures activations without modifying them.

    Captured data accessible via get_captured(hook_name).
    """
    hook_name = hook_name or f"{layer_name}_capture_{len(self.hooks)}"

    def capture_fn(module, input, output):
        data = {}
        if capture_input:
            data["input"] = input[0].detach().clone() if isinstance(input, tuple) else input.detach().clone()
        if capture_output:
            data["output"] = output[0].detach().clone() if isinstance(output, tuple) else output.detach().clone()
        self._captured_data[hook_name] = data
        return None  # Don't modify

    return self.register_forward_hook(layer_name, capture_fn, hook_name)

register_injection_hook

register_injection_hook(layer_name, injection_fn, hook_name=None)

Register a hook that modifies layer output.

Parameters:

Name Type Description Default
layer_name str

Name of layer to inject after

required
injection_fn Callable[[Tensor], Tensor]

Function(hidden_states) -> modified_hidden_states

required
hook_name Optional[str]

Optional name

None
Source code in src/model_garage/core/hooks.py
def register_injection_hook(
    self,
    layer_name: str,
    injection_fn: Callable[[torch.Tensor], torch.Tensor],
    hook_name: Optional[str] = None
) -> str:
    """
    Register a hook that modifies layer output.

    Args:
        layer_name: Name of layer to inject after
        injection_fn: Function(hidden_states) -> modified_hidden_states
        hook_name: Optional name
    """
    hook_name = hook_name or f"{layer_name}_inject_{len(self.hooks)}"

    def inject_fn(module, input, output):
        if isinstance(output, tuple):
            hidden = output[0]
            modified = injection_fn(hidden)
            return (modified,) + output[1:]
        else:
            return injection_fn(output)

    return self.register_forward_hook(layer_name, inject_fn, hook_name)

get_captured

get_captured(hook_name)

Get data captured by a capture hook.

Source code in src/model_garage/core/hooks.py
def get_captured(self, hook_name: str) -> Optional[Dict[str, torch.Tensor]]:
    """Get data captured by a capture hook."""
    return self._captured_data.get(hook_name)

clear_captured

clear_captured()

Clear all captured data.

Source code in src/model_garage/core/hooks.py
def clear_captured(self):
    """Clear all captured data."""
    self._captured_data.clear()

remove_hook

remove_hook(hook_name)

Remove a specific hook by name.

Source code in src/model_garage/core/hooks.py
def remove_hook(self, hook_name: str):
    """Remove a specific hook by name."""
    if hook_name in self.hooks:
        self.hooks[hook_name].remove()
        del self.hooks[hook_name]

remove_all

remove_all()

Remove all registered hooks.

Source code in src/model_garage/core/hooks.py
def remove_all(self):
    """Remove all registered hooks."""
    for hook in self.hooks.values():
        hook.remove()
    self.hooks.clear()
    self._captured_data.clear()

list_hooks

list_hooks()

List all registered hook names.

Source code in src/model_garage/core/hooks.py
def list_hooks(self) -> List[str]:
    """List all registered hook names."""
    return list(self.hooks.keys())

__enter__

__enter__()

Context manager support.

Source code in src/model_garage/core/hooks.py
def __enter__(self):
    """Context manager support."""
    return self

__exit__

__exit__(exc_type, exc_val, exc_tb)

Cleanup hooks on exit.

Source code in src/model_garage/core/hooks.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Cleanup hooks on exit."""
    self.remove_all()

TensorUtils

TensorUtils

Common tensor operations used across all tools.

Source code in src/model_garage/core/tensor.py
class TensorUtils:
    """
    Common tensor operations used across all tools.
    """

    @staticmethod
    def ensure_device(tensor: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
        """Move tensor to specified device if not already there."""
        if tensor.device != torch.device(device):
            return tensor.to(device)
        return tensor

    @staticmethod
    def ensure_shape(tensor: torch.Tensor, target_shape: Tuple[int, ...]) -> torch.Tensor:
        """
        Reshape tensor to target shape, handling common cases.

        Supports:
        - Adding batch dimension
        - Adding sequence dimension
        - Padding/truncating sequence length
        """
        current = tensor.shape
        target = target_shape

        # Add batch dim if missing
        if len(current) == len(target) - 1:
            tensor = tensor.unsqueeze(0)
            current = tensor.shape

        # Add sequence dim if missing
        if len(current) == len(target) - 1:
            tensor = tensor.unsqueeze(1)
            current = tensor.shape

        # Handle sequence length mismatch
        if len(current) == len(target) and current[1] != target[1]:
            if current[1] < target[1]:
                # Pad
                padding = torch.zeros(
                    current[0], target[1] - current[1], *current[2:],
                    device=tensor.device, dtype=tensor.dtype
                )
                tensor = torch.cat([tensor, padding], dim=1)
            else:
                # Truncate
                tensor = tensor[:, :target[1]]

        return tensor

    @staticmethod
    def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
        """Compute cosine similarity between two tensors."""
        a_flat = a.flatten().float()
        b_flat = b.flatten().float()
        return torch.nn.functional.cosine_similarity(
            a_flat.unsqueeze(0), b_flat.unsqueeze(0)
        ).item()

    @staticmethod
    def l2_distance(a: torch.Tensor, b: torch.Tensor) -> float:
        """Compute L2 distance between two tensors."""
        return (a - b).norm().item()

    @staticmethod
    def stats(tensor: torch.Tensor) -> dict:
        """Get basic statistics about a tensor."""
        return {
            "shape": list(tensor.shape),
            "dtype": str(tensor.dtype),
            "device": str(tensor.device),
            "mean": tensor.float().mean().item(),
            "std": tensor.float().std().item(),
            "min": tensor.float().min().item(),
            "max": tensor.float().max().item(),
            "sparsity": (tensor == 0).float().mean().item(),
        }

    @staticmethod
    def project(tensor: torch.Tensor, from_dim: int, to_dim: int) -> torch.Tensor:
        """
        Project tensor from one dimension to another using learned linear.

        Note: This creates a NEW projection each time. For reusable projections,
        use the Projector class instead.
        """
        projection = torch.nn.Linear(from_dim, to_dim, device=tensor.device)
        return projection(tensor)

ensure_device staticmethod

ensure_device(tensor, device)

Move tensor to specified device if not already there.

Source code in src/model_garage/core/tensor.py
@staticmethod
def ensure_device(tensor: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
    """Move tensor to specified device if not already there."""
    if tensor.device != torch.device(device):
        return tensor.to(device)
    return tensor

ensure_shape staticmethod

ensure_shape(tensor, target_shape)

Reshape tensor to target shape, handling common cases.

Supports: - Adding batch dimension - Adding sequence dimension - Padding/truncating sequence length

Source code in src/model_garage/core/tensor.py
@staticmethod
def ensure_shape(tensor: torch.Tensor, target_shape: Tuple[int, ...]) -> torch.Tensor:
    """
    Reshape tensor to target shape, handling common cases.

    Supports:
    - Adding batch dimension
    - Adding sequence dimension
    - Padding/truncating sequence length
    """
    current = tensor.shape
    target = target_shape

    # Add batch dim if missing
    if len(current) == len(target) - 1:
        tensor = tensor.unsqueeze(0)
        current = tensor.shape

    # Add sequence dim if missing
    if len(current) == len(target) - 1:
        tensor = tensor.unsqueeze(1)
        current = tensor.shape

    # Handle sequence length mismatch
    if len(current) == len(target) and current[1] != target[1]:
        if current[1] < target[1]:
            # Pad
            padding = torch.zeros(
                current[0], target[1] - current[1], *current[2:],
                device=tensor.device, dtype=tensor.dtype
            )
            tensor = torch.cat([tensor, padding], dim=1)
        else:
            # Truncate
            tensor = tensor[:, :target[1]]

    return tensor

cosine_similarity staticmethod

cosine_similarity(a, b)

Compute cosine similarity between two tensors.

Source code in src/model_garage/core/tensor.py
@staticmethod
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    """Compute cosine similarity between two tensors."""
    a_flat = a.flatten().float()
    b_flat = b.flatten().float()
    return torch.nn.functional.cosine_similarity(
        a_flat.unsqueeze(0), b_flat.unsqueeze(0)
    ).item()

l2_distance staticmethod

l2_distance(a, b)

Compute L2 distance between two tensors.

Source code in src/model_garage/core/tensor.py
@staticmethod
def l2_distance(a: torch.Tensor, b: torch.Tensor) -> float:
    """Compute L2 distance between two tensors."""
    return (a - b).norm().item()

stats staticmethod

stats(tensor)

Get basic statistics about a tensor.

Source code in src/model_garage/core/tensor.py
@staticmethod
def stats(tensor: torch.Tensor) -> dict:
    """Get basic statistics about a tensor."""
    return {
        "shape": list(tensor.shape),
        "dtype": str(tensor.dtype),
        "device": str(tensor.device),
        "mean": tensor.float().mean().item(),
        "std": tensor.float().std().item(),
        "min": tensor.float().min().item(),
        "max": tensor.float().max().item(),
        "sparsity": (tensor == 0).float().mean().item(),
    }

project staticmethod

project(tensor, from_dim, to_dim)

Project tensor from one dimension to another using learned linear.

Note: This creates a NEW projection each time. For reusable projections, use the Projector class instead.

Source code in src/model_garage/core/tensor.py
@staticmethod
def project(tensor: torch.Tensor, from_dim: int, to_dim: int) -> torch.Tensor:
    """
    Project tensor from one dimension to another using learned linear.

    Note: This creates a NEW projection each time. For reusable projections,
    use the Projector class instead.
    """
    projection = torch.nn.Linear(from_dim, to_dim, device=tensor.device)
    return projection(tensor)

Projector

Projector

Reusable dimension projector.

Like an adapter socket - converts between different sizes.

Source code in src/model_garage/core/tensor.py
class Projector:
    """
    Reusable dimension projector.

    Like an adapter socket - converts between different sizes.
    """

    def __init__(self, from_dim: int, to_dim: int, device: str = "cpu"):
        self.projection = torch.nn.Linear(from_dim, to_dim)
        self.projection = self.projection.to(device)
        self.from_dim = from_dim
        self.to_dim = to_dim

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        """Project tensor to new dimension."""
        return self.projection(tensor)

    def save(self, path: str):
        """Save projector weights."""
        torch.save({
            "weights": self.projection.state_dict(),
            "from_dim": self.from_dim,
            "to_dim": self.to_dim,
        }, path)

    @classmethod
    def load(cls, path: str, device: str = "cpu") -> "Projector":
        """Load projector from file."""
        data = torch.load(path, map_location=device)
        proj = cls(data["from_dim"], data["to_dim"], device)
        proj.projection.load_state_dict(data["weights"])
        return proj

forward

forward(tensor)

Project tensor to new dimension.

Source code in src/model_garage/core/tensor.py
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
    """Project tensor to new dimension."""
    return self.projection(tensor)

save

save(path)

Save projector weights.

Source code in src/model_garage/core/tensor.py
def save(self, path: str):
    """Save projector weights."""
    torch.save({
        "weights": self.projection.state_dict(),
        "from_dim": self.from_dim,
        "to_dim": self.to_dim,
    }, path)

load classmethod

load(path, device='cpu')

Load projector from file.

Source code in src/model_garage/core/tensor.py
@classmethod
def load(cls, path: str, device: str = "cpu") -> "Projector":
    """Load projector from file."""
    data = torch.load(path, map_location=device)
    proj = cls(data["from_dim"], data["to_dim"], device)
    proj.projection.load_state_dict(data["weights"])
    return proj

DeviceManager

DeviceManager

Manage device placement for a session.

Provides consistent device handling across multiple operations.

Source code in src/model_garage/core/device.py
class DeviceManager:
    """
    Manage device placement for a session.

    Provides consistent device handling across multiple operations.
    """

    def __init__(self, device: Optional[str] = None):
        if device is None:
            self.device = get_device()
        else:
            self.device = torch.device(device)

    def to(self, tensor: torch.Tensor) -> torch.Tensor:
        """Move tensor to managed device."""
        return ensure_device(tensor, self.device)

    def to_dict(self, d: dict) -> dict:
        """Move all tensors in a dict to managed device."""
        return {
            k: self.to(v) if isinstance(v, torch.Tensor) else v
            for k, v in d.items()
        }

    @contextmanager
    def scope(self):
        """Context manager for device scope."""
        with torch.device(self.device):
            yield

    @property
    def is_gpu(self) -> bool:
        return self.device.type == "cuda"

    def memory_stats(self) -> dict:
        """Get GPU memory stats if available."""
        if not self.is_gpu:
            return {"device": "cpu"}
        return {
            "device": str(self.device),
            "allocated_mb": torch.cuda.memory_allocated(self.device) / 1024 / 1024,
            "cached_mb": torch.cuda.memory_reserved(self.device) / 1024 / 1024,
            "max_allocated_mb": torch.cuda.max_memory_allocated(self.device) / 1024 / 1024,
        }

    def clear_cache(self):
        """Clear GPU cache if on CUDA."""
        if self.is_gpu:
            torch.cuda.empty_cache()

to

to(tensor)

Move tensor to managed device.

Source code in src/model_garage/core/device.py
def to(self, tensor: torch.Tensor) -> torch.Tensor:
    """Move tensor to managed device."""
    return ensure_device(tensor, self.device)

to_dict

to_dict(d)

Move all tensors in a dict to managed device.

Source code in src/model_garage/core/device.py
def to_dict(self, d: dict) -> dict:
    """Move all tensors in a dict to managed device."""
    return {
        k: self.to(v) if isinstance(v, torch.Tensor) else v
        for k, v in d.items()
    }

scope

scope()

Context manager for device scope.

Source code in src/model_garage/core/device.py
@contextmanager
def scope(self):
    """Context manager for device scope."""
    with torch.device(self.device):
        yield

memory_stats

memory_stats()

Get GPU memory stats if available.

Source code in src/model_garage/core/device.py
def memory_stats(self) -> dict:
    """Get GPU memory stats if available."""
    if not self.is_gpu:
        return {"device": "cpu"}
    return {
        "device": str(self.device),
        "allocated_mb": torch.cuda.memory_allocated(self.device) / 1024 / 1024,
        "cached_mb": torch.cuda.memory_reserved(self.device) / 1024 / 1024,
        "max_allocated_mb": torch.cuda.max_memory_allocated(self.device) / 1024 / 1024,
    }

clear_cache

clear_cache()

Clear GPU cache if on CUDA.

Source code in src/model_garage/core/device.py
def clear_cache(self):
    """Clear GPU cache if on CUDA."""
    if self.is_gpu:
        torch.cuda.empty_cache()

Serialization

save_component

save_component(component, path, metadata=None, **extra_metadata)

Save a component with metadata.

Parameters:

Name Type Description Default
component Union[Module, Tensor, Dict[str, Tensor]]

Module, Tensor, or state dict to save

required
path Union[str, Path]

Directory to save to

required
metadata Optional[ComponentMetadata]

Optional ComponentMetadata

None
**extra_metadata

Additional metadata fields

{}

Returns:

Type Description
Path

Path to saved component directory

Source code in src/model_garage/core/serialization.py
def save_component(
    component: Union[torch.nn.Module, torch.Tensor, Dict[str, torch.Tensor]],
    path: Union[str, Path],
    metadata: Optional[ComponentMetadata] = None,
    **extra_metadata,
) -> Path:
    """
    Save a component with metadata.

    Args:
        component: Module, Tensor, or state dict to save
        path: Directory to save to
        metadata: Optional ComponentMetadata
        **extra_metadata: Additional metadata fields

    Returns:
        Path to saved component directory
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    if isinstance(component, torch.nn.Module):
        state_dict = component.state_dict()
        component_type = "module"
    elif isinstance(component, torch.Tensor):
        state_dict = {"tensor": component}
        component_type = "tensor"
    elif isinstance(component, dict):
        state_dict = component
        component_type = "state_dict"
    else:
        raise ValueError(f"Cannot save component of type {type(component)}")

    torch.save(state_dict, path / "weights.pt")

    config = {
        "component_type": component_type,
        "saved_at": datetime.now().isoformat(),
        "toolkit_version": "0.1.0",
        "shapes": {k: list(v.shape) for k, v in state_dict.items()},
        "dtypes": {k: str(v.dtype) for k, v in state_dict.items()},
    }

    if metadata:
        config["metadata"] = metadata.to_dict()
    config.update(extra_metadata)

    with open(path / "config.json", "w") as f:
        json.dump(config, f, indent=2)

    return path

load_component

load_component(path, device='cpu', return_metadata=False)

Load a saved component.

Parameters:

Name Type Description Default
path Union[str, Path]

Directory containing saved component

required
device str

Device to load to

'cpu'
return_metadata bool

If True, also return metadata

False

Returns:

Type Description
Union[Dict[str, Tensor], tuple]

state_dict, or (state_dict, metadata) if return_metadata=True

Source code in src/model_garage/core/serialization.py
def load_component(
    path: Union[str, Path],
    device: str = "cpu",
    return_metadata: bool = False,
) -> Union[Dict[str, torch.Tensor], tuple]:
    """
    Load a saved component.

    Args:
        path: Directory containing saved component
        device: Device to load to
        return_metadata: If True, also return metadata

    Returns:
        state_dict, or (state_dict, metadata) if return_metadata=True
    """
    path = Path(path)
    state_dict = torch.load(path / "weights.pt", map_location=device, weights_only=True)

    if not return_metadata:
        return state_dict

    config_path = path / "config.json"
    metadata = None
    if config_path.exists():
        with open(config_path) as f:
            config = json.load(f)
        raw = config.get("metadata")
        if raw:
            metadata = ComponentMetadata.from_dict(raw)

    return state_dict, metadata