알고리즘

BOJ 1504 : 특정한 최단 경로

minbear 2023. 7. 28. 18:07

#BOJ 1504 : 특정한 최단 경로

 

1504번: 특정한 최단 경로

첫째 줄에 정점의 개수 N과 간선의 개수 E가 주어진다. (2 ≤ N ≤ 800, 0 ≤ E ≤ 200,000) 둘째 줄부터 E개의 줄에 걸쳐서 세 개의 정수 a, b, c가 주어지는데, a번 정점에서 b번 정점까지 양방향 길이 존

www.acmicpc.net

#풀이

1부터 N까지 가는데 무조건 정점 v1,v2를 거처서 가야 한다.

먼저 1->N 까지 가는데 v1,v2를 거쳐서 가는 방법은

 

1. 1->v1->v2->N

2. 1->v2->v1->N

 

이렇게 2가지 방법이 있다.

각각 최단거리를 구해여 둘 중에 작은 비용을 출력하면 된다.

 

출력하지 못하는 경우는 길이 없는 경우 인데 INF로 초기화를 하기 때문에 ans가 INF보다 크면 -1를 출력하면 된다.

 

#코드

#include<bits/stdc++.h>

using namespace std;
#define X first
#define Y second
int n, e;
const int INF = 0x3f3f3f3f;
vector<pair<int,int>> adj[801];
vector<int> d_st(801, INF);
vector<int> d_v1(801,INF);
vector<int> d_v2(801,INF);
int main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin >> n >> e;
	for(int i = 0; i < e; i++){
		int a, b, c;
		cin >> a>> b>>c;
		adj[a].push_back({c,b});
		adj[b].push_back({c,a});
	}
	int v1,v2;
	cin >>v1 >>v2;
	
	d_st[1] = 0;
	d_v1[v1] = 0;
	d_v2[v2] = 0;
	
	priority_queue<pair<int,int>, vector<pair<int,int>>, greater<pair<int,int>>>pq;
	pq.push({d_st[1], 1});
	while(!pq.empty()){
		auto cur = pq.top();pq.pop();
		if(cur.X != d_st[cur.Y]) continue;
		
		for(auto nxt : adj[cur.Y]){
			if(nxt.X+ d_st[cur.Y] >= d_st[nxt.Y]) continue;
			
			pq.push({nxt.X+ d_st[cur.Y], nxt.Y});
			d_st[nxt.Y] = nxt.X+ d_st[cur.Y];
		}
	}
	
	pq.push({d_v1[v1], v1});
	while(!pq.empty()){
		auto cur = pq.top();pq.pop();
		if(cur.X != d_v1[cur.Y]) continue;
		
		for(auto nxt : adj[cur.Y]){
			if(nxt.X+ d_v1[cur.Y] >= d_v1[nxt.Y]) continue;
			
			pq.push({nxt.X+ d_v1[cur.Y], nxt.Y});
			d_v1[nxt.Y] = nxt.X + d_v1[cur.Y];
		}
	}
	
	pq.push({d_v2[v2], v2});
	while(!pq.empty()){
		auto cur = pq.top();pq.pop();
		if(cur.X != d_v2[cur.Y]) continue;
		
		for(auto nxt : adj[cur.Y]){
			if(nxt.X+ d_v2[cur.Y] >= d_v2[nxt.Y]) continue;
			
			pq.push({nxt.X+ d_v2[cur.Y], nxt.Y});
			d_v2[nxt.Y] = nxt.X+ d_v2[cur.Y];
		}
	}
	long long ans = 0;
	if((long long)d_st[v1] + d_v1[v2]+ d_v2[n] > d_st[v2] + d_v2[v1] + d_v1[n]){
		ans = (long long)d_st[v2] + d_v2[v1] + d_v1[n];		
	} 
	else
		ans = (long long)d_st[v1] + d_v1[v2] + d_v2[n];
	if(ans >= INF)
		ans = -1;
	cout << ans << '\n';
	return 0;
}