Skip to content

Registry

Model decomposition, component cataloging, and architecture comparison.

ModelRegistry

ModelRegistry

Central registry for decomposed models.

Usage

registry = ModelRegistry()

Decompose and register a model

spec = registry.register("meta-llama/Llama-2-7b-hf", model)

Get parts for experiments

attn_layer_5 = registry.get_part("meta-llama/Llama-2-7b-hf", "attention_5")

List all registered models

models = registry.list_models()

Compare architectures

registry.compare("gpt2", "meta-llama/Llama-2-7b-hf")

Source code in src/model_garage/registry/models.py
class ModelRegistry:
    """
    Central registry for decomposed models.

    Usage:
        registry = ModelRegistry()

        # Decompose and register a model
        spec = registry.register("meta-llama/Llama-2-7b-hf", model)

        # Get parts for experiments
        attn_layer_5 = registry.get_part("meta-llama/Llama-2-7b-hf", "attention_5")

        # List all registered models
        models = registry.list_models()

        # Compare architectures
        registry.compare("gpt2", "meta-llama/Llama-2-7b-hf")
    """

    def __init__(self, cache_dir: Optional[Path] = None):
        self.decomposers: List[ModelDecomposer] = [
            GPT2Decomposer(),
            LlamaDecomposer(),
            MistralDecomposer(),
            GemmaDecomposer(),
            QwenDecomposer(),
            PhiDecomposer(),
        ]
        self.specs: Dict[str, ModelSpec] = {}
        self.models: Dict[str, nn.Module] = {}  # Cached loaded models
        self.cache_dir = cache_dir or Path.home() / ".model_garage" / "registry"
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def detect_family(self, model: nn.Module, model_id: str) -> ModelFamily:
        """Detect which family a model belongs to."""
        for decomposer in self.decomposers:
            if decomposer.detect(model, model_id):
                return decomposer.decompose(model, model_id).family
        return ModelFamily.UNKNOWN

    def register(self, model_id: str, model: nn.Module) -> ModelSpec:
        """Decompose and register a model."""
        for decomposer in self.decomposers:
            if decomposer.detect(model, model_id):
                spec = decomposer.decompose(model, model_id)
                self.specs[model_id] = spec
                self.models[model_id] = model
                self._save_spec(spec)
                return spec

        raise ValueError(f"No decomposer found for model: {model_id}")

    def get_spec(self, model_id: str) -> Optional[ModelSpec]:
        """Get the spec for a registered model."""
        if model_id in self.specs:
            return self.specs[model_id]
        # Try loading from cache
        return self._load_spec(model_id)

    def get_part(self, model_id: str, part_name: str) -> Optional[PartSpec]:
        """Get a specific part from a registered model."""
        spec = self.get_spec(model_id)
        if spec:
            return spec.parts.get(part_name)
        return None

    def get_module(self, model_id: str, part_name: str) -> Optional[nn.Module]:
        """Get the actual module for a part."""
        if model_id not in self.models:
            return None
        spec = self.get_spec(model_id)
        if not spec or part_name not in spec.parts:
            return None

        part = spec.parts[part_name]
        for decomposer in self.decomposers:
            if decomposer.detect(self.models[model_id], model_id):
                return decomposer.get_module(self.models[model_id], part)
        return None

    def list_models(self) -> List[str]:
        """List all registered models."""
        return list(self.specs.keys())

    def list_parts(self, model_id: str, part_type: Optional[PartType] = None) -> List[str]:
        """List all parts for a model, optionally filtered by type."""
        spec = self.get_spec(model_id)
        if not spec:
            return []
        if part_type:
            return [name for name, part in spec.parts.items() if part.part_type == part_type]
        return list(spec.parts.keys())

    def compare(self, model_a: str, model_b: str) -> Dict[str, Any]:
        """Compare two model architectures."""
        spec_a = self.get_spec(model_a)
        spec_b = self.get_spec(model_b)

        if not spec_a or not spec_b:
            return {"error": "One or both models not registered"}

        return {
            "models": [model_a, model_b],
            "families": [spec_a.family.value, spec_b.family.value],
            "hidden_dims": [spec_a.hidden_dim, spec_b.hidden_dim],
            "num_layers": [spec_a.num_layers, spec_b.num_layers],
            "num_heads": [spec_a.num_heads, spec_b.num_heads],
            "vocab_sizes": [spec_a.vocab_size, spec_b.vocab_size],
            "compatible_parts": self._find_compatible_parts(spec_a, spec_b),
        }

    def _find_compatible_parts(self, spec_a: ModelSpec, spec_b: ModelSpec) -> Dict[str, List[str]]:
        """Find parts that could potentially be swapped between models."""
        compatible = {
            "same_dim": [],
            "attention_compatible": [],
            "ffn_compatible": [],
        }

        # Same hidden dimension = most compatible
        if spec_a.hidden_dim == spec_b.hidden_dim:
            compatible["same_dim"] = ["All parts potentially swappable"]

        # Check attention compatibility
        if spec_a.num_heads == spec_b.num_heads:
            compatible["attention_compatible"] = [
                f"attention_0 through attention_{min(spec_a.num_layers, spec_b.num_layers) - 1}"
            ]

        return compatible

    def _save_spec(self, spec: ModelSpec) -> None:
        """Save spec to cache."""
        safe_id = spec.model_id.replace("/", "__")
        path = self.cache_dir / f"{safe_id}.json"

        data = {
            "model_id": spec.model_id,
            "family": spec.family.value,
            "hidden_dim": spec.hidden_dim,
            "num_layers": spec.num_layers,
            "num_heads": spec.num_heads,
            "vocab_size": spec.vocab_size,
            "max_seq_len": spec.max_seq_len,
            "extra_info": spec.extra_info,
            "parts": {
                name: {
                    "part_type": part.part_type.value,
                    "layer_idx": part.layer_idx,
                    "module_path": part.module_path,
                    "input_dim": part.input_dim,
                    "output_dim": part.output_dim,
                    "num_heads": part.num_heads,
                    "head_dim": part.head_dim,
                    "intermediate_dim": part.intermediate_dim,
                    "extra_info": part.extra_info,
                }
                for name, part in spec.parts.items()
            }
        }

        with open(path, "w") as f:
            json.dump(data, f, indent=2)

    def _load_spec(self, model_id: str) -> Optional[ModelSpec]:
        """Load spec from cache."""
        safe_id = model_id.replace("/", "__")
        path = self.cache_dir / f"{safe_id}.json"

        if not path.exists():
            return None

        with open(path) as f:
            data = json.load(f)

        spec = ModelSpec(
            model_id=data["model_id"],
            family=ModelFamily(data["family"]),
            hidden_dim=data["hidden_dim"],
            num_layers=data["num_layers"],
            num_heads=data["num_heads"],
            vocab_size=data["vocab_size"],
            max_seq_len=data["max_seq_len"],
            extra_info=data.get("extra_info", {}),
        )

        for name, part_data in data["parts"].items():
            spec.parts[name] = PartSpec(
                part_type=PartType(part_data["part_type"]),
                layer_idx=part_data["layer_idx"],
                module_path=part_data["module_path"],
                input_dim=part_data["input_dim"],
                output_dim=part_data["output_dim"],
                num_heads=part_data.get("num_heads"),
                head_dim=part_data.get("head_dim"),
                intermediate_dim=part_data.get("intermediate_dim"),
                extra_info=part_data.get("extra_info", {}),
            )

        self.specs[model_id] = spec
        return spec

