Skip to content

Commit

Permalink
update new sequence_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
piruto committed Jul 26, 2023
1 parent 3c51510 commit b11ea10
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 6 deletions.
6 changes: 3 additions & 3 deletions examples/sequence_recall/run_sequence_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.sequence import (ComirecSA, ComirecDR, MIND, CMI, Re4, STAMP, GRU4Rec, SINE, ContraRec,
NARM, YotubeDNN, SRGNN, GCSAN, SASRec, NISER, NextItNet, CLRec)
NARM, YotubeDNN, SRGNN, GCSAN, SASRec, NISER, NextItNet, CLRec, IOCRec)
from custom_model import CustomModel,CustomMOEModel
from rec_pangu.trainer import SequenceTrainer
from rec_pangu.utils import set_device
Expand All @@ -29,7 +29,7 @@
config = {
'embedding_dim': 64,
'lr': 0.001,
'K': 1,
'K': 4,
'device': -1,
}
config['device'] = set_device(config['device'])
Expand All @@ -54,7 +54,7 @@
train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema,
batch_size=50)
# 声明模型,序列召回模型模型目前支持: ComirecSA,ComirecDR,MIND,CMI,Re4,NARM,YotubeDNN,SRGNN
model = GCSAN(enc_dict=enc_dict, config=config)
model = IOCRec(enc_dict=enc_dict, config=config)
# 声明Trainer
# trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt',wandb_config=wandb_config)
trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt')
Expand Down
76 changes: 76 additions & 0 deletions examples/sequence_recall/run_sequence_example_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- ecoding: utf-8 -*-
# @ModuleName: run_sequence_example
# @Author: wk
# @Email: 306178200@qq.com
# @Time: 2023/3/5 18:00
import sys

sys.path.append('../../')
import torch
from rec_pangu.dataset import get_dataloader, get_sequence_dataloader_v2
from rec_pangu.models.sequence import (ComirecSA, ComirecDR, MIND, CMI, Re4, STAMP, GRU4Rec, SINE, ContraRec,
NARM, YotubeDNN, SRGNN, GCSAN, SASRec, NISER, NextItNet, CLRec, IOCRec)
from custom_model import CustomModel,CustomMOEModel
from rec_pangu.trainer import SequenceTrainer
from rec_pangu.utils import set_device
import pandas as pd

if __name__ == '__main__':
# 声明数据schema
schema = {
'user_col': 'user_id',
'item_col': 'item_id',
'cate_cols': ['genre'],
'max_length': 20,
'time_col': 'timestamp',
'task_type': 'sequence'
}
# 模型配置
config = {
'embedding_dim': 64,
'lr': 0.001,
'K': 4,
'device': -1,
}
config['device'] = set_device(config['device'])
config.update(schema)

# wandb配置
wandb_config = {
'key': 'ca0a80eab60eff065b8c16ab3f41dec4783e60ae',
'project': 'pangu_sequence_example',
'name': 'exp_1',
'config': config
}

# 样例数据
train_df = pd.read_csv('./sample_data/sample_train.csv')
valid_df = pd.read_csv('./sample_data/sample_valid.csv')
test_df = pd.read_csv('./sample_data/sample_test.csv')

df = pd.concat([train_df,valid_df,test_df],axis=0).reset_index(drop=True)


# 声明使用的device
device = torch.device('cpu')
# 获取dataloader
# train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema,
# batch_size=50)

train_loader, valid_loader, test_loader, enc_dict = get_sequence_dataloader_v2(df, schema,batch_size=50)

# 声明模型,序列召回模型模型目前支持: ComirecSA,ComirecDR,MIND,CMI,Re4,NARM,YotubeDNN,SRGNN
model = IOCRec(enc_dict=enc_dict, config=config)
# 声明Trainer
# trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt',wandb_config=wandb_config)
trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt')
# 训练模型
# trainer.fit(model, train_loader, valid_loader, epoch=500, lr=1e-3, device=device, log_rounds=10,
# use_earlystoping=True, max_patience=5, monitor_metric='recall@20', )
trainer.fit(model, train_loader, valid_loader, epoch=500, lr=1e-3, device=device, log_rounds=10,
use_earlystoping=True, max_patience=5, monitor_metric='recall@20',
lr_scheduler_type='CosineAnnealingLR', scheduler_params={"T_max": 7, "eta_min": 0})
# 保存模型权重和enc_dict
trainer.save_all(model, enc_dict, './model_ckpt')
# 模型验证
test_metric = trainer.evaluate_model(model, test_loader, device=device)
4 changes: 2 additions & 2 deletions rec_pangu/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# @Time: 2022/1/20 8:21 下午

from .base_dataset import BaseDataset
from .process_data import get_dataloader, get_single_dataloader
from .process_data import get_dataloader, get_single_dataloader, get_sequence_dataloader_v2
from .multi_task_dataset import MultiTaskDataset
from .graph_dataset import GeneralGraphDataset
from .sequence_dataset import SequenceDataset,seq_collate
from .sequence_dataset import SequenceDataset, seq_collate
17 changes: 16 additions & 1 deletion rec_pangu/dataset/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# @Time: 2022/6/10 7:40 PM
from .base_dataset import BaseDataset
from .multi_task_dataset import MultiTaskDataset
from .sequence_dataset import SequenceDataset
from .sequence_dataset import SequenceDataset, SequenceDatasetV2
import torch.utils.data as D


