package org.simantics.pythonlink; import java.util.Arrays; public class NDArray { int[] dims; double[] value; public NDArray(double[] value) { this.value = value; dims = new int[] { value.length }; } public NDArray(int m, int n, double[] value) { if (n*m != value.length) throw new IllegalArgumentException("Invalid dimensions for data vector"); this.value = value; dims = new int[] { m, n }; } public NDArray(int[] dims, double[] value) { int l = dims.length > 0 ? 1 : 0; for (int d : dims) l *= d; if (l != value.length) throw new IllegalArgumentException("Invalid dimensions for data vector"); this.dims = dims; this.value = value; } public int size() { return value.length; } public int[] dims() { return dims; } public double[] getValues() { return value; } public double getValue(int index) { return value[index]; } public double getValue(int i, int j) { if (dims.length != 2) throw new IllegalArgumentException("Invalid indices for array of dimension " + dims.length); return value[dims[1] * i + j]; } public double getValue(int... is) { if (dims.length != is.length) throw new IllegalArgumentException("Invalid indices for array of dimension " + dims.length); int index = 0; for (int k = 0; k < dims.length; k++) { index = dims[k] * index + is[k]; } return value[index]; } @Override public boolean equals( Object o ) { return o instanceof NDArray && Arrays.equals(dims, ((NDArray)o).dims) && Arrays.equals(value, ((NDArray)o).value); } @Override public int hashCode() { return Arrays.hashCode(dims) + 11 * Arrays.hashCode(value); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("ndarray("); for (int i = 0; i < dims.length; i++) { if (i > 0) sb.append('x'); sb.append(dims[i]); } sb.append(") "); if (dims.length > 0) buildString(sb, 0, 0); else sb.append("[]"); return sb.toString(); } private void buildString( StringBuilder sb, int d, int i ) { if (d == dims.length) { sb.append(value[i]); } else { i *= dims[d]; sb.append('['); for (int j = 0; j < dims[d]; j++) { if (j > 0) sb.append(", "); buildString(sb, d+1, i + j); } sb.append(']'); } } }