question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Use your own data, there is only one category, and this category can only be divided into two parts

See original GitHub issue

train_partseg.py

if mydata==1:
    seg_classes = {'Nut': [0, 1]}
else:
    seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}


    if mydata == 1:
        num_classes = 1
        num_part = 2
    else:
        num_classes = 16
        num_part = 50

ShapeNetDataLoader.py

        if mydata == 1:
            self.seg_classes = {'Nut': [0, 1]}
        else:
            #Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
            self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
                                'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
                                'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
                                'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
                                'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}

pointnet2_part_seg_msg.py

        if mydata == 1:
             self.fp1 = PointNetFeaturePropagation(in_channel=128 + 1 + 6 + additional_channel, mlp=[128, 128])
        else:
             self.fp1 = PointNetFeaturePropagation(in_channel=150+additional_channel, mlp=[128, 128])

        if mydata == 1:
            cls_label_one_hot = cls_label.view(B,1,1).repeat(1,1,N)
        else:
            cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:5
  • Comments:7

github_iconTop GitHub Comments

1reaction
yangninghuacommented, Aug 21, 2020

inference.py code

class COUPLING(object):
    def __init__(self, device_num):
        self.device_num = device_num.split(':')[-1]
        os.environ["CUDA_VISIBLE_DEVICES"] = self.device_num
        self.num_classes = 1
        self.num_part = 2
        self.num_votes = 3
        self.checkpoint = torch.load('./checkpoints/best_model.pth')
        self.model = pointnet2_part_seg_msg.get_model(self.num_part, normal_channel=False).cuda()

    def pipeline(self, ~~~~~~~):
        classifier = self.model
        classifier.load_state_dict(self.checkpoint['model_state_dict'])
        classifier = classifier.eval()

        with torch.no_grad():
            for pointcloud in points_list:
                ori_point = pointcloud
                point_set = copy.deepcopy(ori_point)
                # 归一化
                point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
                # 重采样
                choice = np.random.choice(len(ori_point), len(ori_point), replace=True)
                point_set = point_set[choice, :]
                ori_point = pointcloud[choice, :]

                # 升维适应网络NCHW
                point_set = np.expand_dims(point_set, axis=0)

                # 转换格式
                point_set = torch.from_numpy(point_set)
                points_cuda = point_set.float().cuda()
                points_cuda = points_cuda.transpose(2, 1)

                # 该网络特殊使用方式
                label = np.array([0])
                label = np.expand_dims(label, axis=0)
                label = torch.from_numpy(label)
                label = label.long().cuda()

                vote_pool = torch.zeros(1, point_set.shape[1], self.num_part).cuda()
                for _ in range(self.num_votes):
                    seg_pred, _ = classifier(points_cuda, to_categorical(label, self.num_classes))
                    vote_pool += seg_pred
                seg_pred = vote_pool / self.num_votes

                cur_pred_val = seg_pred.cpu().data.numpy()
                logits = cur_pred_val[0, :, :]
                seg_result = np.argmax(logits, axis=1)
                nut_point = ori_point[seg_result ==1]
                ~~~
0reactions
madinweicommented, Dec 12, 2022

@yangninghua , am very grateful for your code contribution. however, I am a novice, if possible can you guide me in implementing the inference code? and can it also work for the classification as well?

thank you again and I hope you have the time to help

Read more comments on GitHub >

github_iconTop Results From Across the Web

4.2 Types of variables - Statistique Canada
Variables may be classified into two main categories: categorical and numeric. Each category is then classified in two subcategories: ...
Read more >
Data classification methods—ArcGIS Pro | Documentation
The features are divided into classes whose boundaries are set where there are relatively big differences in the data values.
Read more >
Content categories - HTML: HyperText Markup Language | MDN
Most HTML elements are a member of one or more content categories — these categories group elements that share common characteristics.
Read more >
Categorical Data: Definition + [Examples, Variables & Analysis]
A categorical variable is a variable type with two or more categories. Sometimes called a discrete variable, it is mainly classified into two...
Read more >
A Complete Guide to Stacked Bar Charts | Tutorial by Chartio
Stacked bar charts extend the standard bar chart by dividing each bar into multiple subcategories. Learn how to best use this chart type...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found