|
CurveFitter |
|
package ij.measure; import ij.*; import ij.gui.*; /** Curve fitting class based on the Simplex method described * in the article "Fitting Curves to Data" in the May 1984 * issue of Byte magazine, pages 340-362. * * 2001/02/14: Midified to handle a gamma variate curve. * Uses altered Simplex method based on method in "Numerical Recipes in C". * This method tends to converge closer in less iterations. * Has the option to restart the simplex at the initial best solution in * case it is "stuck" in a local minimum (by default, restarted once). Also includes * settings dialog option for user control over simplex parameters and functions to * evaluate the goodness-of-fit. The results can be easily reported with the * getResultString() method. * * @author Kieran Holland (email: holki659@student.otago.ac.nz) * @version 1.0 * */ public class CurveFitter { public static final int STRAIGHT_LINE=0,POLY2=1,POLY3=2,POLY4=3, EXPONENTIAL=4,POWER=5,LOG=6,RODBARD=7,GAMMA_VARIATE=8, LOG2=9, RODBARD2=10, EXP_WITH_OFFSET=11; public static final int IterFactor = 500; public static final String[] fitList = {"Straight Line","2nd Degree Polynomial", "3rd Degree Polynomial", "4th Degree Polynomial","Exponential","Power", "log","Rodbard", "Gamma Variate", "y = a+b*ln(x-c)","Rodbard (NIH Image)", "Exponential with Offset"}; public static final String[] fList = {"y = a+bx","y = a+bx+cx^2", "y = a+bx+cx^2+dx^3", "y = a+bx+cx^2+dx^3+ex^4","y = a*exp(bx)","y = ax^b", "y = a*ln(bx)", "y = d+(a-d)/(1+(x/c)^b)", "y = a*(x-b)^c*exp(-(x-b)/d)", "y = a+b*ln(x-c)", "y = d+(a-d)/(1+(x/c)^b)", "y = a*exp(-bx) + c"}; private static final double alpha = -1.0; // reflection coefficient private static final double beta = 0.5; // contraction coefficient private static final double gamma = 2.0; // expansion coefficient private static final double root2 = 1.414214; // square root of 2 private int fit; // Number of curve type to fit private double[] xData, yData; // x,y data to fit private int numPoints; // number of data points private int numParams; // number of parametres private int numVertices; // numParams+1 (includes sumLocalResiduaalsSqrd) private int worst; // worst current parametre estimates private int nextWorst; // 2nd worst current parametre estimates private int best; // best current parametre estimates private double[][] simp; // the simplex (the last element of the array at each vertice is the sum of the square of the residuals) private double[] next; // new vertex to be tested private int numIter; // number of iterations so far private int maxIter; // maximum number of iterations per restart private int restarts; // number of times to restart simplex after first soln. private double maxError; // maximum error tolerance private double[] initialParams; // user specified initial parameters /** Construct a new CurveFitter. */ public CurveFitter (double[] xData, double[] yData) { this.xData = xData; this.yData = yData; numPoints = xData.length; } /** Perform curve fitting with the simplex method * doFit(fitType) just does the fit * doFit(fitType, true) pops up a dialog allowing control over simplex parameters * alpha is reflection coefficient (-1) * beta is contraction coefficient (0.5) * gamma is expansion coefficient (2) */ public void doFit(int fitType) { doFit(fitType, false); } public void doFit(int fitType, boolean showSettings) { if (fitType < STRAIGHT_LINE || fitType > EXP_WITH_OFFSET) throw new IllegalArgumentException("Invalid fit type"); int saveFitType = fitType; if (fitType==RODBARD2) { double[] temp; temp = xData; xData = yData; yData = temp; fitType = RODBARD; } fit = fitType; initialize(); if (initialParams!=null) { for (int i=0; i<numParams; i++) simp[0][i] = initialParams[i]; initialParams = null; } if (showSettings) settingsDialog(); restart(0); numIter = 0; boolean done = false; double[] center = new double[numParams]; // mean of simplex vertices while (!done) { numIter++; for (int i = 0; i < numParams; i++) center[i] = 0.0; // get mean "center" of vertices, excluding worst for (int i = 0; i < numVertices; i++) if (i != worst) for (int j = 0; j < numParams; j++) center[j] += simp[i][j]; // Reflect worst vertex through centre for (int i = 0; i < numParams; i++) { center[i] /= numParams; next[i] = center[i] + alpha*(simp[worst][i] - center[i]); } sumResiduals(next); // if it's better than the best... if (next[numParams] <= simp[best][numParams]) { newVertex(); // try expanding it for (int i = 0; i < numParams; i++) next[i] = center[i] + gamma * (simp[worst][i] - center[i]); sumResiduals(next); // if this is even better, keep it if (next[numParams] <= simp[worst][numParams]) newVertex(); } // else if better than the 2nd worst keep it... else if (next[numParams] <= simp[nextWorst][numParams]) { newVertex(); } // else try to make positive contraction of the worst else { for (int i = 0; i < numParams; i++) next[i] = center[i] + beta*(simp[worst][i] - center[i]); sumResiduals(next); // if this is better than the second worst, keep it. if (next[numParams] <= simp[nextWorst][numParams]) { newVertex(); } // if all else fails, contract simplex in on best else { for (int i = 0; i < numVertices; i++) { if (i != best) { for (int j = 0; j < numVertices; j++) simp[i][j] = beta*(simp[i][j]+simp[best][j]); sumResiduals(simp[i]); } } } } order(); double rtol = 2 * Math.abs(simp[best][numParams] - simp[worst][numParams]) / (Math.abs(simp[best][numParams]) + Math.abs(simp[worst][numParams]) + 0.0000000001); if (numIter >= maxIter) done = true; else if (rtol < maxError) { //System.out.print(getResultString()); restarts--; if (restarts < 0) { done = true; } else { restart(best); } } } fitType = saveFitType; } /** Initialise the simplex */ void initialize() { // Calculate some things that might be useful for predicting parametres numParams = getNumParams(); numVertices = numParams + 1; // need 1 more vertice than parametres, simp = new double[numVertices][numVertices]; next = new double[numVertices]; double firstx = xData[0]; double firsty = yData[0]; double lastx = xData[numPoints-1]; double lasty = yData[numPoints-1]; double xmean = (firstx+lastx)/2.0; double ymean = (firsty+lasty)/2.0; double slope; if ((lastx - firstx) != 0.0) slope = (lasty - firsty)/(lastx - firstx); else slope = 1.0; double yintercept = firsty - slope * firstx; maxIter = IterFactor * numParams * numParams; // Where does this estimate come from? restarts = 1; maxError = 1e-9; switch (fit) { case STRAIGHT_LINE: simp[0][0] = yintercept; simp[0][1] = slope; break; case POLY2: simp[0][0] = yintercept; simp[0][1] = slope; simp[0][2] = 0.0; break; case POLY3: simp[0][0] = yintercept; simp[0][1] = slope; simp[0][2] = 0.0; simp[0][3] = 0.0; break; case POLY4: simp[0][0] = yintercept; simp[0][1] = slope; simp[0][2] = 0.0; simp[0][3] = 0.0; simp[0][4] = 0.0; break; case EXPONENTIAL: simp[0][0] = 0.1; simp[0][1] = 0.01; break; case EXP_WITH_OFFSET: simp[0][0] = 0.1; simp[0][1] = 0.01; simp[0][2] = 0.1; break; case POWER: simp[0][0] = 0.0; simp[0][1] = 1.0; break; case LOG: simp[0][0] = 0.5; simp[0][1] = 0.05; break; case RODBARD: case RODBARD2: simp[0][0] = firsty; simp[0][1] = 1.0; simp[0][2] = xmean; simp[0][3] = lasty; break; case GAMMA_VARIATE: // First guesses based on following observations: // t0 [b] = time of first rise in gamma curve - so use the user specified first limit // tm = t0 + a*B [c*d] where tm is the time of the peak of the curve // therefore an estimate for a and B is sqrt(tm-t0) // K [a] can now be calculated from these estimates simp[0][0] = firstx; double ab = xData[getMax(yData)] - firstx; simp[0][2] = Math.sqrt(ab); simp[0][3] = Math.sqrt(ab); simp[0][1] = yData[getMax(yData)] / (Math.pow(ab, simp[0][2]) * Math.exp(-ab/simp[0][3])); break; case LOG2: simp[0][0] = 0.5; simp[0][1] = 0.05; simp[0][2] = 0.0; break; } } /** Pop up a dialog allowing control over simplex starting parameters */ private void settingsDialog() { GenericDialog gd = new GenericDialog("Simplex Fitting Options", IJ.getInstance()); gd.addMessage("Function name: " + fitList[fit] + "\n" + "Formula: " + fList[fit]); char pChar = 'a'; for (int i = 0; i < numParams; i++) { gd.addNumericField("Initial "+(new Character(pChar)).toString()+":", simp[0][i], 2); pChar++; } gd.addNumericField("Maximum iterations:", maxIter, 0); gd.addNumericField("Number of restarts:", restarts, 0); gd.addNumericField("Error tolerance [1*10^(-x)]:", -(Math.log(maxError)/Math.log(10)), 0); gd.showDialog(); if (gd.wasCanceled() || gd.invalidNumber()) { IJ.error("Parameter setting canceled.\nUsing default parameters."); } // Parametres: for (int i = 0; i < numParams; i++) { simp[0][i] = gd.getNextNumber(); } maxIter = (int) gd.getNextNumber(); restarts = (int) gd.getNextNumber(); maxError = Math.pow(10.0, -gd.getNextNumber()); } /** Restart the simplex at the nth vertex */ void restart(int n) { // Copy nth vertice of simplex to first vertice for (int i = 0; i < numParams; i++) { simp[0][i] = simp[n][i]; } sumResiduals(simp[0]); // Get sum of residuals^2 for first vertex double[] step = new double[numParams]; for (int i = 0; i < numParams; i++) { step[i] = simp[0][i] / 2.0; // Step half the parametre value if (step[i] == 0.0) // We can't have them all the same or we're going nowhere step[i] = 0.01; } // Some kind of factor for generating new vertices double[] p = new double[numParams]; double[] q = new double[numParams]; for (int i = 0; i < numParams; i++) { p[i] = step[i] * (Math.sqrt(numVertices) + numParams - 1.0)/(numParams * root2); q[i] = step[i] * (Math.sqrt(numVertices) - 1.0)/(numParams * root2); } // Create the other simplex vertices by modifing previous one. for (int i = 1; i < numVertices; i++) { for (int j = 0; j < numParams; j++) { simp[i][j] = simp[i-1][j] + q[j]; } simp[i][i-1] = simp[i][i-1] + p[i-1]; sumResiduals(simp[i]); } // Initialise current lowest/highest parametre estimates to simplex 1 best = 0; worst = 0; nextWorst = 0; order(); } // Display simplex [Iteration: s0(p1, p2....), s1(),....] in ImageJ window void showSimplex(int iter) { ij.IJ.write("" + iter); for (int i = 0; i < numVertices; i++) { String s = ""; for (int j=0; j < numVertices; j++) s += " "+ ij.IJ.d2s(simp[i][j], 6); ij.IJ.write(s); } } /** Get number of parameters for current fit function */ public int getNumParams() { switch (fit) { case STRAIGHT_LINE: return 2; case POLY2: return 3; case POLY3: return 4; case POLY4: return 5; case EXPONENTIAL: return 2; case POWER: return 2; case LOG: return 2; case RODBARD: case RODBARD2: return 4; case GAMMA_VARIATE: return 4; case LOG2: return 3; case EXP_WITH_OFFSET: return 3; } return 0; } /** Returns "fit" function value for parameters "p" at "x" */ public static double f(int fit, double[] p, double x) { double y; switch (fit) { case STRAIGHT_LINE: return p[0] + p[1]*x; case POLY2: return p[0] + p[1]*x + p[2]* x*x; case POLY3: return p[0] + p[1]*x + p[2]*x*x + p[3]*x*x*x; case POLY4: return p[0] + p[1]*x + p[2]*x*x + p[3]*x*x*x + p[4]*x*x*x*x; case EXPONENTIAL: return p[0]*Math.exp(p[1]*x); case EXP_WITH_OFFSET: return p[0]*Math.exp(p[1]*x*-1)+p[2]; case POWER: if (x == 0.0) return 0.0; else return p[0]*Math.exp(p[1]*Math.log(x)); //y=ax^b case LOG: if (x == 0.0) x = 0.5; return p[0]*Math.log(p[1]*x); case RODBARD: double ex; if (x == 0.0) ex = 0.0; else ex = Math.exp(Math.log(x/p[2])*p[1]); y = p[0]-p[3]; y = y/(1.0+ex); return y+p[3]; case GAMMA_VARIATE: if (p[0] >= x) return 0.0; if (p[1] <= 0) return -100000.0; if (p[2] <= 0) return -100000.0; if (p[3] <= 0) return -100000.0; double pw = Math.pow((x - p[0]), p[2]); double e = Math.exp((-(x - p[0]))/p[3]); return p[1]*pw*e; case LOG2: double tmp = x-p[2]; if (tmp<0.001) tmp = 0.001; return p[0]+p[1]*Math.log(tmp); case RODBARD2: if (x<=p[0]) y = 0.0; else { y = (p[0]-x)/(x-p[3]); y = Math.exp(Math.log(y)*(1.0/p[1])); //y=y**(1/b) y = y*p[2]; } return y; default: return 0.0; } } /** Get the set of parameter values from the best corner of the simplex */ public double[] getParams() { order(); return simp[best]; } /** Returns residuals array ie. differences between data and curve. */ public double[] getResiduals() { int saveFit = fit; if (fit==RODBARD2) fit=RODBARD; double[] params = getParams(); double[] residuals = new double[numPoints]; for (int i = 0; i < numPoints; i++) residuals[i] = yData[i] - f(fit, params, xData[i]); fit = saveFit; return residuals; } /* Last "parametre" at each vertex of simplex is sum of residuals * for the curve described by that vertex */ public double getSumResidualsSqr() { double sumResidualsSqr = (getParams())[getNumParams()]; return sumResidualsSqr; } /** Returns the standard deviation of the residuals. */ public double getSD() { double[] residuals = getResiduals(); int n = residuals.length; double sum=0.0, sum2=0.0; for (int i=0; i<n; i++) { sum += residuals[i]; sum2 += residuals[i]*residuals[i]; } double stdDev = (n*sum2-sum*sum)/n; return Math.sqrt(stdDev/(n-1.0)); } /** Returns R^2, where 1.0 is best. <pre> r^2 = 1 - SSE/SSD where: SSE = sum of the squares of the errors SSD = sum of the squares of the deviations about the mean. </pre> */ public double getRSquared() { double sumY = 0.0; for (int i=0; i<numPoints; i++) sumY += yData[i]; double mean = sumY/numPoints; double sumMeanDiffSqr = 0.0; for (int i=0; i<numPoints; i++) sumMeanDiffSqr += sqr(yData[i]-mean); double rSquared = 0.0; if (sumMeanDiffSqr>0.0) rSquared = 1.0 - getSumResidualsSqr()/sumMeanDiffSqr; return rSquared; } /** Get a measure of "goodness of fit" where 1.0 is best. */ public double getFitGoodness() { double sumY = 0.0; for (int i = 0; i < numPoints; i++) sumY += yData[i]; double mean = sumY / numPoints; double sumMeanDiffSqr = 0.0; int degreesOfFreedom = numPoints - getNumParams(); double fitGoodness = 0.0; for (int i = 0; i < numPoints; i++) { sumMeanDiffSqr += sqr(yData[i] - mean); } if (sumMeanDiffSqr > 0.0 && degreesOfFreedom != 0) fitGoodness = 1.0 - (getSumResidualsSqr() / degreesOfFreedom) * ((numPoints) / sumMeanDiffSqr); return fitGoodness; } /** Get a string description of the curve fitting results * for easy output. */ public String getResultString() { StringBuffer results = new StringBuffer("\nNumber of iterations: " + getIterations() + "\nMaximum number of iterations: " + getMaxIterations() + "\nSum of residuals squared: " + IJ.d2s(getSumResidualsSqr(),4) + "\nStandard deviation: " + IJ.d2s(getSD(),4) + "\nR^2: " + IJ.d2s(getRSquared(),4) + "\nParameters:"); char pChar = 'a'; double[] pVal = getParams(); for (int i = 0; i < numParams; i++) { results.append("\n " + pChar + " = " + IJ.d2s(pVal[i],4)); pChar++; } return results.toString(); } double sqr(double d) { return d * d; } /** Adds sum of square of residuals to end of array of parameters */ void sumResiduals (double[] x) { x[numParams] = 0.0; for (int i = 0; i < numPoints; i++) { x[numParams] = x[numParams] + sqr(f(fit,x,xData[i])-yData[i]); // if (IJ.debugMode) ij.IJ.log(i+" "+x[n-1]+" "+f(fit,x,xData[i])+" "+yData[i]); } } /** Keep the "next" vertex */ void newVertex() { for (int i = 0; i < numVertices; i++) simp[worst][i] = next[i]; } /** Find the worst, nextWorst and best current set of parameter estimates */ void order() { for (int i = 0; i < numVertices; i++) { if (simp[i][numParams] < simp[best][numParams]) best = i; if (simp[i][numParams] > simp[worst][numParams]) worst = i; } nextWorst = best; for (int i = 0; i < numVertices; i++) { if (i != worst) { if (simp[i][numParams] > simp[nextWorst][numParams]) nextWorst = i; } } // IJ.write("B: " + simp[best][numParams] + " 2ndW: " + simp[nextWorst][numParams] + " W: " + simp[worst][numParams]); } /** Get number of iterations performed */ public int getIterations() { return numIter; } /** Get maximum number of iterations allowed */ public int getMaxIterations() { return maxIter; } /** Set maximum number of iterations allowed */ public void setMaxIterations(int x) { maxIter = x; } /** Get number of simplex restarts to do */ public int getRestarts() { return restarts; } /** Set number of simplex restarts to do */ public void setRestarts(int x) { restarts = x; } /** Sets the initial parameters, which override the default initial parameters. */ public void setInitialParameters(double[] params) { initialParams = params; } /** * Gets index of highest value in an array. * * @param Double array. * @return Index of highest value. */ public static int getMax(double[] array) { double max = array[0]; int index = 0; for(int i = 1; i < array.length; i++) { if(max < array[i]) { max = array[i]; index = i; } } return index; } }
|
CurveFitter |
|