基于分子图的 BERT 模型,原文:MG-BERT: leveraging unsupervised atomic representation learning for molecular property prediction,原文解析:MG-BERT | 利用 无监督 原子表示学习 预测分子性质 | 在分子图上应用BERT | GNN | 无监督学习(掩蔽原子预训练) | attention,代码:Molecular-graph-BERT,其中缺少的数据以logD7.4例,与上一篇文章处理类似,可以删除 Index 列。代码解析从 pretrain 开始,模型整体框架如下:
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
keras.backend.clear_session()
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
optimizer = tf.keras.optimizers.Adam(1e-4)
small = {'name': 'Small', 'num_layers': 3, 'num_heads': 4, 'd_model': 128, 'path': 'small_weights','addH':True}
medium = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256, 'path': 'medium_weights','addH':True}
medium3 = {'name': 'Medium', 'num_layers': 6, 'num_heads': 4, 'd_model': 256, 'path': 'medium_weights3','addH':True}
large = {'name': 'Large', 'num_layers': 12, 'num_heads': 12, 'd_model': 576, 'path': 'large_weights','addH':True}
medium_balanced = {'name':'Medium','num_layers': 6, 'num_heads': 8, 'd_model': 256,'path':'weights_balanced','addH':True}
medium_without_H = {'name':'Medium','num_layers': 6, 'num_heads': 8, 'd_model': 256,'path':'weights_without_H','addH':False}
arch = medium3 ## small 3 4 128 medium: 6 6 256 large: 12 8 516
num_layers = arch['num_layers']
num_heads = arch['num_heads']
d_model = arch['d_model']
addH = arch['addH']
dff = d_model*2
vocab_size =17
dropout_rate = 0.1
model = BertModel(num_layers=num_layers,d_model=d_model,dff=dff,num_heads=num_heads,vocab_size=vocab_size)
train_dataset, test_dataset = Graph_Bert_Dataset(path='data/chem.txt',smiles_field='CAN_SMILES',addH=addH).get_data()
"""
{'O': 5000757, 'C': 34130255, 'N': 5244317, 'F': 641901, 'H': 37237224, 'S': 648962,
'Cl': 373453, 'P': 26195, 'Br': 76939, 'B': 2895, 'I': 9203, 'Si': 1990, 'Se': 1860,
'Te': 104, 'As': 202, 'Al': 21, 'Zn': 6, 'Ca': 1, 'Ag': 3}
H C N O F S Cl P Br B I Si Se
"""
str2num = {'' :0 ,'H': 1, 'C': 2, 'N': 3, 'O': 4, 'F': 5, 'S': 6, 'Cl': 7, 'P': 8, 'Br': 9,
'B': 10,'I': 11,'Si':12,'Se':13,'' :14,'' :15,'' :16}
num2str = {i:j for j,i in str2num.items()}
class Graph_Bert_Dataset(object):
def __init__(self,path,smiles_field='Smiles',addH=True):
if path.endswith('.txt') or path.endswith('.tsv'):
self.df = pd.read_csv(path,sep='\t')
else:
self.df = pd.read_csv(path)
self.smiles_field = smiles_field
self.vocab = str2num
self.devocab = num2str
self.addH = addH
def get_data(self):
data = self.df
train_idx = []
idx = data.sample(frac=0.9).index
train_idx.extend(idx)
data1 = data[data.index.isin(train_idx)]
data2 = data[~data.index.isin(train_idx)]
self.dataset1 = tf.data.Dataset.from_tensor_slices(data1[self.smiles_field].tolist())
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(256, padded_shapes=(
tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([None]) ,tf.TensorShape([None]))).prefetch(50)
self.dataset2 = tf.data.Dataset.from_tensor_slices(data2[self.smiles_field].tolist())
self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(512, padded_shapes=(
tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None]),
tf.TensorShape([None]))).prefetch(50)
return self.dataset1, self.dataset2
def tf_numerical_smiles(self, data):
# x,adjoin_matrix,y,weight = tf.py_function(self.balanced_numerical_smiles,
# [data], [tf.int64, tf.float32 ,tf.int64,tf.float32])
x, adjoin_matrix, y, weight = tf.py_function(self.numerical_smiles, [data],
[tf.int64, tf.float32, tf.int64, tf.float32])
x.set_shape([None])
adjoin_matrix.set_shape([None,None])
y.set_shape([None])
weight.set_shape([None])
return x, adjoin_matrix, y, weight
tf.py_function 调用 numerical_smiles,将 smiles 解析为四种数据,set_shape 补全 shape 信息
def numerical_smiles(self, smiles):
smiles = smiles.numpy().decode()
atoms_list, adjoin_matrix = smiles2adjoin(smiles,explicit_hydrogens=self.addH)
atoms_list = ['' ] + atoms_list
nums_list = [str2num.get(i,str2num['' ]) for i in atoms_list]
temp = np.ones((len(nums_list),len(nums_list)))
temp[1:,1:] = adjoin_matrix
adjoin_matrix = (1 - temp) * (-1e9)
choices = np.random.permutation(len(nums_list)-1)[:max(int(len(nums_list)*0.15),1)] + 1
y = np.array(nums_list).astype('int64')
weight = np.zeros(len(nums_list))
for i in choices:
rand = np.random.rand()
weight[i] = 1
if rand < 0.8:
nums_list[i] = str2num['' ]
elif rand < 0.9:
nums_list[i] = int(np.random.rand() * 14 + 1)
x = np.array(nums_list).astype('int64')
weight = weight.astype('float32')
return x, adjoin_matrix, y, weight
import numpy as np
from utils import smiles2adjoin
import tensorflow as tf
str2num = {'' :0 ,'H': 1, 'C': 2, 'N': 3, 'O': 4, 'F': 5, 'S': 6, 'Cl': 7, 'P': 8, 'Br': 9,
'B': 10,'I': 11,'Si':12,'Se':13,'' :14,'' :15,'' :16}
num2str = {i:j for j,i in str2num.items()}
def numerical_smiles(smiles):
addH=True
#smiles = smiles.numpy().decode()
atoms_list, adjoin_matrix = smiles2adjoin(smiles,explicit_hydrogens=addH)
atoms_list = ['' ] + atoms_list
nums_list = [str2num.get(i,str2num['' ]) for i in atoms_list]
temp = np.ones((len(nums_list),len(nums_list)))
temp[1:,1:] = adjoin_matrix
adjoin_matrix = (1 - temp) * (-1e9)
choices = np.random.permutation(len(nums_list)-1)[:max(int(len(nums_list)*0.15),1)] + 1
y = np.array(nums_list).astype('int64')
weight = np.zeros(len(nums_list))
for i in choices:
rand = np.random.rand()
weight[i] = 1
if rand < 0.8:
nums_list[i] = str2num['' ]
elif rand < 0.9:
nums_list[i] = int(np.random.rand() * 14 + 1)
x = np.array(nums_list).astype('int64')
weight = weight.astype('float32')
return x, adjoin_matrix, y, weight
smiles='CC(C)OC(=O)C(C)NP(=O)(OCC1C(C(C(O1)N2C=CC(=O)NC2=O)(C)F)O)OC3=CC=CC=C3'
x, adjoin_matrix, y, weight=numerical_smiles(smiles)
x, adjoin_matrix, y, weight
"""
(array([16, 2, 2, 2, 4, 2, 4, 2, 2, 3, 8, 4, 4, 2, 2, 2, 2,
2, 4, 3, 2, 2, 2, 4, 3, 2, 4, 2, 15, 4, 4, 2, 2, 2,
15, 15, 2, 1, 15, 1, 1, 1, 1, 15, 1, 1, 1, 1, 1, 15, 1,
1, 1, 1, 1, 1, 1, 1, 1, 15, 1, 1, 1, 1, 1, 1]),
array([[-0.e+00, -0.e+00, -0.e+00, ..., -0.e+00, -0.e+00, -0.e+00],
[-0.e+00, -0.e+00, -0.e+00, ..., -1.e+09, -1.e+09, -1.e+09],
[-0.e+00, -0.e+00, -0.e+00, ..., -1.e+09, -1.e+09, -1.e+09],
...,
[-0.e+00, -1.e+09, -1.e+09, ..., -0.e+00, -1.e+09, -1.e+09],
[-0.e+00, -1.e+09, -1.e+09, ..., -1.e+09, -0.e+00, -1.e+09],
[-0.e+00, -1.e+09, -1.e+09, ..., -1.e+09, -1.e+09, -0.e+00]]),
array([16, 2, 2, 2, 4, 2, 4, 2, 2, 3, 8, 4, 4, 2, 2, 2, 2,
2, 4, 3, 2, 2, 2, 4, 3, 2, 4, 2, 5, 4, 4, 2, 2, 2,
2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
dtype=float32))
"""
def smiles2adjoin(smiles,explicit_hydrogens=True,canonical_atom_order=False):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
print('error')
#mol = Chem.MolFromSmiles(obsmitosmile(smiles))
#assert mol is not None, smiles + ' is not valid '
if explicit_hydrogens:
mol = Chem.AddHs(mol)
else:
mol = Chem.RemoveHs(mol)
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)
num_atoms = mol.GetNumAtoms()
atoms_list = []
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atoms_list.append(atom.GetSymbol())
adjoin_matrix = np.eye(num_atoms)
# Add edges
num_bonds = mol.GetNumBonds()
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
adjoin_matrix[u,v] = 1.0
adjoin_matrix[v,u] = 1.0
return atoms_list,adjoin_matrix
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
mol=Chem.MolFromSmiles('OC1C2C1CC2')
num_atoms = mol.GetNumAtoms()
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
print(atom.GetSymbol(),end='') #OCCCCC
"""
class Graph_Bert_Dataset(object):
def __init__(self,path,smiles_field='Smiles',addH=True):
if path.endswith('.txt') or path.endswith('.tsv'):
self.df = pd.read_csv(path,sep='\t') 改为sep=','
else:
self.df = pd.read_csv(path)
self.smiles_field = smiles_field
self.vocab = str2num
self.devocab = num2str
self.addH = addH
"""
from dataset import Graph_Bert_Dataset
addH=True
train_dataset, test_dataset = Graph_Bert_Dataset(path='data/logD.txt',smiles_field='SMILES',addH=addH).get_data()
for (i,(x, adjoin_matrix ,y , char_weight)) in enumerate(train_dataset):
print("x=\n",x)
print("adjoin_matrix=\n",adjoin_matrix)
print("y=\n",y)
print("char_weight=\n",char_weight)
if i==2:break
"""
x=
tf.Tensor(
[[16 5 2 ... 0 0 0]
[16 6 4 ... 0 0 0]
[16 15 2 ... 0 0 0]
...
[16 4 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
[16 15 2 ... 0 0 0]], shape=(256, 115), dtype=int64)
adjoin_matrix=
tf.Tensor(
[[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
...
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]], shape=(256, 115, 115), dtype=float32)
y=
tf.Tensor(
[[16 5 2 ... 0 0 0]
[16 6 4 ... 0 0 0]
[16 4 2 ... 0 0 0]
...
[16 4 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
[16 4 2 ... 0 0 0]], shape=(256, 115), dtype=int64)
char_weight=
tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]], shape=(256, 115), dtype=float32)
x=
tf.Tensor(
[[16 7 2 ... 0 0 0]
[16 7 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
...
[16 4 2 ... 0 0 0]
[16 15 15 ... 0 0 0]
[16 6 2 ... 0 0 0]], shape=(256, 132), dtype=int64)
adjoin_matrix=
tf.Tensor(
[[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
...
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]], shape=(256, 132, 132), dtype=float32)
y=
tf.Tensor(
[[16 7 2 ... 0 0 0]
[16 7 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
...
[16 4 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
[16 6 2 ... 0 0 0]], shape=(256, 132), dtype=int64)
char_weight=
tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(256, 132), dtype=float32)
x=
tf.Tensor(
[[16 4 2 ... 0 0 0]
[16 4 15 ... 0 0 0]
[16 7 2 ... 0 0 0]
...
[16 15 2 ... 0 0 0]
[16 4 8 ... 0 0 0]
[16 3 2 ... 0 0 0]], shape=(256, 130), dtype=int64)
adjoin_matrix=
tf.Tensor(
[[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
...
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]
[[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
[-0. -0. -0. ... 0. 0. 0.]
...
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]
[ 0. 0. 0. ... 0. 0. 0.]]], shape=(256, 130, 130), dtype=float32)
y=
tf.Tensor(
[[16 4 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
[16 7 2 ... 0 0 0]
...
[16 4 2 ... 0 0 0]
[16 4 2 ... 0 0 0]
[16 3 2 ... 0 0 0]], shape=(256, 130), dtype=int64)
char_weight=
tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(256, 130), dtype=float32)
"""
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[-0.e+00, -0.e+00, -0.e+00, -0.e+00, -0.e+00],
[-0.e+00, -0.e+00, -0.e+00, -1.e+09, -1.e+09],
[-0.e+00, -0.e+00, -0.e+00, -0.e+00, -1.e+09],
[-0.e+00, -1.e+09, -0.e+00, -0.e+00, -0.e+00],
[-0.e+00, -1.e+09, -1.e+09, -0.e+00, -0.e+00]], dtype=float32)>
class BertModel(tf.keras.Model):
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size = 17,dropout_rate = 0.1):
super(BertModel, self).__init__()
self.encoder = Encoder(num_layers=num_layers,d_model=d_model,num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
self.fc1 = tf.keras.layers.Dense(d_model, activation=gelu)
self.layernorm = tf.keras.layers.LayerNormalization(-1)
self.fc2 = tf.keras.layers.Dense(vocab_size)
def call(self,x,adjoin_matrix,mask,training=False):
x = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
x = self.fc1(x)
x = self.layernorm(x)
x = self.fc2(x)
return x
class Encoder(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
# self.pos_encoding = positional_encoding(maximum_position_encoding,
# self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
seq_len = tf.shape(x)[1]
adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
return x # (batch_size, input_seq_len, d_model)
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
attn_output, attention_weights = self.mha(x, x, x, mask,adjoin_matrix) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
return out2,attention_weights
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation=gelu), # (batch_size, seq_len, dff)tf.keras.layers.LeakyReLU(0.01)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask,adjoin_matrix):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask,adjoin_matrix)
scaled_attention = tf.transpose(scaled_attention,
perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
def scaled_dot_product_attention(q, k, v, mask,adjoin_matrix):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
Args:
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output, attention_weights
"""
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
if adjoin_matrix is not None:
scaled_attention_logits += adjoin_matrix
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
import tensorflow as tf
x=tf.convert_to_tensor([ #batch_size=2,seq_len=3
[1,0,3],
[0,5,6]
])
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
mask
"""
"""
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None,None), dtype=tf.float32),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.float32),
]
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for epoch in range(10):
start = time.time()
train_loss.reset_states()
for (batch, (x, adjoin_matrix ,y , char_weight)) in enumerate(train_dataset):
train_step(x, adjoin_matrix, y , char_weight)
if batch % 500 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(
epoch + 1, batch, train_loss.result()))
print('Accuracy: {:.4f}'.format(train_accuracy.result()))
#
# for x, adjoin_matrix ,y , char_weight in test_dataset:
# test_step(x, adjoin_matrix, y , char_weight)
# print('Test Accuracy: {:.4f}'.format(test_accuracy.result()))
# test_accuracy.reset_states()
train_accuracy.reset_states()
print(arch['path'] + '/bert_weights{}_{}.h5'.format(arch['name'], epoch+1))
print('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))
print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
print('Accuracy: {:.4f}'.format(train_accuracy.result()))
model.save_weights(arch['path']+'/bert_weights{}_{}.h5'.format(arch['name'],epoch+1))
print('Saving checkpoint')
def train_step(x, adjoin_matrix,y, char_weight):
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
with tf.GradientTape() as tape:
predictions = model(x,adjoin_matrix=adjoin_matrix,mask=mask,training=True)
loss = loss_function(y,predictions,sample_weight=char_weight)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_accuracy.update_state(y,predictions,sample_weight=char_weight)