mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix: add try/finally to guarantee gym environment cleanup
This commit is contained in:
parent
708b42a00f
commit
7e5ddbce06
1 changed files with 48 additions and 47 deletions
|
|
@ -177,60 +177,61 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
|
||||
async with self.server.dedicated_server() as server:
|
||||
for _ in range(self.config.max_episode_turns):
|
||||
if (
|
||||
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
|
||||
> self.config.max_token_length - 50
|
||||
):
|
||||
logger.warning(
|
||||
f"[Seed: {seed}] Max token length reached, truncating episode."
|
||||
)
|
||||
break
|
||||
try:
|
||||
async with self.server.dedicated_server() as server:
|
||||
for _ in range(self.config.max_episode_turns):
|
||||
if (
|
||||
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
|
||||
> self.config.max_token_length - 50
|
||||
):
|
||||
logger.warning(
|
||||
f"[Seed: {seed}] Max token length reached, truncating episode."
|
||||
)
|
||||
break
|
||||
|
||||
max_tokens_for_action = 512
|
||||
max_tokens_for_action = 512
|
||||
|
||||
try:
|
||||
chat_completions = await server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=max_tokens_for_action,
|
||||
temperature=0.5,
|
||||
)
|
||||
llm_action_response = chat_completions.choices[
|
||||
0
|
||||
].message.content.strip()
|
||||
logger.info(
|
||||
f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
||||
break
|
||||
try:
|
||||
chat_completions = await server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=max_tokens_for_action,
|
||||
temperature=0.5,
|
||||
)
|
||||
llm_action_response = chat_completions.choices[
|
||||
0
|
||||
].message.content.strip()
|
||||
logger.info(
|
||||
f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": llm_action_response})
|
||||
messages.append({"role": "assistant", "content": llm_action_response})
|
||||
|
||||
action = self._parse_action_from_llm(llm_action_response)
|
||||
if action is None:
|
||||
logger.warning(
|
||||
f"[Seed: {seed}] Invalid action parsed. Ending episode."
|
||||
)
|
||||
game_reward = -1.0
|
||||
break
|
||||
action = self._parse_action_from_llm(llm_action_response)
|
||||
if action is None:
|
||||
logger.warning(
|
||||
f"[Seed: {seed}] Invalid action parsed. Ending episode."
|
||||
)
|
||||
game_reward = -1.0
|
||||
break
|
||||
|
||||
try:
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
game_reward = float(reward)
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] Error stepping env: {e}")
|
||||
break
|
||||
try:
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
game_reward = float(reward)
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] Error stepping env: {e}")
|
||||
break
|
||||
|
||||
if terminated or truncated:
|
||||
break
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
||||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
|
||||
env.close()
|
||||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
finally:
|
||||
env.close()
|
||||
self.episode_outcomes_buffer.append(game_reward)
|
||||
|
||||
tokenization_result = tokenize_for_trainer(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue