diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 3a0a9b15..774e2480 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -65,6 +65,8 @@ class ScoredData(BaseModel): messages: Optional[List[List[Dict[str, Any]]]] = ( None # Changed from Message TypedDict to Dict ) + generation_params: Optional[Dict[str, Any]] = None + inference_logprobs: Optional[List[List[float]]] = None overrides: Optional[List[dict]] = None group_overrides: Optional[dict] = None images: Optional[Any] = None @@ -268,6 +270,7 @@ async def get_latest_example(): "scores": [], "advantages": [], "ref_logprobs": [], + "inference_logprobs": [], "messages": [], "images": [], } @@ -282,6 +285,8 @@ async def scored_data(scored_data: ScoredData): "advantages": scored_data.advantages, "ref_logprobs": scored_data.ref_logprobs, "messages": scored_data.messages, + "generation_params": scored_data.generation_params, + "inference_logprobs": scored_data.inference_logprobs, "overrides": scored_data.overrides, "group_overrides": scored_data.group_overrides, "images": scored_data.images, @@ -344,6 +349,8 @@ async def scored_data_list(scored_data_list: List[ScoredData]): "ref_logprobs": scored_data.ref_logprobs, "images": scored_data.images, "messages": scored_data.messages, + "generation_params": scored_data.generation_params, + "inference_logprobs": scored_data.inference_logprobs, "overrides": scored_data.overrides, "group_overrides": scored_data.group_overrides, "env_id": scored_data.env_id, diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 64d2eb04..bf6491f9 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -59,6 +59,8 @@ class ScoredDataGroup(TypedDict): advantages: Optional[List[List[float]]] ref_logprobs: Optional[List[List[float]]] messages: Optional[List[List[Message]]] + generation_params: Optional[Dict[str, Any]] + inference_logprobs: Optional[List[List[float]]] group_overrides: Optional[Dict] overrides: Optional[List[Dict]] images: Optional[Any] diff --git a/atroposlib/tests/test_api_messages_handling.py b/atroposlib/tests/test_api_messages_handling.py index 1d54aaee..0f7ac922 100644 --- a/atroposlib/tests/test_api_messages_handling.py +++ b/atroposlib/tests/test_api_messages_handling.py @@ -140,12 +140,16 @@ class TestAPIMessagesHandling: "messages": [messages], "advantages": [[0.5, 0.5, 0.5, 0.5, 0.5]], "ref_logprobs": [[-0.1, -0.2, -0.3, -0.4, -0.5]], + "generation_params": {"temperature": 0.7}, }, ) if response.status_code != 200: print(f"Error response: {response.text}") assert response.status_code == 200 assert response.json()["status"] == "received" + # batches + latest = requests.get("http://localhost:8000/latest_example").json() + assert latest.get("generation_params", {}).get("temperature") == 0.7 def test_scored_data_list_with_messages(self, api_server): """Test posting a list of scored data with messages.""" @@ -307,20 +311,25 @@ class TestAPIMessagesHandling: ) assert register_response.status_code == 200 - # Post two items to make a full batch for i in range(2): messages = [ {"role": "user", "content": f"Test message {i}", "reward": None}, {"role": "assistant", "content": f"Response {i}", "reward": None}, ] + payload = { + "tokens": [[i * 10 + j for j in range(5)]], + "masks": [[1] * 5], + "scores": [float(i)], + "messages": [messages], + } + if i == 0: + payload["overrides"] = [{"temperature": 0.5}] + else: + payload["generation_params"] = {"temperature": 0.8} + response = requests.post( "http://localhost:8000/scored_data", - json={ - "tokens": [[i * 10 + j for j in range(5)]], - "masks": [[1] * 5], - "scores": [float(i)], - "messages": [messages], - }, + json=payload, ) assert response.status_code == 200 @@ -339,6 +348,13 @@ class TestAPIMessagesHandling: assert len(item["messages"][0]) == 2 assert item["messages"][0][0]["role"] == "user" assert item["messages"][0][1]["role"] == "assistant" + # temp passthroughs + if i == 0: + assert item.get("overrides") is not None + assert item["overrides"][0].get("temperature") == 0.5 + else: + assert item.get("generation_params") is not None + assert item["generation_params"].get("temperature") == 0.8 def test_latest_example_with_messages(self, api_server): """Test that latest example endpoint includes messages.""" @@ -376,6 +392,7 @@ class TestAPIMessagesHandling: "masks": [[1, 1, 1]], "scores": [0.95], "messages": [messages], + "inference_logprobs": [[-1.0, -0.7, -0.2]], }, ) if response.status_code != 200: @@ -390,6 +407,7 @@ class TestAPIMessagesHandling: assert "messages" in latest_data assert latest_data["messages"] == [messages] assert len(latest_data["messages"][0]) == 3 + assert latest_data.get("inference_logprobs") == [[-1.0, -0.7, -0.2]] def test_empty_messages_handling(self, api_server): """Test handling of empty or None messages.""" diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index fb340a2f..70d22c8f 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -128,6 +128,7 @@ def pad_data_to_good_offset(data, batch_size: int): labels = list() advantages = list() lengths = list() + temperatures = list() for item in data["batch"]: scores = item["scores"] scores = np.array(scores) @@ -166,10 +167,21 @@ def pad_data_to_good_offset(data, batch_size: int): input_ids.append(item["tokens"][i][:-1]) labels.append(label_item[1:]) advantages.append(item["scores"][i]) + # per-sample override -> group generation_params -> group_overrides - > 1.0 + # need to update docs since this lets you set the temperature for each sample from the override + t = 1.0 + if item.get("overrides") and i < len(item["overrides"]) and isinstance(item["overrides"][i], dict) and ("temperature" in item["overrides"][i]): + t = float(item["overrides"][i]["temperature"]) + elif item.get("generation_params") and ("temperature" in item["generation_params"]): + t = float(item["generation_params"]["temperature"]) + elif item.get("group_overrides") and ("temperature" in item["group_overrides"]): + t = float(item["group_overrides"]["temperature"]) + temperatures.append(t) # combine all lists into tensors token_batches = [] label_batches = [] advantage_batches = [] + temperature_batches = [] for i in range(len(input_ids) // batch_size): token_batches.append( torch.tensor( @@ -186,12 +198,22 @@ def pad_data_to_good_offset(data, batch_size: int): np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0) ).view(-1, 1) ) - return token_batches, label_batches, advantage_batches + # Temperatures: one per sample, shaped for broadcasting to [B, 1, 1] + temperature_batches.append( + torch.tensor( + np.array( + temperatures[i * batch_size : (i + 1) * batch_size], + dtype=np.float32, + ) + ).view(-1, 1, 1) + ) + + return token_batches, label_batches, advantage_batches, temperature_batches def get_data( batch_size: int, seq_len: int -) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: +) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: """ getting data from the api """ @@ -321,7 +343,7 @@ def train(config: TrainingConfig): total_neg = 0 if len(batches) == 0: batches = get_data(config.batch_size, config.seq_len) - token_batches, label_batches, advantage_batches = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) # Terminate existing vLLM process if running if ( step + 1 @@ -339,8 +361,8 @@ def train(config: TrainingConfig): vllm_process.kill() vllm_process.wait() vllm_process = None - for tokens, labels, advantages in zip( - token_batches, label_batches, advantage_batches + for tokens, labels, advantages, temperatures in zip( + token_batches, label_batches, advantage_batches, temperature_batches ): tokens, labels, advantages = ( @@ -353,6 +375,10 @@ def train(config: TrainingConfig): # User specified that tokens/labels are already prepared by get_data outputs = model(tokens) # Assuming model just needs tokens logits = outputs.logits # Assuming this is the structure + # temp scaled logits before corss entropy (clamp to prevent zero division or just ignore 0 temps?) + t = temperatures.to(logits.device, logits.dtype) + t = torch.where(t <= 0, torch.ones_like(t), t) + logits = logits / t # Calculate GRPO loss (reverting to user's previous logic) # User stated ignore_index is -100 and tokens/labels are aligned by get_data