Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit eb24cbd

Browse files
author
linyiqun
committedJun 28, 2015
贝叶斯网络算法工具类
贝叶斯网络算法工具类
1 parent 03c577f commit eb24cbd

File tree

1 file changed

+326
-0
lines changed

1 file changed

+326
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
package DataMining_BayesNetwork;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.io.IOException;
7+
import java.util.ArrayList;
8+
import java.util.HashMap;
9+
10+
/**
11+
* 贝叶斯网络算法工具类
12+
*
13+
* @author lyq
14+
*
15+
*/
16+
public class BayesNetWorkTool {
17+
// 联合概率分布数据文件地址
18+
private String dataFilePath;
19+
// 事件关联数据文件地址
20+
private String attachFilePath;
21+
// 属性列列数
22+
private int columns;
23+
// 概率分布数据
24+
private String[][] totalData;
25+
// 关联数据对
26+
private ArrayList<String[]> attachData;
27+
// 节点存放列表
28+
private ArrayList<Node> nodes;
29+
// 属性名与列数之间的对应关系
30+
private HashMap<String, Integer> attr2Column;
31+
32+
public BayesNetWorkTool(String dataFilePath, String attachFilePath) {
33+
this.dataFilePath = dataFilePath;
34+
this.attachFilePath = attachFilePath;
35+
36+
initDatas();
37+
}
38+
39+
/**
40+
* 初始化关联数据和概率分布数据
41+
*/
42+
private void initDatas() {
43+
String[] columnValues;
44+
String[] array;
45+
ArrayList<String> datas;
46+
ArrayList<String> adatas;
47+
48+
// 从文件中读取数据
49+
datas = readDataFile(dataFilePath);
50+
adatas = readDataFile(attachFilePath);
51+
52+
columnValues = datas.get(0).split(" ");
53+
// 从数据中取出属性名称行,列数值存入图中
54+
this.attr2Column = new HashMap<>();
55+
for (int i = 0; i < columnValues.length; i++) {
56+
this.attr2Column.put(columnValues[i], i);
57+
}
58+
59+
this.columns = columnValues.length;
60+
this.totalData = new String[datas.size()][columns];
61+
for (int i = 0; i < datas.size(); i++) {
62+
this.totalData[i] = datas.get(i).split(" ");
63+
}
64+
65+
this.attachData = new ArrayList<>();
66+
// 解析关联数据对
67+
for (String str : adatas) {
68+
array = str.split(" ");
69+
this.attachData.add(array);
70+
}
71+
72+
// 构造贝叶斯网络结构图
73+
constructDAG();
74+
}
75+
76+
/**
77+
* 从文件中读取数据
78+
*/
79+
private ArrayList<String> readDataFile(String filePath) {
80+
File file = new File(filePath);
81+
ArrayList<String> dataArray = new ArrayList<String>();
82+
83+
try {
84+
BufferedReader in = new BufferedReader(new FileReader(file));
85+
String str;
86+
while ((str = in.readLine()) != null) {
87+
dataArray.add(str);
88+
}
89+
in.close();
90+
} catch (IOException e) {
91+
e.getStackTrace();
92+
}
93+
94+
return dataArray;
95+
}
96+
97+
/**
98+
* 根据关联数据构造贝叶斯网络无环有向图
99+
*/
100+
private void constructDAG() {
101+
// 节点存在标识
102+
boolean srcExist;
103+
boolean desExist;
104+
String name1;
105+
String name2;
106+
Node srcNode;
107+
Node desNode;
108+
109+
this.nodes = new ArrayList<>();
110+
for (String[] array : this.attachData) {
111+
srcExist = false;
112+
desExist = false;
113+
114+
name1 = array[0];
115+
name2 = array[1];
116+
117+
// 新建节点
118+
srcNode = new Node(name1);
119+
desNode = new Node(name2);
120+
121+
for (Node temp : this.nodes) {
122+
// 如果找到相同节点,则取出
123+
if (srcNode.isEqual(temp)) {
124+
srcExist = true;
125+
srcNode = temp;
126+
} else if (desNode.isEqual(temp)) {
127+
desExist = true;
128+
desNode = temp;
129+
}
130+
131+
// 如果2个节点都已找到,则跳出循环
132+
if (srcExist && desExist) {
133+
break;
134+
}
135+
}
136+
137+
// 将2个节点进行连接
138+
srcNode.connectNode(desNode);
139+
140+
// 根据标识判断是否需要加入列表容器中
141+
if (!srcExist) {
142+
this.nodes.add(srcNode);
143+
}
144+
145+
if (!desExist) {
146+
this.nodes.add(desNode);
147+
}
148+
}
149+
}
150+
151+
/**
152+
* 查询条件概率
153+
*
154+
* @param attrValues
155+
* 条件属性值
156+
* @return
157+
*/
158+
private double queryConditionPro(ArrayList<String[]> attrValues) {
159+
// 判断是否满足先验属性值条件
160+
boolean hasPrior;
161+
// 判断是否满足后验属性值条件
162+
boolean hasBack;
163+
int priorIndex;
164+
int attrIndex;
165+
double backPro;
166+
double totalPro;
167+
double pro;
168+
double currentPro;
169+
// 先验属性
170+
String[] priorValue;
171+
String[] tempData;
172+
173+
pro = 0;
174+
totalPro = 0;
175+
backPro = 0;
176+
attrValues.get(0);
177+
priorValue = attrValues.get(0);
178+
// 得到后验概率
179+
attrValues.remove(0);
180+
181+
// 取出先验属性的列数
182+
priorIndex = this.attr2Column.get(priorValue[0]);
183+
// 跳过第一行的属性名称行
184+
for (int i = 1; i < this.totalData.length; i++) {
185+
tempData = this.totalData[i];
186+
187+
hasPrior = false;
188+
hasBack = true;
189+
190+
// 当前行的概率
191+
currentPro = Double.parseDouble(tempData[this.columns - 1]);
192+
// 判断是否满足先验条件
193+
if (tempData[priorIndex].equals(priorValue[1])) {
194+
hasPrior = true;
195+
}
196+
197+
for (String[] array : attrValues) {
198+
attrIndex = this.attr2Column.get(array[0]);
199+
200+
// 判断值是否满足条件
201+
if (!tempData[attrIndex].equals(array[1])) {
202+
hasBack = false;
203+
break;
204+
}
205+
}
206+
207+
// 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数
208+
if (hasBack) {
209+
backPro += currentPro;
210+
if (hasPrior) {
211+
totalPro += currentPro;
212+
}
213+
} else if (hasPrior && attrValues.size() == 0) {
214+
// 如果只有先验概率则为纯概率的计算
215+
totalPro += currentPro;
216+
backPro = 1.0;
217+
}
218+
}
219+
220+
// 计算总的概率=都发生概率/只发生后验条件的时间概率
221+
pro = totalPro / backPro;
222+
223+
return pro;
224+
}
225+
226+
/**
227+
* 根据贝叶斯网络计算概率
228+
*
229+
* @param queryStr
230+
* 查询条件串
231+
* @return
232+
*/
233+
public double calProByNetWork(String queryStr) {
234+
double temp;
235+
double pro;
236+
String[] array;
237+
// 先验条件值
238+
String[] preValue;
239+
// 后验条件值
240+
String[] backValue;
241+
// 所有先验条件和后验条件值的属性值的汇总
242+
ArrayList<String[]> attrValues;
243+
244+
// 判断是否满足网络结构
245+
if (!satisfiedNewWork(queryStr)) {
246+
return -1;
247+
}
248+
249+
pro = 1;
250+
// 首先做查询条件的分解
251+
array = queryStr.split(",");
252+
253+
// 概率的初值等于第一个事件发生的随机概率
254+
attrValues = new ArrayList<>();
255+
attrValues.add(array[0].split("="));
256+
pro = queryConditionPro(attrValues);
257+
258+
for (int i = 0; i < array.length - 1; i++) {
259+
attrValues.clear();
260+
261+
// 下标小的在前面的属于后验属性
262+
backValue = array[i].split("=");
263+
preValue = array[i + 1].split("=");
264+
attrValues.add(preValue);
265+
attrValues.add(backValue);
266+
267+
// 算出此种情况的概率值
268+
temp = queryConditionPro(attrValues);
269+
// 进行积的相乘
270+
pro *= temp;
271+
}
272+
273+
return pro;
274+
}
275+
276+
/**
277+
* 验证事件的查询因果关系是否满足贝叶斯网络
278+
*
279+
* @param queryStr
280+
* 查询字符串
281+
* @return
282+
*/
283+
private boolean satisfiedNewWork(String queryStr) {
284+
String attrName;
285+
String[] array;
286+
boolean isExist;
287+
boolean isSatisfied;
288+
// 当前节点
289+
Node currentNode;
290+
// 候选节点列表
291+
ArrayList<Node> nodeList;
292+
293+
isSatisfied = true;
294+
currentNode = null;
295+
// 做查询字符串的分解
296+
array = queryStr.split(",");
297+
nodeList = this.nodes;
298+
299+
for (String s : array) {
300+
// 开始时默认属性对应的节点不存在
301+
isExist = false;
302+
// 得到属性事件名
303+
attrName = s.split("=")[0];
304+
305+
for (Node n : nodeList) {
306+
if (n.name.equals(attrName)) {
307+
isExist = true;
308+
309+
currentNode = n;
310+
// 下一轮的候选节点为当前节点的孩子节点
311+
nodeList = currentNode.childNodes;
312+
313+
break;
314+
}
315+
}
316+
317+
// 如果存在未找到的节点,则说明不满足依赖结构跳出循环
318+
if (!isExist) {
319+
isSatisfied = false;
320+
break;
321+
}
322+
}
323+
324+
return isSatisfied;
325+
}
326+
}

0 commit comments

Comments
 (0)
Please sign in to comment.