1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| #include <bits/stdc++.h> #include "railroad2.h" using namespace std;
static constexpr int64_t Mod = 1e9 + 7; static constexpr int64_t INF = numeric_limits<int64_t>::max() >> 2;
struct dsu { vector<size_t> fa, siz; dsu(size_t n) : fa(n), siz(n) { iota(fa.begin(), fa.end(), 0); fill(siz.begin(), siz.end(), 1); } size_t findf(size_t x) { return fa[x] == x ? x : fa[x] = findf(fa[x]); } pair<size_t, size_t> merge(size_t u, size_t v) { size_t x = findf(u), y = findf(v); if (x == y) return {0, 0}; if (siz[x] > siz[y]) swap(x, y); size_t sx = siz[x], sy = siz[y]; fa[x] = y, siz[y] += siz[x]; return {sx, sy}; } };
void get_depth(size_t u, size_t fa, const vector<vector<pair<size_t, int64_t>>> &tree, vector<int64_t> &depth) { for (const auto &[v, w] : tree[u]) if (v != fa) depth[v] = depth[u] + w, get_depth(v, u, tree, depth); }
int travel(vector<int> U, vector<int> V, vector<int> W) { size_t n = U.size() + 1; vector<vector<pair<size_t, int64_t>>> tree(n); for (size_t i = 0; i < n - 1; ++i) { size_t u = U[i], v = V[i]; int64_t w = W[i]; tree[u].emplace_back(v, w); tree[v].emplace_back(u, w); } size_t u = [&]() -> size_t { size_t uu = 0; vector<int64_t> depth(n); get_depth(0, 0, tree, depth); for (size_t i = 0; i < n; ++i) if (depth[i] > depth[uu]) uu = i; return uu; }(); vector<int64_t> depth_from_u(n); get_depth(u, u, tree, depth_from_u); size_t v = [&]() -> size_t { size_t vv = 0; for (size_t i = 0; i < n; ++i) if (depth_from_u[i] > depth_from_u[vv]) vv = i; return vv; }(); vector<int64_t> depth_from_v(n); get_depth(v, v, tree, depth_from_v); vector<tuple<int64_t, size_t, size_t>> es; es.reserve(n); for (size_t i = 0; i < n; ++i) { if (depth_from_u[i] > depth_from_v[i]) es.emplace_back(depth_from_u[i], i, u); else es.emplace_back(depth_from_v[i], i, v); } sort(es.rbegin(), es.rend()); dsu dd(n); int64_t sum = 0; for (const auto &[w, uu, vv] : es) { auto [sx, sy] = dd.merge(uu, vv); (sum += (sx * sy % Mod * (w % Mod)) % Mod) %= Mod; } return (sum << 1) % Mod; }
|