@@ -70,6 +70,9 @@ typedef struct {
7070 int numOutputs = 0 ;
7171 int numInputArgs = 0 ;
7272 uint32_t scratchSize = 0 ;
73+ #ifdef EXTERNAL_MEM
74+ uint32_t sramScratchSize = 0 ;
75+ #endif
7376 uint32_t profileSize = 0 ;
7477 uint32_t debugSize = 0 ;
7578 NeutronModelConfig mcfg;
@@ -79,7 +82,18 @@ typedef struct {
7982 const uint8_t * outputTranspositionFlags;
8083 const uint8_t * inputMap;
8184 const uint8_t * outputMap;
82- } NeutronConfig;
85+ } NeutronExecutorchConfig;
86+
87+ #ifdef EXTERNAL_MEM
88+ // Neutron compute has no access to FLASH.
89+ // Prefetch weights from FLASH to SRAM using memcpy.
90+ // For a model converted with --fetch_constants_to_sram.
91+ void copy (void * dst, void * src, uint32_t size, uint32_t channel) {
92+ memcpy (dst, src, size);
93+ }
94+ void wait (uint32_t channel) {}
95+ static NeutronConfig neutronMemCopyConfig = {copy, wait};
96+ #endif
8397
8498// Applied on outputs.
8599template <typename T>
@@ -258,7 +272,7 @@ class NeutronBackend final : public PyTorchBackendInterface {
258272 ArrayRef<CompileSpec> compile_specs) const override {
259273 MemoryAllocator* allocator = context.get_runtime_allocator ();
260274
261- auto * cfg = allocator->allocateInstance <NeutronConfig >();
275+ auto * cfg = allocator->allocateInstance <NeutronExecutorchConfig >();
262276
263277 // The following data is read from the "processed" data blob.
264278 // cfg->numInputs
@@ -293,15 +307,22 @@ class NeutronBackend final : public PyTorchBackendInterface {
293307 switch (payloadVersion) {
294308 case 0 :
295309 cfg->scratchSize = buffer[9 ];
310+ #ifdef EXTERNAL_MEM
311+ cfg->sramScratchSize = buffer[10 ];
312+ #endif
296313 cfg->profileSize = 0 ;
297314 cfg->debugSize = 0 ;
298315 cfg->numInputs = buffer[11 ];
299316 cfg->numOutputs = buffer[12 ];
300317 break ;
301318 case 1 :
302319 cfg->scratchSize = buffer[9 ];
303- cfg->profileSize = buffer[10 ];
320+ // The highest bit has special meaning in NS >= 2.2.3
321+ cfg->profileSize = buffer[10 ] & 0x7FFFFFFF ;
304322 cfg->debugSize = buffer[11 ];
323+ #ifdef EXTERNAL_MEM
324+ cfg->sramScratchSize = buffer[12 ];
325+ #endif
305326 cfg->numInputs = buffer[13 ];
306327 cfg->numOutputs = buffer[14 ];
307328 break ;
@@ -351,6 +372,14 @@ class NeutronBackend final : public PyTorchBackendInterface {
351372 return Error::InvalidProgram;
352373 }
353374
375+ #ifdef EXTERNAL_MEM
376+ neutronRC = neutronSetConfig (&neutronMemCopyConfig);
377+ if (neutronRC != ENONE) {
378+ ET_LOG (Error, " Neutron set config failed with error code %ld" , neutronRC);
379+ return Error::InvalidProgram;
380+ }
381+ #endif
382+
354383 return cfg;
355384 }
356385
@@ -365,7 +394,8 @@ class NeutronBackend final : public PyTorchBackendInterface {
365394 BackendExecutionContext& context,
366395 DelegateHandle* input_handle,
367396 Span<EValue*> args) const override {
368- NeutronConfig* cfg = static_cast <NeutronConfig*>(input_handle);
397+ NeutronExecutorchConfig* cfg =
398+ static_cast <NeutronExecutorchConfig*>(input_handle);
369399
370400 // Allocate place for input and output pointers.
371401 cfg->dcfg .inputs = static_cast <const void **>(
@@ -381,6 +411,12 @@ class NeutronBackend final : public PyTorchBackendInterface {
381411 cfg->dcfg .outputs [cfg->numOutputs + 2 ] =
382412 static_cast <void *>(context.allocate (cfg->debugSize , 16 ));
383413
414+ #ifdef EXTERNAL_MEM
415+ // Allocate the space in SRAM to prefetch weights from FLASH.
416+ cfg->dcfg .scratchWeights =
417+ static_cast <void *>(context.allocate (cfg->sramScratchSize , 16 ));
418+ #endif
419+
384420 // Set inputs from args.
385421 // Transpose inputs if needed.
386422 for (int i = 0 ; i < cfg->numInputs ; i++) {
@@ -527,7 +563,8 @@ class NeutronBackend final : public PyTorchBackendInterface {
527563 }
528564
529565 void destroy (DelegateHandle* handle) const override {
530- NeutronConfig* cfg = reinterpret_cast <NeutronConfig*>(handle);
566+ NeutronExecutorchConfig* cfg =
567+ reinterpret_cast <NeutronExecutorchConfig*>(handle);
531568
532569 // Unprepare to free resources in neutron driver.
533570 NeutronError neutronRC = neutronModelUnprepare (cfg->nmh );
0 commit comments