DL4J基本操作
文章目录
- DL4J基本操作
- 1. 创建矩阵
- 2. 矩阵元素读取
- 3. 矩阵行元素读取
- 4. 矩阵运算
导入依赖
<nd4j.version>1.0.0-beta2</nd4j.version><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifactId><version>${nd4j.version}</version></dependency>
1. 创建矩阵
/*构造一个3行5列的全0 ndarray*/System.out.println("构造一个3行5列的全0 ndarray");INDArray zeros = Nd4j.zeros(3, 5);System.out.println(zeros);
/*构造一个3行5列的全1 ndarray*/System.out.println("构造一个3行5列的全1 ndarray");INDArray ones = Nd4j.ones(3, 5);System.out.println(ones);
/*构造一个3行5列,数组元素均为随机产生的ndarray*/System.out.println("构造一个3行5列,数组元素均为随机产生的ndarray");INDArray rands = Nd4j.rand(3, 5);System.out.println(rands);
/*构造一个3行5列,数组元素服从高斯分布(平均值为0,标准差为1)的ndarray*/System.out.println("构造一个3行5列,数组元素服从高斯分布(平均值为0,标准差为1)的ndarray");INDArray randns = Nd4j.randn(3, 5);System.out.println(randns);
/*给一个一维数据,根据shape创造ndarray*/System.out.println("给一个一维数据,根据shape创造ndarray");System.out.println("创建一个值全是2,一行四列的的ndarray");INDArray array1 = Nd4j.create(new float[]{2, 2, 2, 2}, new int[]{1, 4});System.out.println(array1);System.out.println("创建一个值全是2,2行2列的的ndarray");INDArray array2 = Nd4j.create(new float[]{2, 2, 2, 2}, new int[]{2, 2});System.out.println(array2);
2. 矩阵元素读取
System.out.println("把一维数组转换成2行6列的 ndarray");INDArray nd = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[]{2, 6});System.out.println("打印原有数组");System.out.println(nd);/*获取指定索引的值*/System.out.println("获取数组下标为0, 3的值");double value = nd.getDouble(0, 3);System.out.println(value);/*修改指定索引的值*/System.out.println("修改数组下标为0, 3的值");//scalar 标量nd.putScalar(0, 3, 100);System.out.println(nd);/*使用索引迭代器遍历ndarray,使用c order*/System.out.println("使用索引迭代器遍历ndarray");NdIndexIterator iter = new NdIndexIterator(2, 6);while (iter.hasNext()) {long[] nextIndex = iter.next();double nextVal = nd.getDouble(nextIndex);System.out.println(nextVal);}
3. 矩阵行元素读取
INDArray nd = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[]{2, 6});System.out.println("原始数组");System.out.println(nd);/*获取一行*/System.out.println("获取数组中的一行:第0行");INDArray singleRow = nd.getRow(0);System.out.println(singleRow);/*获取多行*/System.out.println("获取数组中的多行:0行和1行");INDArray multiRows = nd.getRows(0, 1);System.out.println(multiRows);/*替换其中的一行*/System.out.println("替换原有数组中的一行:把第0行替换掉");INDArray replaceRow = Nd4j.create(new float[]{1, 3, 5, 7, 9, 11});nd.putRow(0, replaceRow);System.out.println(nd);
4. 矩阵运算
// 1x2的行向量INDArray nd = Nd4j.create(new float[]{1,2},new int[]{1, 2});// 2x1的列向量INDArray nd2 = Nd4j.create(new float[]{3,4},new int[]{2, 1}); //vector as column// 创造两个2x2的矩阵INDArray nd3 = Nd4j.create(new float[]{1,3,2,4},new int[]{2,2}); //elements arranged column majorINDArray nd4 = Nd4j.create(new float[]{3,4,5,6},new int[]{2, 2});//打印System.out.println("打印1x2的行向量");System.out.println(nd);System.out.println("打印2x1的列向量");System.out.println(nd2);System.out.println("打印创造2x2的矩阵");System.out.println(nd3);System.out.println("打印创造2x2的矩阵");System.out.println(nd4);System.out.println("---------------");//1x2 and 2x1 -> 1x1System.out.println("打印1x2 and 2x1 -> 1x1");INDArray ndv = nd.mmul(nd2);System.out.println(ndv + ", shape = " + Arrays.toString(ndv.shape()));//1x2 and 2x2 -> 1x2System.out.println("打印1x2 and 2x2 -> 1x2");ndv = nd.mmul(nd4);System.out.println(ndv + ", shape = " + Arrays.toString(ndv.shape()));//2x2 and 2x2 -> 2x2System.out.println("打印2x2 and 2x2 -> 2x2");ndv = nd3.mmul(nd4);System.out.println(ndv + ", shape = " + Arrays.toString(ndv.shape()));