代码
package com.dam.heuristic.sa.test; import java.util.Arrays; import java.util.Random; /** * 模拟退火 */ public class SaApi { //序列长度 private int sequenceLen; //初始温度 private double startT; //结束温度 private double endT; //每个温度的迭代次数 private int timeInPerTemperature; //公式中的常数k private double k; //降温系数 private double coolingCoefficient; //距离矩阵 private double[][] distanceMatrix; /** * 构造函数,用于创建对象 */ public SaApi(double startT, double endT, int timeInPerTemperature, double k, double coolingCoefficient, double[][] distanceMatrix) { this.sequenceLen = distanceMatrix[0].length; this.startT = startT; this.endT = endT; this.timeInPerTemperature = timeInPerTemperature; this.k = k; this.coolingCoefficient = coolingCoefficient; this.distanceMatrix = distanceMatrix; } /** * 模拟退火算法接口 */ public void solve() { 定义变量 long startTime = System.currentTimeMillis(); //当前温度 double curT = this.startT; ///最优解 //最优温度 double bestT = this.startT; //序列 int[] localSequence = new int[this.sequenceLen]; //最优序列 int[] bestSequence; //存储求得最优解的时间 long bestTime = 0; //最优序列对应的目标函数值 double bestObjectValue = 0; //上一序列的目标函数值 double lastObjectValue = Double.MAX_VALUE; ///对象 Random random = new Random(); 生成初始序列 this.generateInitialSequence(localSequence); //初始化bestSequence,刚开始的最优序列为初始序列 bestSequence = localSequence.clone(); bestObjectValue = this.getObjectValue(bestSequence); // System.out.println("初始序列:" + Arrays.toString(bestSequence)); // System.out.println("初始目标函数值:" + bestObjectValue); 对序列进行迭代优化 while (curT > this.endT) { // System.out.println("当前温度:" + curT + ",--------------------------------------------------------------"); for (int i = 0; i < this.timeInPerTemperature; i++) { int[] tempSequence = this.generateNewSequence(localSequence); double tempObjectValue = this.getObjectValue(tempSequence); double de = tempObjectValue - lastObjectValue; // System.out.println("温度:" + bestT); if (de < 0) { localSequence = tempSequence.clone(); lastObjectValue = tempObjectValue; if (tempObjectValue<bestObjectValue){ //更新最优解 bestT = curT; bestObjectValue = tempObjectValue; bestSequence = tempSequence.clone(); // bestTime = (System.currentTimeMillis() - startTime); // System.out.println("找到更优解:" + bestObjectValue + ",计算时间:" + bestTime + "ms"); } } else { double p = Math.exp(-de / (this.k * curT)); if (p > random.nextDouble()) { // System.out.println("当前接受差解为:" + tempObjectValue); //替换序列 localSequence = tempSequence.clone(); lastObjectValue = tempObjectValue; } else { // System.out.println("不接受"); } } } curT *= this.coolingCoefficient; } System.out.println("-----------------------------------------------------------------------------------------------------------------------------------"); System.out.println("最佳温度:" + bestT); System.out.println("最优目标函数值:" + bestObjectValue); System.out.println("最优解对应序列:" + Arrays.toString(bestSequence)); System.out.println("求解时间:" + (System.currentTimeMillis() - startTime) + "ms"); } /** * 生成初始序列 */ public void generateInitialSequence(int[] sequence) { // HashSet<Integer> sequenceSet = new HashSet<>(); // for (int i = 1; i < sequence.length; i++) { // sequenceSet.add(i); // } // // //贪婪算法获取初始序列,从城市0开始旅行,即城市0为起点城市 // sequence[0] = 0; // //每次获取离当前城市最近的城市,并加入到sequence // for (int i = 1; i < sequence.length; i++) { // //寻找离i-1城市最近的城市,即确定第i个城市是哪个 // double smallDistance = Double.MAX_VALUE; // int curCity = 0; // for (Integer j : sequenceSet) { // if (this.distanceMatrix[sequence[i - 1]][j] < smallDistance && j != sequence[i - 1]) { // smallDistance = this.distanceMatrix[sequence[i - 1]][j]; // curCity = j; // } // } // sequence[i] = curCity; // sequenceSet.remove(curCity); // } for (int i = 0; i < sequence.length; i++) { sequence[i] = i; } } /** * 根据当前序列获取目标函数值 * * @param sequence * @return */ public double getObjectValue(int[] sequence) { double objectValue = 0; //计算从第0个城市到最后一个城市的路程 for (int i = 0; i < sequence.length - 1; i++) { objectValue += this.distanceMatrix[sequence[i]][sequence[i + 1]]; } //计算最后一个城市到第0个城市的路程 objectValue += this.distanceMatrix[sequence[0]][sequence[sequence.length - 1]]; return objectValue; } /** * 生产新序列 * * @param sequence * @return */ public int[] generateNewSequence(int[] sequence) { int[] sequenceClone = sequence.clone(); //对序列中的元素进行打乱,即可产生新的序列 Random random = new Random(); int i = random.nextInt(sequence.length); int j = random.nextInt(sequence.length); while (i == j) { j = random.nextInt(sequence.length); } int temp = sequenceClone[i]; sequenceClone[i] = sequenceClone[j]; sequenceClone[j] = temp; return sequenceClone; } }
模拟退火过程
测试
package com.dam.heuristic.sa.test; import com.dam.heuristic.sa.improve.ImproveSaApi; import java.io.File; import java.io.FileInputStream; public class SAMainRun { public static void main(String[] args) throws Exception { 声明变量 //距离矩阵,可以直接获取任意两个编号城市的距离 double[][] distanceMatrix; 读取数据 String data = read(new File("src/main/java/com/data/tsp/att48.txt"), "GBK"); String[] cityDataArr = data.split("\n"); //初始化数组 distanceMatrix = new double[cityDataArr.length][cityDataArr.length]; for (int i = 0; i < cityDataArr.length; i++) { String[] city1Arr = cityDataArr[i].split(" "); int cityOne = Integer.valueOf(city1Arr[0]); for (int j = 0; j < i; j++) { String[] city2Arr = cityDataArr[j].split(" "); int cityTwo = Integer.valueOf(city2Arr[0]); if (cityOne == cityTwo) { distanceMatrix[cityOne - 1][cityTwo - 1] = 0; } else { distanceMatrix[cityOne - 1][cityTwo - 1] = getDistance(Double.valueOf(city1Arr[1]), Double.valueOf(city1Arr[2]), Double.valueOf(city2Arr[1]), Double.valueOf(city2Arr[2])); //对称赋值 distanceMatrix[cityTwo - 1][cityOne - 1] = distanceMatrix[cityOne - 1][cityTwo - 1]; } } } /* System.out.println("输出距离矩阵-------------------------------------------------------------------"); for (double[] matrix : distanceMatrix) { System.out.println(Arrays.toString(matrix)); }*/ System.out.println("saApi求解----------------------------------------------------------------------------"); SaApi saApi = new SaApi(100000, 1e-8, 3000, 10000.0, 0.98, distanceMatrix); saApi.solve(); } /** * 给定两个城市坐标,获取两个城市的直线距离 * * @param x1 * @param y1 * @param x2 * @param y2 * @return */ private static double getDistance(double x1, double y1, double x2, double y2) { return Math.sqrt((Math.pow((x1 - x2), 2) + Math.pow((y1 - y2), 2)) / 10); } private static String read(File f, String charset) throws Exception { FileInputStream fstream = new FileInputStream(f); try { int fileSize = (int) f.length(); if (fileSize > 1024 * 512) { throw new Exception("File too large to read! size=" + fileSize); } byte[] buffer = new byte[fileSize]; fstream.read(buffer); return new String(buffer, charset); } finally { try { fstream.close(); } catch (Exception e) { } } } }
saApi求解---------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- 最佳温度:0.003086887433472753 最优目标函数值:11085.716110897114 最优解对应序列:[3, 25, 41, 1, 28, 4, 47, 33, 40, 15, 21, 39, 45, 35, 29, 42, 16, 26, 18, 36, 5, 27, 6, 17, 43, 30, 37, 8, 7, 0, 2, 22, 10, 11, 14, 32, 19, 46, 20, 12, 13, 24, 38, 31, 23, 9, 44, 34] 求解时间:1944ms Process finished with exit code 0