001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016
017import org.eclipse.january.DatasetException;
018import org.eclipse.january.IMonitor;
019import org.eclipse.january.metadata.StatisticsMetadata;
020import org.eclipse.january.metadata.internal.StatisticsMetadataImpl;
021
022/**
023 * Generic container class for data that is compound in nature
024 * 
025 * Each subclass has an array of compound types, items of this array are composed of primitive types
026 * 
027 * Data items can be Complex, Vector, etc
028 * 
029 */
030public abstract class AbstractCompoundDataset extends AbstractDataset implements CompoundDataset {
031        // pin UID to base class
032        private static final long serialVersionUID = Dataset.serialVersionUID;
033
034        protected int isize; // number of elements per item
035
036        @Override
037        public int getElementsPerItem() {
038                return isize;
039        }
040
041        @Override
042        protected int get1DIndex(final int i) {
043                int n = super.get1DIndex(i);
044                return stride == null ? isize * n : n;
045        }
046
047        @Override
048        protected int get1DIndex(final int i, final int j) {
049                int n = super.get1DIndex(i, j);
050                return stride == null ? isize * n : n;
051        }
052
053        @Override
054        protected int get1DIndexFromShape(final int[] n) {
055                return isize * super.get1DIndexFromShape(n);
056        }
057
058        @Override
059        public Dataset getUniqueItems() {
060                throw new UnsupportedOperationException("Cannot sort compound datasets");
061        }
062
063        @Override
064        public IndexIterator getIterator(final boolean withPosition) {
065                if (stride != null) {
066                        return base.getSize() == 1 ? 
067                                        (withPosition ? new PositionIterator(offset, shape) :  new SingleItemIterator(offset, size)) : new StrideIterator(isize, shape, stride, offset);
068                }
069                return withPosition ? getSliceIterator(null, null, null) :
070                        new ContiguousIterator(size, isize);
071        }
072
073        /**
074         * Get an iterator that picks out the chosen element from all items
075         * @param element
076         * @return an iterator
077         */
078        public IndexIterator getIterator(int element) {
079                if (element < 0)
080                        element += isize;
081                if (element < 0 || element > isize) {
082                        logger.error("Invalid choice of element: {}/{}", element, isize);
083                        throw new IllegalArgumentException("Invalid choice of element: " + element + "/" + isize);
084                }
085
086                final IndexIterator it;
087                if (stride != null) {
088                        it = base.getSize() == 1 ? new SingleItemIterator(offset + element, size) : new StrideIterator(isize, shape, stride, offset, element);
089                } else {
090                        it = new ContiguousIterator(size, isize, element);
091                }
092
093                return it;
094        }
095
096        @Override
097        public IndexIterator getSliceIterator(SliceND slice) {
098                if (ShapeUtils.calcLongSize(slice.getShape()) == 0) {
099                        return new NullIterator(shape, slice.getShape());
100                }
101                if (stride != null) {
102                        return new StrideIterator(isize, shape, stride, offset, slice);
103                }
104
105                return new SliceIterator(shape, size, isize, slice);
106        }
107
108        /**
109         * Constructor required for serialisation.
110         */
111        public AbstractCompoundDataset() {
112        }
113
114        @Override
115        public boolean equals(Object obj) {
116                if (!super.equals(obj)) {
117                        return false;
118                }
119
120                CompoundDataset other = (CompoundDataset) obj;
121                return isize == other.getElementsPerItem();
122        }
123
124        @Override
125        public int hashCode() {
126                return getCompoundStats().getHash(shape);
127        }
128
129        @Override
130        public CompoundDataset cast(boolean repeat, int dtype, int isize) {
131                return (CompoundDataset) super.cast(repeat, dtype, isize);
132        }
133
134        @Override
135        public CompoundDataset cast(int dtype) {
136                return (CompoundDataset) super.cast(dtype);
137        }
138
139        @Override
140        abstract public AbstractCompoundDataset clone();
141
142        @Override
143        public CompoundDataset flatten() {
144                return (CompoundDataset) super.flatten();
145        }
146
147        @Override
148        public CompoundDataset getBy1DIndex(IntegerDataset index) {
149                return (CompoundDataset) super.getBy1DIndex(index);
150        }
151
152        @Override
153        public CompoundDataset getByBoolean(Dataset selection) {
154                return (CompoundDataset) super.getByBoolean(selection);
155        }
156
157        @Override
158        public CompoundDataset getByIndexes(Object... indexes) {
159                return (CompoundDataset) super.getByIndexes(indexes);
160        }
161
162        @Override
163        public CompoundDataset getSlice(IMonitor mon, int[] start, int[] stop, int[] step) {
164                return (CompoundDataset) super.getSlice(mon, start, stop, step);
165        }
166
167        @Override
168        public CompoundDataset getSlice(IMonitor mon, Slice... slice) {
169                return (CompoundDataset) super.getSlice(mon, slice);
170        }
171
172        @Override
173        public CompoundDataset getSlice(IMonitor mon, SliceND slice) {
174                return (CompoundDataset) super.getSlice(mon, slice);
175        }
176
177        @Override
178        public CompoundDataset getSlice(int[] start, int[] stop, int[] step) {
179                return (CompoundDataset) super.getSlice(start, stop, step);
180        }
181
182        @Override
183        public CompoundDataset getSlice(Slice... slice) {
184                return (CompoundDataset) super.getSlice(slice);
185        }
186
187        @Override
188        public CompoundDataset getSlice(SliceND slice) {
189                return (CompoundDataset) super.getSlice(slice);
190        }
191
192        @Override
193        abstract public AbstractCompoundDataset getSlice(SliceIterator iterator);
194
195        @Override
196        public CompoundDataset getSliceView(int[] start, int[] stop, int[] step) {
197                return (CompoundDataset) super.getSliceView(start, stop, step);
198        }
199
200        @Override
201        public CompoundDataset getSliceView(Slice... slice) {
202                return (CompoundDataset) super.getSliceView(slice);
203        }
204
205        @Override
206        public CompoundDataset getSliceView(SliceND slice) {
207                return (CompoundDataset) super.getSliceView(slice);
208        }
209
210        @Override
211        public CompoundDataset getTransposedView(int... axes) {
212                return (CompoundDataset) super.getTransposedView(axes);
213        }
214
215        @Override
216        abstract public AbstractCompoundDataset getView(boolean deepCopyMetadata);
217
218        @Override
219        public CompoundDataset getBroadcastView(int... broadcastShape) {
220                return (CompoundDataset) super.getBroadcastView(broadcastShape);
221        }
222
223        @Override
224        public CompoundDataset ifloorDivide(Object o) {
225                return (CompoundDataset) super.ifloorDivide(o);
226        }
227
228        @Override
229        public CompoundDataset reshape(int... shape) {
230                return (CompoundDataset) super.reshape(shape);
231        }
232
233        @Override
234        public CompoundDataset setSlice(Object obj, int[] start, int[] stop, int[] step) {
235                return (CompoundDataset) super.setSlice(obj, start, stop, step);
236        }
237
238        @Override
239        public CompoundDataset setSlice(Object object, Slice... slice) {
240                return (CompoundDataset) super.setSlice(object, slice);
241        }
242
243        @Override
244        public CompoundDataset sort(Integer axis) {
245                throw new UnsupportedOperationException("Cannot sort dataset");
246        }
247
248        @Override
249        public CompoundDataset squeezeEnds() {
250                return (CompoundDataset) super.squeezeEnds();
251        }
252
253        @Override
254        public CompoundDataset squeeze() {
255                return (CompoundDataset) super.squeeze();
256        }
257
258        @Override
259        public CompoundDataset squeeze(boolean onlyFromEnd) {
260                return (CompoundDataset) super.squeeze(onlyFromEnd);
261        }
262
263        @Override
264        public CompoundDataset swapAxes(int axis1, int axis2) {
265                return (CompoundDataset) super.swapAxes(axis1, axis2);
266        }
267
268        @Override
269        public synchronized CompoundDataset synchronizedCopy() {
270                return clone();
271        }
272
273        @Override
274        public CompoundDataset transpose(int... axes) {
275                return (CompoundDataset) super.transpose(axes);
276        }
277
278        /**
279         * @since 2.0
280         * @return first value
281         */
282        abstract protected double getFirstValue();
283
284        abstract protected double getFirstValue(final int i);
285
286        abstract protected double getFirstValue(final int i, final int j);
287
288        abstract protected double getFirstValue(final int...pos);
289
290        @Override
291        public boolean getBoolean() {
292                return getFirstValue() != 0;
293        }
294
295        @Override
296        public boolean getBoolean(final int i) {
297                return getFirstValue(i) != 0;
298        }
299
300        @Override
301        public boolean getBoolean(final int i, final int j) {
302                return getFirstValue(i, j) != 0;
303        }
304
305        @Override
306        public boolean getBoolean(final int... pos) {
307                return getFirstValue(pos) != 0;
308        }
309
310        @Override
311        public byte getByte() {
312                return (byte) getFirstValue();
313        }
314
315        @Override
316        public byte getByte(final int i) {
317                return (byte) getFirstValue(i);
318        }
319
320        @Override
321        public byte getByte(final int i, final int j) {
322                return (byte) getFirstValue(i, j);
323        }
324
325        @Override
326        public byte getByte(final int... pos) {
327                return (byte) getFirstValue(pos);
328        }
329
330        @Override
331        public short getShort() {
332                return (short) getFirstValue();
333        }
334
335        @Override
336        public short getShort(final int i) {
337                return (short) getFirstValue(i);
338        }
339
340        @Override
341        public short getShort(final int i, final int j) {
342                return (short) getFirstValue(i, j);
343        }
344
345        @Override
346        public short getShort(final int... pos) {
347                return (short) getFirstValue(pos);
348        }
349
350        @Override
351        public int getInt() {
352                return (int) getFirstValue();
353        }
354
355        @Override
356        public int getInt(final int i) {
357                return (int) getFirstValue(i);
358        }
359
360        @Override
361        public int getInt(final int i, final int j) {
362                return (int) getFirstValue(i, j);
363        }
364
365        @Override
366        public int getInt(final int... pos) {
367                return (int) getFirstValue(pos);
368        }
369
370        @Override
371        public long getLong() {
372                return (long) getFirstValue();
373        }
374
375        @Override
376        public long getLong(final int i) {
377                return (long) getFirstValue(i);
378        }
379
380        @Override
381        public long getLong(final int i, final int j) {
382                return (long) getFirstValue(i, j);
383        }
384
385        @Override
386        public long getLong(final int... pos) {
387                return (long) getFirstValue(pos);
388        }
389
390        @Override
391        public float getFloat() {
392                return (float) getFirstValue();
393        }
394
395        @Override
396        public float getFloat(final int i) {
397                return (float) getFirstValue(i);
398        }
399
400        @Override
401        public float getFloat(final int i, final int j) {
402                return (float) getFirstValue(i, j);
403        }
404
405        @Override
406        public float getFloat(final int... pos) {
407                return (float) getFirstValue(pos);
408        }
409
410        @Override
411        public double getDouble() {
412                return getFirstValue();
413        }
414
415        @Override
416        public double getDouble(final int i) {
417                return getFirstValue(i);
418        }
419
420        @Override
421        public double getDouble(final int i, final int j) {
422                return getFirstValue(i, j);
423        }
424
425        @Override
426        public double getDouble(final int... pos) {
427                return getFirstValue(pos);
428        }
429
430        @Override
431        public void getDoubleArray(final double[] darray) {
432                getDoubleArrayAbs(getFirst1DIndex(), darray);
433        }
434
435        @Override
436        public void getDoubleArray(final double[] darray, final int i) {
437                getDoubleArrayAbs(get1DIndex(i), darray);
438        }
439
440        @Override
441        public void getDoubleArray(final double[] darray, final int i, final int j) {
442                getDoubleArrayAbs(get1DIndex(i, j), darray);
443        }
444
445        @Override
446        public void getDoubleArray(final double[] darray, final int... pos) {
447                getDoubleArrayAbs(get1DIndex(pos), darray);
448        }
449
450        /**
451         * @return statistics metadata
452         * @since 2.0
453         */
454        @SuppressWarnings("unchecked")
455        protected StatisticsMetadata<double[]> getCompoundStats() {
456                StatisticsMetadata<double[]> md = getFirstMetadata(StatisticsMetadata.class);
457                if (md == null || md.isDirty(this)) {
458                        md = new StatisticsMetadataImpl<double[]>();
459                        md.initialize(this);
460                        setMetadata(md);
461                }
462                return md;
463        }
464
465        @Override
466        public IntegerDataset argMax(int axis, boolean... ignoreInvalids) {
467                logger.error("Cannot compare compound numbers");
468                throw new UnsupportedOperationException("Cannot compare compound numbers");
469        }
470
471        @Override
472        public IntegerDataset argMin(int axis, boolean... ignoreInvalids) {
473                logger.error("Cannot compare compound numbers");
474                throw new UnsupportedOperationException("Cannot compare compound numbers");
475        }
476
477        @Override
478        public Number max(boolean... ignoreInvalids) {
479                logger.error("Cannot compare compound numbers");
480                throw new UnsupportedOperationException("Cannot compare compound numbers");
481        }
482
483        @Override
484        public CompoundDataset max(int axis, boolean... ignoreInvalids) {
485                logger.error("Cannot compare compound numbers");
486                throw new UnsupportedOperationException("Cannot compare compound numbers");
487        }
488
489        @Override
490        public Number min(boolean... ignoreInvalids) {
491                logger.error("Cannot compare compound numbers");
492                throw new UnsupportedOperationException("Cannot compare compound numbers");
493        }
494
495        @Override
496        public CompoundDataset min(int axis, boolean... ignoreInvalids) {
497                logger.error("Cannot compare compound numbers");
498                throw new UnsupportedOperationException("Cannot compare compound numbers");
499        }
500
501
502        @Override
503        public int[] maxPos(boolean... ignoreNaNs) {
504                logger.error("Cannot compare compound numbers");
505                throw new UnsupportedOperationException("Cannot compare compound numbers");
506        }
507
508        @Override
509        public int[] minPos(boolean... ignoreNaNs) {
510                logger.error("Cannot compare compound numbers");
511                throw new UnsupportedOperationException("Cannot compare compound numbers");
512        }
513
514        @Override
515        public CompoundDataset peakToPeak(int axis, boolean... ignoreInvalids) {
516                logger.error("Cannot compare compound numbers");
517                throw new UnsupportedOperationException("Cannot compare compound numbers");
518        }
519
520        @Override
521        public double[] maxItem() {
522                return getCompoundStats().getMaximum();
523        }
524
525        @Override
526        public double[] minItem() {
527                return getCompoundStats().getMinimum();
528        }
529
530        @Override
531        public Object mean(boolean... ignoreInvalids) {
532                return getCompoundStats().getMean();
533        }
534
535        @Override
536        public CompoundDataset mean(int axis, boolean... ignoreInvalids) {
537                return (CompoundDataset) super.mean(axis, ignoreInvalids);
538        }
539
540        @Override
541        public CompoundDataset product(int axis, boolean... ignoreInvalids) {
542                return (CompoundDataset) super.product(axis, ignoreInvalids);
543        }
544
545        @Override
546        public CompoundDataset rootMeanSquare(int axis, boolean... ignoreInvalids) {
547                return (CompoundDataset) super.rootMeanSquare(axis, ignoreInvalids);
548        }
549
550        @Override
551        public CompoundDataset stdDeviation(int axis) {
552                return (CompoundDataset) super.stdDeviation(axis, false);
553        }
554
555        @Override
556        public CompoundDataset stdDeviation(int axis, boolean isWholePopulation, boolean... ignoreInvalids) {
557                return (CompoundDataset) super.stdDeviation(axis, isWholePopulation, ignoreInvalids);
558        }
559
560        @Override
561        public Object sum(boolean... ignoreInvalids) {
562                return getCompoundStats().getSum();
563        }
564
565        @Override
566        public CompoundDataset sum(int axis, boolean... ignoreInvalids) {
567                return (CompoundDataset) super.sum(axis, ignoreInvalids);
568        }
569
570        @Override
571        public double variance(boolean isWholePopulation, boolean... ignoreInvalids) {
572                return getCompoundStats().getVariance(isWholePopulation, ignoreInvalids);
573        }
574
575        @Override
576        public CompoundDataset variance(int axis) {
577                return (CompoundDataset) super.variance(axis, false);
578        }
579
580        @Override
581        public CompoundDataset variance(int axis, boolean isWholePopulation, boolean... ignoreInvalids) {
582                return (CompoundDataset) super.variance(axis, isWholePopulation, ignoreInvalids);
583        }
584
585        @Override
586        public double rootMeanSquare(boolean... ignoreInvalids) {
587                StatisticsMetadata<double[]> stats = getCompoundStats();
588
589                double[] mean = stats.getMean(ignoreInvalids);
590                double result = 0;
591                for (int i = 0; i < isize; i++) {
592                        double m = mean[i];
593                        result += m * m;
594                }
595                return Math.sqrt(result + stats.getVariance(true));
596        }
597
598        /**
599         * @return error
600         */
601        private CompoundDataset getInternalError() {
602                ILazyDataset led = super.getErrors();
603                if (led == null)
604                        return null;
605
606                Dataset ed = null;
607                try {
608                        ed = DatasetUtils.sliceAndConvertLazyDataset(led);
609                } catch (DatasetException e) {
610                        logger.error("Could not get data from lazy dataset", e);
611                }
612
613                CompoundDataset ced; // ensure it has the same number of elements
614                if (!(ed instanceof CompoundDataset) || ed.getElementsPerItem() != isize) {
615                        ced = new CompoundDoubleDataset(isize, true, ed);
616                } else {
617                        ced = (CompoundDataset) ed;
618                }
619                
620                if (led != ced) {
621                        setErrors(ced); // set back
622                }
623                return ced;
624        }
625
626        @Override
627        public CompoundDataset getErrors() {
628                CompoundDataset ed = getInternalError();
629                if (ed == null)
630                        return null;
631
632                return ed.getBroadcastView(shape);
633        }
634
635        @Override
636        public double getError(final int i) {
637                return calcError(getInternalErrorArray(true, i));
638        }
639
640        @Override
641        public double getError(final int i, final int j) {
642                return calcError(getInternalErrorArray(true, i, j));
643        }
644
645        @Override
646        public double getError(final int... pos) {
647                return calcError(getInternalErrorArray(true, pos));
648        }
649
650        private double calcError(double[] es) {
651                if (es == null)
652                        return 0;
653
654                // assume elements are independent
655                double e = 0;
656                for (int k = 0; k < isize; k++) {
657                        e += es[k];
658                }
659
660                return Math.sqrt(e);
661        }
662
663        @Override
664        public double[] getErrorArray(final int i) {
665                return getInternalErrorArray(false, i);
666        }
667
668        @Override
669        public double[] getErrorArray(final int i, final int j) {
670                return getInternalErrorArray(false, i, j);
671        }
672
673        @Override
674        public double[] getErrorArray(final int... pos) {
675                return getInternalErrorArray(false, pos);
676        }
677
678        private Dataset getInternalError(final boolean squared) {
679                Dataset sed = squared ? getInternalSquaredError() : getInternalError();
680                if (sed == null)
681                        return null;
682
683                return sed.getBroadcastView(shape);
684        }
685
686        private double[] getInternalErrorArray(final boolean squared, final int i) {
687                Dataset sed = getInternalError(squared);
688                if (sed == null)
689                        return null;
690
691                double[] es;
692                if (sed instanceof CompoundDoubleDataset) {
693                        es = ((CompoundDoubleDataset) sed).getDoubleArray(i);
694                        if (sed.getElementsPerItem() != isize) { // ensure error is broadcasted
695                                Arrays.fill(es, es[0]);
696                        }
697                } else {
698                        es = new double[isize];
699                        Arrays.fill(es, ((DoubleDataset) sed).getDouble(i));
700                }
701                return es;
702        }
703
704        private double[] getInternalErrorArray(final boolean squared, final int i, final int j) {
705                Dataset sed = getInternalError(squared);
706                if (sed == null)
707                        return null;
708
709                double[] es;
710                if (sed instanceof CompoundDoubleDataset) {
711                        es = ((CompoundDoubleDataset) sed).getDoubleArray(i, j);
712                        if (sed.getElementsPerItem() != isize) { // ensure error is broadcasted
713                                Arrays.fill(es, es[0]);
714                        }
715                } else {
716                        es = new double[isize];
717                        Arrays.fill(es, ((DoubleDataset) sed).getDouble(i, j));
718                }
719                return es;
720        }
721
722        private double[] getInternalErrorArray(final boolean squared, final int... pos) {
723                Dataset sed = getInternalError(squared);
724                if (sed == null)
725                        return null;
726
727                double[] es = new double[isize];
728                if (sed instanceof CompoundDoubleDataset) {
729                        es = ((CompoundDoubleDataset) sed).getDoubleArray(pos);
730                        if (sed.getElementsPerItem() != isize) { // ensure error is broadcasted
731                                Arrays.fill(es, es[0]);
732                        }
733                } else {
734                        es = new double[isize];
735                        Arrays.fill(es, ((DoubleDataset) sed).getDouble(pos));
736                }
737                return es;
738        }
739}
740