====== paddle统计模型信息 ======
===== 打印paddle的pdparams的基本信息 =====
import paddle
def print_paddle_model(model_path):
state = paddle.load(model_path)
for k, v in state.items():
print(
f"{k:60s} | "
f"dtype={v.dtype} | "
f"shape={list(v.shape)} | "
f"numel={v.numel()}"
)
# print_paddle_model("xxx/model_state.pdparams")
===== 打印paddle的模型的基本信息 =====
假如我们已经加载了一个paddle的module模型信息,如何打印具体的参数信息。
state_dict = model.state_dict()
for name, tensor in state_dict.items():
dtype = getattr(tensor, "dtype", "")
shape = list(tensor.shape) if hasattr(tensor, "shape") else []
numel = tensor.numel() if hasattr(tensor, "numel") else 0
print(f"{name:60s} | dtype={dtype} | shape={shape} | numel={numel}")