很开心找到一个题目和我的问题对应的文档,但是真是读不懂。
求助,救救孩子!
import os
import mxnet as mx
import numpy as np
class Softmax(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
x = in_data[0].asnumpy()
y = np.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))
y /= y.sum(axis=1).reshape((x.shape[0], 1))
self.assign(out_data[0], req[0], mx.nd.array(y))
这里面的0和1都是什么啊?对应关系有地方介绍么?
看起来indata[0]就是输入了?那这个输入是什么维度啊?包不包括batchsize啊?有没有indata[100]啊?
req又是啥?req【0】?为啥是0啊,1又是啥?到几结束啊?
out_data[0],又是0?为什么啊
更可怕的是bp:
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
l = in_data[1].asnumpy().ravel().astype(np.int)
y = out_data[0].asnumpy()
y[np.arange(l.shape[0]), l] -= 1.0
self.assign(in_grad[0], req[0], mx.nd.array(y))
l是loss吗?还是label啊?为啥又是indata[1]了?1又是啥啊
为啥又要把y计算一下放到in_grad[0]啊?不放到[out_grad]么
在infer shape里面,
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape], []
这一堆0,1index都是啥啊
小白刚要入坑,发现坑深,要摔死了啊,求助