detect_family

detect_family(model, model_id)

Detect which family a model belongs to.

Source code in src/model_garage/registry/models.py
def detect_family(self, model: nn.Module, model_id: str) -> ModelFamily:
    """Detect which family a model belongs to."""
    for decomposer in self.decomposers:
        if decomposer.detect(model, model_id):
            return decomposer.decompose(model, model_id).family
    return ModelFamily.UNKNOWN

register

register(model_id, model)

Decompose and register a model.

Source code in src/model_garage/registry/models.py
def register(self, model_id: str, model: nn.Module) -> ModelSpec:
    """Decompose and register a model."""
    for decomposer in self.decomposers:
        if decomposer.detect(model, model_id):
            spec = decomposer.decompose(model, model_id)
            self.specs[model_id] = spec
            self.models[model_id] = model
            self._save_spec(spec)
            return spec

    raise ValueError(f"No decomposer found for model: {model_id}")

get_spec

get_spec(model_id)

Get the spec for a registered model.

Source code in src/model_garage/registry/models.py
def get_spec(self, model_id: str) -> Optional[ModelSpec]:
    """Get the spec for a registered model."""
    if model_id in self.specs:
        return self.specs[model_id]
    # Try loading from cache
    return self._load_spec(model_id)

get_part

get_part(model_id, part_name)

