001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.ArrayList; 016import java.util.Collections; 017import java.util.Comparator; 018 019import org.apache.commons.lang.ArrayUtils; 020 021class InterpolatedPoint { 022 023 Dataset realPoint; 024 Dataset coordPoint; 025 026 public InterpolatedPoint(Dataset realPoint, Dataset coordPoint) { 027 this.realPoint = realPoint; 028 this.coordPoint = coordPoint; 029 } 030 031 public Dataset getRealPoint() { 032 return realPoint; 033 } 034 035 public Dataset getCoordPoint() { 036 return coordPoint; 037 } 038 039 @Override 040 public String toString() { 041 String realString = "[ " + realPoint.getDouble(0); 042 for(int i = 1; i < realPoint.getShapeRef()[0]; i++) { 043 realString += " , " + realPoint.getDouble(i); 044 } 045 realString += " ]"; 046 047 String coordString = "[ " + coordPoint.getDouble(0); 048 for(int i = 1; i < coordPoint.getShapeRef()[0]; i++) { 049 coordString += " , " + coordPoint.getDouble(i) ; 050 } 051 coordString += " ]"; 052 053 return realString + " : " + coordString; 054 } 055 056} 057 058public class InterpolatorUtils { 059 060 public static Dataset regridOld(Dataset data, Dataset x, Dataset y, 061 Dataset gridX, Dataset gridY) throws Exception { 062 063 DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, gridX.getShapeRef()[0], gridY.getShapeRef()[0]); 064 065 IndexIterator itx = gridX.getIterator(); 066 067 // need a list of lists to store points 068 ArrayList<ArrayList<InterpolatedPoint>> pointList = new ArrayList<ArrayList<InterpolatedPoint>>(); 069 070 while(itx.hasNext()){ 071 // Add a list to contain all the points which we find 072 pointList.add(new ArrayList<InterpolatedPoint>()); 073 074 int xindex = itx.index; 075 double xPos = gridX.getDouble(xindex); 076 077 IndexIterator ity = gridY.getIterator(); 078 while(ity.hasNext()){ 079 int yindex = ity.index; 080 System.out.println("Testing : "+xindex+","+yindex); 081 double yPos = gridX.getDouble(yindex); 082 result.set(getInterpolated(data, x, y, xPos, yPos), yindex, xindex); 083 084 } 085 } 086 return result; 087 } 088 089 090 091 092 public static Dataset selectDatasetRegion(Dataset dataset, int x, int y, int xSize, int ySize) { 093 int startX = x - xSize; 094 int startY = y - ySize; 095 int endX = x + xSize + 1; 096 int endY = y + ySize +1; 097 098 int shapeX = dataset.getShapeRef()[0]; 099 int shapeY = dataset.getShapeRef()[1]; 100 101 // Do edge checking 102 if (startX < 0) { 103 startX = 0; 104 endX = 3; 105 } 106 107 if (endX > shapeX) { 108 endX = shapeX; 109 startX = endX-3; 110 } 111 112 if (startY < 0) { 113 startY = 0; 114 endY = 3; 115 } 116 117 if (endY > shapeY) { 118 endY = shapeY; 119 startY = endY-3; 120 } 121 122 int[] start = new int[] { startX, startY }; 123 int[] stop = new int[] { endX, endY }; 124 125 126 return dataset.getSlice(start, stop, null); 127 } 128 129 private static double getInterpolated(Dataset val, Dataset x, Dataset y, double xPos, 130 double yPos) throws Exception { 131 132 // initial guess 133 Dataset xPosDS = x.getSlice(new int[] {0,0}, new int[] {x.getShapeRef()[0],1}, null).isubtract(xPos); 134 int xPosMin = xPosDS.minPos()[0]; 135 Dataset yPosDS = y.getSlice(new int[] {xPosMin,0}, new int[] {xPosMin+1,y.getShapeRef()[1]}, null).isubtract(yPos); 136 int yPosMin = yPosDS.minPos()[0]; 137 138 139 // now search around there 5x5 140 141 Dataset xClipped = selectDatasetRegion(x,xPosMin,yPosMin,2,2); 142 Dataset yClipped = selectDatasetRegion(y,xPosMin,yPosMin,2,2); 143 144 // first find the point in the arrays nearest to the point 145 Dataset xSquare = Maths.subtract(xClipped, xPos).ipower(2); 146 Dataset ySquare = Maths.subtract(yClipped, yPos).ipower(2); 147 148 Dataset total = Maths.add(xSquare, ySquare); 149 150 int[] pos = total.minPos(); 151 152 // now pull out the region around that point, as a 3x3 grid 153 Dataset xReduced = selectDatasetRegion(x, pos[0], pos[1], 1, 1); 154 Dataset yReduced = selectDatasetRegion(y, pos[0], pos[1], 1, 1); 155 Dataset valReduced = selectDatasetRegion(val, pos[0], pos[1], 1, 1); 156 157 return getInterpolatedResultFromNinePoints(valReduced, xReduced, yReduced, xPos, yPos); 158 } 159 160 private static double getInterpolatedResultFromNinePoints(Dataset val, Dataset x, Dataset y, 161 double xPos, double yPos) throws Exception { 162 163 // First build the nine points 164 InterpolatedPoint p00 = makePoint(x, y, 0, 0); 165 InterpolatedPoint p01 = makePoint(x, y, 0, 1); 166 InterpolatedPoint p02 = makePoint(x, y, 0, 2); 167 InterpolatedPoint p10 = makePoint(x, y, 1, 0); 168 InterpolatedPoint p11 = makePoint(x, y, 1, 1); 169 InterpolatedPoint p12 = makePoint(x, y, 1, 2); 170 InterpolatedPoint p20 = makePoint(x, y, 2, 0); 171 InterpolatedPoint p21 = makePoint(x, y, 2, 1); 172 InterpolatedPoint p22 = makePoint(x, y, 2, 2); 173 174 // now try every connection and find points that intersect with the interpolated value 175 ArrayList<InterpolatedPoint> points = new ArrayList<InterpolatedPoint>(); 176 177 InterpolatedPoint A = get1DInterpolatedPoint(p00, p10, 0, xPos); 178 InterpolatedPoint B = get1DInterpolatedPoint(p10, p20, 0, xPos); 179 InterpolatedPoint C = get1DInterpolatedPoint(p00, p01, 0, xPos); 180 InterpolatedPoint D = get1DInterpolatedPoint(p10, p11, 0, xPos); 181 InterpolatedPoint E = get1DInterpolatedPoint(p20, p21, 0, xPos); 182 InterpolatedPoint F = get1DInterpolatedPoint(p01, p11, 0, xPos); 183 InterpolatedPoint G = get1DInterpolatedPoint(p11, p21, 0, xPos); 184 InterpolatedPoint H = get1DInterpolatedPoint(p01, p02, 0, xPos); 185 InterpolatedPoint I = get1DInterpolatedPoint(p11, p12, 0, xPos); 186 InterpolatedPoint J = get1DInterpolatedPoint(p21, p22, 0, xPos); 187 InterpolatedPoint K = get1DInterpolatedPoint(p02, p12, 0, xPos); 188 InterpolatedPoint L = get1DInterpolatedPoint(p12, p22, 0, xPos); 189 190 // Now add any to the list which are not null 191 if (A != null) 192 points.add(A); 193 if (B != null) 194 points.add(B); 195 if (C != null) 196 points.add(C); 197 if (D != null) 198 points.add(D); 199 if (E != null) 200 points.add(E); 201 if (F != null) 202 points.add(F); 203 if (G != null) 204 points.add(G); 205 if (H != null) 206 points.add(H); 207 if (I != null) 208 points.add(I); 209 if (J != null) 210 points.add(J); 211 if (K != null) 212 points.add(K); 213 if (L != null) 214 points.add(L); 215 216 // if no intercepts, then retun NaN; 217 if (points.size() == 0) return Double.NaN; 218 219 InterpolatedPoint bestPoint = null; 220 221 // sort the points by y 222 Collections.sort(points, new Comparator<InterpolatedPoint>() { 223 224 @Override 225 public int compare(InterpolatedPoint o1, InterpolatedPoint o2) { 226 return (int) Math.signum(o1.realPoint.getDouble(1) - o2.realPoint.getDouble(1)); 227 } 228 }); 229 230 231 // now we have all the points which fit the x criteria, Find the points which fit the y 232 for (int a = 1; a < points.size(); a++) { 233 InterpolatedPoint testPoint = get1DInterpolatedPoint(points.get(a - 1), points.get(a), 1, yPos); 234 if (testPoint != null) { 235 bestPoint = testPoint; 236 break; 237 } 238 } 239 240 if (bestPoint == null) { 241 return Double.NaN; 242 } 243 244 // now we have the best point, we can calculate the weights, and positions 245 int xs = (int) Math.floor(bestPoint.getCoordPoint().getDouble(0)); 246 int ys = (int) Math.floor(bestPoint.getCoordPoint().getDouble(1)); 247 248 double xoff = bestPoint.getCoordPoint().getDouble(0) - xs; 249 double yoff = bestPoint.getCoordPoint().getDouble(1) - ys; 250 251 // check corner cases 252 if (xs == 2) { 253 xs = 1; 254 xoff = 1.0; 255 } 256 257 if (ys == 2) { 258 ys = 1; 259 yoff = 1.0; 260 } 261 262 double w00 = (1 - xoff) * (1 - yoff); 263 double w10 = (xoff) * (1 - yoff); 264 double w01 = (1 - xoff) * (yoff); 265 double w11 = (xoff) * (yoff); 266 267 // now using the weights, we can get the final interpolated value 268 double result = val.getDouble(xs, ys) * w00; 269 result += val.getDouble(xs + 1, ys) * w10; 270 result += val.getDouble(xs, ys + 1) * w01; 271 result += val.getDouble(xs + 1, ys + 1) * w11; 272 273 return result; 274 } 275 276 private static InterpolatedPoint makePoint(Dataset x, Dataset y, int i, int j) { 277 Dataset realPoint = DatasetFactory.createFromObject(new double[] { x.getDouble(i, j), y.getDouble(i, j) }); 278 Dataset coordPoint = DatasetFactory.createFromObject(new double[] { i, j }); 279 return new InterpolatedPoint(realPoint, coordPoint); 280 } 281 282 /** 283 * Gets an interpolated position when only dealing with 1 dimension for the interpolation. 284 * 285 * @param p1 286 * Point 1 287 * @param p2 288 * Point 2 289 * @param interpolationDimension 290 * The dimension in which the interpolation should be carried out 291 * @param interpolatedValue 292 * The value at which the interpolated point should be at in the chosen dimension 293 * @return the new interpolated point. 294 * @throws IllegalArgumentException 295 */ 296 private static InterpolatedPoint get1DInterpolatedPoint(InterpolatedPoint p1, InterpolatedPoint p2, 297 int interpolationDimension, double interpolatedValue) throws IllegalArgumentException { 298 299 checkPoints(p1, p2); 300 301 if (interpolationDimension >= p1.getRealPoint().getShapeRef()[0]) { 302 throw new IllegalArgumentException("Dimention is too large for these datasets"); 303 } 304 305 double p1_n = p1.getRealPoint().getDouble(interpolationDimension); 306 double p2_n = p2.getRealPoint().getDouble(interpolationDimension); 307 double max = Math.max(p1_n, p2_n); 308 double min = Math.min(p1_n, p2_n); 309 310 if (interpolatedValue < min || interpolatedValue > max || min==max) { 311 return null; 312 } 313 314 double proportion = (interpolatedValue - min) / (max - min); 315 316 return getInterpolatedPoint(p1, p2, proportion); 317 } 318 319 /** 320 * Gets an interpolated point between 2 points given a certain proportion 321 * 322 * @param p1 323 * the initial point 324 * @param p2 325 * the final point 326 * @param proportion 327 * how far the new point is along the path between P1(0.0) and P2(1.0) 328 * @return a new point which is the interpolated point 329 */ 330 private static InterpolatedPoint getInterpolatedPoint(InterpolatedPoint p1, InterpolatedPoint p2, double proportion) { 331 332 checkPoints(p1, p2); 333 334 if (proportion < 0 || proportion > 1.0) { 335 throw new IllegalArgumentException("Proportion must be between 0 and 1"); 336 } 337 338 Dataset p1RealContribution = Maths.multiply(p1.getRealPoint(), (1.0 - proportion)); 339 Dataset p2RealContribution = Maths.multiply(p2.getRealPoint(), (proportion)); 340 341 Dataset realPoint = Maths.add(p1RealContribution, p2RealContribution); 342 343 Dataset p1CoordContribution = Maths.multiply(p1.getCoordPoint(), (1.0 - proportion)); 344 Dataset p2CoordContribution = Maths.multiply(p2.getCoordPoint(), (proportion)); 345 346 Dataset coordPoint = Maths.add(p1CoordContribution, p2CoordContribution); 347 348 return new InterpolatedPoint(realPoint, coordPoint); 349 } 350 351 /** 352 * Checks to see if 2 points have the same dimensionality 353 * 354 * @param p1 355 * Point 1 356 * @param p2 357 * Point 2 358 * @throws IllegalArgumentException 359 */ 360 private static void checkPoints(InterpolatedPoint p1, InterpolatedPoint p2) throws IllegalArgumentException { 361 if (!p1.getCoordPoint().isCompatibleWith(p2.getCoordPoint())) { 362 throw new IllegalArgumentException("Datasets do not match"); 363 } 364 } 365 366 367 368 369 370 371 private static Dataset getTrimmedAxis(Dataset axis, int axisIndex, InterpolatedPoint p1, InterpolatedPoint p2) { 372 double startPoint = p1.getRealPoint().getDouble(axisIndex); 373 double endPoint = p2.getRealPoint().getDouble(axisIndex); 374 375 // swap if needed 376 if (startPoint > endPoint) { 377 startPoint = p2.getRealPoint().getDouble(axisIndex); 378 endPoint = p1.getRealPoint().getDouble(axisIndex); 379 } 380 381 int start = getTrimmedAxisStart(axis, startPoint); 382 int end = getTrimmedAxisEnd(axis, start, endPoint); 383 384 return axis.getSlice(new int[] {start}, new int[] {end}, null); 385 } 386 387 private static int getTrimmedAxisStart(Dataset axis, double startPoint) { 388 for (int i = 0; i < axis.getShapeRef()[0]; i++) { 389 if (axis.getDouble(i) > startPoint) return i; 390 } 391 // if we get to here then the start point is higher than the whole system 392 return -1; 393 } 394 395 private static int getTrimmedAxisEnd(Dataset axis, int startPos, double endPoint) { 396 for (int i = startPos; i < axis.getShapeRef()[0]; i++) { 397 if (axis.getDouble(i) > endPoint) return i-1; 398 } 399 // if we get to here then the end point is higher than the whole system 400 return axis.getShapeRef()[0]; 401 } 402 403 public static Dataset remap1D(Dataset dataset, Dataset axis, Dataset outputAxis) { 404 Dataset data = DatasetFactory.zeros(DoubleDataset.class, outputAxis.getShapeRef()); 405 for(int i = 0; i < outputAxis.getShapeRef()[0]; i++) { 406 double point = outputAxis.getDouble(i); 407 double position = getRealPositionAsIndex(axis, point); 408 if (position >= 0.0) { 409 data.set(Maths.interpolate(dataset, position), i); 410 } else { 411 data.set(Double.NaN,i); 412 } 413 } 414 415 return data; 416 } 417 418 // TODO need to make this work with reverse number lists 419 private static double getRealPositionAsIndex(Dataset dataset, double point) { 420 for (int j = 0; j < dataset.getShapeRef()[0]-1; j++) { 421 double end = dataset.getDouble(j+1); 422 double start = dataset.getDouble(j); 423 //TODO could make this check once outside the loop with a minor assumption. 424 if ( start < end) { 425 if ((end > point) && (start <= point)) { 426 // we have a bounding point 427 double proportion = ((point-start)/(end-start)); 428 return j + proportion; 429 } 430 } else { 431 if ((end < point) && (start >= point)) { 432 // we have a bounding point 433 double proportion = ((point-start)/(end-start)); 434 return j + proportion; 435 } 436 } 437 } 438 return -1.0; 439 } 440 441 public static Dataset remapOneAxis(Dataset dataset, int axisIndex, Dataset corrections, 442 Dataset originalAxisForCorrection, Dataset outputAxis) { 443 int[] stop = dataset.getShape(); 444 int[] start = new int[stop.length]; 445 int[] step = new int[stop.length]; 446 int[] resultSize = new int[stop.length]; 447 for (int i = 0 ; i < start.length; i++) { 448 start[i] = 0; 449 step[i] = 1; 450 resultSize[i] = stop[i]; 451 } 452 453 resultSize[axisIndex] = outputAxis.getShapeRef()[0]; 454 DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, resultSize); 455 456 step[axisIndex] = dataset.getShapeRef()[axisIndex]; 457 IndexIterator iter = dataset.getSliceIterator(start, stop, step); 458 459 int[] pos = iter.getPos(); 460 int[] posEnd = new int[pos.length]; 461 while (iter.hasNext()){ 462 for (int i = 0 ; i < posEnd.length; i++) { 463 posEnd[i] = pos[i]+1; 464 } 465 posEnd[axisIndex] = stop[axisIndex]; 466 // get the dataset 467 Dataset slice = dataset.getSlice(pos, posEnd, null).squeeze(); 468 int[] correctionPos = new int[pos.length-1]; 469 int index = 0; 470 for(int j = 0; j < pos.length; j++) { 471 if (j != axisIndex) { 472 correctionPos[index] = pos[j]; 473 index++; 474 } 475 } 476 Dataset axis = Maths.subtract(originalAxisForCorrection,corrections.getDouble(correctionPos)); 477 Dataset remapped = remap1D(slice,axis,outputAxis); 478 479 int[] ref = ArrayUtils.clone(pos); 480 481 for (int k = 0; k < result.getShapeRef()[axisIndex]; k++) { 482 ref[axisIndex] = k; 483 result.set(remapped.getDouble(k), ref); 484 } 485 } 486 487 return result; 488 } 489 490 491 public static Dataset remapAxis(Dataset dataset, int axisIndex, Dataset originalAxisForCorrection, Dataset outputAxis) { 492 if (!dataset.isCompatibleWith(originalAxisForCorrection)) { 493 throw new IllegalArgumentException("Datasets must be of the same shape"); 494 } 495 496 int[] stop = dataset.getShapeRef(); 497 int[] start = new int[stop.length]; 498 int[] step = new int[stop.length]; 499 int[] resultSize = new int[stop.length]; 500 for (int i = 0 ; i < start.length; i++) { 501 start[i] = 0; 502 step[i] = 1; 503 resultSize[i] = stop[i]; 504 } 505 506 resultSize[axisIndex] = outputAxis.getShapeRef()[0]; 507 DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, resultSize); 508 509 step[axisIndex] = dataset.getShapeRef()[axisIndex]; 510 IndexIterator iter = dataset.getSliceIterator(start, stop, step); 511 512 int[] pos = iter.getPos(); 513 int[] posEnd = new int[pos.length]; 514 while (iter.hasNext()){ 515 for (int i = 0 ; i < posEnd.length; i++) { 516 posEnd[i] = pos[i]+1; 517 } 518 posEnd[axisIndex] = stop[axisIndex]; 519 520 // get the dataset 521 Dataset slice = dataset.getSlice(pos, posEnd, null).squeeze(); 522 Dataset axis = originalAxisForCorrection.getSlice(pos, posEnd, null).squeeze(); 523 524 Dataset remapped = remap1D(slice,axis,outputAxis); 525 526 int[] ref = ArrayUtils.clone(pos); 527 528 for (int k = 0; k < result.shape[axisIndex]; k++) { 529 ref[axisIndex] = k; 530 result.set(remapped.getDouble(k), ref); 531 } 532 } 533 534 return result; 535 } 536 537 public static Dataset regrid(Dataset data, Dataset x, Dataset y, Dataset gridX, Dataset gridY) { 538 539 // apply X then Y regridding 540 Dataset result = remapAxis(data,1,x,gridX); 541 result = remapAxis(result,0,y,gridY); 542 543 return result; 544 } 545}