""" Train V(s) cumulatively across a schedule of checkpoints, dump a 3D surface snapshot at each, then stitch them into a GIF. SCHEDULE controls steps, default is to show 1K to 50K every 1K episodes Usage: python sweep.py python sweep.py --frame-ms 200 --hold-ms 3000 """ import argparse import os import random import numpy as np import matplotlib.pyplot as plt from PIL import Image from core import train, StateValueFunction PLAYER_TOTALS = np.arange(12, 22) DEALER_CARDS = np.arange(1, 11) # 1 = Ace # Cumulative episode checkpoints. Each frame snapshots V(s) at that total. SCHEDULE = [500_000] #list(range(1000, 50_001, 1000)) # 1K to 50K w/ frame every 1K OUT_DIR = os.path.join(os.path.dirname(__file__), "plots") def value_grid(V, usable_ace): return np.array( [[V.predict((p, d, usable_ace)) for d in DEALER_CARDS] for p in PLAYER_TOTALS] ) def plot_value(V, n_episodes, out_path): fig = plt.figure(figsize=(12, 5)) X, Y = np.meshgrid(DEALER_CARDS, PLAYER_TOTALS) for i, (label, usable) in enumerate( [("No usable ace", False), ("Usable ace", True)] ): ax = fig.add_subplot(1, 2, i + 1, projection="3d") Z = value_grid(V, usable) ax.plot_surface(X, Y, Z, cmap="viridis", edgecolor="none", alpha=0.9) ax.set_xlabel("Dealer showing") ax.set_ylabel("Player total") ax.set_zlabel("V(s)") ax.set_zlim(-1, 1) ax.set_xticks(DEALER_CARDS) ax.set_xticklabels(["A"] + [str(c) for c in range(2, 11)]) ax.set_yticks(PLAYER_TOTALS) ax.set_title(label) fig.suptitle(f"V(s) after {n_episodes:,} episodes (stick >= 20)") # Fixed margins so every frame renders at the same canvas size for the GIF. fig.subplots_adjust(left=0.05, right=0.95, top=0.88, bottom=0.05, wspace=0.05) fig.savefig(out_path, dpi=120) plt.close(fig) def make_gif(frame_paths, gif_path, frame_ms, hold_ms): frames = [Image.open(p).convert("RGBA") for p in frame_paths] target_size = frames[0].size frames = [f if f.size == target_size else f.resize(target_size) for f in frames] durations = [frame_ms] * (len(frames) - 1) + [hold_ms] frames[0].save( gif_path, save_all=True, append_images=frames[1:], duration=durations, loop=0, disposal=2, ) def main(frame_ms, hold_ms, seed=0): os.makedirs(OUT_DIR, exist_ok=True) if seed is not None: random.seed(seed) V = StateValueFunction() frame_paths = [] trained = 0 for target in SCHEDULE: delta = target - trained train(delta, value_fn=V) # cumulative — V keeps its state across calls trained = target path = os.path.join(OUT_DIR, f"frame_{target:09d}.png") plot_value(V, target, path) frame_paths.append(path) print(f" frame {target:>9,}: {path}") gif_path = os.path.join(OUT_DIR, "sweep.gif") make_gif(frame_paths, gif_path, frame_ms, hold_ms) print(f"\nGIF: {gif_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--frame-ms", type=int, default=200, help="Per-frame duration in ms (default 200).", ) parser.add_argument( "--hold-ms", type=int, default=3000, help="Final-frame hold duration in ms (default 3000).", ) args = parser.parse_args() main(frame_ms=args.frame_ms, hold_ms=args.hold_ms)