1+ import functools
2+ from collections import OrderedDict
3+
4+
5+ # using wonder's beautiful simplification:
6+ # https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
7+ def rgetattr (obj , attr , * args ):
8+ def _getattr (obj , attr ):
9+ return getattr (obj , attr , * args )
10+
11+ return functools .reduce (_getattr , [obj ] + attr .split ('.' ))
12+
13+
14+ class IntermediateLayerGetter :
15+ def __init__ (self , model , return_layers , keep_output = True ):
16+ """Wraps a Pytorch module to get intermediate values
17+
18+ Arguments:
19+ model {nn.module} -- The Pytorch module to call
20+ return_layers {dict} -- Dictionary with the selected submodules
21+ to return the output (format: {[current_module_name]: [desired_output_name]},
22+ current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)
23+
24+ Keyword Arguments:
25+ keep_output {bool} -- If True model_output contains the final model's output
26+ in the other case model_output is None (default: {True})
27+
28+ Returns:
29+ (mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
30+ your desired_output_name (s) and their values are the returned tensors
31+ of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
32+ See keep_output argument for model_output description.
33+ In case a submodule is called more than one time, all it's outputs are
34+ stored in a list.
35+ """
36+ self ._model = model
37+ self .return_layers = return_layers
38+ self .keep_output = keep_output
39+
40+ def __call__ (self , * args , ** kwargs ):
41+ ret = OrderedDict ()
42+ handles = []
43+ for name , new_name in self .return_layers .items ():
44+ layer = rgetattr (self ._model , name )
45+
46+ def hook (module , input , output , new_name = new_name ):
47+ if new_name in ret :
48+ if type (ret [new_name ]) is list :
49+ ret [new_name ].append (output )
50+ else :
51+ ret [new_name ] = [ret [new_name ], output ]
52+ else :
53+ ret [new_name ] = output
54+
55+ try :
56+ h = layer .register_forward_hook (hook )
57+ except AttributeError as e :
58+ raise AttributeError (f'Module { name } not found' )
59+ handles .append (h )
60+
61+ if self .keep_output :
62+ output = self ._model (* args , ** kwargs )
63+ else :
64+ self ._model (* args , ** kwargs )
65+ output = None
66+
67+ for h in handles :
68+ h .remove ()
69+
70+ return ret , output
71+
72+
73+ def main (args , config ):
74+ from models import build_model
75+ import torchvision .transforms as T
76+ from PIL import Image
77+
78+ model = build_model (config )
79+ checkpoint = torch .load (config .MODEL .RESUME , map_location = 'cpu' )
80+ model .load_state_dict (checkpoint ['model' ], strict = False )
81+ model .cuda ()
82+
83+ # examples:
84+ # return_layers = {
85+ # 'patch_embed': 'patch_embed',
86+ # 'levels.0.downsample': 'levels.0.downsample',
87+ # 'levels.0.blocks.0.dcn': 'levels.0.blocks.0.dcn',
88+ # }
89+ return_layers = {k : k for k in args .keys }
90+ mid_getter = IntermediateLayerGetter (model , return_layers = return_layers , keep_output = True )
91+
92+ image = Image .open (args .img )
93+
94+ transforms = T .Compose ([
95+ T .Resize (config .DATA .IMG_SIZE ),
96+ T .ToTensor (),
97+ T .Normalize (config .AUG .MEAN , config .AUG .STD )
98+ ])
99+ image = transforms (image )
100+ image = image .unsqueeze (0 )
101+ image = image .cuda ()
102+
103+ mid_outputs , model_output = mid_getter (image )
104+
105+ for k , v in mid_outputs .items ():
106+ print (k , v .shape )
107+
108+ return mid_outputs , model_output
109+
110+
111+ if __name__ == '__main__' :
112+ import argparse
113+ import torch
114+ from config import get_config
115+
116+ parser = argparse .ArgumentParser ('Get Intermediate Layer Output' )
117+ parser .add_argument ('--cfg' , type = str , required = True , metavar = "FILE" , help = 'Path to config file' )
118+ parser .add_argument ('--img' , type = str , required = True , metavar = "FILE" , help = 'Path to img file' )
119+ parser .add_argument ("--keys" , default = None , nargs = '+' , help = "The intermediate layer's keys you want to save." )
120+ parser .add_argument ('--resume' , help = 'resume from checkpoint' )
121+ parser .add_argument ('--save' , action = 'store_true' , help = 'Save the results.' )
122+ args = parser .parse_args ()
123+ config = get_config (args )
124+
125+ mid_outputs , model_output = main (args , config )
126+
127+ if args .save :
128+ torch .save (mid_outputs , args .img [:- 3 ] + '.pth' )
0 commit comments