Get a specific part from a registered model.

Source code in src/model_garage/registry/models.py
def get_part(self, model_id: str, part_name: str) -> Optional[PartSpec]:
    """Get a specific part from a registered model."""
    spec = self.get_spec(model_id)
    if spec:
        return spec.parts.get(part_name)
    return None

get_module

get_module(model_id, part_name)

Get the actual module for a part.

Source code in src/model_garage/registry/models.py
def get_module(self, model_id: str, part_name: str) -> Optional[nn.Module]:
    """Get the actual module for a part."""
    if model_id not in self.models:
        return None
    spec = self.get_spec(model_id)
    if not spec or part_name not in spec.parts:
        return None

    part = spec.parts[part_name]
    for decomposer in self.decomposers:
        if decomposer.detect(self.models[model_id], model_id):
            return decomposer.get_module(self.models[model_id], part)
    return None

list_models

list_models()

List all registered models.

Source code in src/model_garage/registry/models.py
def list_models(self) -> List[str]:
    """List all registered models."""
    return list(self.specs.keys())

list_parts

list_parts(model_id, part_type=None)

List all parts for a model, optionally filtered by type.

Source code in src/model_garage/registry/models.py
def list_parts(self, model_id: str, part_type: Optional[PartType] = None) -> List[str]:
    """List all parts for a model, optionally filtered by type."""
    spec = self.get_spec(model_id)
    if not spec:
        return []
    if part_type:
        return [name for name, part in spec.parts.items() if part.part_type == part_type]
    return list(spec.parts.keys())

compare

compare(model_a, model_b)

Compare two model architectures.

Source code in src/model_garage/registry/models.py
def compare(self, model_a: str, model_b: str) -> Dict[str, Any]:
    """Compare two model architectures."""
    spec_a = self.get_spec(model_a)
    spec_b = self.get_spec(model_b)

    if not spec_a or not spec_b:
        return {"error": "One or both models not registered"}

    return {
        "models": [model_a, model_b],
        "families": [spec_a.family.value, spec_b.family.value],
        "hidden_dims": [spec_a.hidden_dim, spec_b.hidden_dim],
        "num_layers": [spec_a.num_layers, spec_b.num_layers],
        "num_heads": [spec_a.num_heads, spec_b.num_heads],
        "vocab_sizes": [spec_a.vocab_size, spec_b.vocab_size],
        "compatible_parts": self._find_compatible_parts(spec_a, spec_b),
    }

ModelSpec

ModelSpec dataclass

Full specification of a decomposed model.

Source code in src/model_garage/registry/models.py
@dataclass
class ModelSpec:
    """Full specification of a decomposed model."""
    model_id: str                     # HuggingFace model ID
    family: ModelFamily
    hidden_dim: int
    num_layers: int
    num_heads: int
    vocab_size: int
    max_seq_len: int
    parts: Dict[str, PartSpec] = field(default_factory=dict)
    extra_info: Dict[str, Any] = field(default_factory=dict)

    def get_attention(self, layer_idx: int) -> Optional[PartSpec]:
        """Get attention part for a specific layer."""
        key = f"attention_{layer_idx}"
        return self.parts.get(key)

    def get_ffn(self, layer_idx: int) -> Optional[PartSpec]:
        """Get FFN part for a specific layer."""
        key = f"ffn_{layer_idx}"
        return self.parts.get(key)

    def get_full_layer(self, layer_idx: int) -> Optional[PartSpec]:
        """Get full layer part."""
        key = f"layer_{layer_idx}"
        return self.parts.get(key)

    def all_attention_parts(self) -> List[PartSpec]:
        """Get all attention parts."""
        return [p for p in self.parts.values() if p.part_type == PartType.ATTENTION]

    def all_ffn_parts(self) -> List[PartSpec]:
        """Get all FFN parts."""
        return [p for p in self.parts.values() if p.part_type == PartType.FFN]

get_attention

get_attention(layer_idx)

Get attention part for a specific layer.

Source code in src/model_garage/registry/models.py
def get_attention(self, layer_idx: int) -> Optional[PartSpec]:
    """Get attention part for a specific layer."""
    key = f"attention_{layer_idx}"
    return self.parts.get(key)

get_ffn

get_ffn(layer_idx)

Get FFN part for a specific layer.

