Install
需要使用pip安装:
pip install torch safetensorsSave Model
使用PyTorch进行训练:
import safetensors.torch
model = Model()
train(model, **kwargs)
# save weight
safetensors.torch.save_model(model, "path/to/model.safetensors")Load Model
如果这个.safetensors文件对应的模型为Model, 那么可以:
import safetensors.torch
model = Model()
safetensors.torch.load_model(model, "path/to/model.safetensors")Inspect
如果不想将模型权重赋值给一个Model, 只希望查看这个模型内部的key, 那么可以使用下面的方法:
from safetensors.torch import load_file
# 你的 safetensors 文件路径
file_path = "path/to/your/model.safetensors"
# 加载整个文件到内存中,返回一个 state_dict
# device="cpu" 确保张量加载到 CPU 上,避免 GPU 显存溢出
state_dict = load_file(file_path, device="cpu")
# state_dict 就是一个字典,可以直接调用 .keys()
all_keys = state_dict.keys()
# 打印所有 keys
print(f"Found {len(all_keys)} tensors.")
for key in all_keys:
print(key)
# 打印第一个 tensor 的形状作为示例
first_key = list(all_keys)[0]
print(f"\nShape of '{first_key}': {state_dict[first_key].shape}")