🚀 DepthScale Quick Start Guide¶
Welcome to DepthScale! This guide will get you up and running with the Universal Self-Decoder framework, designed for parameter-shared, iterative reasoning in transformer models.
🎯 Goal Overview¶
DepthScale implements a framework that leverages parameter-shared iterative reasoning. By recursively applying the same transformer weights across multiple reasoning steps, we aim to achieve: 1. Constant Memory Overhead: Efficiently reusing model weights across iterations. 2. Convergence-Based Refinement: Improving multi-step logical reasoning accuracy through iterative refinement.
📦 Installation & Setup¶
Assuming you have the necessary dependencies installed, the core components are located within the universal_yoco module.
Prerequisites: * Python 3.8+ * PyTorch (or compatible framework)
Structure:
The primary components are:
* universal_yoco/types.py: Defines core data structures and types.
* universal_yoco/yoco_base.py: Contains the core logic for the Self-Decoder and iteration management.
* universal_yoco/__init__.py: Entry point for the module.
🛠️ Core Concepts¶
The framework revolves around the YocoBase class, which manages the state and the iterative application of the shared transformer.
- Parameter Sharing: The core transformer weights are instantiated once and reused in every reasoning step.
- Iterative Reasoning: The process involves feeding the output of step $t$ back as input context for step $t+1$, guided by specialized attention mechanisms.
💡 Usage Examples¶
Here are a few examples demonstrating how to initialize and use the DepthScale components.
Example 1: Basic Initialization and Single-Step Inference¶
This example shows how to instantiate the base decoder and perform a single forward pass without iterative refinement.
from universal_yoco.yoco_base import YocoBase
from universal_yoco.types import ReasoningState
# 1. Initialize the base model (assuming a pre-trained or initialized transformer)
# In a real scenario, 'transformer_model' would be your actual PyTorch/TF model instance.
class MockTransformer:
def forward(self, input_ids, attention_mask):
# Mock output: (batch_size, sequence_length, hidden_size)
return torch.randn(1, 10, 768)
transformer_model = MockTransformer()
initial_state = ReasoningState(initial_context="The premise is X.")
# 2. Instantiate the YocoBase decoder
yoco_decoder = YocoBase(
transformer=transformer_model,
initial_state=initial_state
)
# 3. Perform a single inference step
output_logits = yoco_decoder.step(input_ids=None, attention_mask=None)
print("--- Example 1: Single Step Inference ---")
print(f"Output shape: {output_logits.shape}")
Example 2: Implementing Iterative Refinement (Convergence Loop)¶
This is the core use case. We loop the step function, allowing the model to refine its reasoning over several iterations until convergence or a maximum step limit is reached.
import torch
from universal_yoco.yoco_base import YocoBase
from universal_yoco.types import ReasoningState
# Setup (using mocks from Example 1)
transformer_model = MockTransformer()
initial_state = ReasoningState(initial_context="Q: What is the capital of France? A: Paris.")
yoco_decoder = YocoBase(transformer=transformer_model, initial_state=initial_state)
MAX_ITERATIONS = 5
TOLERANCE = 0.01 # Convergence threshold (e.g., change in loss/output vector)
print("\n--- Example 2: Iterative Reasoning ---")
current_state = initial_state
history = []
for i in range(MAX_ITERATIONS):
# Perform the reasoning step
next_state, refinement_metric = yoco_decoder.step(
current_state=current_state
)
history.append(refinement_metric)
current_state = next_state
print(f"Iteration {i+1}: Metric = {refinement_metric:.4f}")
# Check for convergence
if i > 0 and abs(refinement_metric - history[-2]) < TOLERANCE:
print(f"Converged at iteration {i+1}.")
break
print(f"Final State Context: {current_state.context}")
Example 3: Customizing Attention Mechanisms¶
If you need to inject specialized attention logic (e.g., a memory-gated attention layer) during the reasoning process, you can subclass YocoBase and override the internal attention call, utilizing the types.ReasoningState to pass necessary context.
from universal_yoco.yoco_base import YocoBase
from universal_yoco.types import ReasoningState
# Assume a custom attention module exists
from .custom_attention import MemoryGatedAttention
class CustomYocoDecoder(YocoBase):
def __init__(self, transformer, initial_state):
super().__init__(transformer, initial_state)
# Replace the standard attention mechanism with our custom one
self.custom_attention = MemoryGatedAttention()
def step(self, current_state: ReasoningState, input_ids=None, attention_mask=None):
# 1. Get the current context embedding
context_embedding = self._embed_context(current_state.context)
# 2. Use the custom attention layer instead of the default
attended_output = self.custom_attention(
query=context_embedding,
memory=current_state.memory_bank
)
# 3. Pass the specialized output to the transformer block
# (This part would integrate deeply with the original YocoBase logic)
refined_output = self.transformer.forward(attended_output, attention_mask)
# 4. Update state and return
new_state = self._update_state(current_state, refined_output)
return new_state, 0.5 # Mock metric
# Usage:
# custom_decoder = CustomYocoDecoder(transformer_model, initial_state)
# final_state, metric = custom_decoder.step()