- 进入FlyAI预训练模型地址
- 找到需要的keras模型,相应链接后确定
3.得到复制后的内容
# 必须使用该方法下载模型,然后加载
from flyai.utils import remote_helper
path = remote_helper.get_remote_date("https://www.flyai.com/m/v0.8|NASNet-mobile.h5")
直接在本地运行会出错。需要修改以下两个文件:
- remote_helper.py。把
os.path.join(sys.path[0], 'data', 'input', 'model')
改成自己的地址(保存模型的地址)
- download.py中的地址改成自己的地址
- 下载模型放到本地的
.keras
文件夹下,即可在任意位置直接调用。
附:remote_helper.py
代码
import sysimport hashlib
import json
import os
import platform
import uuid
from os.path import joinfrom flyai.processor.download import download_model__DOMAIN = "https://www.flyai.com"def __genearteMD5(str):hl = hashlib.md5()hl.update(str.encode(encoding='utf-8'))return hl.hexdigest()def __check_dir(str):if " " in str:return Falsereturn all(ord(c) < 128 for c in str)def __get_home_path():sys = platform.system()if sys == "Windows":if not __check_dir(os.environ['HOMEPATH']):path = join("C://", '.flyai', "")else:path = join(os.environ['HOMEPATH'], '.flyai', "")else:path = join(os.environ['HOME'], '.flyai', "")if not os.path.exists(path):os.makedirs(path)return pathdef __get_mac():try:address = hex(uuid.getnode())[2:]return '-'.join(address[i:i + 2] for i in range(0, len(address), 2))except:return "unknown"def __get_token():GOOS = platform.system()if GOOS == "Windows":file_path = os.path.join(os.environ['HOMEPATH'], '.flyai_flyai')else:file_path = os.path.join(os.environ['HOME'], '.flyai_flyai')if os.path.exists(file_path):file = open(file_path, 'r')token = file.read()return tokenelse:file_path = join(__get_home_path(), "." + __genearteMD5(__get_mac() + __DOMAIN))if os.path.exists(file_path):file = open(file_path)login_data = json.loads(file.read())return login_data['token']else:file_path = os.path.join(sys.path[0], 'train.json')if os.path.exists(os.path.join(sys.path[0], 'train.json')):file = open(file_path)login_data = json.loads(file.read())return login_data['token']def get_remote_date(remote_name):if "http" in remote_name:token = __get_token()if token is not None:return download_model(remote_name + "?token=" + __get_token(),os.path.join(sys.path[0], 'data', 'input', 'model'), is_print=True)else:return None