import os
|
|
import sys
|
|
if '__file__' in globals():
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(script_dir)
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
GREEN_SET = {'g', 'G'}
|
|
YELLOW_SET = {'y'}
|
|
RED_SET = {'r'}
|
|
def get_intermediate_signal_state(prev, cur, signal_type='yellow'):
|
|
'''
|
|
https://github.com/cts198859/deeprl_signal_control/blob/master/envs/env.py#L128-L152
|
|
'signal_type' has one of the following two values. ['yellow', 'red']
|
|
'''
|
|
|
|
# 이전신호와 같을때
|
|
if prev == cur:
|
|
return cur, False
|
|
|
|
# 각 인덱스에 해당하는 문자를 비교
|
|
switch_reds, switch_greens = [], []
|
|
for i, (p0, p1) in enumerate(zip(prev, cur)):
|
|
|
|
# 녹색에서 적색으로 바뀔 때
|
|
if (p0 in GREEN_SET) and (p1 == 'r'):
|
|
switch_reds.append(i)
|
|
|
|
# 적색에서 녹색으로 바뀔 때
|
|
elif (p0 in 'r') and (p1 in GREEN_SET):
|
|
switch_greens.append(i)
|
|
|
|
# 녹색에서 적색으로 바뀌는 경우가 없으면
|
|
if (not switch_reds) and (signal_type == 'yellow'):
|
|
return cur, False
|
|
|
|
mid = list(cur)
|
|
for i in switch_reds:
|
|
if signal_type == 'yellow':
|
|
mid[i] = 'y'
|
|
for i in switch_greens:
|
|
mid[i] = 'r'
|
|
|
|
return ''.join(mid), True
|
|
|
|
|
|
class SetupTLSProgram:
|
|
def __init__(self, durations, reds, yellows, offset, states, path_tll_xml, name=None):
|
|
self.durations = self.as_array(durations)
|
|
self.reds = self.as_array(reds)
|
|
self.yellows = self.as_array(yellows)
|
|
self.offset = offset
|
|
self.states = states
|
|
self.path_tll_xml = path_tll_xml
|
|
self.name = name
|
|
self.num_phase = len(states) // 2
|
|
|
|
|
|
def as_array(self, x):
|
|
if isinstance(x, list):
|
|
return np.array(x)
|
|
return x
|
|
|
|
|
|
def _val(self):
|
|
|
|
# 길이가 짝수인지 확인
|
|
if len(self.states) % 2 != 0:
|
|
raise ValueError("It must have an even number of elements. Odd length detected.")
|
|
|
|
length = self.num_phase * 2
|
|
if length != len(self.durations) or length != len(self.reds) or length != len(self.yellows):
|
|
raise ValueError("The length must all match length.")
|
|
|
|
if len(set([len(s) for s in self.states])) != 1:
|
|
ValueError("All elements in 'states' must have the same length.")
|
|
|
|
if np.sum(self.durations[:self.num_phase]) != np.sum(self.durations[self.num_phase:]):
|
|
raise ValueError("The sums of the two halves of durations do not match")
|
|
|
|
self._check_always_red(self.states)
|
|
|
|
self._validate_non_negative_elements(self.durations)
|
|
self._validate_non_negative_elements(self.reds)
|
|
self._validate_non_negative_elements(self.yellows)
|
|
self._validate_non_negative_elements(self.durations - self.reds - self.yellows)
|
|
|
|
|
|
def _validate_non_negative_elements(self, lst):
|
|
if any(x < 0 for x in lst):
|
|
raise ValueError('The list contains values less than 0. Please ensure all elements in the list are non-negative.')
|
|
|
|
|
|
def _check_always_red(self, states):
|
|
colors = ['r'] * len(states[0])
|
|
for s in states:
|
|
for i, char in enumerate(s.lower()):
|
|
if char != 'r':
|
|
colors[i] = char
|
|
if 'r' in colors:
|
|
raise ValueError("'r' is not allowed in the colors collection.")
|
|
|
|
|
|
def _merge(self, s1, s2):
|
|
|
|
if len(s1) != len(s2):
|
|
raise ValueError("The lengths of s1 and s2 must be the same.")
|
|
|
|
new_s = []
|
|
for c1, c2 in zip(s1, s2):
|
|
|
|
if c1 == 'r':
|
|
new_s.append(c2)
|
|
elif c2 == 'r':
|
|
new_s.append(c1)
|
|
elif c1 == c2:
|
|
new_s.append(c1)
|
|
else:
|
|
raise ValueError(f"Unexpected characters encountered: c1={c1}, c2={c2}")
|
|
|
|
return ''.join(new_s)
|
|
|
|
|
|
def _add(self, durations, reds, yellows, states, ring_type='None'):
|
|
greens = durations - reds - yellows
|
|
colors = []
|
|
phase_numbers = []
|
|
new_states = []
|
|
for curr_i in range(len(states)):
|
|
|
|
prev_i = (curr_i - 1 + len(states)) % len(states)
|
|
next_i = (curr_i + 1) % len(states)
|
|
phase_no = curr_i + 1
|
|
|
|
r_state, _ = get_intermediate_signal_state(states[prev_i], states[curr_i], signal_type='red')
|
|
y_state, has_y = get_intermediate_signal_state(states[curr_i], states[next_i], signal_type='yellow')
|
|
g_state = states[curr_i]
|
|
|
|
# red
|
|
if reds[curr_i] != 0:
|
|
new_states += [r_state] * reds[curr_i]
|
|
colors += [1] * reds[curr_i] # 'r'
|
|
phase_numbers += [phase_no] * reds[curr_i]
|
|
|
|
# green
|
|
new_states += [g_state] * greens[curr_i]
|
|
colors += [2] * greens[curr_i] # 'g'
|
|
phase_numbers += [phase_no] * greens[curr_i]
|
|
|
|
# yellow
|
|
if has_y and yellows[curr_i] == 0:
|
|
raise ValueError('Yellow signal is required, but the yellow duration is 0.')
|
|
|
|
if not has_y and yellows[curr_i] != 0:
|
|
y_state = g_state
|
|
|
|
new_states += [y_state] * yellows[curr_i]
|
|
colors += [3] * yellows[curr_i] # 'y'
|
|
phase_numbers += [phase_no] * yellows[curr_i]
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
f'{ring_type}_ring_phase_no':phase_numbers,
|
|
f'{ring_type}_ring_color':colors,
|
|
f'{ring_type}_ring_state':new_states}
|
|
)
|
|
return df
|
|
|
|
|
|
def write_tll_xml(self, df_plan, path_tll_xml):
|
|
strings = ['<tlLogics>\n']
|
|
tllogic_id = 'None' if self.name is None else self.name
|
|
strings.append(f' <tlLogic id="{tllogic_id}" type="static" programID="tllogic_id" offset="{self.offset}">\n')
|
|
for _, row in df_plan.iterrows():
|
|
name = str(row['A_ring_phase_no']) + row['A_ring_color']
|
|
name += '_' + str(row['B_ring_phase_no']) + row['B_ring_color']
|
|
duration = row['duration']
|
|
signal_state = row['signal_state']
|
|
strings.append(f' <phase duration="{duration}" name="{name}" state="{signal_state}"/>\n')
|
|
strings.append(' </tlLogic>\n')
|
|
strings.append('</tlLogics>')
|
|
strings = ''.join(strings)
|
|
with open(path_tll_xml, 'w') as f:
|
|
f.write(strings)
|
|
|
|
|
|
def main(self):
|
|
|
|
durations, reds, yellows, states = self.durations, self.reds, self.yellows, self.states
|
|
n = self.num_phase
|
|
|
|
self._val()
|
|
|
|
df_a = self._add(durations[:n], reds[:n], yellows[:n], states[:n], 'A')
|
|
df_b = self._add(durations[n:], reds[n:], yellows[n:], states[n:], 'B')
|
|
group_cols = ['A_ring_phase_no', 'A_ring_color', 'A_ring_state', 'B_ring_phase_no', 'B_ring_color', 'B_ring_state']
|
|
sort_cols = ['A_ring_phase_no', 'A_ring_color', 'B_ring_phase_no', 'B_ring_color']
|
|
df_plan = pd.concat([df_a, df_b], axis=1) \
|
|
.groupby(group_cols) \
|
|
.size() \
|
|
.reset_index(name='duration') \
|
|
.sort_values(by=sort_cols) \
|
|
.reset_index(drop=True)
|
|
|
|
mapping = {1: 'r', 2: 'g', 3: 'y'}
|
|
df_plan['A_ring_color'] = df_plan['A_ring_color'].map(mapping)
|
|
df_plan['B_ring_color'] = df_plan['B_ring_color'].map(mapping)
|
|
df_plan['signal_state'] = df_plan.apply(lambda row: self._merge(row['A_ring_state'], row['B_ring_state']), axis=1)
|
|
|
|
assert df_plan['duration'].sum() == np.sum(durations[:n])
|
|
assert df_plan['duration'].sum() == np.sum(durations[n:])
|
|
|
|
self.write_tll_xml(df_plan, self.path_tll_xml)
|
|
|
|
return df_plan
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
'''
|
|
Test data
|
|
|
|
- 교차로번호: 5034
|
|
- 앞에서 부터 순서대로 A링1현시, A링2현시, ..., B링1현시, B링2현시, ...
|
|
'''
|
|
|
|
# 오버랩 테스트
|
|
durations = [51, 34, 52, 43, 46, 39, 52, 43]
|
|
reds = [2, 2, 2, 2, 2, 2, 2, 2]
|
|
yellows = [4, 4, 4, 4, 4, 4, 4, 4]
|
|
offset = 5
|
|
|
|
# # 원본
|
|
# durations = [53, 33, 51, 43, 26, 60, 51, 43]
|
|
# reds = [0, 0, 0, 0, 0, 0, 0, 0]
|
|
# yellows = [4, 4, 4, 4, 4, 4, 4, 4]
|
|
# offset = 0
|
|
|
|
# # 1795 (2023년 12월 20일 7시)
|
|
# durations = [58, 29, 50, 43, 28, 59, 50, 43]
|
|
# reds = [0, 0, 0, 0, 0, 0, 0, 0]
|
|
# yellows = [4, 4, 4, 4, 4, 4, 4, 4]
|
|
# offset = 0
|
|
|
|
# # 1888 (2023년 12월 21일 7시)
|
|
# durations = [61, 28, 48, 43, 27, 62, 48, 43]
|
|
# reds = [0, 0, 0, 0, 0, 0, 0, 0]
|
|
# yellows = [4, 4, 4, 4, 4, 4, 4, 4]
|
|
# offset = 0
|
|
|
|
# 각 현시의 이동류에 해당하는 signal state
|
|
signal_states = [
|
|
'rrrrrrrrrGGGGrrrrrrrr',
|
|
'rrrGrrrrrrrrrrrrrrrrr',
|
|
'rrrrGGGrrrrrrrrrrrrrr',
|
|
'rrrrrrrrrrrrrrrrrrrGG',
|
|
'rrrrrrrrrrrrrGGrrrrrr',
|
|
'GGGrrrrrrrrrrrrrrrrrr',
|
|
'rrrrrrrGGrrrrrrrrrrrr',
|
|
'rrrrrrrrrrrrrrrGGGGrr'
|
|
]
|
|
|
|
# 저장할 경로
|
|
path_xml = 'test.tll.xml'
|
|
path_csv = 'test.tll.csv'
|
|
|
|
# 노드id
|
|
node_id = '11053' # 교차로번호: 5034
|
|
|
|
args = durations, reds, yellows, offset, signal_states, path_xml, node_id
|
|
SetupTLSProgram(*args).main().to_csv(path_csv, index=False)
|
|
print('hello')
|