在上一單元中,您建立了主題向量。在本單元中,您將建置和部署內容推薦單元,該單元將保留主題向量的索引。
首先,建立字典,將經過重新整理的標籤連結至訓練資料中的原始標籤。在筆記本中,複製並貼上以下程式碼,然後選擇執行。
labels = newidx
labeldict = dict(zip(newidx,idx))
接著,使用以下程式碼將訓練資料存放在 S3 儲存貯體中:
import io
import sagemaker.amazon.common as smac
print('train_features shape = ', predictions.shape)
print('train_labels shape = ', labels.shape)
buf = io.BytesIO()
smac.write_numpy_to_dense_tensor(buf, predictions, labels)
buf.seek(0)
bucket = BUCKET
prefix = PREFIX
key = 'knn/train'
fname = os.path.join(prefix, key)
print(fname)
boto3.resource('s3').Bucket(bucket).Object(fname).upload_fileobj(buf)
s3_train_data = 's3://{}/{}/{}'.format(bucket, prefix, key)
print('uploaded training data location: {}'.format(s3_train_data))
接下來,使用以下協助程式函數來建立 k-NN 估算器,就像在單元 3 中建立的 NTM 估算器一樣。
def trained_estimator_from_hyperparams(s3_train_data, hyperparams, output_path, s3_test_data=None):
"""
Create an Estimator from the given hyperparams, fit to training data,
and return a deployed predictor
"""
# set up the estimator
knn = sagemaker.estimator.Estimator(get_image_uri(boto3.Session().region_name, "knn"),
get_execution_role(),
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
output_path=output_path,
sagemaker_session=sagemaker.Session())
knn.set_hyperparameters(**hyperparams)
# train a model. fit_input contains the locations of the train and test data
fit_input = {'train': s3_train_data}
knn.fit(fit_input)
return knn
hyperparams = {
'feature_dim': predictions.shape[1],
'k': NUM_NEIGHBORS,
'sample_size': predictions.shape[0],
'predictor_type': 'classifier' ,
'index_metric':'COSINE'
}
output_path = 's3://' + bucket + '/' + prefix + '/knn/output'
knn_estimator = trained_estimator_from_hyperparams(s3_train_data, hyperparams, output_path)
在執行訓練任務時,請仔細查看協助程式函數中的參數。
Amazon SageMaker k-NN 演算法提供許多不同的距離指標,來計算最近芳鄰。在自然語言處理中使用的一種常用指標是餘弦距離。數學上,兩個向量 A 和 B 之間的餘弦「相似度」由以下方程式得出:
透過將 index_metric 設定為 COSINE,Amazon SageMaker 會自動使用餘弦相似度來計算最近芳鄰。預設距離為 L2 範數,這是標準的歐幾里得距離。請注意,在發佈時,僅 faiss.IVFFlat 索引類型,而非 faiss.IVFPQ 索引方法支援 COSINE。
Completed - Training job completed
成功! 由於您希望該模型傳回指定的特定測試主題的最近鄰居,因此您需要將其部署為即時託管端點。