Expand Down Expand Up @@ -50,6 +50,21 @@ def get_sequence_dataloader(train_df, valid_df, test_df, schema, batch_size=512

return train_loader, valid_loader, test_loader, enc_dict

def get_sequence_dataloader_v2(df, schema, batch_size=512 * 3):
train_dataset = SequenceDatasetV2(schema, df=df, phase='train')
enc_dict = train_dataset.get_enc_dict()
valid_dataset = SequenceDatasetV2(schema, df=df, enc_dict=enc_dict, phase='valid')
test_dataset = SequenceDatasetV2(schema, df=df, enc_dict=enc_dict, phase='test')

train_loader = D.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
num_workers=0, pin_memory=True, drop_last=True)
valid_loader = D.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
num_workers=0, pin_memory=True, drop_last=True)
test_loader = D.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
num_workers=0, pin_memory=True, drop_last=True)

return train_loader, valid_loader, test_loader, enc_dict


def get_dataloader(train_df, valid_df, test_df, schema, batch_size=512 * 3):
if schema['task_type'] == 'ranking':
Expand Down
72 changes: 72 additions & 0 deletions rec_pangu/dataset/sequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,78 @@ def get_test_gd(self):
return self.test_gd


class SequenceDatasetV2(SequenceDataset):
def __init__(self, config, df, enc_dict=None, phase='train'):
super().__init__(config, df, enc_dict, phase)

def get_test_gd(self):
self.test_gd = dict()
for user in self.user2item:
item_list = self.user2item[user]
if self.phase=='valid':
test_item_index = len(item_list)-2
else:
test_item_index = len(item_list) - 1
self.test_gd[str(user)] = [item_list[test_item_index]]
return self.test_gd

def __getitem__(self, index):
user_id = self.user_list[index]
item_list = self.user2item[user_id]
hist_item_list = []
hist_mask_list = []
if self.phase == 'train':

k = random.choice(range(3, len(item_list))) # 从[4,len(item_list))中随机选择一个index
item_id = item_list[k] # 该index对应的item加入item_id_list

if k >= self.max_length: # 选取seq_len个物品
hist_item_list.append(item_list[k - self.max_length: k])
hist_mask_list.append([1.0] * self.max_length)
for col in self.cate_cols:
cate_seq = getattr(self, f'user2{col}')[user_id]
setattr(self, f'hist_{col}_list', cate_seq[k - self.max_length: k])
else:
hist_item_list.append(item_list[:k] + [0] * (self.max_length - k))
hist_mask_list.append([1.0] * k + [0.0] * (self.max_length - k))
for col in self.cate_cols:
cate_seq = getattr(self, f'user2{col}')[user_id]
setattr(self, f'hist_{col}_list', cate_seq[:k] + [0] * (self.max_length - k))
data = {
'hist_item_list': torch.Tensor(hist_item_list).squeeze(0).long(),
'hist_mask_list': torch.Tensor(hist_mask_list).squeeze(0).long(),
'target_item': torch.Tensor([item_id]).long()
}

for col in self.cate_cols:
data.update({f'hist_{col}_list': torch.Tensor(getattr(self, f'hist_{col}_list')).squeeze(0).long()})
else:
if self.phase == 'valid':
k = len(item_list) - 2
else:
k = len(item_list) - 1
if k >= self.max_length: # 选取seq_len个物品
hist_item_list.append(item_list[k - self.max_length: k])
hist_mask_list.append([1.0] * self.max_length)
for col in self.cate_cols:
cate_seq = getattr(self, f'user2{col}')[user_id]
setattr(self, f'hist_{col}_list', cate_seq[k - self.max_length: k])
else:
hist_item_list.append(item_list[:k] + [0] * (self.max_length - k))
hist_mask_list.append([1.0] * k + [0.0] * (self.max_length - k))
for col in self.cate_cols:
cate_seq = getattr(self, f'user2{col}')[user_id]
setattr(self, f'hist_{col}_list', cate_seq[:k] + [0] * (self.max_length - k))
data = {
'user': str(user_id),
'hist_item_list': torch.Tensor(hist_item_list).squeeze(0).long(),
'hist_mask_list': torch.Tensor(hist_mask_list).squeeze(0).long(),
}
for col in self.cate_cols:
data.update({f'hist_{col}_list': torch.Tensor(getattr(self, f'hist_{col}_list')).squeeze(0).long()})
return data


def seq_collate(batch):
hist_item = torch.rand(len(batch), batch[0][0].shape[0])
hist_mask = torch.rand(len(batch), batch[0][0].shape[0])
Expand Down
1 change: 1 addition & 0 deletions rec_pangu/models/sequence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .sine import SINE
from .contrarec import ContraRec
from .clrec import CLRec
from .iocrec import IOCRec
Loading

0 comments on commit b11ea10

Please sign in to comment.