目录

paddle统计模型信息

打印paddle的pdparams的基本信息

import paddle

def print_paddle_model(model_path):
    state = paddle.load(model_path)

    for k, v in state.items():
        print(
            f"{k:60s} | "
            f"dtype={v.dtype} | "
            f"shape={list(v.shape)} | "
            f"numel={v.numel()}"
        )

# print_paddle_model("xxx/model_state.pdparams")

打印paddle的模型的基本信息

假如我们已经加载了一个paddle的module模型信息,如何打印具体的参数信息。

state_dict = model.state_dict()
        for name, tensor in state_dict.items():
            dtype = getattr(tensor, "dtype", "<unknown>")
            shape = list(tensor.shape) if hasattr(tensor, "shape") else []
            numel = tensor.numel() if hasattr(tensor, "numel") else 0
            print(f"{name:60s} | dtype={dtype} | shape={shape} | numel={numel}")