Pruning Filters For Efficient ConvNets 剪枝代码小结
2021-01-17 02:14
标签:+= one 行操作 绝对值 depend uem lag for 绝对值排序 https://github.com/tyui592/Pruning_filters_for_efficient_convnets 其中主要是用VGG来进行在CIFAR100上的剪枝,理解#args是一些参数,比如VGG的权重路径等信息配置 Pruning Filters For Efficient ConvNets 剪枝代码小结 标签:+= one 行操作 绝对值 depend uem lag for 绝对值排序 原文地址:https://www.cnblogs.com/zonechen/p/13373259.htmlThe Code of Pruning Filters For Efficient ConvNets
1. 代码参考
def prune_network(args, network=None):
device = torch.device("cuda" if args.gpu_no >= 0 else "cpu") #配置GPU
if network is None:
network = VGG(args.vgg, args.data_set) #加载VGG模型
if args.load_path:
check_point = torch.load(args.load_path)
network.load_state_dict(check_point[‘state_dict‘])
# prune network
network = prune_step(network, args.prune_layers, args.prune_channels, args.independent_prune_flag)
network = network.to(device)
print("-*-"*10 + "\n\tPrune network\n" + "-*-"*10)
if args.retrain_flag:
# update arguemtns for retraing pruned network
args.epoch = args.retrain_epoch
args.lr = args.retrain_lr
args.lr_milestone = None # don‘t decay learning rate
network = train_network(args, network)
return network
def prune_step(network, prune_layers, prune_channels, independent_prune_flag):
network = network.cpu() #剪枝主要是cpu上进行操作
count = 0 # count for indexing ‘prune_channels‘
conv_count = 1 # conv count for ‘indexing_prune_layers‘
dim = 0 # 0: prune corresponding dim of filter weight [out_ch, in_ch, k1, k2],如果是0.表示输入不变,将这个卷积核的输出给去掉,同时一连串后面的bn,以及后面对应的对应的卷积核也需要剪掉,如果是1,表示要把前面的feature maps给剪掉。
residue = None # residue is need to prune by ‘independent strategy‘ #残差
for i in range(len(network.features)):
if isinstance(network.features[i], torch.nn.Conv2d):
if dim == 1:
#当前是1,表明上一层的filters被剪了,所以这一层要将inchannel的filters按照channel_index同时给剪掉
new_, residue = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag)
network.features[i] = new_
dim ^= 1
#当前是0,表明我们这一层要把输出的filters给剪掉。同时得到channel_index
if ‘conv%d‘%conv_count in prune_layers:
channel_index = get_channel_index(network.features[i].weight.data, prune_channels[count], residue)
new_ = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag)
network.features[i] = new_
dim ^= 1
count += 1
else:
residue = None
conv_count += 1
# bn层也是有通道的,需要将bn层同样做下处理。
elif dim == 1 and isinstance(network.features[i], torch.nn.BatchNorm2d):
new_ = get_new_norm(network.features[i], channel_index)
network.features[i] = new_
# update to check last conv layer pruned
if ‘conv13‘ in prune_layers:
network.classifier[0] = get_new_linear(network.classifier[0], channel_index)
return network
def get_channel_index(kernel, num_elimination, residue=None):
#绝对值排序,按照最小值挑出前num_elimination个的下标
sum_of_kernel = torch.sum(torch.abs(kernel.view(kernel.size(0), -1)), dim=1)
if residue is not None:
sum_of_kernel += torch.sum(torch.abs(residue.view(residue.size(0), -1)), dim=1)
vals, args = torch.sort(sum_of_kernel)
return args[:num_elimination].tolist()
def index_remove(tensor, dim, index, removed=False):
#根据index进行剪枝
if tensor.is_cuda:
tensor = tensor.cpu()
size_ = list(tensor.size())
new_size = tensor.size(dim) - len(index)
size_[dim] = new_size
select_index = list(set(range(tensor.size(dim))) - set(index))
new_tensor = torch.index_select(tensor, dim, torch.tensor(select_index))
if removed:
return new_tensor, torch.index_select(tensor, dim, torch.tensor(index))
return new_tensor
def get_new_conv(conv, dim, channel_index, independent_prune_flag=False):
if dim == 0:
new_conv = torch.nn.Conv2d(in_channels=conv.in_channels,
out_channels=int(conv.out_channels - len(channel_index)),
kernel_size=conv.kernel_size,
stride=conv.stride, padding=conv.padding, dilation=conv.dilation)
new_conv.weight.data = index_remove(conv.weight.data, dim, channel_index)
new_conv.bias.data = index_remove(conv.bias.data, dim, channel_index)
return new_conv
elif dim == 1:
new_conv = torch.nn.Conv2d(in_channels=int(conv.in_channels - len(channel_index)),
out_channels=conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride, padding=conv.padding, dilation=conv.dilation)
new_weight = index_remove(conv.weight.data, dim, channel_index, independent_prune_flag)
residue = None
if independent_prune_flag:
new_weight, residue = new_weight
new_conv.weight.data = new_weight
new_conv.bias.data = conv.bias.data
return new_conv, residue
def get_new_norm(norm, channel_index):
new_norm = torch.nn.BatchNorm2d(num_features=int(norm.num_features - len(channel_index)),
eps=norm.eps,
momentum=norm.momentum,
affine=norm.affine,
track_running_stats=norm.track_running_stats)
new_norm.weight.data = index_remove(norm.weight.data, 0, channel_index)
new_norm.bias.data = index_remove(norm.bias.data, 0, channel_index)
if norm.track_running_stats:
new_norm.running_mean.data = index_remove(norm.running_mean.data, 0, channel_index)
new_norm.running_var.data = index_remove(norm.running_var.data, 0, channel_index)
return new_norm
#全连接因为filters数目的变化,也需要进行变化
def get_new_linear(linear, channel_index):
new_linear = torch.nn.Linear(in_features=int(linear.in_features - len(channel_index)),
out_features=linear.out_features,
bias=linear.bias is not None)
new_linear.weight.data = index_remove(linear.weight.data, 1, channel_index)
new_linear.bias.data = linear.bias.data
return new_linear
文章标题:Pruning Filters For Efficient ConvNets 剪枝代码小结
文章链接:http://soscw.com/index.php/essay/42989.html