python新手,代码不规范之处敬请见谅,第一次发帖排版也不太懂,各位将就看。
参考:
SIR模型和Python实现_阿丢是丢心心的博客-CSDN博客_sir模型
【保姆级教程】使用python实现SIR模型(包含数据集的制作与导入及最终结果的可视化)_自然卷的悲伤的博客-CSDN博客_sir模型python
这篇文章增加了平均次数,并且传播过程与参考文章有所不同,代码注释大家看参考文章的就行,SIR模型在上述参考文章中已经介绍很清楚了,咱就不赘述,直接上代码:
1.用到的包
import random
import networkx as nx
import re
import os
import matplotlib.pyplot as plt
import pandas as pd
import time
2. 将连边文件输入,建立Graph(最好是csv文件)
def readEdgeslistToGraph(edges_filename):edge = []with open(edges_filename, 'r', encoding='utf-8-sig') as f:data = f.readlines()for line in data:line_str = line.replace('\r','').replace('\n','').replace('\t','')line = re.split(',| ',line_str)single_edge = tuple([line[0],line[1]])edge.append(single_edge)G = nx.Graph()G.add_edges_from(edge)return G
注意:是连边文件,或者说边集,比如
3. 初始化网络并随机选择初始“I”态节点
def randomChoiseInfectedNodes(G, initial_infected_nodes_num):infected_nodes = random.sample(list(G.nodes), k=initial_infected_nodes_num)for node in G.nodes():G.nodes[node]["state"] = "S"for node in infected_nodes:G.nodes[node]["state"] = "I"
4.更新节点状态
##更新单个节点状态
def updateNodeState(G, node, infected_rate, recover_rate, old_G):if old_G.nodes[node]["state"] == "I":p = random.random()if p < recover_rate:G.nodes[node]["state"] = "R"elif old_G.nodes[node]["state"] == "S":k = 0for neibor in old_G.adj[node]:if old_G.nodes[neibor]["state"] == "I":k += 1p = random.random()if p < (1 - (1 - infected_rate)**k):G.nodes[node]["state"] = "I"##更新整个网络节点状态
def updateNetworkState(G, infected_rate, recover_rate):old_G = G.copy()for node in G:updateNodeState(G, node, infected_rate, recover_rate, old_G)
注意: 这里与参考文章不一致,参考文章中没有记录传播前的状态,假如某个节点从“I”态转为“R”态,那么这个节点的“S”态邻居转为“I”态的概率就会偏小(参考文章算这个概率的时候,此时已经把这个节点当成“R”态而非“I”态),因此这里引入old_G来记录每次传播前的状态。
5.计算各个态的节点数
def countSIRnum(G):S = 0I = 0for node in G:if G.nodes[node]["state"] == "S":S += 1if G.nodes[node]["state"] == "I":I += 1 R = len(G.nodes) - S - Ireturn S, I, R def eachIterateSIRnum(days):eachiterate_SIR_list = []for day in range(1, days+1):updateNetworkState(G, infected_rate, recover_rate)tuple_sir = countSIRnum(G)print("day%s:\tS:%s\tI:%s\tR:%s"%(day,tuple_sir[0],tuple_sir[1],tuple_sir[2]))eachiterate_SIR_list.append(list(tuple_sir))return eachiterate_SIR_list
6.画图
def plotSIR(SIR_num_list):color_dict = {"S": "blue", "I": "red", "R": "green"} df = pd.DataFrame(SIR_num_list,columns=["S","I","R"]) df.plot(figsize=(9,6),color=[color_dict.get(x) for x in df.columns])plt.ylabel("number")plt.xlabel("day")plt.show()
7.完整代码
import random
import networkx as nx
import re
import os
import matplotlib.pyplot as plt
import pandas as pd
import timedef readEdgeslistToGraph(edges_filename):edge = []with open(edges_filename, 'r', encoding='utf-8-sig') as f:data = f.readlines()for line in data:line_str = line.replace('\r','').replace('\n','').replace('\t','')line = re.split(',| ',line_str)single_edge = tuple([line[0],line[1]])edge.append(single_edge)G = nx.Graph()G.add_edges_from(edge)return Gdef randomChoiseInfectedNodes(G, initial_infected_nodes_num):infected_nodes = random.sample(list(G.nodes), k=initial_infected_nodes_num)for node in G.nodes():G.nodes[node]["state"] = "S"for node in infected_nodes:G.nodes[node]["state"] = "I"def updateNodeState(G, node, infected_rate, recover_rate, old_G):if old_G.nodes[node]["state"] == "I":p = random.random()if p < recover_rate:G.nodes[node]["state"] = "R"elif old_G.nodes[node]["state"] == "S":k = 0for neibor in old_G.adj[node]:if old_G.nodes[neibor]["state"] == "I":k += 1p = random.random()if p < (1 - (1 - infected_rate)**k):G.nodes[node]["state"] = "I"def updateNetworkState(G, infected_rate, recover_rate):old_G = G.copy()for node in G:updateNodeState(G, node, infected_rate, recover_rate, old_G)def countSIRnum(G):S = 0I = 0for node in G:if G.nodes[node]["state"] == "S":S += 1if G.nodes[node]["state"] == "I":I += 1 R = len(G.nodes) - S - Ireturn S, I, Rdef eachIterateSIRnum(days):eachiterate_SIR_list = []for day in range(1, days+1):updateNetworkState(G, infected_rate, recover_rate)tuple_sir = countSIRnum(G)# print("day%s:\tS:%s\tI:%s\tR:%s"%(day,tuple_sir[0],tuple_sir[1],tuple_sir[2]))eachiterate_SIR_list.append(list(tuple_sir))return eachiterate_SIR_listdef plotSIR(SIR_num_list):color_dict = {"S": "blue", "I": "red", "R": "green"} df = pd.DataFrame(SIR_num_list,columns=["S","I","R"]) df.plot(figsize=(9,6),color=[color_dict.get(x) for x in df.columns])plt.ylabel("number")plt.xlabel("day")plt.show()if __name__ =='__main__':start = time.time()edges_filename = 'l_1.txt'initial_infected_nodes_num = 10infected_rate = 0.1recover_rate = 0.02days = 100iterate_num = 100average_SIR_list = [[0,0,0] for i in range(days)]for i in range(iterate_num):G = readEdgeslistToGraph(edges_filename)randomChoiseInfectedNodes(G, initial_infected_nodes_num)eachiterate_SIR_list = eachIterateSIRnum(days)for day in range(days):for state in range(3):average_SIR_list[day][state] += eachiterate_SIR_list[day][state] for day in range(days):for state in range(3):average_SIR_list[day][state] = average_SIR_list[day][state]/iterate_numplotSIR(average_SIR_list)end = time.time()print("running time: %.5s s"%(end - start))
8.运行结果
9.数据集
https://wwxd.lanzoue.com/iBdSw0i5uo6j
密码:h6s4