from kaggle_secrets import UserSecretsClient #kaggle 可忽略
import wandb
#####
user_secrets = UserSecretsClient() #### kaggle
secret_value_0 = user_secrets.get_secret("wandb_key") ### kaggle,此次为wandb_api
wandb.login(key=secret_value_0)
#####初始化
from wandb.keras import WandbCallback, WandbMetricsLogger
run = wandb.init(project = 'open_problems', #项目名称,自动创建
save_code = True,
name='tabtransformer'
)
#### 中间插入代码 ####
tabTransformer = TabTransformer(
categories = nu, # number of unique elements in each categorical feature
num_continuous = 5, # number of numerical features
dim = 16, # embedding/transformer dimension
dim_out = 35, # dimension of the model output
depth = 6, # number of transformer layers in the stack
heads = 8, # number of attention heads
attn_dropout = 0.1, # attention layer dropout in transformers
ff_dropout = 0.1, # feed-forward layer dropout in transformers
mlp_hidden = [(32, 'relu'), (16, 'relu')] # mlp layer dimensions and activations
)
tabTransformer.compile(Adam(0.001),'mae',metrics=['mae'])
tabTransformer.fit(X_train,y_train,validation_data=(X_val,y_val),batch_size=32,epochs=30,callbacks=[WandbMetricsLogger()]) ##
##############
run.finish() #运行结束
参考[Keras]TabTransformer+W&B | Kaggle