import os, subprocess, sys
import torch, torch.nn as nn, torch.optim as optim
JOB_NAME = os.environ.get("KONDUKTOR_JOB_NAME", "job")
ATTEMPT = int(os.environ.get("RESTART_ATTEMPT", "0"))
IDX = os.environ.get("JOB_COMPLETION_INDEX", "0")
BUCKET = "my-konduktor-bucket"
PREFIX = f"checkpoints/{JOB_NAME}"
REMOTE = f"s3://{BUCKET}/{PREFIX}/idx_{IDX}"
TMP = "/tmp/ckpt.pt"
RESUME = "/tmp/resume.pt"
def sh(cmd): subprocess.check_call(cmd, shell=True)
# Upload checkpoints to S3
def upload(step: int):
name = f"step_{step:06d}.pt"
sh(f"aws s3 cp {TMP} {REMOTE}/{name} --only-show-errors") # versioned
sh(f"aws s3 cp {TMP} {REMOTE}/latest.pt --only-show-errors") # stable pointer
print(f"Saved {name} and latest.pt")
# Resume from latest checkpoint if available
def try_resume():
try:
sh(f"aws s3 cp {REMOTE}/latest.pt {RESUME} --only-show-errors")
ckpt = torch.load(RESUME, map_location="cpu")
model.load_state_dict(ckpt["model"])
opt.load_state_dict(ckpt["opt"])
start = int(ckpt.get("step", 0)) + 1
print(f"Resumed from latest.pt @ step {start}")
return start
except subprocess.CalledProcessError:
# nothing remote yet; start fresh
return 0
# tiny model
torch.manual_seed(0)
model = nn.Sequential(nn.Linear(10, 1))
opt = optim.SGD(model.parameters(), lr=0.1)
loss = nn.MSELoss()
print(f"ATTEMPT={ATTEMPT} REMOTE={REMOTE}")
start = try_resume() if ATTEMPT > 0 else 0
for step in range(start, 501):
x, y = torch.randn(64,10), torch.randn(64,1)
opt.zero_grad(); out = model(x); l = loss(out, y); l.backward(); opt.step()
if step % 20 == 0: print(f"[{step}] loss={l.item():.4f}")
# save every 100 steps
if step > 0 and step % 100 == 0:
torch.save({"model": model.state_dict(), "opt": opt.state_dict(), "step": step}, TMP)
upload(step)
# Force a crash at the first attempt after we have a good checkpoint
if ATTEMPT == 0 and step == 200:
print("Intentionally crashing after saving step_000200.pt (attempt 0).")
sys.exit(1) # non-zero -> pod fails -> Job restarts (max_restarts must be >=1)
print("Training complete")