سلام
برای برای ویژوالایز کردن قبلا اینجا سوال داشتیم که جواب دادم . ولی به هر حال برای ویژوالایز کردن وزنها میتونید از توابع زیر استفاده کنید بصورت زیر.
import matplotlib.pyplot as plt
%matplotlib inline
def vis_square(data):
"""Take an array of shape (n, height, width) or (n, height, width, 3)
and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
# normalize data for display
data = (data - data.min()) / (data.max() - data.min())
# force the number of filters to be square
n = int(np.ceil(np.sqrt(data.shape[0])))
padding = (((0, n ** 2 - data.shape[0]),
(0, 1), (0, 1)) # add some space between filters
+ ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if there is one)
data = np.pad(data, padding, mode='constant', constant_values=1) # pad with ones (white)
# tile the filters into an image
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
plt.show()
plt.imshow(data);
plt.axis('off')
return plt
def show_weights(net, layer_name, msg=''):
print(msg)
filters = net.params[layer_name][0].data
plt = vis_square(filters.transpose(0, 2, 3, 1))
#save_plot(plt, "{0}_{1}_weights".format(msg, layer_name))
plt.show()
def show_all_weights(net, layer_name, msg=''):
print(msg)
filters = net.params[layer_name][0].data
print(filters.shape)
for i in range(filters.shape[1]//3):
j = i
vis_square(filters[:,i:j+3,:,:].reshape(-1,3,3,3).transpose(0, 2, 3, 1))
print(i,j+3)
plt.show()
و برای نمایش هم کافیه یکی از این توابع اخری رو اجرا کنید و بهش پارامترهای مورد نیازش رو بدید.
اینا مربوط به بخشی از کارهای قبلی من بودن که مجبور شدم بعضی بخشا رو پاک/کامنت کنم دلیل بخشای کامنت شده اینه.
برای نمایش فیچرمپها هم باز تقریبا به همین شکله
def show_filters(net, layer_name, msg='',start=0, end=36):
print(msg)
#plt.rcParams["figure.figsize"] = (10,10)
feat = net.blobs[layer_name].data[start, :end]
plt = vis_square(feat)
#plt.savefig("{0}_{1}_featuremaps.jpeg".format(msg, layer_name), dpi = 1200)
#save_plot(plt, "{0}_{1}_featuremaps".format(msg, layer_name))
plt.show()
رلو الزاما وزنها رو صفر نمیکنه و اگه دقیق بشید میبینید در عمل شاید خیلی کم این قضیه رودر شبکه ببینید. اگه منظور شما بحث دد رلو هست که اون باعث هرز رفتن ظرفیت پردازشی شبکه میشه و برای رفعش استفاده از انواع دیگه خانواده رلو مثل PReLU و یا Leaky ReLU و .... و یا استفاده از دراپ اوت بین همه لایه ها میتونید استفاده کنید.
برای تکه اخر هم باز همینجا سرچ کنید من قبلا اسکریپتی که برای نمایش training/va accuracy نوشتم رو اینجا گذاشتم. یه سرچ کنید پیدا میکنید و از همون میتونید استفاده کنید.