在前面的模块中,您创建了主题向量。在本模块中,您将构建并部署其中保留了主题向量索引的内容推荐模块。
首先,创建一个字典将随机化标签与训练数据中的原始标签关联。在您的笔记本中,复制并粘贴以下代码,然后选择运行。
labels = newidx
labeldict = dict(zip(newidx,idx))
接下来,使用以下代码将训练数据存储在您的 S3 存储桶中:
import io
import sagemaker.amazon.common as smac
print('train_features shape = ', predictions.shape)
print('train_labels shape = ', labels.shape)
buf = io.BytesIO()
smac.write_numpy_to_dense_tensor(buf, predictions, labels)
buf.seek(0)
bucket = BUCKET
prefix = PREFIX
key = 'knn/train'
fname = os.path.join(prefix, key)
print(fname)
boto3.resource('s3').Bucket(bucket).Object(fname).upload_fileobj(buf)
s3_train_data = 's3://{}/{}/{}'.format(bucket, prefix, key)
print('uploaded training data location: {}'.format(s3_train_data))
接下来,使用下面的帮助程序函数创建 k-NN 估算器,很像您在模块 3 中创建的 NTM 估算器。
def trained_estimator_from_hyperparams(s3_train_data, hyperparams, output_path, s3_test_data=None):
"""
Create an Estimator from the given hyperparams, fit to training data,
and return a deployed predictor
"""
# set up the estimator
knn = sagemaker.estimator.Estimator(get_image_uri(boto3.Session().region_name, "knn"),
get_execution_role(),
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
output_path=output_path,
sagemaker_session=sagemaker.Session())
knn.set_hyperparameters(**hyperparams)
# train a model. fit_input contains the locations of the train and test data
fit_input = {'train': s3_train_data}
knn.fit(fit_input)
return knn
hyperparams = {
'feature_dim': predictions.shape[1],
'k': NUM_NEIGHBORS,
'sample_size': predictions.shape[0],
'predictor_type': 'classifier' ,
'index_metric':'COSINE'
}
output_path = 's3://' + bucket + '/' + prefix + '/knn/output'
knn_estimator = trained_estimator_from_hyperparams(s3_train_data, hyperparams, output_path)
训练作业运行时,仔细看看帮助程序函数中的参数。
Amazon SageMaker k-NN 算法提供很多不同的距离指标来计算最近的邻居。自然语言处理中常用的一个指标为 cosine 距离。在数学上,向量 A 与 B 之间的余弦“相似度”由以下等式给出:
通过将 index_metric 设置为 COSINE,Amazon SageMaker 自动使用余弦相似度计算最近的邻居。默认距离为 L2 norm,它是标准的欧氏距离。请注意,在发布时,COSINE 仅支持 faiss.IVFFlat 索引类型,不支持 faiss.IVFPQ 索引方法。
Completed - Training job completed
成功! 由于您希望此模型在已知特定测试主题的情况下返回最近的邻居,您需要将其部署为实时托管终端节点。