mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-29 17:35:09 +00:00
Publish Python SDK
Hello world! Signed-off-by: Daniel Xu <dxu@dxuuu.xyz>
This commit is contained in:
commit
829c151ba7
192 changed files with 25717 additions and 0 deletions
373
tests/mock_api_server.py
Normal file
373
tests/mock_api_server.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
"""
|
||||
Mock API server for the Tinker Python SDK.
|
||||
|
||||
This server provides mock implementations of all Tinker API endpoints for testing purposes.
|
||||
"""
|
||||
|
||||
import random
|
||||
import uuid
|
||||
import traceback
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Import types from tinker
|
||||
from tinker import types
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Tinker Mock API Server")
|
||||
|
||||
# Custom exception handler to log stack traces
|
||||
@app.exception_handler(Exception)
|
||||
async def log_exceptions(request: Request, exc: Exception):
|
||||
"""Log all exceptions with full stack traces."""
|
||||
logger.error(f"Unhandled exception in {request.method} {request.url}")
|
||||
logger.error(f"Exception type: {type(exc).__name__}")
|
||||
logger.error(f"Exception message: {str(exc)}")
|
||||
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": f"Internal server error: {str(exc)}"}
|
||||
)
|
||||
|
||||
|
||||
# Handler for validation errors (422)
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Log validation errors with details."""
|
||||
logger.error(f"Validation error in {request.method} {request.url}")
|
||||
logger.error(f"Validation errors: {exc.errors()}")
|
||||
logger.error(f"Request body: {exc.body}")
|
||||
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"detail": exc.errors(),
|
||||
"body": exc.body,
|
||||
"debug_info": {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"errors": exc.errors()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Handler for Pydantic validation errors
|
||||
@app.exception_handler(ValidationError)
|
||||
async def pydantic_validation_exception_handler(request: Request, exc: ValidationError):
|
||||
"""Log Pydantic validation errors with details."""
|
||||
logger.error(f"Pydantic validation error in {request.method} {request.url}")
|
||||
logger.error(f"Validation errors: {exc.errors()}")
|
||||
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"detail": exc.errors(),
|
||||
"debug_info": {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"errors": exc.errors()
|
||||
}
|
||||
}
|
||||
)
|
||||
# In-memory storage for futures and their results
|
||||
futures_store: Dict[str, Any] = {}
|
||||
|
||||
# In-memory storage for LoRA adapters
|
||||
lora_adapters: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Mock model configurations
|
||||
SUPPORTED_MODELS = [
|
||||
{"model_id": "llama-3-8b", "model_name": "meta-llama/Meta-Llama-3-8B", "arch": "llama"},
|
||||
{"model_id": "llama-3-70b", "model_name": "meta-llama/Meta-Llama-3-70B", "arch": "llama"},
|
||||
{"model_id": "qwen2-72b", "model_name": "Qwen/Qwen2-72B", "arch": "qwen2"},
|
||||
]
|
||||
|
||||
|
||||
def generate_future_id() -> str:
|
||||
"""Generate a unique future ID."""
|
||||
return f"future_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
@app.get("/healthz", response_model=types.HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return types.HealthResponse(status="ok")
|
||||
|
||||
|
||||
@app.get("/get_server_capabilities", response_model=types.GetServerCapabilitiesResponse)
|
||||
async def get_server_capabilities():
|
||||
"""Get server capabilities including supported models."""
|
||||
supported_models = [
|
||||
{"model_id": model["model_id"], "model_name": model["model_name"], "arch": model["arch"]}
|
||||
for model in SUPPORTED_MODELS
|
||||
]
|
||||
return types.GetServerCapabilitiesResponse(supported_models=supported_models)
|
||||
|
||||
def generate_mock_logprobs(seq_len: int) -> List[float]:
|
||||
"""Generate mock log probabilities for a sequence."""
|
||||
return [random.uniform(-4.0, 0.0) for _ in range(seq_len)]
|
||||
|
||||
def generate_mock_loss() -> float:
|
||||
"""Generate mock loss."""
|
||||
return random.uniform(0.5, 3.0)
|
||||
|
||||
|
||||
def chunk_length(chunk: types.ModelInputChunkParam) -> int:
|
||||
match chunk["type"]:
|
||||
case "encoded_text":
|
||||
return len(chunk["tokens"])
|
||||
case "image_asset_pointer":
|
||||
return chunk["tokens"]
|
||||
case _:
|
||||
raise ValueError(f"Unknown chunk type: {chunk['type']}")
|
||||
|
||||
def sequence_length(model_input: types.ModelInputParam) -> int:
|
||||
return sum(chunk_length(chunk) for chunk in model_input["chunks"])
|
||||
|
||||
|
||||
@app.post("/fwd", response_model=types.UntypedAPIFuture)
|
||||
async def forward(params: types.TrainingForwardParams):
|
||||
"""Perform forward pass."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
result = types.FwdBwdOutput(
|
||||
loss_fn_outputs={
|
||||
"logprobs": [generate_mock_logprobs(sequence_length(datum.input_sequence)) for datum in params.fwdbwd_input.data]
|
||||
},
|
||||
metrics={
|
||||
"loss": generate_mock_loss(),
|
||||
"perplexity": generate_mock_loss(),
|
||||
}
|
||||
)
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return types.UntypedAPIFuture(request_id=future_id, model_id=params.get("model_id"))
|
||||
|
||||
|
||||
@app.post("/fwdbwd", response_model=types.UntypedAPIFuture)
|
||||
async def forward_backward(params: types.TrainingForwardBackwardParams):
|
||||
"""Perform forward and backward pass."""
|
||||
# Since the mock implementation is identical, we can reuse the forward logic
|
||||
# In a real implementation, forward_backward would also compute gradients
|
||||
return await forward(params)
|
||||
|
||||
|
||||
@app.post("/optim_step", response_model=types.UntypedAPIFuture)
|
||||
async def optim_step(params: types.TrainingOptimStepParams):
|
||||
"""Perform optimization step."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
# Mock optimization step result (OptimStepResponse is just a Dict[str, Union[float, str]])
|
||||
result = {
|
||||
"grad_norm": random.uniform(0.1, 10.0),
|
||||
"weight_norm": random.uniform(10.0, 100.0),
|
||||
"update_norm": random.uniform(0.001, 0.1),
|
||||
}
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return UntypedAPIFuture(request_id=future_id, model_id=params.get("model_id"))
|
||||
|
||||
@app.post("/retrieve_future")
|
||||
async def retrieve_future(params: types.FutureRetrieveParams):
|
||||
"""Retrieve the result of a future."""
|
||||
future_id = params["request_id"]
|
||||
|
||||
if future_id not in futures_store:
|
||||
raise HTTPException(status_code=404, detail=f"Future {future_id} not found")
|
||||
|
||||
future_data = futures_store[future_id]
|
||||
result = future_data["result"]
|
||||
|
||||
# Handle different result types explicitly
|
||||
if isinstance(result, (types.FwdBwdOutput, types.AddLoraResponse, types.UnloadModelResponse,
|
||||
types.LoadWeightsResponse, types.SaveWeightsResponse,
|
||||
types.SaveWeightsForSamplerResponse)):
|
||||
serialized_result = result.model_dump()
|
||||
print(f"RETRIEVE_FUTURE: Returning Pydantic model result: {serialized_result}")
|
||||
return serialized_result
|
||||
else:
|
||||
# For dict results (like OptimStepResponse)
|
||||
print(f"RETRIEVE_FUTURE: Returning dict result: {result}")
|
||||
return result
|
||||
|
||||
@app.post("/add_lora", response_model=types.UntypedAPIFuture)
|
||||
async def add_lora(params: types.LoraAddParams):
|
||||
"""Add a LoRA adapter to the model."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
# Generate new model_id with LoRA
|
||||
base_model = params["base_model"]
|
||||
lora_model_id = f"{base_model}_lora_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Store LoRA configuration
|
||||
if base_model not in lora_adapters:
|
||||
lora_adapters[base_model] = {}
|
||||
|
||||
lora_adapters[base_model][lora_model_id] = {
|
||||
"rank": params.get("rank", 16),
|
||||
"alpha": params.get("alpha", 32),
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Create the result that will be retrieved later
|
||||
result = types.AddLoraResponse(model_id=lora_model_id)
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
return types.UntypedAPIFuture(
|
||||
request_id=future_id,
|
||||
model_id=lora_model_id
|
||||
)
|
||||
|
||||
|
||||
@app.post("/remove_lora", response_model=types.UntypedAPIFuture)
|
||||
async def remove_lora(params: types.LoraRemoveParams):
|
||||
"""Remove a LoRA adapter from the model."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
model_id = params["model_id"]
|
||||
|
||||
# Check if this is a LoRA model
|
||||
assert "_lora_" in model_id, f"Model {model_id} is not a LoRA model"
|
||||
|
||||
# Remove from our tracking
|
||||
base_model_id = model_id.split("_lora_")[0]
|
||||
if base_model_id in lora_adapters and model_id in lora_adapters[base_model_id]:
|
||||
del lora_adapters[base_model_id][model_id]
|
||||
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
return types.UntypedAPIFuture(
|
||||
request_id=future_id,
|
||||
model_id=params.get("model_id")
|
||||
)
|
||||
|
||||
|
||||
@app.post("/load_weights", response_model=types.UntypedAPIFuture)
|
||||
async def load_weights(params: types.WeightLoadParams):
|
||||
"""Load model weights from a path."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
# Mock implementation - in reality this would load weights from storage
|
||||
result = LoadWeightsResponse(message=f"Weights loaded from {params['path']}", success=True)
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return UntypedAPIFuture(request_id=future_id, model_id=params.get("model_id"))
|
||||
|
||||
|
||||
@app.post("/save_weights", response_model=types.UntypedAPIFuture)
|
||||
async def save_weights(params: types.WeightSaveParams):
|
||||
"""Save model weights to a path."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
# Mock implementation - in reality this would save weights to storage
|
||||
save_path = (
|
||||
f"{params.get('path', '/tmp')}/checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
result = SaveWeightsResponse(message=f"Weights saved to {save_path}", success=True)
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return UntypedAPIFuture(request_id=future_id, model_id=params.get("model_id"))
|
||||
|
||||
|
||||
@app.post("/save_weights_for_sampler", response_model=types.UntypedAPIFuture)
|
||||
async def save_weights_for_sampler(params: types.WeightSaveForSamplerParams):
|
||||
"""Save weights in a format suitable for the sampler."""
|
||||
future_id = generate_future_id()
|
||||
|
||||
# Mock implementation
|
||||
save_path = (
|
||||
f"{params.get('path', '/tmp')}/sampler_weights_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
result = types.SaveWeightsForSamplerResponse(
|
||||
message=f"Sampler weights saved to {save_path}", success=True
|
||||
)
|
||||
|
||||
# Store the result for future retrieval
|
||||
futures_store[future_id] = {
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return types.UntypedAPIFuture(request_id=future_id, model_id=params.get("model_id"))
|
||||
|
||||
|
||||
@app.post("/get_info", response_model=types.GetInfoResponse)
|
||||
async def get_info(params: types.ModelGetInfoParams):
|
||||
"""Get information about a model."""
|
||||
model_id = params["model_id"]
|
||||
|
||||
# Find the model in our supported models or check if it's a LoRA model
|
||||
model_info = None
|
||||
if "_lora_" in model_id:
|
||||
# Extract base model ID from LoRA model
|
||||
base_model_id = model_id.split("_lora_")[0]
|
||||
model_info = next((m for m in SUPPORTED_MODELS if m["model_id"] == base_model_id), None)
|
||||
else:
|
||||
model_info = next((m for m in SUPPORTED_MODELS if m["model_id"] == model_id), None)
|
||||
|
||||
if not model_info:
|
||||
# Default model info for unknown models
|
||||
model_info = {"model_name": f"unknown/{model_id}", "arch": "unknown"}
|
||||
|
||||
return types.GetInfoResponse(
|
||||
model_data={"model_name": model_info["model_name"], "arch": model_info["arch"]},
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Loading…
Add table
Add a link
Reference in a new issue