import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as mpimg
from PIL import Image
from sklearn.decomposition import PCA
import matplotlib
import os
import pandas as pd
patch_h = 28
patch_w = 28
feat_dim = 768
transform = T.Compose([
T.GaussianBlur(9, sigma=(0.1, 2.0)),
T.Resize((patch_h * 14, patch_w * 14)),
T.CenterCrop((patch_h * 14, patch_w * 14)),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
dinov2_vitb14 = torch.hub.load('', 'dinov2_vitb14',source='local').cuda()
features = torch.zeros(4, patch_h * patch_w, feat_dim)
imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14).cuda()
img_path = f'/home/wangzhenkuan/val_cropped/cropped_(0, 0, 7, 26)_obj365_val_000000605687.jpg'
img = Image.open(img_path).convert('RGB')
imgs_tensor[0] = transform(img)[:3]
with torch.no_grad():
features_dict = dinov2_vitb14.forward_features(imgs_tensor)
features = features_dict['x_norm_patchtokens']
features = features.reshape(4 * patch_h * patch_w, feat_dim).cpu()
pca = PCA(n_components=3)
pca.fit(features)
pca_features = pca.transform(features)
pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / (pca_features[:, 0].max() - pca_features[:, 0].min())
new_pca_features = pca_features.flatten()
print(new_pca_features, new_pca_features.shape)
输出的结果为:
[ 0.77485222 3.1461922 -2.36750582 ... 0.44878434 9.83799508
23.6097603 ] (9408,)
from PIL import Image
image_path = "/home/wangzhenkuan/val_cropped/cropped_(25, 140, 39, 143)_obj365_val_000000685822.jpg"
img = Image.open(image_path)
width, height = img.size
print(f"图片尺寸:宽度 = {width}px, 高度 = {height}px")
输出的结果为:
图片尺寸:宽度 = 14px,高度 = 3px
features_dict.keys()
输出结果为:
dict_keys(['x_norm_clstoken', 'x_norm_patchtokens', 'x_prenorm', 'masks'])
patch_h = 28
patch_w = 28
feat_dim = 768
transform = T.Compose([
T.GaussianBlur(9, sigma=(0.1, 2.0)),
T.Resize((patch_h * 14, patch_w * 14)),
T.CenterCrop((patch_h * 14, patch_w * 14)),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
dinov2_vitb14 = torch.hub.load('', 'dinov2_vitb14',source='local').cuda()
features = torch.zeros(4, patch_h * patch_w, feat_dim)
imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14).cuda()
img_path = f'/home/wangzhenkuan/val_cropped/cropped_(25, 140, 39, 143)_obj365_val_000000685822.jpg'
img = Image.open(img_path).convert('RGB').resize((28, 28))
imgs_tensor[0] = transform(img)[:3]
with torch.no_grad():
features_dict = dinov2_vitb14.forward_features(imgs_tensor)
features = features_dict['x_norm_clstoken'][0]
features = features.cpu()
new_features = features.tolist()
print(new_features)
输出结果为:
[-1.782302737236023, -2.4636621475219727, 2.9943976402282715, -0.5234289169311523, 1.691330075263977, -1.6833631992340088, -1.330706238746643, 1.0088632106781006, -1.913888931274414, 0.23068946599960327, 1.7929106950759888, -0.27120286226272583, -0.49883294105529785, -1.2453603744506836, 2.3303046226501465, -0.6939842700958252, -2.5716118812561035, -1.8246359825134277, 1.2785545587539673, 2.3685169219970703, -2.226424217224121, -0.6657404899597168, -0.3867187798023224, 0.2776636779308319, -2.4077162742614746, 0.6283755898475647, 0.36297979950904846, 1.4644601345062256, 0.6084825992584229, 0.04457835853099823, -2.8952596187591553, -1.4791420698165894, -1.6125147342681885, 1.0907434225082397, 1.4983779191970825, -1.2263838052749634, -0.21449723839759827, -3.396991729736328, 0.14435461163520813, -1.3464086055755615, -3.393324613571167, 0.29084402322769165, -0.5454743504524231, -2.327784776687622, 0.7393595576286316, -0.021981626749038696, -2.108546257019043, -1.1266191005706787, -0.755315899848938, 1.9216516017913818, -1.1222351789474487, -0.778069019317627, -1.8038049936294556, 3.842763662338257, -2.166956663131714, -0.18735790252685547, 0.5142223238945007, 0.4452049434185028, 0.3436106741428375, 2.238577365875244, -0.49933183193206787, 2.982692003250122, -1.7857329845428467, 0.31839874386787415, 0.493878573179245, -0.43247273564338684, 0.7188404202461243, -0.48016902804374695, -1.4780107736587524, -2.2374534606933594, -1.2843281030654907, 2.1578004360198975, -0.5961135625839233, 5.014259338378906, -1.3459324836730957, -0.07546994090080261, 2.8282198905944824, -1.489823579788208, 0.8014816641807556, 2.3728225231170654, 0.9512873888015747, -1.3640193939208984, -0.08151863515377045, -1.4603017568588257, -0.8666728734970093, -0.6388664841651917, 0.26002177596092224, 0.3265385329723358, 0.05554609000682831, -0.25209227204322815, -2.3126885890960693, -3.8923213481903076, 1.4711894989013672, -0.3243894577026367, 0.04903823509812355, 1.0250012874603271, 0.9641492366790771, 0.6563885807991028, -1.0517845153808594, -0.03144245967268944, 0.8328980803489685, 1.000225305557251, -0.3773617744445801, -0.20199733972549438, 0.6676586270332336, -0.12478910386562347, -1.5276193618774414, -0.3590487837791443, -1.273734211921692, -1.2506481409072876, 1.1392039060592651, -0.734968364238739, 2.1350884437561035, -0.7077086567878723, -1.9917738437652588, -1.3919944763183594, -1.2059870958328247, -1.1604094505310059, 2.529306173324585, 1.526973843574524, -1.525469422340393, 0.03462134301662445, -1.2334036827087402, 1.4593768119812012, -1.5264039039611816, -0.6068321466445923, 2.037444591522217, 2.0417494773864746, 3.0911359786987305, -0.5778148174285889, -0.5444684028625488, 1.1871347427368164, -1.918009638786316, -0.7611539363861084, -1.476572036743164, 1.595043420791626, -2.813671588897705, -1.9813976287841797, -0.9153470396995544, 0.17812523245811462, -0.8469423651695251, 0.0019542728550732136, 1.3013651371002197, 1.0463727712631226, -0.5279057621955872, 0.6109235286712646, 1.013996958732605, -0.8054614663124084, 1.4095799922943115, -0.42996102571487427, -0.18834379315376282, -0.3168187737464905, -0.6558595299720764, -1.1568105220794678, 1.170350193977356, 0.33856046199798584, -0.7012094855308533, -2.375239849090576, 0.27717700600624084, 1.1880180835723877, -0.25741657614707947, -2.8391549587249756, 0.38533759117126465, 0.9957401156425476, 1.236900806427002, -0.21157701313495636, -2.0252935886383057, -1.5473968982696533, -0.821943998336792, 1.1762545108795166, 0.46207451820373535, -1.7723321914672852, -1.6779030561447144, -0.14319078624248505, -1.7182550430297852, 1.468103051185608, -1.592165470123291, 0.2622267007827759, -0.1855563521385193, -1.8285667896270752, 0.524232029914856, 1.2373623847961426, -1.9170277118682861, 0.7803347110748291, 1.1997917890548706, -0.7289928793907166, 0.9738048315048218, -1.1404708623886108, -0.06302018463611603, 0.5733993649482727, -0.8244834542274475, -2.9138309955596924, -2.609675168991089, 0.909052848815918, 1.4253618717193604, 1.0942734479904175, 0.20974589884281158, 0.9608197808265686, -0.13692501187324524, -0.480473130941391, -1.6277155876159668, 3.6779026985168457, 0.5295201539993286, 0.7025696635246277, -0.3715479075908661, 0.7490695714950562, 2.0802066326141357, -0.280375599861145, 1.3140792846679688, -3.235884428024292, -0.8986896872520447, -1.6405057907104492, -0.3694521188735962, 1.251508116722107, -2.051178455352783, -1.1959902048110962, -2.896202325820923, -1.0829169750213623, -2.469109058380127, 0.06871689110994339, 1.1800856590270996, 0.6479581594467163, -0.12025940418243408, -0.8590918183326721, 0.447131484746933, 2.0885043144226074, 0.20625364780426025, 1.277632236480713, 0.06128314509987831, -0.856684148311615, -4.042417049407959, -0.8839811086654663, -2.3334860801696777, -2.824632167816162, 2.012399673461914, -1.6131731271743774, 0.03637152910232544, 1.3474061489105225, 1.0460636615753174, 1.7871677875518799, -1.2456921339035034, -1.2018588781356812, -1.2330501079559326, -0.731390118598938, -2.53403377532959, 0.8381593823432922, 0.8807539939880371, 1.925523281097412, 0.782750129699707, -0.15729139745235443, 1.3447301387786865, -0.1672334372997284, 0.9798004031181335, 1.5715116262435913, 3.090116024017334, 0.621734082698822, 1.733325719833374, 1.5695395469665527, 0.17969544231891632, 1.3452461957931519, -0.002069800393655896, 3.495548725128174, -0.5111413598060608, -0.5300322771072388, 2.1211509704589844, 1.5424525737762451, -0.4201726019382477, -1.6935588121414185, -3.400237798690796, 2.746851682662964, -2.273819923400879, -0.7426249384880066, -0.6773086190223694, 0.858296811580658, -0.8582549095153809, 2.606393337249756, 1.1274425983428955, 1.1355055570602417, -0.9266760945320129, -1.6875674724578857, 0.6409508585929871, -0.5526147484779358, 0.809528112411499, 0.4231944978237152, 2.7432198524475098, 0.02784748561680317, 1.4140912294387817, -0.6977792978286743, 1.1178771257400513, -1.3610061407089233, -1.1296662092208862, -0.21126246452331543, 1.8089278936386108, -2.900038719177246, 3.099142074584961, 0.6932424306869507, 1.3892556428909302, -0.24539850652217865, -0.9480887651443481, -0.736323893070221, -0.5573472380638123, 0.2089473307132721, 0.28838279843330383, -1.9168763160705566, 3.3088252544403076, 1.3098756074905396, 0.9232526421546936, 1.546655297279358, -0.9164364337921143, -2.630692958831787, -3.5181384086608887, -1.8235714435577393, 0.17165851593017578, 2.639941453933716, 1.415391445159912, -0.7797530293464661, 0.13846322894096375, 1.328677773475647, -1.44327974319458, 0.5074266791343689, -0.03981078043580055, 2.3744325637817383, 1.3007115125656128, 0.714430570602417, -1.7004998922348022, -1.0638532638549805, -0.8481042385101318, -1.8069591522216797, 1.261594295501709, -1.10498046875, 2.0613486766815186, 1.4846584796905518, 1.7253369092941284, -0.7635220289230347, 0.22499537467956543, 0.09731481969356537, 2.057244300842285, -0.7979379892349243, -2.6087210178375244, 1.0734786987304688, -0.018653851002454758, 1.5032217502593994, 0.12841078639030457, 0.6642627120018005, -0.24752289056777954, -0.09741762280464172, 0.26423177123069763, -0.13321749866008759, -0.6164339184761047, -1.5291261672973633, -0.5502980947494507, -0.1252671331167221, -0.5465449690818787, 2.9258711338043213, 2.6053402423858643, 0.2354675680398941, 0.00872958917170763, 0.773603081703186, -0.735529363155365, -0.9919795989990234, -1.5949963331222534, 2.6893129348754883, 0.7358701825141907, -1.1472917795181274, 2.359499216079712, -2.2209489345550537, 0.3828812837600708, 0.5850518941879272, 0.8996657729148865, -0.5891192555427551, -1.1196407079696655, -0.5374260544776917, 0.2961849272251129, 2.75605845451355, -0.7352980971336365, -0.5794591903686523, -3.153204917907715, -1.4226027727127075, 0.8484690189361572, -1.476483941078186, 0.8465580940246582, 1.8315489292144775, -1.2356393337249756, -1.098958134651184, 2.0075953006744385, -3.0050487518310547, 0.6299751996994019, -0.36766427755355835, 1.3836698532104492, 0.23282812535762787, 0.7143048644065857, 1.014121413230896, -0.0469597727060318, -2.7185325622558594, 0.29532307386398315, 1.040669322013855, -0.26178792119026184, -1.7334434986114502, 0.7760761976242065, 1.1974889039993286, -0.4802558720111847, -2.4174513816833496, 0.87935870885849, 0.8979024291038513, 0.44834524393081665, -1.0819308757781982, -0.5651684403419495, 1.2345761060714722, 3.2679555416107178, -1.6266891956329346, 1.2048085927963257, 0.08094020187854767, 2.277104139328003, 1.2699872255325317, 0.45772409439086914, 4.867003440856934, -0.5402204990386963, -1.437784194946289, 1.6940743923187256, -0.27067723870277405, 0.28807947039604187, -1.7572873830795288, 3.902473211288452, 1.0111587047576904, -2.020484685897827, -0.377560555934906, -0.14685222506523132, 1.0295541286468506, 2.5080201625823975, 0.7465159893035889, -1.6725640296936035, 2.120771884918213, 1.8752940893173218, 0.13404321670532227, -0.9339718222618103, -1.0764025449752808, -1.120705008506775, -1.3446354866027832, 1.01795494556427, -0.9050564169883728, 0.6137460470199585, 0.260585218667984, 1.133894681930542, -0.05889385566115379, -0.6519827842712402, 1.1598161458969116, -1.7704927921295166, -1.4216445684432983, 1.1419371366500854, -0.19676364958286285, -0.12388337403535843, 1.676942229270935, -0.2588338553905487, 1.1000467538833618, 1.587817907333374, 0.6045276522636414, -1.9270410537719727, 1.7428301572799683, -1.0723416805267334, -0.27404242753982544, 2.9821128845214844, -0.06173550710082054, 0.7838605046272278, -0.20503199100494385, 0.18047350645065308, 1.2533975839614868, 4.150588512420654, -2.07193660736084, 2.7952287197113037, 1.7451260089874268, -1.5707392692565918, -1.210213541984558, 1.4766311645507812, -0.8516893982887268, -3.074359893798828, -0.5563936829566956, 0.592707633972168, -0.9391210675239563, -0.12514294683933258, -1.3752027750015259, 0.8371762633323669, 0.17783881723880768, 1.2013310194015503, -1.8072235584259033, 0.008545350283384323, -0.8450319766998291, -0.1838560849428177, 0.04656567424535751, 1.8328806161880493, -1.337045669555664, 0.10181967914104462, -2.5168120861053467, -0.8175782561302185, 1.0487390756607056, 0.7896212339401245, 0.8936480283737183, -0.295285701751709, -0.11428146064281464, 1.2843492031097412, 0.21505780518054962, 0.9035468697547913, 1.7623189687728882, 2.5622332096099854, 0.7209411263465881, -4.4749250411987305, 0.06234496459364891, 0.5075461864471436, 0.6106500029563904, -0.7091167569160461, 0.3739960193634033, 2.325373649597168, 1.6865546703338623, -1.2292370796203613, 0.002017478458583355, -1.5167118310928345, 2.3719675540924072, -0.12257850915193558, -0.2742123603820801, 1.1188805103302002, 1.1275124549865723, 1.969099760055542, 1.0678647756576538, -0.6574950218200684, -0.04013944789767265, -0.7438542246818542, 0.01077490858733654, -0.36018097400665283, 0.765718936920166, 0.9758625626564026, -1.5729905366897583, 1.5433160066604614, -1.8463612794876099, 1.5079660415649414, -0.6314525008201599, 1.556646704673767, -0.9214202761650085, -0.5937618017196655, 0.5787580609321594, -0.544879674911499, 2.2051501274108887, 2.1140949726104736, 1.3734196424484253, 1.8830134868621826, 1.1051276922225952, 0.010084696114063263, -1.2945022583007812, 0.5964868068695068, -0.574560821056366, -0.7125140428543091, -1.4985989332199097, 1.1931395530700684, -3.4669835567474365, 2.1407690048217773, 0.9654586315155029, -0.9334999322891235, -1.0985212326049805, 1.9350329637527466, -0.0157951470464468, 0.8070175051689148, -1.4314080476760864, 1.643994688987732, 1.4210337400436401, 0.4065215289592743, -1.2252724170684814, 0.6296184062957764, -1.015487551689148, 0.95037841796875, -1.2684987783432007, -3.35871958732605, 0.7230661511421204, -3.321512460708618, -2.866903066635132, -1.6048551797866821, 0.9972029328346252, 1.5359357595443726, 0.7939863204956055, 1.3079074621200562, -0.40280815958976746, -1.0989586114883423, -0.8992764949798584, -1.0560827255249023, -0.22213973104953766, 2.229839563369751, 0.9596742987632751, -0.5350444912910461, 0.4394168257713318, -0.2593521773815155, -1.6894360780715942, -3.510387897491455, 0.9091895222663879, 0.3822038173675537, 0.6567855477333069, -1.2245944738388062, 1.14584219455719, -2.105818271636963, -1.683109998703003, 0.8035480380058289, -1.9503647089004517, 2.902534246444702, -1.0506266355514526, 3.05086088180542, 0.6392022371292114, -2.083533525466919, -2.2651782035827637, 4.038768291473389, -0.6145167946815491, -0.10916915535926819, -1.8031373023986816, 1.2886639833450317, 1.1107773780822754, -1.218652606010437, 2.6412882804870605, -0.37478649616241455, -2.4777863025665283, 0.5200248956680298, 2.220710039138794, -0.20018260180950165, -0.37540286779403687, 1.422378659248352, -1.3118668794631958, -0.28698980808258057, 1.3089606761932373, 0.2861316502094269, -1.9453247785568237, 0.5933266878128052, 0.4733945429325104, 0.2217245101928711, 0.6716511845588684, 0.478162556886673, -2.9247443675994873, 0.6833447217941284, 0.6265335083007812, 0.17462027072906494, 1.3712408542633057, -0.9348064064979553, -1.5202499628067017, 0.6559357643127441, 1.2336370944976807, 1.6443251371383667, -1.6532347202301025, -1.308686375617981, -0.5601414442062378, -1.253258466720581, 0.7729612588882446, 0.9917702674865723, 5.5190300941467285, -2.008385181427002, -1.3119657039642334, 1.905112862586975, -1.4413495063781738, -1.022444725036621, -1.4267915487289429, -3.0826029777526855, -2.507009506225586, -0.3101186752319336, 0.9453191757202148, -1.4448578357696533, -5.463378429412842, 1.685957431793213, 0.9716243147850037, -1.4983900785446167, -10.368620872497559, 0.8836182355880737, -0.901665210723877, -1.001763105392456, -0.7415831089019775, 2.982067584991455, -1.173464298248291, 0.37221759557724, 0.8291199803352356, 0.3778395354747772, -0.28813591599464417, 1.6545953750610352, 1.1575158834457397, 0.46889442205429077, 0.8416760563850403, -1.6701411008834839, -2.469136953353882, -1.473692536354065, -1.8790534734725952, 1.8253601789474487, 1.193148136138916, -0.9570328593254089, 1.5213717222213745, -0.058223407715559006, -1.1784520149230957, 0.7866590619087219, -3.971035957336426, -4.860577583312988, 0.4887443482875824, -0.8394114375114441, -1.2669053077697754, 0.4229445159435272, 0.438689261674881, -0.6596183180809021, 2.205631971359253, -0.4488086700439453, 2.098450183868408, 0.0931340754032135, 1.8314093351364136, -1.1382595300674438, -3.118795394897461, -0.962195098400116, -0.43586546182632446, 0.3552641272544861, 1.1865928173065186, 0.15520243346691132, 2.439142942428589, -6.7766876220703125, 1.0746952295303345, 1.0203089714050293, -1.586142659187317, -1.255768895149231, -3.907606840133667, -1.913203239440918, -0.11117686331272125, -0.8351805210113525, -1.964456558227539, 0.04278327524662018, -0.3471674919128418, -2.366966962814331, 0.18251053988933563, -2.861321210861206, -0.6874016523361206, 1.8516937494277954, 1.390030860900879, -0.21073535084724426, 3.342275619506836, 0.4797917902469635, -2.315653085708618, -2.963287353515625, -1.9324268102645874, -1.4271326065063477, 2.3010289669036865, 0.08775099366903305, 1.0504770278930664, 1.8316768407821655, 1.092706322669983, 0.24372225999832153, 1.3356677293777466, 0.0620085634291172, -4.382974147796631, 2.237565517425537, -2.0954740047454834, -0.36105191707611084, 0.5060109496116638, 0.8488038182258606, -0.9324450492858887, -2.5846095085144043, -0.25201156735420227, -1.9630687236785889, 2.7448606491088867, 0.23988445103168488, -2.1208527088165283, -0.38995805382728577, 1.1672645807266235, 1.0231435298919678, -0.7268855571746826, 1.912980556488037, -0.004989832639694214, -0.19368630647659302, 0.6368448734283447, 0.5241373777389526, 0.2989656627178192, 1.4544707536697388, -0.1921188235282898, 2.232924222946167, -2.444467306137085]
image_names = os.listdir('/home/wangzhenkuan/val_cropped/')
image_names[:5]
输出结果为:
['cropped_(259, 581, 318, 630)_obj365_val_000000440619.jpg',
'cropped_(340, 185, 354, 213)_obj365_val_000000247906.jpg',
'cropped_(560, 192, 567, 200)_obj365_val_000000037054.jpg',
'cropped_(25, 140, 39, 143)_obj365_val_000000685822.jpg',
'cropped_(143, 379, 188, 524)_obj365_val_000000560071.jpg']
for image_name in image_names[:5]:
image_name = os.path.join('/home/wangzhenkuan/val_cropped/', image_name)
img = Image.open(image_name).convert('RGB').resize((28, 28))
imgs_tensor[0] = transform(img)[:3]
with torch.no_grad():
features_dict = dinov2_vitb14.forward_features(imgs_tensor)
features = features_dict['x_norm_clstoken']
features = features.cpu()
print(features, features.shape)
输出结果为:
tensor([[-1.4513, 0.9261, 1.6200, ..., -0.1176, -0.5844, -3.1325],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517]]) torch.Size([4, 768])
tensor([[ 0.9361, -1.5400, 1.4137, ..., 0.8697, -0.9790, -1.2595],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517]]) torch.Size([4, 768])
tensor([[-0.5115, 0.1547, 2.0663, ..., -0.0101, 1.2684, -1.3007],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517]]) torch.Size([4, 768])
tensor([[ 5.1416e-04, -1.4544e+00, 3.4188e+00, ..., 2.0114e-01,
7.3515e-01, -1.5456e+00],
[ 2.6113e+00, -6.3915e+00, 1.7829e+00, ..., -2.1981e+00,
-2.3696e-01, -3.0517e+00],
[ 2.6113e+00, -6.3915e+00, 1.7829e+00, ..., -2.1981e+00,
-2.3696e-01, -3.0517e+00],
[ 2.6113e+00, -6.3915e+00, 1.7829e+00, ..., -2.1981e+00,
-2.3695e-01, -3.0517e+00]]) torch.Size([4, 768])
tensor([[ 0.7783, -0.3085, 0.3504, ..., -1.7902, -1.7831, -0.5644],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517],
[ 2.6113, -6.3915, 1.7829, ..., -2.1981, -0.2370, -3.0517]]) torch.Size([4, 768])