Source code in src/model_garage/registry/models.py
def get_ffn(self, layer_idx: int) -> Optional[PartSpec]:
    """Get FFN part for a specific layer."""
    key = f"ffn_{layer_idx}"
    return self.parts.get(key)

get_full_layer

get_full_layer(layer_idx)

Get full layer part.

Source code in src/model_garage/registry/models.py
def get_full_layer(self, layer_idx: int) -> Optional[PartSpec]:
    """Get full layer part."""
    key = f"layer_{layer_idx}"
    return self.parts.get(key)

all_attention_parts

all_attention_parts()

Get all attention parts.

Source code in src/model_garage/registry/models.py
def all_attention_parts(self) -> List[PartSpec]:
    """Get all attention parts."""
    return [p for p in self.parts.values() if p.part_type == PartType.ATTENTION]

all_ffn_parts

all_ffn_parts()

Get all FFN parts.

Source code in src/model_garage/registry/models.py
def all_ffn_parts(self) -> List[PartSpec]:
    """Get all FFN parts."""
    return [p for p in self.parts.values() if p.part_type == PartType.FFN]

PartSpec

PartSpec dataclass

Specification for a model part (rust bucket component).

Source code in src/model_garage/registry/models.py
@dataclass
class PartSpec:
    """Specification for a model part (rust bucket component)."""
    part_type: PartType
    layer_idx: Optional[int]          # None for non-layer parts (embedding, output)
    module_path: str                  # Path in model (e.g., "transformer.h.0.attn")
    input_dim: int
    output_dim: int
    num_heads: Optional[int] = None   # For attention
    head_dim: Optional[int] = None    # For attention
    intermediate_dim: Optional[int] = None  # For FFN
    extra_info: Dict[str, Any] = field(default_factory=dict)

ModelDecomposer

ModelDecomposer

Bases: ABC

Abstract base for model-family-specific decomposers.

Source code in src/model_garage/registry/models.py
class ModelDecomposer(ABC):
    """Abstract base for model-family-specific decomposers."""

    @abstractmethod
    def detect(self, model: nn.Module, model_id: str) -> bool:
        """Check if this decomposer handles the given model."""
        pass

    @abstractmethod
    def decompose(self, model: nn.Module, model_id: str) -> ModelSpec:
        """Decompose the model into parts."""
        pass

    @abstractmethod
    def get_module(self, model: nn.Module, part: PartSpec) -> nn.Module:
        """Get the actual module for a part."""
        pass

detect abstractmethod

detect(model, model_id)

Check if this decomposer handles the given model.

Source code in src/model_garage/registry/models.py
@abstractmethod
def detect(self, model: nn.Module, model_id: str) -> bool:
    """Check if this decomposer handles the given model."""
    pass

decompose abstractmethod

decompose(model, model_id)

Decompose the model into parts.

Source code in src/model_garage/registry/models.py
@abstractmethod
def decompose(self, model: nn.Module, model_id: str) -> ModelSpec:
    """Decompose the model into parts."""
    pass

get_module abstractmethod

get_module(model, part)

Get the actual module for a part.

Source code in src/model_garage/registry/models.py
@abstractmethod
def get_module(self, model: nn.Module, part: PartSpec) -> nn.Module:
    """Get the actual module for a part."""
    pass

ModelFamily

ModelFamily

Bases: Enum

Supported model architecture families.

Source code in src/model_garage/registry/models.py
class ModelFamily(Enum):
    """Supported model architecture families."""
    GPT2 = "gpt2"
    LLAMA = "llama"
    MISTRAL = "mistral"
    GEMMA = "gemma"
    QWEN = "qwen"
    PHI = "phi"
    UNKNOWN = "unknown"

PartType

PartType

Bases: Enum

Standard rust bucket part types.

Source code in src/model_garage/registry/models.py
class PartType(Enum):
    """Standard rust bucket part types."""
    EMBEDDING = "embedding"           # Token + position embeddings
    ATTENTION = "attention"           # Self-attention block
    FFN = "ffn"                       # Feed-forward network
    LAYER_NORM = "layer_norm"         # Normalization layer
    OUTPUT_HEAD = "output_head"       # Final projection to vocab
    ROTARY_EMB = "rotary_emb"         # RoPE embeddings (modern models)
    GATE = "gate"                     # MoE gating (Mixtral, etc.)
    FULL_LAYER = "full_layer"         # Complete transformer layer