lc308-Range Sum Query 2D - Mutable

这个版本太慢了,改天优化。。
以及我第一次碰到了MLE的问题,原因是update调用了sumTrees,new之后没有free,之后直接update不新建node。以后还要注意。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
struct row_node{
row_node * left = nullptr;
row_node * right = nullptr;
int down, up;
int val = 0;
row_node(int l, int u):up(u),down(l) {}
};
struct col_node {
col_node * left = nullptr;
col_node * right = nullptr;
int down, up;
row_node * root = nullptr;// row segtree
col_node(int l, int u):up(u),down(l) {}
};
class NumMatrix {
int col_num = 0;
int row_num = 0;
col_node* root;
vector<vector<int>> &mat;
public:
col_node* buildBigTree(int rl, int ru, int cl, int cu) {
col_node* r = new col_node(rl, ru);
if (rl == ru) {
r->root = buildSmallTree(cl, cu, mat[ru]);
return r;
}
if ((rl+ru)/2 >= rl)
r->left = buildBigTree(rl, (rl+ru)/2, cl, cu);
if ((rl+ru)/2+1 <= ru)
r->right = buildBigTree((rl+ru)/2+1, ru, cl, cu);
if (r->left && r->right)
r -> root = sumTrees(r->left->root, r->right->root);
return r;
}
row_node* sumTrees(row_node* l, row_node*r) {
row_node* node = new row_node(l->down, l->up);
node->val = l->val + r->val;
if (l->left)
node->left = sumTrees(l->left, r->left);
if (l->right)
node->right = sumTrees(l->right, r->right);
return node;
}
void updateSumTrees(row_node* node, row_node* l, row_node*r) {
node->val = l->val + r->val;
if (l->left)
updateSumTrees(node->left, l->left, r->left);
if (l->right)
updateSumTrees(node->right, l->right, r->right);
}
row_node* buildSmallTree(int l, int u, vector<int> &row) {
row_node* r = new row_node(l, u);
if (l == u) {
r->val = row[l];
return r;
}
if ((l+u)/2 >= l)
r->left = buildSmallTree(l, (l+u)/2, row);
if ((l+u)/2+1 <= u)
r->right = buildSmallTree((l+u)/2+1, u, row);
r -> val = ((r->left)?r->left->val:0) + ((r->right)?r->right->val:0);
return r;
}
void updateSmallTree(int j, int val, row_node* r) {
if (r->down == r->up) {
r->val = val;
return;
}
int mid = (r->down+r->up)/2;
if (mid>=j) {
updateSmallTree(j, val, r->left);
r->val = ((r->left)?r->left->val:0) + ((r->right)?r->right->val:0);
}
else {
updateSmallTree(j, val, r->right);
r->val = ((r->left)?r->left->val:0) + ((r->right)?r->right->val:0);
}
}
void updateTree(int i, int j, int val, col_node* r) {
if (r->down == r->up) {
updateSmallTree(j, val, r->root);
return;
}
int mid = (r->down + r->up)/2;
if (mid >= i) {
updateTree(i, j, val, r->left);
updateSumTrees(r->root, r->left->root, r->right->root);
}
else {
updateTree(i, j, val, r->right);
updateSumTrees(r->root, r->left->root, r->right->root);
}
}
int sumRangeFromBigTree(col_node* r, int l, int u, int r1, int c1, int r2, int c2) {
if (l == r1 && u == r2)
return sumRangeFromSmallTree(r->root, 0, col_num-1, c1, c2);
else {
int m = (l+u)/2;
int s1 = (r1<=m)?sumRangeFromBigTree(r->left, l, m, r1, c1, min(r2,m), c2):0;
int s2 = (r2>m)?sumRangeFromBigTree(r->right, m+1, u, max(r1,m+1), c1, r2, c2):0;
return s1+s2;
}
}
int sumRangeFromSmallTree(row_node* r, int l, int u, int i, int j) {
if (l == i && u == j)
return r->val;
else {
int m = (l+u)/2;
int s1 = (i<=m)?sumRangeFromSmallTree(r->left, l, m, i, min(j,m)):0;
int s2 = (j>m)?sumRangeFromSmallTree(r->right, m+1, u, max(i,m+1), j):0;
return s1+s2;
}
}
NumMatrix(vector<vector<int>> &matrix):mat(matrix) {
row_num = matrix.size();
if (row_num) {
col_num = matrix[0].size();
if (col_num)
root = buildBigTree(0, row_num-1, 0, col_num-1);
}
}
void update(int i, int j, int val) {
updateTree(i, j, val, root);
}
int sumRegion(int row1, int col1, int row2, int col2) {
return sumRangeFromBigTree(this->root, 0, row_num-1, row1, col1, row2, col2);
}
};
// Your NumMatrix object will be instantiated and called as such:
// NumMatrix numMatrix(matrix);
// numMatrix.sumRegion(0, 1, 2, 3);
// numMatrix.update(1, 1, 10);
// numMatrix.sumRegion(1, 2, 3, 4);