-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
piruto
committed
Jul 26, 2023
1 parent
3c51510
commit b11ea10
Showing
7 changed files
with
652 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# -*- ecoding: utf-8 -*- | ||
# @ModuleName: run_sequence_example | ||
# @Author: wk | ||
# @Email: 306178200@qq.com | ||
# @Time: 2023/3/5 18:00 | ||
import sys | ||
|
||
sys.path.append('../../') | ||
import torch | ||
from rec_pangu.dataset import get_dataloader, get_sequence_dataloader_v2 | ||
from rec_pangu.models.sequence import (ComirecSA, ComirecDR, MIND, CMI, Re4, STAMP, GRU4Rec, SINE, ContraRec, | ||
NARM, YotubeDNN, SRGNN, GCSAN, SASRec, NISER, NextItNet, CLRec, IOCRec) | ||
from custom_model import CustomModel,CustomMOEModel | ||
from rec_pangu.trainer import SequenceTrainer | ||
from rec_pangu.utils import set_device | ||
import pandas as pd | ||
|
||
if __name__ == '__main__': | ||
# 声明数据schema | ||
schema = { | ||
'user_col': 'user_id', | ||
'item_col': 'item_id', | ||
'cate_cols': ['genre'], | ||
'max_length': 20, | ||
'time_col': 'timestamp', | ||
'task_type': 'sequence' | ||
} | ||
# 模型配置 | ||
config = { | ||
'embedding_dim': 64, | ||
'lr': 0.001, | ||
'K': 4, | ||
'device': -1, | ||
} | ||
config['device'] = set_device(config['device']) | ||
config.update(schema) | ||
|
||
# wandb配置 | ||
wandb_config = { | ||
'key': 'ca0a80eab60eff065b8c16ab3f41dec4783e60ae', | ||
'project': 'pangu_sequence_example', | ||
'name': 'exp_1', | ||
'config': config | ||
} | ||
|
||
# 样例数据 | ||
train_df = pd.read_csv('./sample_data/sample_train.csv') | ||
valid_df = pd.read_csv('./sample_data/sample_valid.csv') | ||
test_df = pd.read_csv('./sample_data/sample_test.csv') | ||
|
||
df = pd.concat([train_df,valid_df,test_df],axis=0).reset_index(drop=True) | ||
|
||
|
||
# 声明使用的device | ||
device = torch.device('cpu') | ||
# 获取dataloader | ||
# train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema, | ||
# batch_size=50) | ||
|
||
train_loader, valid_loader, test_loader, enc_dict = get_sequence_dataloader_v2(df, schema,batch_size=50) | ||
|
||
# 声明模型,序列召回模型模型目前支持: ComirecSA,ComirecDR,MIND,CMI,Re4,NARM,YotubeDNN,SRGNN | ||
model = IOCRec(enc_dict=enc_dict, config=config) | ||
# 声明Trainer | ||
# trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt',wandb_config=wandb_config) | ||
trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt') | ||
# 训练模型 | ||
# trainer.fit(model, train_loader, valid_loader, epoch=500, lr=1e-3, device=device, log_rounds=10, | ||
# use_earlystoping=True, max_patience=5, monitor_metric='recall@20', ) | ||
trainer.fit(model, train_loader, valid_loader, epoch=500, lr=1e-3, device=device, log_rounds=10, | ||
use_earlystoping=True, max_patience=5, monitor_metric='recall@20', | ||
lr_scheduler_type='CosineAnnealingLR', scheduler_params={"T_max": 7, "eta_min": 0}) | ||
# 保存模型权重和enc_dict | ||
trainer.save_all(model, enc_dict, './model_ckpt') | ||
# 模型验证 | ||
test_metric = trainer.evaluate_model(model, test_loader, device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ | |
from .sine import SINE | ||
from .contrarec import ContraRec | ||
from .clrec import CLRec | ||
from .iocrec import IOCRec |
Oops, something went wrong.