1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
| import os
import time
import torch
import argparse
def generate_dummy_checkpoint(size_mb):
"""
生成一个指定大小的模拟 checkpoint 文件。
size_mb: 文件大小,以 MB 为单位。
"""
# 1 MB = 1024 * 1024 bytes
size_bytes = size_mb * 1024 * 1024
dummy_data = torch.rand(size_bytes // 4) # 随机生成数据,假设每个浮点数占 4 字节
return {"dummy_data": dummy_data}
def save_checkpoint(checkpoint, save_dir):
"""
将 checkpoint 保存到指定目录,并测量存储速度。
checkpoint: 要保存的 checkpoint 数据。
save_dir: 保存目录。
"""
start_time = time.time()
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "checkpoint.pt")
torch.save(checkpoint, save_path)
end_time = time.time()
# 计算存储速度 (MB/s)
file_size_mb = os.path.getsize(save_path) / (1024 * 1024)
duration = end_time - start_time
speed_mb_per_sec = file_size_mb / duration
return speed_mb_per_sec
def main():
parser = argparse.ArgumentParser(description="Test storage speed for checkpoint saving.")
parser.add_argument("--size_mb", type=int, default=100, help="Size of the dummy checkpoint in MB (default: 100MB)")
parser.add_argument("--save_dir", type=str, default="/3fs/stage/checkpoint", help="Directory to save the checkpoint (default: /3fs/stage/checkpoint)")
args = parser.parse_args()
# 生成模拟的 checkpoint 文件
checkpoint = generate_dummy_checkpoint(args.size_mb)
# 数与存储速度
speed = save_checkpoint(checkpoint, args.save_dir)
print(f"Storage speed: {speed:.2f} MB/s")
if __name__ == "__main__":
main()
|