mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
ruff
This commit is contained in:
parent
5f5dfdf1a6
commit
1591279b23
5 changed files with 397 additions and 241 deletions
172
lm_game.py
172
lm_game.py
|
|
@ -9,8 +9,7 @@ import json
|
|||
from collections import defaultdict
|
||||
|
||||
# Suppress Gemini/PaLM gRPC warnings
|
||||
os.environ['GRPC_PYTHON_LOG_LEVEL'] = '40' # ERROR level only
|
||||
import google.generativeai as genai # Import after setting log level
|
||||
os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40" # ERROR level only
|
||||
|
||||
from diplomacy import Game
|
||||
from diplomacy.utils.export import to_saved_game_format
|
||||
|
|
@ -21,7 +20,7 @@ from diplomacy.engine.message import Message, GLOBAL
|
|||
# For concurrency:
|
||||
import concurrent.futures
|
||||
|
||||
from lm_service_versus import load_model_client, assign_models_to_powers
|
||||
from ai_diplomacy.clients import load_model_client, assign_models_to_powers
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
|
@ -29,9 +28,10 @@ logger = logging.getLogger(__name__)
|
|||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def gather_possible_orders(game, power_name):
|
||||
"""
|
||||
Returns a dictionary mapping each orderable location to the list of valid orders.
|
||||
|
|
@ -44,6 +44,7 @@ def gather_possible_orders(game, power_name):
|
|||
result[loc] = all_possible.get(loc, [])
|
||||
return result
|
||||
|
||||
|
||||
def conduct_negotiations(game, model_error_stats, max_rounds=10):
|
||||
"""
|
||||
Conducts a round-robin conversation among all non-eliminated powers.
|
||||
|
|
@ -66,11 +67,15 @@ def conduct_negotiations(game, model_error_stats, max_rounds=10):
|
|||
visible_messages = []
|
||||
for msg in conversation_messages:
|
||||
# Include if message is global or if power is sender/recipient
|
||||
if msg['recipient'] == GLOBAL or msg['sender'] == power_name or msg['recipient'] == power_name:
|
||||
if (
|
||||
msg["recipient"] == GLOBAL
|
||||
or msg["sender"] == power_name
|
||||
or msg["recipient"] == power_name
|
||||
):
|
||||
visible_messages.append(
|
||||
f"{msg['sender']} to {msg['recipient']}: {msg['content']}"
|
||||
)
|
||||
|
||||
|
||||
conversation_so_far = "\n".join(visible_messages)
|
||||
|
||||
# Add few-shot example for message format
|
||||
|
|
@ -99,60 +104,69 @@ If you have your plan already figured out, you can just send a public '.' to ind
|
|||
power_name=power_name,
|
||||
conversation_so_far=conversation_so_far + "\n" + few_shot_example,
|
||||
game_phase=game.current_short_phase,
|
||||
phase_summaries=game.phase_summaries
|
||||
phase_summaries=game.phase_summaries,
|
||||
)
|
||||
|
||||
if new_message:
|
||||
try:
|
||||
# Parse the JSON response
|
||||
# Find the JSON block between curly braces
|
||||
json_match = re.search(r'\{[^}]+\}', new_message)
|
||||
json_match = re.search(r"\{[^}]+\}", new_message)
|
||||
if json_match:
|
||||
message_data = json.loads(json_match.group(0))
|
||||
|
||||
|
||||
# Extract message details
|
||||
message_type = message_data.get('message_type', 'global')
|
||||
content = message_data.get('content', '').strip()
|
||||
recipient = message_data.get('recipient', GLOBAL)
|
||||
|
||||
message_type = message_data.get("message_type", "global")
|
||||
content = message_data.get("content", "").strip()
|
||||
recipient = message_data.get("recipient", GLOBAL)
|
||||
|
||||
# Validate recipient if private message
|
||||
if message_type == 'private' and recipient not in active_powers:
|
||||
logger.warning(f"Invalid recipient {recipient} for private message, defaulting to GLOBAL")
|
||||
if message_type == "private" and recipient not in active_powers:
|
||||
logger.warning(
|
||||
f"Invalid recipient {recipient} for private message, defaulting to GLOBAL"
|
||||
)
|
||||
recipient = GLOBAL
|
||||
|
||||
|
||||
# For private messages, ensure recipient is specified
|
||||
if message_type == 'private' and recipient == GLOBAL:
|
||||
logger.warning("Private message without recipient specified, defaulting to GLOBAL")
|
||||
|
||||
if message_type == "private" and recipient == GLOBAL:
|
||||
logger.warning(
|
||||
"Private message without recipient specified, defaulting to GLOBAL"
|
||||
)
|
||||
|
||||
# Log for debugging
|
||||
logger.info(f"Power {power_name} sends {message_type} message to {recipient}")
|
||||
|
||||
logger.info(
|
||||
f"Power {power_name} sends {message_type} message to {recipient}"
|
||||
)
|
||||
|
||||
# Keep local record for building future conversation context
|
||||
conversation_messages.append({
|
||||
"sender": power_name,
|
||||
"recipient": recipient,
|
||||
"content": content
|
||||
})
|
||||
conversation_messages.append(
|
||||
{
|
||||
"sender": power_name,
|
||||
"recipient": recipient,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
# Create an official message in the Diplomacy engine
|
||||
diplo_message = Message(
|
||||
phase=game.current_short_phase,
|
||||
sender=power_name,
|
||||
recipient=recipient,
|
||||
message=content
|
||||
message=content,
|
||||
)
|
||||
game.add_message(diplo_message)
|
||||
|
||||
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.error(f"Failed to parse message from {power_name}: {e}")
|
||||
# Increment conversation parse error
|
||||
model_id = game.power_model_map.get(power_name, "unknown")
|
||||
model_error_stats[model_id]["conversation_errors"] += 1
|
||||
continue
|
||||
|
||||
logger.info("Negotiation phase complete.")
|
||||
return conversation_messages
|
||||
|
||||
|
||||
def my_summary_callback(system_prompt, user_prompt):
|
||||
# e.g., route to your desired model:
|
||||
client = load_model_client("o3-mini")
|
||||
|
|
@ -160,15 +174,18 @@ def my_summary_callback(system_prompt, user_prompt):
|
|||
# Pseudo-code for generating a response:
|
||||
return client.generate_response(combined_prompt)
|
||||
|
||||
def get_valid_orders_with_retry(game,
|
||||
client,
|
||||
board_state,
|
||||
power_name,
|
||||
possible_orders,
|
||||
conversation_text_for_orders,
|
||||
phase_summaries,
|
||||
model_error_stats,
|
||||
max_retries=3):
|
||||
|
||||
def get_valid_orders_with_retry(
|
||||
game,
|
||||
client,
|
||||
board_state,
|
||||
power_name,
|
||||
possible_orders,
|
||||
conversation_text_for_orders,
|
||||
phase_summaries,
|
||||
model_error_stats,
|
||||
max_retries=3,
|
||||
):
|
||||
"""
|
||||
Tries up to 'max_retries' to generate and validate orders.
|
||||
If invalid, we append the error feedback to the conversation
|
||||
|
|
@ -190,11 +207,11 @@ def get_valid_orders_with_retry(game,
|
|||
possible_orders=possible_orders,
|
||||
conversation_text=augmented_conversation_text,
|
||||
phase_summaries=phase_summaries,
|
||||
model_error_stats=model_error_stats
|
||||
model_error_stats=model_error_stats,
|
||||
)
|
||||
|
||||
print(f'orders: {orders}')
|
||||
|
||||
|
||||
print(f"orders: {orders}")
|
||||
|
||||
# Validate each order
|
||||
invalid_info = []
|
||||
for move in orders:
|
||||
|
|
@ -206,13 +223,15 @@ def get_valid_orders_with_retry(game,
|
|||
)
|
||||
continue
|
||||
unit = " ".join(tokens[:2]) # e.g. "A PAR"
|
||||
order_part = tokens[2] # e.g. "H" or "S A MAR"
|
||||
order_part = tokens[2] # e.g. "H" or "S A MAR"
|
||||
|
||||
# Use the internal game validation method
|
||||
if order_part == 'B':
|
||||
validity = 1 # hack because game._valid_order doesn't support 'B'
|
||||
else:
|
||||
validity = game._valid_order(game.powers[power_name], unit, order_part, report=1)
|
||||
if order_part == "B":
|
||||
validity = 1 # hack because game._valid_order doesn't support 'B'
|
||||
else:
|
||||
validity = game._valid_order(
|
||||
game.powers[power_name], unit, order_part, report=1
|
||||
)
|
||||
if validity != 1:
|
||||
invalid_info.append(
|
||||
f"Order '{move}' returned validity={validity}. (None/-1=invalid, 0=partial, 1=valid)"
|
||||
|
|
@ -224,7 +243,7 @@ def get_valid_orders_with_retry(game,
|
|||
else:
|
||||
# Build feedback for the next retry
|
||||
error_feedback = (
|
||||
f"Attempt {attempt+1}/{max_retries} had invalid orders:\n"
|
||||
f"Attempt {attempt + 1}/{max_retries} had invalid orders:\n"
|
||||
+ "\n".join(invalid_info)
|
||||
)
|
||||
|
||||
|
|
@ -235,17 +254,21 @@ def get_valid_orders_with_retry(game,
|
|||
fallback = client.fallback_orders(possible_orders)
|
||||
return fallback
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!")
|
||||
logger.info(
|
||||
"Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!"
|
||||
)
|
||||
start_whole = time.time()
|
||||
|
||||
from collections import defaultdict
|
||||
model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0})
|
||||
model_error_stats = defaultdict(
|
||||
lambda: {"conversation_errors": 0, "order_decoding_errors": 0}
|
||||
)
|
||||
|
||||
# Create a fresh Diplomacy game
|
||||
game = Game()
|
||||
# Ensure game has phase_summaries = {}
|
||||
if not hasattr(game, 'phase_summaries'):
|
||||
if not hasattr(game, "phase_summaries"):
|
||||
game.phase_summaries = {}
|
||||
|
||||
# For storing results in a unique subfolder
|
||||
|
|
@ -260,12 +283,14 @@ def main():
|
|||
stats_file_path = f"{result_folder}/error_stats.json"
|
||||
|
||||
game.power_model_map = assign_models_to_powers()
|
||||
max_year = 1901
|
||||
max_year = 1910
|
||||
|
||||
while not game.is_game_done:
|
||||
phase_start = time.time()
|
||||
current_phase = game.get_current_phase()
|
||||
logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)")
|
||||
logger.info(
|
||||
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
|
||||
)
|
||||
|
||||
# DEBUG: Print the short phase to confirm
|
||||
logger.info(f"DEBUG: current_short_phase is '{game.current_short_phase}'")
|
||||
|
|
@ -280,24 +305,31 @@ def main():
|
|||
# Use endswith("M") for movement phases (like F1901M, S1902M)
|
||||
if game.current_short_phase.endswith("M"):
|
||||
logger.info("Starting negotiation phase block...")
|
||||
conversation_messages = conduct_negotiations(game, model_error_stats, max_rounds=10)
|
||||
conversation_messages = conduct_negotiations(
|
||||
game, model_error_stats, max_rounds=10
|
||||
)
|
||||
else:
|
||||
# If we have no conversation_messages in phases that are not Movement (e.g. Retreat/Build)
|
||||
conversation_messages = []
|
||||
|
||||
conversation_text_for_orders = "\n".join([
|
||||
f"{msg['sender']} to {msg['recipient']}: {msg['content']}"
|
||||
for msg in conversation_messages
|
||||
])
|
||||
conversation_text_for_orders = "\n".join(
|
||||
[
|
||||
f"{msg['sender']} to {msg['recipient']}: {msg['content']}"
|
||||
for msg in conversation_messages
|
||||
]
|
||||
)
|
||||
|
||||
# Gather orders from each power concurrently
|
||||
active_powers = [
|
||||
(p_name, p_obj) for p_name, p_obj in game.powers.items()
|
||||
(p_name, p_obj)
|
||||
for p_name, p_obj in game.powers.items()
|
||||
if not p_obj.is_eliminated()
|
||||
]
|
||||
|
||||
# Then proceed with concurrent order generation
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
) as executor:
|
||||
futures = {}
|
||||
for power_name, _ in active_powers:
|
||||
model_id = game.power_model_map.get(power_name, "o3-mini")
|
||||
|
|
@ -319,10 +351,12 @@ def main():
|
|||
conversation_text_for_orders, # existing conversation text
|
||||
game.phase_summaries,
|
||||
model_error_stats,
|
||||
3 # max_retries
|
||||
3, # max_retries
|
||||
)
|
||||
futures[future] = power_name
|
||||
logger.debug(f"Submitted get_valid_orders_with_retry task for {power_name}.")
|
||||
logger.debug(
|
||||
f"Submitted get_valid_orders_with_retry task for {power_name}."
|
||||
)
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
p_name = futures[future]
|
||||
|
|
@ -331,7 +365,9 @@ def main():
|
|||
logger.debug(f"Validated orders for {p_name}: {orders}")
|
||||
if orders:
|
||||
game.set_orders(p_name, orders)
|
||||
logger.debug(f"Set orders for {p_name} in {game.current_short_phase}: {orders}")
|
||||
logger.debug(
|
||||
f"Set orders for {p_name} in {game.current_short_phase}: {orders}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No valid orders returned for {p_name}.")
|
||||
except Exception as exc:
|
||||
|
|
@ -346,7 +382,9 @@ def main():
|
|||
|
||||
# Print in pretty ASCII format
|
||||
border = "=" * 80
|
||||
logger.info(f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}")
|
||||
logger.info(
|
||||
f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}"
|
||||
)
|
||||
|
||||
# Write to unique game_manifesto in the timestamped folder
|
||||
with open(manifesto_path, "a") as f:
|
||||
|
|
@ -369,16 +407,18 @@ def main():
|
|||
to_saved_game_format(game, output_path=output_path)
|
||||
else:
|
||||
logger.info("Game file already exists, saving with unique filename.")
|
||||
output_path = f'{output_path}_{time.strftime("%Y%m%d_%H%M%S")}.json'
|
||||
output_path = f"{output_path}_{time.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
to_saved_game_format(game, output_path=output_path)
|
||||
|
||||
# Dump our error stats to JSON
|
||||
import json
|
||||
|
||||
with open(stats_file_path, "w") as stats_f:
|
||||
json.dump(model_error_stats, stats_f, indent=2)
|
||||
|
||||
logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}")
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue