mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-29 17:35:18 +00:00
fixed system prompt for summary , made improvements and debugging for summaries too. Much can be optimized still
This commit is contained in:
parent
72327cfb22
commit
8f61ba06b3
4 changed files with 150 additions and 55 deletions
73
lm_game.py
73
lm_game.py
|
|
@ -34,7 +34,7 @@ logging.basicConfig(
|
|||
|
||||
def my_summary_callback(system_prompt, user_prompt, model_name):
|
||||
# Route to the desired model specified by the command-line argument
|
||||
client = load_model_client(model_name)
|
||||
client = load_model_client(model_name, emptysystem=True)
|
||||
combined_prompt = f"{system_prompt}\n\n{user_prompt}"
|
||||
# Pseudo-code for generating a response:
|
||||
return client.generate_response(combined_prompt)
|
||||
|
|
@ -47,7 +47,7 @@ def parse_arguments():
|
|||
parser.add_argument(
|
||||
"--max_year",
|
||||
type=int,
|
||||
default=1910,
|
||||
default=1905,
|
||||
help="Maximum year to simulate. The game will stop once this year is reached.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -59,7 +59,7 @@ def parse_arguments():
|
|||
parser.add_argument(
|
||||
"--num_negotiation_rounds",
|
||||
type=int,
|
||||
default=5,
|
||||
default=1,
|
||||
help="Number of negotiation rounds per phase.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -80,6 +80,42 @@ def parse_arguments():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False):
|
||||
"""
|
||||
Save the current game state and related information
|
||||
|
||||
Args:
|
||||
game: The diplomacy game instance
|
||||
result_folder: Path to the results folder
|
||||
game_file_path: Base path for the game file
|
||||
model_error_stats: Dictionary containing model error statistics
|
||||
args: Command line arguments
|
||||
is_final: Boolean indicating if this is the final save
|
||||
"""
|
||||
# Generate unique filename for periodic saves
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
if not is_final:
|
||||
output_path = f"{game_file_path}_checkpoint_{timestamp}.json"
|
||||
else:
|
||||
output_path = game_file_path
|
||||
# If final file exists, append timestamp
|
||||
if os.path.exists(output_path):
|
||||
logger.info("Game file already exists, saving with unique filename.")
|
||||
output_path = f"{output_path}_{timestamp}.json"
|
||||
|
||||
# Save game state
|
||||
to_saved_game_format(game, output_path=output_path)
|
||||
|
||||
# Save overview data
|
||||
overview_file_path = f"{result_folder}/overview.jsonl"
|
||||
with open(overview_file_path, "w") as overview_file:
|
||||
overview_file.write(json.dumps(model_error_stats) + "\n")
|
||||
overview_file.write(json.dumps(game.power_model_map) + "\n")
|
||||
overview_file.write(json.dumps(vars(args)) + "\n")
|
||||
|
||||
logger.info(f"Saved game checkpoint to: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
max_year = args.max_year
|
||||
|
|
@ -135,6 +171,8 @@ def main():
|
|||
else:
|
||||
game.power_model_map = assign_models_to_powers(randomize=True)
|
||||
|
||||
round_counter = 0 # Track number of rounds
|
||||
|
||||
while not game.is_game_done:
|
||||
phase_start = time.time()
|
||||
current_phase = game.get_current_phase()
|
||||
|
|
@ -143,7 +181,7 @@ def main():
|
|||
)
|
||||
|
||||
# DEBUG: Print the short phase to confirm
|
||||
logger.info(f"DEBUG: current_short_phase is '{game.current_short_phase}'")
|
||||
logger.info(f"INFO: The current short phase is '{game.current_short_phase}'")
|
||||
|
||||
# Prevent unbounded simulation based on year
|
||||
year_str = current_phase[1:5]
|
||||
|
|
@ -253,6 +291,14 @@ def main():
|
|||
with open(manifesto_path, "a") as f:
|
||||
f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n")
|
||||
|
||||
# Increment round counter after processing each phase
|
||||
round_counter += 1
|
||||
|
||||
# Save every 5 rounds
|
||||
if round_counter % 5 == 0:
|
||||
logger.info(f"Saving checkpoint after round {round_counter}...")
|
||||
save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False)
|
||||
|
||||
# Check if we've exceeded the max year
|
||||
year_str = current_phase[1:5]
|
||||
year_int = int(year_str)
|
||||
|
|
@ -262,21 +308,10 @@ def main():
|
|||
|
||||
# Save final result
|
||||
duration = time.time() - start_whole
|
||||
logger.info(f"Game ended after {duration:.2f}s. Saving to final JSON...")
|
||||
|
||||
output_path = game_file_path
|
||||
# If the file already exists, append a timestamp to the filename
|
||||
if os.path.exists(output_path):
|
||||
logger.info("Game file already exists, saving with unique filename.")
|
||||
output_path = f"{output_path}_{time.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
to_saved_game_format(game, output_path=output_path)
|
||||
|
||||
# Dump error stats and power model mapping to the overview file
|
||||
with open(overview_file_path, "w") as overview_file:
|
||||
overview_file.write(json.dumps(model_error_stats) + "\n")
|
||||
overview_file.write(json.dumps(game.power_model_map) + "\n")
|
||||
overview_file.write(json.dumps(vars(args)) + "\n")
|
||||
|
||||
logger.info(f"Game ended after {duration:.2f}s. Saving final state...")
|
||||
|
||||
save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=True)
|
||||
|
||||
logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}")
|
||||
logger.info("Done.")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue