10 #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
11 #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
14 #include "llvm/Config/llvm-config.h"
16 #ifdef LLVM_HAVE_TF_API
28 class ModelUnderTrainingRunner final :
public MLModelRunner {
31 ModelUnderTrainingRunner(
const ModelUnderTrainingRunner &) =
delete;
32 ModelUnderTrainingRunner &
33 operator=(
const ModelUnderTrainingRunner &) =
delete;
35 const std::vector<LoggedFeatureSpec> &outputLoggedFeatureSpecs()
const {
39 const Optional<TFModelEvaluator::EvaluationResult> &
40 lastEvaluationResult()
const {
41 return LastEvaluationResult;
43 static bool classof(
const MLModelRunner *R) {
47 static std::unique_ptr<ModelUnderTrainingRunner>
48 createAndEnsureValid(LLVMContext &Ctx,
const std::string &ModelPath,
50 const std::vector<TensorSpec> &InputSpecs,
51 StringRef OutputSpecsPathOverride =
"");
52 static std::unique_ptr<ModelUnderTrainingRunner>
53 createAndEnsureValid(LLVMContext &Ctx,
const std::string &ModelPath,
55 const std::vector<TensorSpec> &InputSpecs,
56 const std::vector<LoggedFeatureSpec> &OutputSpecs);
59 ModelUnderTrainingRunner(LLVMContext &Ctx,
const std::string &ModelPath,
60 const std::vector<TensorSpec> &InputSpecs,
61 const std::vector<LoggedFeatureSpec> &OutputSpecs);
63 std::unique_ptr<TFModelEvaluator>
Evaluator;
64 const std::vector<LoggedFeatureSpec> OutputSpecs;
65 Optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
66 void *evaluateUntyped()
override;
71 #endif // define(LLVM_HAVE_TF_API)
72 #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H