symbolic 接口里面损失函数在哪里指定的?

以最简单的官方的mnist example为例:

这里模型最后只输出了softmax概率:

def get_symbol(num_classes=10, **kwargs):
    data = mx.symbol.Variable('data')
    data = mx.sym.Flatten(data=data)
    fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
    act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
    fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
    act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
    fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
    mlp  = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
    return mlp

然后就直接掉fit来训练了

model.fit(train,
              begin_epoch=args.load_epoch if args.load_epoch else 0,
              num_epoch=args.num_epochs,
              eval_data=val,
              eval_metric=eval_metrics,
              kvstore=kv,
              optimizer=args.optimizer,
              optimizer_params=optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              batch_end_callback=batch_end_callbacks,
              epoch_end_callback=checkpoint,
              allow_missing=True,
              monitor=monitor)

我在源代码里找遍了都没找到哪里算了-负log损失了。

如果用symbolic接口的话,用fit函数训练,模型损失函数到底应该在哪里指定呢?

我也有这个疑惑,难道只有我们两个吗?:joy:
(补充,看了下文档,SoftmaxOutput已经默认了一个交叉熵损失)

1、softmaxout里面已经计算了loss,你可以看看官方API说明 mxnet.symbol.SoftmaxOutput,里面明确说明计算过了交叉商
2、输出loss需要和model绑定,比如 model = mx.mod.Module(symbol=mpl, context=…)
3、如果不用softmaxout定义Loss,而是自定义Loss,需要采用MakeLoss接口API,示例代码如下:

import logging
logging.getLogger().setLevel(logging.DEBUG) 
import mxnet as mx
import numpy as np
mnist = mx.test_utils.get_mnist()

batch_size = 100
weighted_train_labels =      
np.zeros((mnist['train_label'].shape[0],np.max(mnist['train_label'])+ 1))
weighted_train_labels[np.arange(mnist['train_label'].shape[0]),mnist['train_label']] = 1
train_iter = mx.io.NDArrayIter(mnist['train_data'], {'label':weighted_train_labels}, batch_size, shuffle=True)

weighted_test_labels = np.zeros((mnist['test_label'].shape[0],np.max(mnist['test_label'])+ 1))
weighted_test_labels[np.arange(mnist['test_label'].shape[0]),mnist['test_label']] = 1
val_iter = mx.io.NDArrayIter(mnist['test_data'], {'label':weighted_test_labels}, batch_size)

data = mx.sym.var('data')
# first conv layer
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten = mx.sym.flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
#lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')

label = mx.sym.var('label')
softmax = mx.sym.log_softmax(data=fc2)
softmax_output = mx.sym.BlockGrad(data = softmax,name = 'softmax')
ce = -mx.sym.sum(mx.sym.sum(mx.sym.broadcast_mul(softmax,label),1))
lenet = mx.symbol.MakeLoss(ce, normalization='batch')

sym = mx.sym.Group([softmax_output,lenet])
print sym.list_outputs 

def custom_metric(label,softmax):
    return len(np.where(np.argmax(softmax,1)==np.argmax(label,1))[0])/float(label.shape[0])

eval_metrics = mx.metric.CustomMetric(custom_metric,name='custom-accuracy', output_names=['softmax_output'],label_names=['label'])

lenet_model = mx.mod.Module(symbol=sym, context=mx.gpu(),data_names=['data'], label_names=['label'])
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='sgd',
                optimizer_params={'learning_rate':0.1},
                eval_metric=eval_metrics,#mx.metric.Loss(),#'acc',
                #batch_end_callback = mx.callback.Speedometer(batch_size, 100),
                num_epoch=10)