-
Notifications
You must be signed in to change notification settings - Fork 166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lifelong learning supporting non-structure #352
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: luosiqi <[email protected]>
Signed-off-by: luosiqi <[email protected]>
Sedna lifelong learning supports unstructured data based on semantic segmentation example
Signed-off-by: luosiqi <[email protected]>
Code check and base model improvement of unstructured lifelong learning framework
@luosiqi Delete code files irrelevant to the scenarios in example folder. |
return CPA | ||
|
||
|
||
if __name__ == '__main__': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luosiqi put the test code elsewhere, e.g. ./test/test_basemodel.py
|
||
|
||
def train_args(): | ||
parser = argparse.ArgumentParser(description="PyTorch RFNet Training") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luosiqi
The command-line parsing module argparse should not be used, because it dose not use in this scense. It's easy to misunderstand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luosiqi suggest that:
Class TrainArgs:
def __init__(self, **kwargs):
self.depth = kwargs.get('depth', False)
self.dateaset = Context.get_parameters('dataset', 'cityscapes')
``
'best_pred': self.trainer.best_pred, | ||
}, is_best) | ||
|
||
# if not self.trainer.args.no_val and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luosiqi delete comment code
return args | ||
|
||
|
||
def accuracy(y_true, y_pred, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luosiqi the ./accuracy.py
has this func accuracy
in the project, so you can import it.
from dataloaders import make_data_loader | ||
from dataloaders import custom_transforms as tr | ||
|
||
def preprocess(image_urls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this func may be the class(Model
)‘s private func
from utils.metrics import Evaluator | ||
from tqdm import tqdm | ||
from dataloaders import make_data_loader | ||
from sedna.common.class_factory import ClassType, ClassFactory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note the order of import
@@ -0,0 +1,38 @@ | |||
from basemodel import val_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import of the relative path should be adjusted.
__all__ = ('accuracy') | ||
|
||
@ClassFactory.register(ClassType.GENERAL) | ||
def accuracy(y_true, y_pred, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Common keyword. Use alias while register.
_, _, test_loader, num_class = make_data_loader(args, test_data=y_true) | ||
evaluator = Evaluator(num_class) | ||
|
||
tbar = tqdm(test_loader, desc='\r') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless
if args.cuda: | ||
image, target = image.cuda(), target.cuda() | ||
if args.depth: | ||
depth = depth.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check whether the device supports GPU.
'cityrand', | ||
'target', | ||
'xrlab', | ||
'e1', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the meanning of xrlab
and e1
|
||
if args.checkname is None: | ||
args.checkname = 'RFNet' | ||
print(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
relace print by using logger
choices=[ | ||
'citylostfound', | ||
'cityscapes', | ||
'xrlab', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be inconsistent with the training
from dataloaders import custom_transforms as tr | ||
|
||
class CityscapesSegmentation(data.Dataset): | ||
NUM_CLASSES = 19 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic number
def __init__(self, args, root=Path.db_root_dir('cityscapes'), data=None, split="train"): | ||
|
||
# self.root = root | ||
self.root = "/home/lsq/Dataset/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask
@@ -0,0 +1,27 @@ | |||
import torch.nn as nn | |||
from itertools import chain # 串联多个迭代对象 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace with english will be more general
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[2, 2, 2, 2] Hyperparameters are restricted.
@@ -0,0 +1,88 @@ | |||
# -*- coding: utf-8 -*- | |||
# File : replicate.py | |||
# Author : Jiayuan Mao |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be aware of the use of other people's code under community constraints
label_colours = get_cityscapes_labels() | ||
elif dataset == 'target': | ||
n_classes = 24 | ||
label_colours = get_cityscapes_labels() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
switch Statements
@JoeyHwong-gk: changing LGTM is restricted to collaborators In response to this: Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
Please add kubeedge copyright at the top |
self.extractor_key = KBResourceConstant.EXTRACTOR.value | ||
|
||
ModelLoadingThread(self, self.task_index).start() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of repetitive code, please move up.
try: | ||
task_index = FileOps.load(task_index_url) | ||
except Exception as err: | ||
self.log.error(f"{err}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proposed merge
@@ -37,6 +37,11 @@ class ClassType: | |||
DATASET = 'data_process' | |||
CALLBACK = 'post_process_callback' | |||
|
|||
# TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what the todo
tags for?
|
||
def __init__(self, task_extractor, **kwargs): | ||
self.task_extractor = task_extractor | ||
self.log = LOGGER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reasons to define self.log
?
for i in range(self.n_class): | ||
# sample = BaseDataSource() | ||
# sample.x = samples.x[i * partition_length: (i + 1) * partition_length] | ||
# sample.y = samples.y[i * partition_length: (i + 1) * partition_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup
@luosiqi |
self.val_args.label_save_path = os.path.join(label_save_dir, "label") | ||
self.val_args.save_predicted_image = kwargs.get( | ||
"save_predicted_image", "true").lower() | ||
self.validator = Validator(self.val_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not recommended that self.validator = Validator(self.val_args)
be placed in the initialization phase.
from dataloaders import custom_transforms as tr | ||
|
||
class CityscapesSegmentation(data.Dataset): | ||
NUM_CLASSES = 24 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic number
txt file which contain image list parser | ||
""" | ||
|
||
def __init__(self, data_type, func=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use func to handle it!
# func may use this
# func = _data_feature_process
def _data_feature_process(line: str):
res = line.strip().split()
return res[:-1], res[-1]
KBResourceConstant.EDGE_KB_DIR.value), | ||
task_index=KBResourceConstant.KB_INDEX_NAME.value) | ||
|
||
self.cloud_knowledge_management = CloudKnowledgeManagement( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't put it in the this func.
you can put it in "train func" and "eval func"
self.cloud_knowledge_management = CloudKnowledgeManagement( | ||
config, estimator=e) | ||
|
||
self.edge_knowledge_management = EdgeKnowledgeManagement( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't put it in the this func.
you can put it in "infer func"
self.cloud_knowledge_management, | ||
self.edge_knowledge_management, | ||
unseen_task_allocation) | ||
|
||
task_index = FileOps.join_path(config['output_url'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CloudKnowledgeManagement
also has this command, delete it?
|
||
seen_samples, unseen_samples = unseen_sample_re_recognition(train_data) | ||
|
||
# TODO: retrain temporarily |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete the omment
relpath=self.config.data_path_prefix) | ||
self.report_task_info( | ||
None, K8sResourceKindStatus.COMPLETED.value, task_info_res) | ||
self.log.info(f"Lifelong learning Train task Finished, " | ||
f"KB idnex save in {self.config.task_index}") | ||
f"KB index save in {task_index}") | ||
return callback_func(self.estimator, res) if callback_func else res | ||
|
||
def update(self, train_data, valid_data=None, post_process=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine this funcupdate
and train
into an external func, e,g,:
def train(self):
if not has_completed_initial_training:
return self._initial_train()
return self._update(self)
Combine A and B into an external function.
train_data = IndexDataParse(data_type="train", func=_load_txt_dataset) | ||
train_data.parse(train_dataset_url, use_raw=False) | ||
|
||
is_completed_initilization = str(Context.get_parameters("HAS_COMPLETED_INITIAL_TRAINING", "false")).lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this judgment in the sedna lib
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only one train interface is exposed to users.
self.estimator = set_backend(estimator=estimator, config=config) | ||
self.cloud_knowledge_management = cloud_knowledge_management | ||
self.edge_knowledge_management = edge_knowledge_management | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put parameters(cloud_knowledge_management
and edge_knowledge_management
)to other funs instead of initial func.
|
||
feedback = {} | ||
for i, task in enumerate(task_groups): | ||
LOGGER.info(f"MTL Train start {i} : {task.entry}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task.samples
may be [ ]
self.task_update_decision = task_update_decision or { | ||
"method": "UpdateStrategyDefault" | ||
} | ||
self.task_update_decision_param = e._parse_param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if task_update_decision
is a callable module instance, then there's no need to set its param by _parse_param
.
seen_samples.y = np.concatenate( | ||
(seen_samples.y, unseen_samples.y), axis=0) | ||
|
||
task_update_decision = ClassFactory.get_cls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if task_update_decision
is callable, then skip ClassFactory.get_cls
method and set task index instead.
|
||
|
||
@ClassFactory.register(ClassType.KM) | ||
class UpdateStrategyDefault: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add set
method to set task index.
Signed-off-by: JimmyYang20 <[email protected]>
Add pylint in ci
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
@jaypume: PR needs rebase. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
No description provided.