|
| 1 | +package DataMining_BayesNetwork; |
| 2 | + |
| 3 | +import java.io.BufferedReader; |
| 4 | +import java.io.File; |
| 5 | +import java.io.FileReader; |
| 6 | +import java.io.IOException; |
| 7 | +import java.util.ArrayList; |
| 8 | +import java.util.HashMap; |
| 9 | + |
| 10 | +/** |
| 11 | + * 贝叶斯网络算法工具类 |
| 12 | + * |
| 13 | + * @author lyq |
| 14 | + * |
| 15 | + */ |
| 16 | +public class BayesNetWorkTool { |
| 17 | + // 联合概率分布数据文件地址 |
| 18 | + private String dataFilePath; |
| 19 | + // 事件关联数据文件地址 |
| 20 | + private String attachFilePath; |
| 21 | + // 属性列列数 |
| 22 | + private int columns; |
| 23 | + // 概率分布数据 |
| 24 | + private String[][] totalData; |
| 25 | + // 关联数据对 |
| 26 | + private ArrayList<String[]> attachData; |
| 27 | + // 节点存放列表 |
| 28 | + private ArrayList<Node> nodes; |
| 29 | + // 属性名与列数之间的对应关系 |
| 30 | + private HashMap<String, Integer> attr2Column; |
| 31 | + |
| 32 | + public BayesNetWorkTool(String dataFilePath, String attachFilePath) { |
| 33 | + this.dataFilePath = dataFilePath; |
| 34 | + this.attachFilePath = attachFilePath; |
| 35 | + |
| 36 | + initDatas(); |
| 37 | + } |
| 38 | + |
| 39 | + /** |
| 40 | + * 初始化关联数据和概率分布数据 |
| 41 | + */ |
| 42 | + private void initDatas() { |
| 43 | + String[] columnValues; |
| 44 | + String[] array; |
| 45 | + ArrayList<String> datas; |
| 46 | + ArrayList<String> adatas; |
| 47 | + |
| 48 | + // 从文件中读取数据 |
| 49 | + datas = readDataFile(dataFilePath); |
| 50 | + adatas = readDataFile(attachFilePath); |
| 51 | + |
| 52 | + columnValues = datas.get(0).split(" "); |
| 53 | + // 从数据中取出属性名称行,列数值存入图中 |
| 54 | + this.attr2Column = new HashMap<>(); |
| 55 | + for (int i = 0; i < columnValues.length; i++) { |
| 56 | + this.attr2Column.put(columnValues[i], i); |
| 57 | + } |
| 58 | + |
| 59 | + this.columns = columnValues.length; |
| 60 | + this.totalData = new String[datas.size()][columns]; |
| 61 | + for (int i = 0; i < datas.size(); i++) { |
| 62 | + this.totalData[i] = datas.get(i).split(" "); |
| 63 | + } |
| 64 | + |
| 65 | + this.attachData = new ArrayList<>(); |
| 66 | + // 解析关联数据对 |
| 67 | + for (String str : adatas) { |
| 68 | + array = str.split(" "); |
| 69 | + this.attachData.add(array); |
| 70 | + } |
| 71 | + |
| 72 | + // 构造贝叶斯网络结构图 |
| 73 | + constructDAG(); |
| 74 | + } |
| 75 | + |
| 76 | + /** |
| 77 | + * 从文件中读取数据 |
| 78 | + */ |
| 79 | + private ArrayList<String> readDataFile(String filePath) { |
| 80 | + File file = new File(filePath); |
| 81 | + ArrayList<String> dataArray = new ArrayList<String>(); |
| 82 | + |
| 83 | + try { |
| 84 | + BufferedReader in = new BufferedReader(new FileReader(file)); |
| 85 | + String str; |
| 86 | + while ((str = in.readLine()) != null) { |
| 87 | + dataArray.add(str); |
| 88 | + } |
| 89 | + in.close(); |
| 90 | + } catch (IOException e) { |
| 91 | + e.getStackTrace(); |
| 92 | + } |
| 93 | + |
| 94 | + return dataArray; |
| 95 | + } |
| 96 | + |
| 97 | + /** |
| 98 | + * 根据关联数据构造贝叶斯网络无环有向图 |
| 99 | + */ |
| 100 | + private void constructDAG() { |
| 101 | + // 节点存在标识 |
| 102 | + boolean srcExist; |
| 103 | + boolean desExist; |
| 104 | + String name1; |
| 105 | + String name2; |
| 106 | + Node srcNode; |
| 107 | + Node desNode; |
| 108 | + |
| 109 | + this.nodes = new ArrayList<>(); |
| 110 | + for (String[] array : this.attachData) { |
| 111 | + srcExist = false; |
| 112 | + desExist = false; |
| 113 | + |
| 114 | + name1 = array[0]; |
| 115 | + name2 = array[1]; |
| 116 | + |
| 117 | + // 新建节点 |
| 118 | + srcNode = new Node(name1); |
| 119 | + desNode = new Node(name2); |
| 120 | + |
| 121 | + for (Node temp : this.nodes) { |
| 122 | + // 如果找到相同节点,则取出 |
| 123 | + if (srcNode.isEqual(temp)) { |
| 124 | + srcExist = true; |
| 125 | + srcNode = temp; |
| 126 | + } else if (desNode.isEqual(temp)) { |
| 127 | + desExist = true; |
| 128 | + desNode = temp; |
| 129 | + } |
| 130 | + |
| 131 | + // 如果2个节点都已找到,则跳出循环 |
| 132 | + if (srcExist && desExist) { |
| 133 | + break; |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + // 将2个节点进行连接 |
| 138 | + srcNode.connectNode(desNode); |
| 139 | + |
| 140 | + // 根据标识判断是否需要加入列表容器中 |
| 141 | + if (!srcExist) { |
| 142 | + this.nodes.add(srcNode); |
| 143 | + } |
| 144 | + |
| 145 | + if (!desExist) { |
| 146 | + this.nodes.add(desNode); |
| 147 | + } |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + /** |
| 152 | + * 查询条件概率 |
| 153 | + * |
| 154 | + * @param attrValues |
| 155 | + * 条件属性值 |
| 156 | + * @return |
| 157 | + */ |
| 158 | + private double queryConditionPro(ArrayList<String[]> attrValues) { |
| 159 | + // 判断是否满足先验属性值条件 |
| 160 | + boolean hasPrior; |
| 161 | + // 判断是否满足后验属性值条件 |
| 162 | + boolean hasBack; |
| 163 | + int priorIndex; |
| 164 | + int attrIndex; |
| 165 | + double backPro; |
| 166 | + double totalPro; |
| 167 | + double pro; |
| 168 | + double currentPro; |
| 169 | + // 先验属性 |
| 170 | + String[] priorValue; |
| 171 | + String[] tempData; |
| 172 | + |
| 173 | + pro = 0; |
| 174 | + totalPro = 0; |
| 175 | + backPro = 0; |
| 176 | + attrValues.get(0); |
| 177 | + priorValue = attrValues.get(0); |
| 178 | + // 得到后验概率 |
| 179 | + attrValues.remove(0); |
| 180 | + |
| 181 | + // 取出先验属性的列数 |
| 182 | + priorIndex = this.attr2Column.get(priorValue[0]); |
| 183 | + // 跳过第一行的属性名称行 |
| 184 | + for (int i = 1; i < this.totalData.length; i++) { |
| 185 | + tempData = this.totalData[i]; |
| 186 | + |
| 187 | + hasPrior = false; |
| 188 | + hasBack = true; |
| 189 | + |
| 190 | + // 当前行的概率 |
| 191 | + currentPro = Double.parseDouble(tempData[this.columns - 1]); |
| 192 | + // 判断是否满足先验条件 |
| 193 | + if (tempData[priorIndex].equals(priorValue[1])) { |
| 194 | + hasPrior = true; |
| 195 | + } |
| 196 | + |
| 197 | + for (String[] array : attrValues) { |
| 198 | + attrIndex = this.attr2Column.get(array[0]); |
| 199 | + |
| 200 | + // 判断值是否满足条件 |
| 201 | + if (!tempData[attrIndex].equals(array[1])) { |
| 202 | + hasBack = false; |
| 203 | + break; |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数 |
| 208 | + if (hasBack) { |
| 209 | + backPro += currentPro; |
| 210 | + if (hasPrior) { |
| 211 | + totalPro += currentPro; |
| 212 | + } |
| 213 | + } else if (hasPrior && attrValues.size() == 0) { |
| 214 | + // 如果只有先验概率则为纯概率的计算 |
| 215 | + totalPro += currentPro; |
| 216 | + backPro = 1.0; |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + // 计算总的概率=都发生概率/只发生后验条件的时间概率 |
| 221 | + pro = totalPro / backPro; |
| 222 | + |
| 223 | + return pro; |
| 224 | + } |
| 225 | + |
| 226 | + /** |
| 227 | + * 根据贝叶斯网络计算概率 |
| 228 | + * |
| 229 | + * @param queryStr |
| 230 | + * 查询条件串 |
| 231 | + * @return |
| 232 | + */ |
| 233 | + public double calProByNetWork(String queryStr) { |
| 234 | + double temp; |
| 235 | + double pro; |
| 236 | + String[] array; |
| 237 | + // 先验条件值 |
| 238 | + String[] preValue; |
| 239 | + // 后验条件值 |
| 240 | + String[] backValue; |
| 241 | + // 所有先验条件和后验条件值的属性值的汇总 |
| 242 | + ArrayList<String[]> attrValues; |
| 243 | + |
| 244 | + // 判断是否满足网络结构 |
| 245 | + if (!satisfiedNewWork(queryStr)) { |
| 246 | + return -1; |
| 247 | + } |
| 248 | + |
| 249 | + pro = 1; |
| 250 | + // 首先做查询条件的分解 |
| 251 | + array = queryStr.split(","); |
| 252 | + |
| 253 | + // 概率的初值等于第一个事件发生的随机概率 |
| 254 | + attrValues = new ArrayList<>(); |
| 255 | + attrValues.add(array[0].split("=")); |
| 256 | + pro = queryConditionPro(attrValues); |
| 257 | + |
| 258 | + for (int i = 0; i < array.length - 1; i++) { |
| 259 | + attrValues.clear(); |
| 260 | + |
| 261 | + // 下标小的在前面的属于后验属性 |
| 262 | + backValue = array[i].split("="); |
| 263 | + preValue = array[i + 1].split("="); |
| 264 | + attrValues.add(preValue); |
| 265 | + attrValues.add(backValue); |
| 266 | + |
| 267 | + // 算出此种情况的概率值 |
| 268 | + temp = queryConditionPro(attrValues); |
| 269 | + // 进行积的相乘 |
| 270 | + pro *= temp; |
| 271 | + } |
| 272 | + |
| 273 | + return pro; |
| 274 | + } |
| 275 | + |
| 276 | + /** |
| 277 | + * 验证事件的查询因果关系是否满足贝叶斯网络 |
| 278 | + * |
| 279 | + * @param queryStr |
| 280 | + * 查询字符串 |
| 281 | + * @return |
| 282 | + */ |
| 283 | + private boolean satisfiedNewWork(String queryStr) { |
| 284 | + String attrName; |
| 285 | + String[] array; |
| 286 | + boolean isExist; |
| 287 | + boolean isSatisfied; |
| 288 | + // 当前节点 |
| 289 | + Node currentNode; |
| 290 | + // 候选节点列表 |
| 291 | + ArrayList<Node> nodeList; |
| 292 | + |
| 293 | + isSatisfied = true; |
| 294 | + currentNode = null; |
| 295 | + // 做查询字符串的分解 |
| 296 | + array = queryStr.split(","); |
| 297 | + nodeList = this.nodes; |
| 298 | + |
| 299 | + for (String s : array) { |
| 300 | + // 开始时默认属性对应的节点不存在 |
| 301 | + isExist = false; |
| 302 | + // 得到属性事件名 |
| 303 | + attrName = s.split("=")[0]; |
| 304 | + |
| 305 | + for (Node n : nodeList) { |
| 306 | + if (n.name.equals(attrName)) { |
| 307 | + isExist = true; |
| 308 | + |
| 309 | + currentNode = n; |
| 310 | + // 下一轮的候选节点为当前节点的孩子节点 |
| 311 | + nodeList = currentNode.childNodes; |
| 312 | + |
| 313 | + break; |
| 314 | + } |
| 315 | + } |
| 316 | + |
| 317 | + // 如果存在未找到的节点,则说明不满足依赖结构跳出循环 |
| 318 | + if (!isExist) { |
| 319 | + isSatisfied = false; |
| 320 | + break; |
| 321 | + } |
| 322 | + } |
| 323 | + |
| 324 | + return isSatisfied; |
| 325 | + } |
| 326 | +} |
0 commit comments