完了~
import base64
from copy import deepcopy
import cloudpickle
import numpy as np
import os
import os.path as osp
import psutil
import string
import subprocess
from subprocess import CalledProcessError
import sys
from textwrap import dedent
import time
import zlib
# 导入待执行的函数
from spinup_utils.mpi_tools import mpi_fork
from spinup_utils.tune_func import func
DIV_LINE_WIDTH = 80
def call_experiment(thunk, params_dict_list, **kwargs):
"""
:params_dict thunk:待启动的函数
:params_dict params_dict:批量参数名
:params kwargs: 其他的一些没考虑到的参数~用处不大,没事儿最好别写这个,容易造成混乱~
正常的函数,传入参数之后,就会直接执行。
但是通过这个神奇的lambda,就可以即把参数传进去,又不执行。返回出一个函数
再次调用的时候,只需要将返回值,加上括号,即当一个无参数传入的函数执行就可以了。
"""
def thunk_plus():
# Fork into multiple processes
mpi_fork(4)
# Run thunk
thunk(params_dict_list)
# lambda封装会让tune_func.py中导入MPI模块报初始化错误。
# thunk_plus = lambda: thunk(params_dict)
# mpi_fork(len(params_dict))
pickled_thunk = cloudpickle.dumps(thunk_plus)
encoded_thunk = base64.b64encode(zlib.compress(pickled_thunk)).decode('utf-8')
# 当前脚本和entry_point.py的路径要在一起,要不然下面的语句要改。
entrypoint = osp.join(osp.abspath(osp.dirname(__file__)), 'run_entrypoint.py')
# subprocess的输入就是一个字符串列表,正常在命令行,该怎么输入,这个就该怎么写。
cmd = [sys.executable if sys.executable else 'python', entrypoint, encoded_thunk]
print("tune_exps_pid:", os.getpid())
try:
subprocess.check_call(cmd, env=os.environ)
except CalledProcessError:
err_msg = '\n'*3 + '='*DIV_LINE_WIDTH + '\n' + dedent("""
Check the traceback above to see what actually went wrong.
""") + '='*DIV_LINE_WIDTH + '\n'*3
print(err_msg)
raise
if __name__ == '__main__':
cpu_num = 5
params_dict = {
'lr': [2, 3, 4],
"batch": [10, 20, 30],
"epoch": [9, 8, 7],
}
import itertools
# 将字典变为排列组合列表
params_list = [list(value) for value in itertools.product(*params_dict.values())]
# 将列表列表变为单个文件的字典列表
params_dict_list = [{key: cur_param.pop(0) for key, value in params_dict.items()} for cur_param in params_list]
print(params_dict_list)
# 每次传入cpu_num数个字典。
for i in range(0, len(params_dict_list), cpu_num):
cur_params_dict_list = params_dict_list[i:i+cpu_num]
print("cur_params_dict_list:", cur_params_dict_list)
call_experiment(thunk=func, params_dict_list=cur_params_dict_list)
打印结果:
可以看到主进程ID一直不变,子进程的ID一直更新。参数传输正常,一切OK。
[{'lr': 2, 'batch': 10, 'epoch': 9}, {'lr': 2, 'batch': 10, 'epoch': 8}, {'lr': 2, 'batch': 10, 'epoch': 7}, {'lr': 2, 'batch': 20, 'epoch': 9}, {'lr': 2, 'batch': 20, 'epoch': 8}, {'lr': 2, 'batch': 20, 'epoch': 7}, {'lr': 2, 'batch': 30, 'epoch': 9}, {'lr': 2, 'batch': 30, 'epoch': 8}, {'lr': 2, 'batch': 30, 'epoch': 7}, {'lr': 3, 'batch': 10, 'epoch': 9}, {'lr': 3, 'batch': 10, 'epoch': 8}, {'lr': 3, 'batch': 10, 'epoch': 7}, {'lr': 3, 'batch': 20, 'epoch': 9}, {'lr': 3, 'batch': 20, 'epoch': 8}, {'lr': 3, 'batch': 20, 'epoch': 7}, {'lr': 3, 'batch': 30, 'epoch': 9}, {'lr': 3, 'batch': 30, 'epoch': 8}, {'lr': 3, 'batch': 30, 'epoch': 7}, {'lr': 4, 'batch': 10, 'epoch': 9}, {'lr': 4, 'batch': 10, 'epoch': 8}, {'lr': 4, 'batch': 10, 'epoch': 7}, {'lr': 4, 'batch': 20, 'epoch': 9}, {'lr': 4, 'batch': 20, 'epoch': 8}, {'lr': 4, 'batch': 20, 'epoch': 7}, {'lr': 4, 'batch': 30, 'epoch': 9}, {'lr': 4, 'batch': 30, 'epoch': 8}, {'lr': 4, 'batch': 30, 'epoch': 7}]
cur_params_dict_list: [{'lr': 2, 'batch': 10, 'epoch': 9}, {'lr': 2, 'batch': 10, 'epoch': 8}, {'lr': 2, 'batch': 10, 'epoch': 7}, {'lr': 2, 'batch': 20, 'epoch': 9}, {'lr': 2, 'batch': 20, 'epoch': 8}]
tune_exps_pid: 15411
proc_id: 1
params_dict: {'lr': 2, 'batch': 10, 'epoch': 8}
--------------------
proc_id: 0
params_dict: {'lr': 2, 'batch': 10, 'epoch': 9}
--------------------
proc_id: 2
params_dict: {'lr': 2, 'batch': 10, 'epoch': 7}
--------------------
proc_id: 3
params_dict: {'lr': 2, 'batch': 20, 'epoch': 9}
--------------------
proc_id: 0
params_dict: {'lr': 2, 'batch': 10, 'epoch': 9}
--------------------
cur_params_dict_list: [{'lr': 2, 'batch': 20, 'epoch': 7}, {'lr': 2, 'batch': 30, 'epoch': 9}, {'lr': 2, 'batch': 30, 'epoch': 8}, {'lr': 2, 'batch': 30, 'epoch': 7}, {'lr': 3, 'batch': 10, 'epoch': 9}]
tune_exps_pid: 15411
proc_id: 3
params_dict: {'lr': 2, 'batch': 30, 'epoch': 7}
--------------------
proc_id: 2
params_dict: {'lr': 2, 'batch': 30, 'epoch': 8}
--------------------
proc_id: 0
params_dict: {'lr': 2, 'batch': 20, 'epoch': 7}
--------------------
proc_id: 1
params_dict: {'lr': 2, 'batch': 30, 'epoch': 9}
--------------------
proc_id: 0
params_dict: {'lr': 2, 'batch': 20, 'epoch': 7}
--------------------
cur_params_dict_list: [{'lr': 3, 'batch': 10, 'epoch': 8}, {'lr': 3, 'batch': 10, 'epoch': 7}, {'lr': 3, 'batch': 20, 'epoch': 9}, {'lr': 3, 'batch': 20, 'epoch': 8}, {'lr': 3, 'batch': 20, 'epoch': 7}]
tune_exps_pid: 15411
proc_id: 2
params_dict: {'lr': 3, 'batch': 20, 'epoch': 9}
--------------------
proc_id: 3
params_dict: {'lr': 3, 'batch': 20, 'epoch': 8}
--------------------
proc_id: 0
params_dict: {'lr': 3, 'batch': 10, 'epoch': 8}
--------------------
proc_id: 1
params_dict: {'lr': 3, 'batch': 10, 'epoch': 7}
--------------------
proc_id: 0
params_dict: {'lr': 3, 'batch': 10, 'epoch': 8}
--------------------
cur_params_dict_list: [{'lr': 3, 'batch': 30, 'epoch': 9}, {'lr': 3, 'batch': 30, 'epoch': 8}, {'lr': 3, 'batch': 30, 'epoch': 7}, {'lr': 4, 'batch': 10, 'epoch': 9}, {'lr': 4, 'batch': 10, 'epoch': 8}]
tune_exps_pid: 15411
proc_id: 1
params_dict: {'lr': 3, 'batch': 30, 'epoch': 8}
--------------------
proc_id: 0
params_dict: {'lr': 3, 'batch': 30, 'epoch': 9}
--------------------
proc_id: 2
params_dict: {'lr': 3, 'batch': 30, 'epoch': 7}
--------------------
proc_id: 3
params_dict: {'lr': 4, 'batch': 10, 'epoch': 9}
--------------------
proc_id: 0
params_dict: {'lr': 3, 'batch': 30, 'epoch': 9}
--------------------
cur_params_dict_list: [{'lr': 4, 'batch': 10, 'epoch': 7}, {'lr': 4, 'batch': 20, 'epoch': 9}, {'lr': 4, 'batch': 20, 'epoch': 8}, {'lr': 4, 'batch': 20, 'epoch': 7}, {'lr': 4, 'batch': 30, 'epoch': 9}]
tune_exps_pid: 15411
proc_id: 1
params_dict: {'lr': 4, 'batch': 20, 'epoch': 9}
--------------------
proc_id: 3
params_dict: {'lr': 4, 'batch': 20, 'epoch': 7}
--------------------
proc_id: 0
params_dict: {'lr': 4, 'batch': 10, 'epoch': 7}
--------------------
proc_id: 2
params_dict: {'lr': 4, 'batch': 20, 'epoch': 8}
--------------------
proc_id: 0
params_dict: {'lr': 4, 'batch': 10, 'epoch': 7}
--------------------
cur_params_dict_list: [{'lr': 4, 'batch': 30, 'epoch': 8}, {'lr': 4, 'batch': 30, 'epoch': 7}]
tune_exps_pid: 15411
proc_id: 1
params_dict: {'lr': 4, 'batch': 30, 'epoch': 7}
--------------------
proc_id: 0
params_dict: {'lr': 4, 'batch': 30, 'epoch': 8}
--------------------
proc_id: 2
sys.exit()
proc_id: 3
sys.exit()
proc_id: 0
params_dict: {'lr': 4, 'batch': 30, 'epoch': 8}
--------------------
诡异bug:
今天遇到了一个极度诡异的bug;
在虚拟环境1中,我的MPI.COMM_WORLD.Get_rank()
拿到的值一直都是0!
不管mpi_fork()几次,都是0,极度离谱;
debug什么的都无法解决,搞得我以为我电脑坏了,因为上午还是好好的;
后来我想起来,可能是我的虚拟环境换了?
换回了另一个虚拟环境2,终于跑成功了,这样的话,就明白了,是我的虚拟环境1坏掉了,我也找不到bug在哪儿,他也不报错,艹,不报错的bug是最恶心的;
我只是看了一下mpi4py的版本,虚拟环境1 的是3.0.3,虚拟环境2的是3.0.2,应该不是这个的原因吧?
联系方式:
ps: 欢迎做强化的同学加群一起学习:
深度强化学习-DRL:799378128
欢迎关注知乎帐号:未入门的炼丹学徒
CSDN帐号:https://blog.csdn.net/hehedadaq
极简spinup+HER+PER代码实现,两小时之内配置完毕:https://github.com/kaixindelele/DRLib