博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
VGG16提取图像特征 (torch7)
阅读量:7086 次
发布时间:2019-06-28

本文共 2742 字,大约阅读时间需要 9 分钟。

VGG16提取图像特征 (torch7)

  1. 下载pretrained model,保存到当前目录下

  1. th> caffemodel_url = 'http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel' 

  2. th> proto_url='https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt' 

  3. th> os.execute('wget VGG_ILSVRC_16_layers.caffemodel' .. caffemodel_url) 

  4. th> os.execute('wget VGG_ILSVRC_16_layers_deploy.prototxt' .. proto_url) 

  1. 使用loadcaffe提取图像特征

  1.  

  2. require 'torch' -- 使用th命令,可忽略 

  3. require 'nn' -- 修改model用到nn包 

  4. require 'loadcaffe' -- 加在caffe训练的包 

  5. require 'image' -- 加载图像,处理图像,可以使用cv中函数替代 

  6.  

  7. local loadSize = {
    3,256,256} -- 加载图像scale尺寸 

  8. local sampleSize = {
    3,224,224} -- 样本尺寸,其实就是选取scale后图像的中间一块 

  9.  

  10. local function loadImage(input) 

  11. -- 将图像缩放到loadSize尺寸,为了保证aspect ratio不变,将短边缩放后,长边按比例缩放 

  12. if input:size(3) < input:size(2) then 

  13. input = image.scale(input,loadSize[2],loadSize[3]*input:size(2)/input:size(3)) 

  14. -- 注意image.scale(src,width,height),width对应的是input:size[3],height对应的是input:size[2] 

  15. else 

  16. input = image.scale(input,loadSize[2]*input:size(3)/input:size(2),loadSize[3]) 

  17. end 

  18. return input 

  19. end 

  20.  

  21. local bgr_means = {
    103.939,116.779,123.68} --VGG预训练中的均值 

  22. local function vggPreProcessing(img) 

  23. local img2=img:clone() 

  24. img2[{
    {
    1}}] =img2[{
    {
    3}}] -- image.load 加载图像是rgb格式,需转化维bgr 

  25. img2[{
    {
    3}}] = img[{
    {
    1}}] 

  26. img2:mul(255) -- image.load()加载的图像 pixel value \in [0,1] 

  27. for i=1,3 do 

  28. img2[i]:add(-bgr_means[i]) -- 中心化 

  29. end 

  30. return img2 

  31. end 

  32.  

  33. local function centerCrop(input) 

  34. local oH = sampleSize[2

  35. local oW = sampleSize[3

  36. local iW = input:size(3

  37. local iH = input:size(2

  38. local w1 = math.ceil((iW-oW)/2

  39. local h1 = math.ceil((iH-oH)/2

  40. local out = image.crop(input,w1,h1,w1+oW,h1+oH) 

  41. return out 

  42. end 

  43.  

  44. local function getPretrainedModel() 

  45. local proto = 'VGG_ILSVRC_16_layers_deploy.prototxt' 

  46. local caffemodel = '/home/zwzhou/modelZoo/VGG_ILSVRC_16_layers.caffemodel' 

  47.  

  48. local model = loadcaffe.load(proto,caffemodel,'nn') -- 加载pretrained model 

  49. for i=1,3 do -- 将最后3层舍掉 

  50. model.modules[#model.modules]=nil 

  51. end 

  52. -- 删除pretrained model的一些层官方方法 

  53. -- ========================== 

  54. -- for i= 40,38,-1 do 

  55. -- model:remove(i) 

  56. -- end 

  57. -- ========================== 

  58. model:add(nn.Normalize(2)) -- 添加一层正则化层,将输出向量归一化 

  59.  

  60. model:evaluate() -- self.training=false ,非训练,让网络参数不变 

  61. return model 

  62. end 

  63.  

  64. torch.setdefaulttensortype('torch.FloatTensor'

  65. model = getPretrainedModel() 

  66.  

  67. filepath = '/home/zwzhou/MOT16/train/MOT16-02/img1/000001.jpg' 

  68. local img1=image.load(filepath) -- rgb图像 

  69. local input = image.crop(img1,910,480,910+97,480+110) -- 里面参数时选择原图像的一个区域,boundingbox 

  70.  

  71. input = loadImage(input) 

  72. local vggPreProcessed = vggPreProcessing(input) 

  73. local out = centerCrop(vggPreProcessed) 

  74.  

  75. local outputs = model:forward(out) 

  76.  

  77. print(outputs) 

  78. print(#outputs) 

  1. 参考项目

转载于:https://www.cnblogs.com/YiXiaoZhou/p/6757165.html

你可能感兴趣的文章
前端面试闲谈
查看>>
android SwipeRefreshLayout嵌套Webview滑动冲突问题解决
查看>>
css3属性和静态页面细节
查看>>
函数式编程在前端权限管理中的应用
查看>>
Javascript — Promise
查看>>
你真的懂CSS3的伪元素::before吗?
查看>>
ES6系列教程之Set和Map
查看>>
mac python3 轻松安装教程
查看>>
死磕 java并发包之AtomicStampedReference源码分析(ABA问题详解)
查看>>
如何隐藏 shortcut bar
查看>>
LeetCode 448 Find All Numbers Disappeared in an Array
查看>>
css-position
查看>>
iOS应用签名(上)
查看>>
有趣的border-radius
查看>>
Ubuntu远程Windows
查看>>
Zilliqa生态构建资助计划第三批获奖名单
查看>>
javaSE基础知识 1.1代码注释
查看>>
分享一些 Broadcast 使用技巧
查看>>
Flutter audioplayers使用小结
查看>>
iOS 状态栏的图标
查看>>