Skip to content

Commit d69f111

Browse files
authored
Support extract intermediate feature (#126)
1 parent ae6c343 commit d69f111

2 files changed

Lines changed: 151 additions & 0 deletions

File tree

classification/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
This folder contains the implementation of the InternImage for image classification.
44

5+
<!-- TOC -->
6+
* [Install](#install)
7+
* [Data Preparation](#data-preparation)
8+
* [Evaluation](#evaluation)
9+
* [Training from Scratch on ImageNet-1K](#training-from-scratch-on-imagenet-1k)
10+
* [Manage Jobs with Slurm.](#manage-jobs-with-slurm)
11+
* [Training with Deepspeed](#training-with-deepspeed)
12+
* [Extracting Intermediate Features](#extracting-intermediate-features)
13+
* [Export](#export)
14+
<!-- TOC -->
15+
516
## Usage
617

718
### Install
@@ -259,6 +270,18 @@ Then, you could use `best.pth` as usual, e.g., `model.load_state_dict(torch.load
259270

260271
> Due to the lack of computational resources, the deepspeed training scripts are currently only verified for the first few epochs. Please fire an issue if you have problems for reproducing the whole training.
261272
273+
### Extracting Intermediate Features
274+
275+
To extract the features of an intermediate layer, you could use `extract_feature.py`.
276+
277+
For example, extract features of `b.png` from layers `patch_embed` and `levels.0.downsample` and save them to 'b.pth'.
278+
279+
```bash
280+
python extract_feature.py --cfg configs/internimage_t_1k_224.yaml --img b.png --keys patch_embed levels.0.downsample --save --resume internimage_t_1k_224.pth
281+
```
282+
283+
284+
262285
### Export
263286

264287
To export `InternImage-T` from PyTorch to ONNX, run:

classification/extract_feature.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import functools
2+
from collections import OrderedDict
3+
4+
5+
# using wonder's beautiful simplification:
6+
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
7+
def rgetattr(obj, attr, *args):
8+
def _getattr(obj, attr):
9+
return getattr(obj, attr, *args)
10+
11+
return functools.reduce(_getattr, [obj] + attr.split('.'))
12+
13+
14+
class IntermediateLayerGetter:
15+
def __init__(self, model, return_layers, keep_output=True):
16+
"""Wraps a Pytorch module to get intermediate values
17+
18+
Arguments:
19+
model {nn.module} -- The Pytorch module to call
20+
return_layers {dict} -- Dictionary with the selected submodules
21+
to return the output (format: {[current_module_name]: [desired_output_name]},
22+
current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)
23+
24+
Keyword Arguments:
25+
keep_output {bool} -- If True model_output contains the final model's output
26+
in the other case model_output is None (default: {True})
27+
28+
Returns:
29+
(mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
30+
your desired_output_name (s) and their values are the returned tensors
31+
of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
32+
See keep_output argument for model_output description.
33+
In case a submodule is called more than one time, all it's outputs are
34+
stored in a list.
35+
"""
36+
self._model = model
37+
self.return_layers = return_layers
38+
self.keep_output = keep_output
39+
40+
def __call__(self, *args, **kwargs):
41+
ret = OrderedDict()
42+
handles = []
43+
for name, new_name in self.return_layers.items():
44+
layer = rgetattr(self._model, name)
45+
46+
def hook(module, input, output, new_name=new_name):
47+
if new_name in ret:
48+
if type(ret[new_name]) is list:
49+
ret[new_name].append(output)
50+
else:
51+
ret[new_name] = [ret[new_name], output]
52+
else:
53+
ret[new_name] = output
54+
55+
try:
56+
h = layer.register_forward_hook(hook)
57+
except AttributeError as e:
58+
raise AttributeError(f'Module {name} not found')
59+
handles.append(h)
60+
61+
if self.keep_output:
62+
output = self._model(*args, **kwargs)
63+
else:
64+
self._model(*args, **kwargs)
65+
output = None
66+
67+
for h in handles:
68+
h.remove()
69+
70+
return ret, output
71+
72+
73+
def main(args, config):
74+
from models import build_model
75+
import torchvision.transforms as T
76+
from PIL import Image
77+
78+
model = build_model(config)
79+
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
80+
model.load_state_dict(checkpoint['model'], strict=False)
81+
model.cuda()
82+
83+
# examples:
84+
# return_layers = {
85+
# 'patch_embed': 'patch_embed',
86+
# 'levels.0.downsample': 'levels.0.downsample',
87+
# 'levels.0.blocks.0.dcn': 'levels.0.blocks.0.dcn',
88+
# }
89+
return_layers = {k: k for k in args.keys}
90+
mid_getter = IntermediateLayerGetter(model, return_layers=return_layers, keep_output=True)
91+
92+
image = Image.open(args.img)
93+
94+
transforms = T.Compose([
95+
T.Resize(config.DATA.IMG_SIZE),
96+
T.ToTensor(),
97+
T.Normalize(config.AUG.MEAN, config.AUG.STD)
98+
])
99+
image = transforms(image)
100+
image = image.unsqueeze(0)
101+
image = image.cuda()
102+
103+
mid_outputs, model_output = mid_getter(image)
104+
105+
for k, v in mid_outputs.items():
106+
print(k, v.shape)
107+
108+
return mid_outputs, model_output
109+
110+
111+
if __name__ == '__main__':
112+
import argparse
113+
import torch
114+
from config import get_config
115+
116+
parser = argparse.ArgumentParser('Get Intermediate Layer Output')
117+
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='Path to config file')
118+
parser.add_argument('--img', type=str, required=True, metavar="FILE", help='Path to img file')
119+
parser.add_argument("--keys", default=None, nargs='+', help="The intermediate layer's keys you want to save.")
120+
parser.add_argument('--resume', help='resume from checkpoint')
121+
parser.add_argument('--save', action='store_true', help='Save the results.')
122+
args = parser.parse_args()
123+
config = get_config(args)
124+
125+
mid_outputs, model_output = main(args, config)
126+
127+
if args.save:
128+
torch.save(mid_outputs, args.img[:-3] + '.pth')

0 commit comments

Comments
 (0)