这是本文档旧的修订版!
2-paddle模型缓存相关细节-aistudio-模型仓库
paddle在使用 模型时,使用了aistudio的sdk,具体模型缓存在:
/home/ctbots/.cache/aistudio/hub/models/[项目id坐标] 下
如何避免paddle重复下载大文件,如何不改代码骗过aistudio的sdk
我们在本地处理预训练模型的时候,如果处理完毕,可以上传到paddle的 aistudio上,并获取项目坐标,例如: ctbots/first-paddle-model
但是在代码中,直接使用 ctbots/first-paddle-model 模型的时候,会发现 paddle框架会自动下载模型到本地cache文件中 /home/ctbots/.cache/aistudio/hub/models/[项目id坐标] 下
几百兆的文件还好,对于几十G的文件,实在不想重新下载;而且每次执行,paddle都会对比 远程仓库文件的sha值和本地的是否一致,如果不一致直接重新下载。
我们灵机一动,既然模型在本地push上去的,那么我直接复制到cache目录就行了,实际不行,还有sha值对比的环节。
我们这里给出一段代码,模拟sha值的cache逻辑在本地做一份欺骗的缓存,在不修改paddle代码的情况下,避免前期预训练调整模型阶段,重复下载。
import hashlib
import os
import pickle
from urllib.parse import quote
import requests
def compute_hash(file_path):
buffer_size = 1024 * 64 # 64k buffer size
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
while True:
data = f.read(buffer_size)
if not data:
break
sha256_hash.update(data)
return sha256_hash.hexdigest()
def mk_aistudio_cache_file(project_id: str, local_dir: str):
if not os.path.exists(local_dir):
return False, "模型目录不存在"
model_meta = {
"id": project_id
}
meta_file_path = os.path.join(local_dir, ".mdl")
with open(meta_file_path, 'wb') as f:
pickle.dump(model_meta, f)
files = [f for f in os.listdir(local_dir) if os.path.isfile(os.path.join(local_dir, f))]
cached_files = []
base_url = "https://git.aistudio.baidu.com/api/v1/repos"
for f in files:
encoded_path = quote(f, safe='')
url = f"{base_url}/{project_id}/contents/{encoded_path}"
try:
response = requests.get(url)
response.raise_for_status()
data = response.json()
commit_id = data.get("sha", "N/A")
cached_files.append({
"Path": f,
"Revision": commit_id
})
except Exception as e:
cached_files.append({
"Path": f,
"Revision": "N/A"
})
cache_keys_file_path = os.path.join(local_dir, ".msc")
with open(cache_keys_file_path, 'wb') as f:
pickle.dump(cached_files, f)
return True, "成功写入cache"
def test_mk_cache():
rs = mk_aistudio_cache_file("ctbots/first-paddle-model", "/home/ctbots/.cache/aistudio/hub/models/ctbots/first-paddle-model")
print(rs)
评论