diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h index c085960ebad3fc32de3c0bd7c973fc23f049ed2f..7d294afca45fa25ecb1c3860b8501915da1f4372 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h @@ -476,6 +476,9 @@ class VSO { friend class ExecutionSession; friend class MaterializationResponsibility; public: + using FallbackDefinitionGeneratorFunction = + std::function; + using AsynchronousSymbolQuerySet = std::set>; @@ -495,6 +498,14 @@ public: /// Get a reference to the ExecutionSession for this VSO. ExecutionSessionBase &getExecutionSession() const { return ES; } + /// Set a fallback defenition generator. If set, lookup and lookupFlags will + /// pass the unresolved symbols set to the fallback definition generator, + /// allowing it to add a new definition to the VSO. + void setFallbackDefinitionGenerator( + FallbackDefinitionGeneratorFunction FallbackDefinitionGenerator) { + this->FallbackDefinitionGenerator = std::move(FallbackDefinitionGenerator); + } + /// Define all symbols provided by the materialization unit to be part /// of the given VSO. template @@ -567,9 +578,17 @@ private: SymbolMap Symbols; UnmaterializedInfosMap UnmaterializedInfos; MaterializingInfosMap MaterializingInfos; + FallbackDefinitionGeneratorFunction FallbackDefinitionGenerator; Error defineImpl(MaterializationUnit &MU); + SymbolNameSet lookupFlagsImpl(SymbolFlagsMap &Flags, + const SymbolNameSet &Names); + + void lookupImpl(std::shared_ptr &Q, + std::vector> &MUs, + SymbolNameSet &Unresolved); + void detachQueryHelper(AsynchronousSymbolQuery &Q, const SymbolNameSet &QuerySymbols); diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp index 0d92f0a5f7748f138b08cded1b845622f39b8a57..bbb0b1ba26c4b25022bf934380339c5b722c80e2 100644 --- a/llvm/lib/ExecutionEngine/Orc/Core.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp @@ -639,84 +639,57 @@ void VSO::notifyFailed(const SymbolNameSet &FailedSymbols) { SymbolNameSet VSO::lookupFlags(SymbolFlagsMap &Flags, const SymbolNameSet &Names) { return ES.runSessionLocked([&, this]() { - SymbolNameSet Unresolved; - - for (auto &Name : Names) { - auto I = Symbols.find(Name); - if (I == Symbols.end()) { - Unresolved.insert(Name); - continue; + auto Unresolved = lookupFlagsImpl(Flags, Names); + if (FallbackDefinitionGenerator && !Unresolved.empty()) { + auto FallbackDefs = FallbackDefinitionGenerator(*this, Unresolved); + if (!FallbackDefs.empty()) { + auto Unresolved2 = lookupFlagsImpl(Flags, FallbackDefs); + (void)Unresolved2; + assert(Unresolved2.empty() && + "All fallback defs should have been found by lookupFlagsImpl"); + for (auto &D : FallbackDefs) + Unresolved.erase(D); } + }; + return Unresolved; + }); +} + +SymbolNameSet VSO::lookupFlagsImpl(SymbolFlagsMap &Flags, + const SymbolNameSet &Names) { + SymbolNameSet Unresolved; + + for (auto &Name : Names) { + auto I = Symbols.find(Name); - assert(!Flags.count(Name) && "Symbol already present in Flags map"); - Flags[Name] = JITSymbolFlags::stripTransientFlags(I->second.getFlags()); + if (I == Symbols.end()) { + Unresolved.insert(Name); + continue; } - return Unresolved; - }); + assert(!Flags.count(Name) && "Symbol already present in Flags map"); + Flags[Name] = JITSymbolFlags::stripTransientFlags(I->second.getFlags()); + } + + return Unresolved; } SymbolNameSet VSO::lookup(std::shared_ptr Q, SymbolNameSet Names) { - SymbolNameSet Unresolved = std::move(Names); std::vector> MUs; + SymbolNameSet Unresolved = std::move(Names); ES.runSessionLocked([&, this]() { - for (auto I = Unresolved.begin(), E = Unresolved.end(); I != E;) { - auto TmpI = I++; - auto Name = *TmpI; - - // Search for the name in Symbols. Skip it if not found. - auto SymI = Symbols.find(Name); - if (SymI == Symbols.end()) - continue; - - // If we found Name in V, remove it frome the Unresolved set and add it - // to the dependencies set. - Unresolved.erase(TmpI); - - // If the symbol has an address then resolve it. - if (SymI->second.getAddress() != 0) - Q->resolve(Name, SymI->second); - - // If the symbol is lazy, get the MaterialiaztionUnit for it. - if (SymI->second.getFlags().isLazy()) { - assert(SymI->second.getAddress() == 0 && - "Lazy symbol should not have a resolved address"); - assert(!SymI->second.getFlags().isMaterializing() && - "Materializing and lazy should not both be set"); - auto UMII = UnmaterializedInfos.find(Name); - assert(UMII != UnmaterializedInfos.end() && - "Lazy symbol should have UnmaterializedInfo"); - auto MU = std::move(UMII->second->MU); - assert(MU != nullptr && "Materializer should not be null"); - - // Kick all symbols associated with this MaterializationUnit into - // materializing state. - for (auto &KV : MU->getSymbols()) { - auto SymK = Symbols.find(KV.first); - auto Flags = SymK->second.getFlags(); - Flags &= ~JITSymbolFlags::Lazy; - Flags |= JITSymbolFlags::Materializing; - SymK->second.setFlags(Flags); - UnmaterializedInfos.erase(KV.first); - } - - // Add MU to the list of MaterializationUnits to be materialized. - MUs.push_back(std::move(MU)); - } else if (!SymI->second.getFlags().isMaterializing()) { - // The symbol is neither lazy nor materializing. Finalize it and - // continue. - Q->notifySymbolReady(); - continue; + lookupImpl(Q, MUs, Unresolved); + if (FallbackDefinitionGenerator && !Unresolved.empty()) { + auto FallbackDefs = FallbackDefinitionGenerator(*this, Unresolved); + if (!FallbackDefs.empty()) { + for (auto &D : FallbackDefs) + Unresolved.erase(D); + lookupImpl(Q, MUs, FallbackDefs); + assert(FallbackDefs.empty() && + "All fallback defs should have been found by lookupImpl"); } - - // Add the query to the PendingQueries list. - assert(SymI->second.getFlags().isMaterializing() && - "By this line the symbol should be materializing"); - auto &MI = MaterializingInfos[Name]; - MI.PendingQueries.push_back(Q); - Q->addQueryDependence(*this, Name); } }); @@ -733,6 +706,67 @@ SymbolNameSet VSO::lookup(std::shared_ptr Q, return Unresolved; } +void VSO::lookupImpl(std::shared_ptr &Q, + std::vector> &MUs, + SymbolNameSet &Unresolved) { + for (auto I = Unresolved.begin(), E = Unresolved.end(); I != E;) { + auto TmpI = I++; + auto Name = *TmpI; + + // Search for the name in Symbols. Skip it if not found. + auto SymI = Symbols.find(Name); + if (SymI == Symbols.end()) + continue; + + // If we found Name in V, remove it frome the Unresolved set and add it + // to the dependencies set. + Unresolved.erase(TmpI); + + // If the symbol has an address then resolve it. + if (SymI->second.getAddress() != 0) + Q->resolve(Name, SymI->second); + + // If the symbol is lazy, get the MaterialiaztionUnit for it. + if (SymI->second.getFlags().isLazy()) { + assert(SymI->second.getAddress() == 0 && + "Lazy symbol should not have a resolved address"); + assert(!SymI->second.getFlags().isMaterializing() && + "Materializing and lazy should not both be set"); + auto UMII = UnmaterializedInfos.find(Name); + assert(UMII != UnmaterializedInfos.end() && + "Lazy symbol should have UnmaterializedInfo"); + auto MU = std::move(UMII->second->MU); + assert(MU != nullptr && "Materializer should not be null"); + + // Kick all symbols associated with this MaterializationUnit into + // materializing state. + for (auto &KV : MU->getSymbols()) { + auto SymK = Symbols.find(KV.first); + auto Flags = SymK->second.getFlags(); + Flags &= ~JITSymbolFlags::Lazy; + Flags |= JITSymbolFlags::Materializing; + SymK->second.setFlags(Flags); + UnmaterializedInfos.erase(KV.first); + } + + // Add MU to the list of MaterializationUnits to be materialized. + MUs.push_back(std::move(MU)); + } else if (!SymI->second.getFlags().isMaterializing()) { + // The symbol is neither lazy nor materializing. Finalize it and + // continue. + Q->notifySymbolReady(); + continue; + } + + // Add the query to the PendingQueries list. + assert(SymI->second.getFlags().isMaterializing() && + "By this line the symbol should be materializing"); + auto &MI = MaterializingInfos[Name]; + MI.PendingQueries.push_back(Q); + Q->addQueryDependence(*this, Name); + } +} + void VSO::dump(raw_ostream &OS) { ES.runSessionLocked([&, this]() { OS << "VSO \"" << VSOName diff --git a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp index bc7a3622a0656102af457ad90d4bcf4bd89b5902..01f81d9a80fd68ccb690add1ffbf41a62cc4012b 100644 --- a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp @@ -508,6 +508,33 @@ TEST(CoreAPIsTest, DefineMaterializingSymbol) { EXPECT_TRUE(BarResolved) << "Bar should have been resolved"; } +TEST(CoreAPIsTest, FallbackDefinitionGeneratorTest) { + constexpr JITTargetAddress FakeFooAddr = 0xdeadbeef; + constexpr JITTargetAddress FakeBarAddr = 0xcafef00d; + + ExecutionSession ES; + auto Foo = ES.getSymbolStringPool().intern("foo"); + auto Bar = ES.getSymbolStringPool().intern("bar"); + + auto FooSym = JITEvaluatedSymbol(FakeFooAddr, JITSymbolFlags::Exported); + auto BarSym = JITEvaluatedSymbol(FakeBarAddr, JITSymbolFlags::Exported); + + auto &V = ES.createVSO("V"); + + cantFail(V.define(absoluteSymbols({{Foo, FooSym}}))); + + V.setFallbackDefinitionGenerator([&](VSO &W, const SymbolNameSet &Names) { + cantFail(W.define(absoluteSymbols({{Bar, BarSym}}))); + return SymbolNameSet({Bar}); + }); + + auto Result = cantFail(lookup({&V}, {Foo, Bar})); + + EXPECT_EQ(Result.count(Bar), 1U) << "Expected to find fallback def for 'bar'"; + EXPECT_EQ(Result[Bar].getAddress(), FakeBarAddr) + << "Expected address of fallback def for 'bar' to be " << FakeBarAddr; +} + TEST(CoreAPIsTest, FailResolution) { ExecutionSession ES; auto Foo = ES.getSymbolStringPool().intern("foo");