001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.ArrayList; 013import java.util.Arrays; 014import java.util.List; 015 016public final class BroadcastUtils { 017 018 /** 019 * Calculate shapes for broadcasting 020 * @param oldShape old shape 021 * @param size dataset size 022 * @param newShape new shape 023 * @return broadcasted shape and full new shape or null if it cannot be done 024 */ 025 public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) { 026 if (newShape == null) { 027 return null; 028 } 029 030 int brank = newShape.length; 031 if (brank == 0) { 032 if (size == 1) { 033 return new int[][] {oldShape, newShape}; 034 } 035 return null; 036 } 037 038 if (Arrays.equals(oldShape, newShape)) { 039 return new int[][] {oldShape, newShape}; 040 } 041 042 if (ShapeUtils.calcSize(oldShape) != size) { 043 throw new IllegalArgumentException("Size must match old shape"); 044 } 045 046 int offset = brank - oldShape.length; 047 if (offset < 0) { // when new shape is incomplete 048 newShape = padShape(newShape, -offset); 049 offset = 0; 050 } 051 052 int[] bshape = padShape(oldShape, offset); // new shape has extra dimensions 053 054 for (int i = 0; i < brank; i++) { 055 if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) { 056 return null; 057 } 058 } 059 060 return new int[][] {bshape, newShape}; 061 } 062 063 /** 064 * Pad shape by prefixing with ones 065 * @param shape to pad 066 * @param padding number of dimensions to add 067 * @return new shape or old shape if padding is zero 068 */ 069 public static int[] padShape(final int[] shape, final int padding) { 070 if (padding < 0) { 071 throw new IllegalArgumentException("Padding must be zero or greater"); 072 } 073 074 if (padding == 0) { 075 return shape; 076 } 077 078 final int[] nshape = new int[shape.length + padding]; 079 Arrays.fill(nshape, 1); 080 System.arraycopy(shape, 0, nshape, padding, shape.length); 081 return nshape; 082 } 083 084 /** 085 * Take in shapes and broadcast them to same rank 086 * @param shapes null shapes are ignored and passed through 087 * @return list of broadcasted shapes plus the first entry is the maximum shape 088 */ 089 public static List<int[]> broadcastShapes(int[]... shapes) { 090 int maxRank = -1; 091 for (int[] s : shapes) { 092 if (s == null) { 093 continue; 094 } 095 096 int r = s.length; 097 if (r > maxRank) { 098 maxRank = r; 099 } 100 } 101 102 List<int[]> newShapes = new ArrayList<int[]>(); 103 if (maxRank < 0) { 104 for (int i = 0; i <= shapes.length; i++) { // note the extra null 105 newShapes.add(null); 106 } 107 return newShapes; 108 } 109 110 for (int[] s : shapes) { 111 newShapes.add(s == null ? null : padShape(s, maxRank - s.length)); 112 } 113 114 int[] maxShape = new int[maxRank]; 115 for (int i = 0; i < maxRank; i++) { 116 int m = -1; 117 for (int[] s : newShapes) { 118 if (s == null) { 119 continue; 120 } 121 int l = s[i]; 122 if (l > m) { 123 if (m > 1) { 124 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 125 } 126 m = l; 127 } 128 } 129 maxShape[i] = m; 130 } 131 132 checkShapes(maxShape, newShapes); 133 newShapes.add(0, maxShape); 134 return newShapes; 135 } 136 137 /** 138 * Take in shapes and broadcast them to maximum shape 139 * @param maxShape maximum shape 140 * @param shapes inputs 141 * @return list of broadcasted shapes 142 */ 143 public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) { 144 int maxRank = maxShape == null ? -1 : maxShape.length; 145 for (int[] s : shapes) { 146 if (s == null) { 147 continue; 148 } 149 150 int r = s.length; 151 if (r > maxRank) { 152 throw new IllegalArgumentException("A shape exceeds given rank of maximum shape"); 153 } 154 } 155 156 List<int[]> newShapes = new ArrayList<int[]>(); 157 for (int[] s : shapes) { 158 newShapes.add(s == null ? null : padShape(s, maxRank - s.length)); 159 } 160 161 if (maxShape != null ) { 162 checkShapes(maxShape, newShapes); 163 } 164 return newShapes; 165 } 166 167 private static void checkShapes(int[] maxShape, List<int[]> newShapes) { 168 for (int i = 0; i < maxShape.length; i++) { 169 int m = maxShape[i]; 170 for (int[] s : newShapes) { 171 if (s == null) { 172 continue; 173 } 174 int l = s[i]; 175 if (l != 1 && l != m) { 176 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 177 } 178 } 179 } 180 } 181 182 static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) { 183 final Class<? extends Dataset> rc; 184 final int ar = a.getRank(); 185 final int br = b.getRank(); 186 Class<? extends Dataset> tc = InterfaceUtils.getBestInterface(a.getClass(), b.getClass()); 187 if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 188 if (ar == 0) { 189 rc = a.hasFloatingPointElements() ? tc : b.getClass(); 190 } else { 191 rc = b.hasFloatingPointElements() ? tc : a.getClass(); 192 } 193 } else { 194 rc = tc; 195 } 196 final int ia = a.getElementsPerItem(); 197 final int ib = b.getElementsPerItem(); 198 199 return DatasetFactory.zeros(ia > ib ? ia : ib, rc, shape); 200 } 201 202 /** 203 * Check if dataset item sizes are compatible 204 * <p> 205 * Dataset a is considered compatible with the output dataset if any of the 206 * conditions are true: 207 * <ul> 208 * <li>o is undefined</li> 209 * <li>a has item size equal to o's</li> 210 * <li>a has item size equal to 1</li> 211 * <li>o has item size equal to 1</li> 212 * </ul> 213 * @param a input dataset a 214 * @param o output dataset (can be null) 215 */ 216 static void checkItemSize(Dataset a, Dataset o) { 217 final int isa = a.getElementsPerItem(); 218 if (o != null) { 219 final int iso = o.getElementsPerItem(); 220 if (isa != iso && isa != 1 && iso != 1) { 221 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 222 } 223 } 224 } 225 226 /** 227 * Check if dataset item sizes are compatible 228 * <p> 229 * Dataset a is considered compatible with the output dataset if any of the 230 * conditions are true: 231 * <ul> 232 * <li>a has item size equal to b's</li> 233 * <li>a has item size equal to 1</li> 234 * <li>b has item size equal to 1</li> 235 * <li>a or b are single-valued</li> 236 * </ul> 237 * and, o is undefined, or any of the following are true: 238 * <ul> 239 * <li>o has item size equal to maximum of a and b's</li> 240 * <li>o has item size equal to 1</li> 241 * <li>a and b have item sizes of 1</li> 242 * </ul> 243 * @param a input dataset a 244 * @param b input dataset b 245 * @param o output dataset 246 */ 247 static void checkItemSize(Dataset a, Dataset b, Dataset o) { 248 final int isa = a.getElementsPerItem(); 249 final int isb = b.getElementsPerItem(); 250 if (isa != isb && isa != 1 && isb != 1) { 251 // exempt single-value dataset case too 252 if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) { 253 throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another"); 254 } 255 } 256 if (o != null && BooleanDataset.class.isAssignableFrom(o.getClass())) { 257 final int ism = Math.max(isa, isb); 258 final int iso = o.getElementsPerItem(); 259 if (iso != ism && iso != 1 && ism != 1) { 260 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 261 } 262 } 263 } 264 265 /** 266 * Create a stride array from a dataset to a broadcast shape 267 * @param a dataset 268 * @param broadcastShape shape to broadcast 269 * @return stride array 270 */ 271 public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) { 272 return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape); 273 } 274 275 /** 276 * Create a stride array from a dataset to a broadcast shape 277 * @param isize item size 278 * @param oShape original shape 279 * @param oStride original stride 280 * @param broadcastShape shape to broadcast 281 * @return stride array 282 */ 283 public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) { 284 if (oShape == null) { 285 if (broadcastShape == null) { 286 return null; 287 } 288 throw new IllegalArgumentException("Broadcast shape must be null if original shape is null"); 289 } 290 int rank = oShape.length; 291 if (broadcastShape.length != rank) { 292 throw new IllegalArgumentException("Dataset must have same rank as broadcast shape"); 293 } 294 295 int[] stride = new int[rank]; 296 if (oStride == null) { 297 int s = isize; 298 for (int j = rank - 1; j >= 0; j--) { 299 if (broadcastShape[j] == oShape[j]) { 300 stride[j] = s; 301 s *= oShape[j]; 302 } else { 303 stride[j] = 0; 304 } 305 } 306 } else { 307 for (int j = 0; j < rank; j++) { 308 if (broadcastShape[j] == oShape[j]) { 309 stride[j] = oStride[j]; 310 } else { 311 stride[j] = 0; 312 } 313 } 314 } 315 316 return stride; 317 } 318 319 /** 320 * Converts and broadcast all objects as datasets of same shape 321 * @param objects to convert and broadcast 322 * @return all as broadcasted to same shape 323 */ 324 public static Dataset[] convertAndBroadcast(Object... objects) { 325 final int n = objects.length; 326 327 Dataset[] datasets = new Dataset[n]; 328 int[][] shapes = new int[n][]; 329 for (int i = 0; i < n; i++) { 330 Dataset d = DatasetFactory.createFromObject(objects[i]); 331 datasets[i] = d; 332 shapes[i] = d.getShapeRef(); 333 } 334 335 List<int[]> nShapes = broadcastShapes(shapes); 336 int[] mshape = nShapes.get(0); 337 for (int i = 0; i < n; i++) { 338 datasets[i] = datasets[i].getBroadcastView(mshape); 339 } 340 341 return datasets; 342 } 343}