File size: 2,550 Bytes
00d1de8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import json
import subprocess
import time
import shutil

def verify():
    print("Starting verification...")
    
    # Clean up previous logs
    if os.path.exists('logs/scratch'):
        shutil.rmtree('logs/scratch')
    
    # Run training script for a few iterations
    # We use a small model (ff) and cifar10 for speed, with minimal iterations
    cmd = [
        "pixi", "run", "accelerate", "launch", "--cpu", "tasks/image_classification/train_energy.py",
        "--model", "ff",
        "--dataset", "cifar10",
        "--batch_size", "4",
        "--training_iterations", "5", # Run for 5 iterations
        "--track_every", "2", # Track every 2 iterations to ensure we get logs
        "--save_every", "2", # Save every 2 iterations
        "--log_dir", "logs/scratch",
        "--device", "-1" # Use CPU for verification to avoid GPU issues if any
    ]
    
    print(f"Running command: {' '.join(cmd)}")
    try:
        subprocess.run(cmd, check=True, capture_output=True)
    except subprocess.CalledProcessError as e:
        print("Training failed!")
        print(e.stderr.decode())
        return

    print("Training finished. Checking files...")
    
    # Check status.json
    if os.path.exists('logs/scratch/status.json'):
        print("[PASS] status.json exists")
        with open('logs/scratch/status.json', 'r') as f:
            data = json.load(f)
            print(f"  - Iteration: {data.get('iteration')}")
            print(f"  - Train Loss: {data.get('train_loss')}")
    else:
        print("[FAIL] status.json missing")

    # Check artifacts.zip
    if os.path.exists('logs/scratch/artifacts.zip'):
        print("[PASS] artifacts.zip exists")
    else:
        print("[FAIL] artifacts.zip missing")

    # Check plots
    if os.path.exists('logs/scratch/losses.png'):
        print("[PASS] losses.png exists")
    else:
        print("[FAIL] losses.png missing")

    if os.path.exists('logs/scratch/accuracies.png'):
        print("[PASS] accuracies.png exists")
    else:
        print("[FAIL] accuracies.png missing")
        
    # Check index.html content (simple check)
    if os.path.exists('index.html'):
        with open('index.html', 'r') as f:
            content = f.read()
            if 'CTM Training Dashboard' in content and 'status.json' in content:
                print("[PASS] index.html looks correct")
            else:
                print("[FAIL] index.html content incorrect")
    else:
        print("[FAIL] index.html missing")

if __name__ == "__main__":
    verify()