Skip to content

Commit

Permalink
LIB: Add mindspore backend
Browse files Browse the repository at this point in the history
1. Add mindspore backend
2. If both GPU and NPU environments exist in the environment, NPU environment is preferred

Signed-off-by: chou-shun <[email protected]>
  • Loading branch information
unknown authored and chou-shun committed Aug 24, 2021
1 parent c3475e2 commit 10a2ca7
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
12 changes: 11 additions & 1 deletion lib/sedna/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,20 @@ def set_backend(estimator=None, config=None):
if config is None:
config = BaseConfig()
use_cuda = False
use_ascend = False
backend_type = os.getenv(
'BACKEND_TYPE', config.get("backend_type", "UNKNOWN")
)
backend_type = str(backend_type).upper()
device_category = os.getenv(
'DEVICE_CATEGORY', config.get("device_category", "CPU")
)
if 'CUDA_VISIBLE_DEVICES' in os.environ:

# NPU>GPU>CPU
if device_category == "ASCEND":
use_ascend = True
os.environ['DEVICE_CATEGORY'] = "ASCEND"
elif 'CUDA_VISIBLE_DEVICES' in os.environ:
os.environ['DEVICE_CATEGORY'] = 'GPU'
use_cuda = True
else:
Expand All @@ -44,14 +50,18 @@ def set_backend(estimator=None, config=None):
from sedna.backend.tensorflow import TFBackend as REGISTER
elif backend_type == "KERAS":
from sedna.backend.tensorflow import KerasBackend as REGISTER
elif backend_type == "MINDSPORE":
from sedna.backend.mindspore import MSBackend as REGISTER
else:
warnings.warn(f"{backend_type} Not Support yet, use itself")
from sedna.backend.base import BackendBase as REGISTER

model_save_url = config.get("model_url")
base_model_save = config.get("base_model_url") or model_save_url
model_save_name = config.get("model_name")
return REGISTER(
estimator=estimator, use_cuda=use_cuda,
use_ascend=use_ascend,
model_save_path=base_model_save,
model_name=model_save_name,
model_save_url=model_save_url
Expand Down
3 changes: 2 additions & 1 deletion lib/sedna/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BackendBase:
def __init__(self, estimator, fine_tune=True, **kwargs):
self.framework = ""
self.estimator = estimator
self.use_ascend = True if kwargs.get("use_ascend") else False
self.use_cuda = True if kwargs.get("use_cuda") else False
self.fine_tune = fine_tune
self.model_save_path = kwargs.get("model_save_path") or "/tmp"
Expand All @@ -35,7 +36,7 @@ def model_name(self):
if self.default_name:
return self.default_name
model_postfix = {"pytorch": ".pth",
"keras": ".pb", "tensorflow": ".pb"}
"keras": ".pb", "tensorflow": ".pb", "mindspore": ".ckpt"}
continue_flag = "_finetune_" if self.fine_tune else ""
post_fix = model_postfix.get(self.framework, ".pkl")
return f"model{continue_flag}{self.framework}{post_fix}"
Expand Down
90 changes: 90 additions & 0 deletions lib/sedna/backend/mindspore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import mindspore.context as context
from sedna.backend.base import BackendBase
from sedna.common.file_ops import FileOps


class MSBackend(BackendBase):

def __init__(self, estimator, fine_tune=True, **kwargs):
super(MSBackend, self).__init__(
estimator=estimator, fine_tune=fine_tune, **kwargs)
self.framework = "mindspore"

if self.use_ascend:
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
elif self.use_cuda:
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
else:
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

if callable(self.estimator):
self.estimator = self.estimator()

def train(self, train_data, valid_data=None, **kwargs):
if callable(self.estimator):
self.estimator = self.estimator()
if self.fine_tune and FileOps.exists(self.model_save_path):
self.finetune()
self.has_load = True
varkw = self.parse_kwargs(self.estimator.train, **kwargs)
return self.estimator.train(
train_data=train_data,
valid_data=valid_data,
**varkw
)

def predict(self, data, **kwargs):
if not self.has_load:
self.load()
varkw = self.parse_kwargs(self.estimator.predict, **kwargs)
return self.estimator.predict(data=data, **varkw)

def evaluate(self, data, **kwargs):
if not self.has_load:
self.load()
varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs)
return self.estimator.evaluate(data, **varkw)

def finetune(self):
"""todo: no support yet"""

def load_weights(self):
model_path = FileOps.join_path(self.model_save_path, self.model_name)
if os.path.exists(model_path):
self.estimator.load_weights(model_path)

def get_weights(self):
"""todo: no support yet"""

def set_weights(self, weights):
"""todo: no support yet"""

def model_info(self, model, relpath=None, result=None):
_, _type = os.path.splitext(model)
if relpath:
_url = FileOps.remove_path_prefix(model, relpath)
else:
_url = model
_type = _type.lstrip(".").lower()
results = [{
"format": _type,
"url": _url,
"metrics": result
}]
return results

0 comments on commit 10a2ca7

Please sign in to comment.