utils.py中的accuracy函数如何定义

在多类逻辑回归的gluon版实现中,提到可以使用utils.py中的accuracy函数进行精度计算,但我下载的版本里并没有定义accuracy函数,而且我在gluon版中使用scratch版的accuracy函数后会报错,在我把数据集获取从utils.load_data_fashion_mnist(batch_size)变为scratch版的获取方式后又能运行成功了…

想知道这个utils.py中的accuracy函数是怎么定义的。

scratch中的accuracy函数如下:
def accuracy(output,label):
return nd.mean(output.argmax(axis=1) == label).asscalar()

scratch中的数据获取如下:
mnist_train = gluon.data.vision.FashionMNIST(train=True,transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False,transform=transform)
train_data = gluon.data.DataLoader(mnist_train,batch_size,shuffle=True)
test_data = gluon.data.DataLoader(mnist_test,batch_size,shuffle=False)

没事了,我寻思我之前老纠结accuracy干嘛,明明一个evaluate_accuracy就能解决的事情…